| 12345678910111213141516171819202122232425262728293031323334353637383940414243 | 
							- # Copyright (c) Meta Platforms, Inc. and affiliates.
 
- # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
- import torch
 
- from torch.distributed.fsdp import (
 
-     # FullyShardedDataParallel as FSDP,
 
-     # CPUOffload,
 
-     MixedPrecision,
 
-     # BackwardPrefetch,
 
-     # ShardingStrategy,
 
- )
 
- # requires grad scaler in main loop
 
- fpSixteen = MixedPrecision(
 
-     param_dtype=torch.float16,
 
-     # Gradient communication precision.
 
-     reduce_dtype=torch.float16,
 
-     # Buffer precision.
 
-     buffer_dtype=torch.float16,
 
- )
 
- bfSixteen = MixedPrecision(
 
-     param_dtype=torch.bfloat16,
 
-     # Gradient communication precision.
 
-     reduce_dtype=torch.bfloat16,
 
-     # Buffer precision.
 
-     buffer_dtype=torch.bfloat16,
 
-     cast_forward_inputs=True,
 
- )
 
- bfSixteen_mixed = MixedPrecision(
 
-     param_dtype=torch.float32,
 
-     reduce_dtype=torch.bfloat16,
 
-     buffer_dtype=torch.bfloat16,
 
- )
 
- fp32_policy = MixedPrecision(
 
-     param_dtype=torch.float32,
 
-     reduce_dtype=torch.float32,
 
-     buffer_dtype=torch.float32,
 
- )
 
 
  |