| 12345678910111213141516171819202122232425262728293031323334353637383940414243 | # Copyright (c) Meta Platforms, Inc. and affiliates.# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.import torchfrom torch.distributed.fsdp import (    # FullyShardedDataParallel as FSDP,    # CPUOffload,    MixedPrecision,    # BackwardPrefetch,    # ShardingStrategy,)# requires grad scaler in main loopfpSixteen = MixedPrecision(    param_dtype=torch.float16,    # Gradient communication precision.    reduce_dtype=torch.float16,    # Buffer precision.    buffer_dtype=torch.float16,)bfSixteen = MixedPrecision(    param_dtype=torch.bfloat16,    # Gradient communication precision.    reduce_dtype=torch.bfloat16,    # Buffer precision.    buffer_dtype=torch.bfloat16,    cast_forward_inputs=True,)bfSixteen_mixed = MixedPrecision(    param_dtype=torch.float32,    reduce_dtype=torch.bfloat16,    buffer_dtype=torch.bfloat16,)fp32_policy = MixedPrecision(    param_dtype=torch.float32,    reduce_dtype=torch.float32,    buffer_dtype=torch.float32,)
 |