fsdp.py 803 B

1234567891011121314151617181920
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. from dataclasses import dataclass, field
  4. from typing import ClassVar
  5. from torch.distributed.fsdp import ShardingStrategy
  6. from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
  7. @dataclass
  8. class fsdp_config:
  9. mixed_precision: bool=True
  10. use_fp16: bool=False
  11. sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
  12. checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
  13. fsdp_activation_checkpointing: bool=True
  14. pure_bf16: bool = True
  15. optimizer: str= "AdamW"