| 123456789101112131415161718192021222324252627282930313233343536373839 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import torchfrom torch.distributed.fsdp import (    MixedPrecision,)# requires grad scaler in main loopfpSixteen = MixedPrecision(    param_dtype=torch.float16,    # Gradient communication precision.    reduce_dtype=torch.float16,    # Buffer precision.    buffer_dtype=torch.float16,)bfSixteen = MixedPrecision(    param_dtype=torch.bfloat16,    # Gradient communication precision.    reduce_dtype=torch.bfloat16,    # Buffer precision.    buffer_dtype=torch.bfloat16,    cast_forward_inputs=True,)bfSixteen_mixed = MixedPrecision(    param_dtype=torch.float32,    reduce_dtype=torch.bfloat16,    buffer_dtype=torch.bfloat16,)fp32_policy = MixedPrecision(    param_dtype=torch.float32,    reduce_dtype=torch.float32,    buffer_dtype=torch.float32,)
 |