Sfoglia il codice sorgente

Use auto instead of bf16 to determine model dtype; Add comments to describe the various use_fp16 and pure_bf16 options

Igor Kasianenko 2 mesi fa
parent
commit
c51d3af6c9

+ 2 - 2
src/llama_cookbook/configs/fsdp.py

@@ -9,7 +9,7 @@ from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
 @dataclass
 class fsdp_config:
     mixed_precision: bool=True
-    use_fp16: bool=False
+    use_fp16: bool=False # use fp16 for all fsdp.MixedPrecision dtypes (param, reduce, buffer, see https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD # HYBRID_SHARD "Full Shard within a node DDP cross Nodes", SHARD_GRAD_OP "Shard only Gradients and Optimizer States", NO_SHARD "Similar to DDP".
     hsdp : bool =False # Require HYBRID_SHARD to be set. This flag can extend the HYBRID_SHARD by allowing sharding a model on customized number of GPUs (Sharding_group) and Replicas over Sharding_group.
     sharding_group_size : int=0 # requires hsdp to be set. This specifies the sharding group size, number of GPUs that you model can fit into to form a replica of a model.
@@ -17,6 +17,6 @@ class fsdp_config:
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively FULL_STATE_DICT can be used. SHARDED_STATE_DICT saves one file with sharded weights per rank while FULL_STATE_DICT will collect all weights on rank 0 and save them in a single file.
     fsdp_activation_checkpointing: bool=True
     fsdp_cpu_offload: bool=False
-    pure_bf16: bool = False
+    pure_bf16: bool = False  # disables mixed precision, and runs in pure bfloat16
     optimizer: str= "AdamW"
     

+ 1 - 1
src/llama_cookbook/configs/training.py

@@ -25,7 +25,7 @@ class train_config:
     weight_decay: float=0.0
     gamma: float= 0.85 # multiplicatively decay the learning rate by gamma after each epoch
     seed: int=42
-    use_fp16: bool=False
+    use_fp16: bool=False  # load model paramater in torch.float16 dtype (not recommended)
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"

+ 3 - 2
src/llama_cookbook/finetuning.py

@@ -40,6 +40,7 @@ from llama_cookbook.utils.train_utils import (
     freeze_transformer_layers,
     freeze_LLM_only,
     get_policies,
+    hsdp_device_mesh,
     print_model_size,
     print_frozen_model_status,
     setup,
@@ -150,7 +151,7 @@ def main(**kwargs):
                 if train_config.quantization and not train_config.enable_fsdp
                 else None
             ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+            torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
         )
         processor = AutoProcessor.from_pretrained(
             train_config.model_name
@@ -172,7 +173,7 @@ def main(**kwargs):
                 if train_config.quantization and not train_config.enable_fsdp
                 else None
             ),
-            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+            torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
         )
     else:
         raise ValueError(

+ 1 - 1
src/llama_cookbook/utils/__init__.py

@@ -3,5 +3,5 @@
 
 from llama_cookbook.utils.memory_utils import MemoryTrace
 from llama_cookbook.utils.dataset_utils import *
-from llama_cookbook.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh
+from llama_cookbook.utils.fsdp_utils import fsdp_auto_wrap_policy, hsdp_device_mesh, get_policies
 from llama_cookbook.utils.train_utils import *

+ 41 - 1
src/llama_cookbook/utils/fsdp_utils.py

@@ -1,7 +1,12 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
+import torch
+import torch.cuda.nccl as nccl
+import torch.distributed as dist
+
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
+from torch.distributed._tensor.device_mesh import init_device_mesh
 
 def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
@@ -79,3 +84,38 @@ def hsdp_device_mesh(replica_group_size, sharding_group_size, device=None):
         raise RuntimeError("Failed to create a valid device mesh.")
 
     return device_mesh
+
+
+def get_policies(cfg, rank):
+    """Get the policies for mixed precision and fsdp wrapping"""
+
+
+    verify_bfloat_support = ((
+    torch.version.cuda
+    and torch.cuda.is_bf16_supported()
+    and torch.version.cuda >= "11.0"
+    and dist.is_nccl_available()
+    and nccl.version() >= (2, 10)
+    ) or
+    (is_xpu_available()))
+
+
+    mixed_precision_policy = None
+    wrapping_policy = None
+
+    # Mixed precision
+    if cfg.mixed_precision:
+        bf16_ready = verify_bfloat_support
+
+        if bf16_ready and not cfg.use_fp16:
+            mixed_precision_policy = bfSixteen
+            if rank == 0:
+                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
+        elif cfg.use_fp16:
+            mixed_precision_policy = fpSixteen
+            if rank == 0:
+                print(f"FP16 enabled")
+        else:
+            print(f"bFloat16 support not present. Using FP32, and not mixed precision")
+    wrapping_policy = get_llama_wrapper()
+    return mixed_precision_policy, wrapping_policy

+ 0 - 34
src/llama_cookbook/utils/train_utils.py

@@ -11,7 +11,6 @@ import contextlib
 
 
 import torch
-import torch.cuda.nccl as nccl
 import torch.distributed as dist
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
@@ -533,39 +532,6 @@ def print_frozen_model_status(model, config, rank: int = 0) -> None:
                 print(f"    {module}: Unfrozen")
         print("")
 
-def get_policies(cfg, rank):
-    """Get the policies for mixed precision and fsdp wrapping"""
-
-
-    verify_bfloat_support = ((
-    torch.version.cuda
-    and torch.cuda.is_bf16_supported()
-    and torch.version.cuda >= "11.0"
-    and dist.is_nccl_available()
-    and nccl.version() >= (2, 10)
-    ) or
-    (is_xpu_available()))
-
-
-    mixed_precision_policy = None
-    wrapping_policy = None
-
-    # Mixed precision
-    if cfg.mixed_precision:
-        bf16_ready = verify_bfloat_support
-
-        if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen
-            if rank == 0:
-                print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
-        elif cfg.use_fp16:
-            mixed_precision_policy = fpSixteen
-            if rank == 0:
-                print(f"FP16 enabled")
-        else:
-            print(f"bFloat16 support not present. Using FP32, and not mixed precision")
-    wrapping_policy = get_llama_wrapper()
-    return mixed_precision_policy, wrapping_policy
 
 def save_train_params(train_config, fsdp_config, rank):
     """