|
@@ -1,7 +1,12 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
-from torch.distributed._tensor.device_mesh import init_device_mesh
|
|
|
import os
|
|
|
+import torch
|
|
|
+import torch.cuda.nccl as nccl
|
|
|
+import torch.distributed as dist
|
|
|
+
|
|
|
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
|
|
|
+from torch.distributed._tensor.device_mesh import init_device_mesh
|
|
|
|
|
|
def fsdp_auto_wrap_policy(model, transformer_layer_names):
|
|
|
import functools
|
|
@@ -79,3 +84,38 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
|
|
|
raise RuntimeError("Failed to create a valid device mesh.")
|
|
|
|
|
|
return device_mesh
|
|
|
+
|
|
|
+
|
|
|
+def get_policies(cfg, rank):
|
|
|
+ """Get the policies for mixed precision and fsdp wrapping"""
|
|
|
+
|
|
|
+
|
|
|
+ verify_bfloat_support = ((
|
|
|
+ torch.version.cuda
|
|
|
+ and torch.cuda.is_bf16_supported()
|
|
|
+ and torch.version.cuda >= "11.0"
|
|
|
+ and dist.is_nccl_available()
|
|
|
+ and nccl.version() >= (2, 10)
|
|
|
+ ) or
|
|
|
+ (is_xpu_available()))
|
|
|
+
|
|
|
+
|
|
|
+ mixed_precision_policy = None
|
|
|
+ wrapping_policy = None
|
|
|
+
|
|
|
+ # Mixed precision
|
|
|
+ if cfg.mixed_precision:
|
|
|
+ bf16_ready = verify_bfloat_support
|
|
|
+
|
|
|
+ if bf16_ready and not cfg.use_fp16:
|
|
|
+ mixed_precision_policy = bfSixteen
|
|
|
+ if rank == 0:
|
|
|
+ print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
|
|
|
+ elif cfg.use_fp16:
|
|
|
+ mixed_precision_policy = fpSixteen
|
|
|
+ if rank == 0:
|
|
|
+ print(f"FP16 enabled")
|
|
|
+ else:
|
|
|
+ print(f"bFloat16 support not present. Using FP32, and not mixed precision")
|
|
|
+ wrapping_policy = get_llama_wrapper()
|
|
|
+ return mixed_precision_policy, wrapping_policy
|