Kai Wu il y a 7 mois
Parent
commit
ee204ccb98

+ 27 - 7
src/llama_recipes/finetuning.py

@@ -14,7 +14,11 @@ from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
 )
-
+from torch.distributed.fsdp.wrap import (
+    always_wrap_policy,
+    ModuleWrapPolicy,
+    transformer_auto_wrap_policy,
+)
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
@@ -29,7 +33,7 @@ from transformers import (
 
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-
+from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.configs import quantization_config  as QUANTIZATION_CONFIG
@@ -121,11 +125,11 @@ def main(**kwargs):
         bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
     # Load the pre-trained model and setup its configuration
-    #use_cache = False if train_config.enable_fsdp else None
+    use_cache = False if train_config.enable_fsdp else None
     model = LlavaNextForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-    #    use_cache=use_cache,
+        #use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -172,16 +176,25 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-
+        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer])
+        print("FSDP is enabled",my_auto_wrapping_policy)
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        # print(dir(model))
+        # for layer in model.named_children():
+        #     print(f"Layer: {layer}")
+            
+        # layernorm = model.CLIPVisionTransformer.CLIPEncoder.LayerNorm
+        # for name, param in layernorm.named_parameters():
+        #     print(f"Parameter: {name}, Shape: {param.shape}, Dtype: {param.dtype}")
+        # exit()
         model = FSDP(
             model,
-            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
+            auto_wrap_policy= ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer]),
+            #auto_wrap_policy= my_auto_wrapping_policy, #if train_config.use_peft else wrapping_policy,
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
@@ -192,6 +205,7 @@ def main(**kwargs):
             param_init_fn=(lambda module: module.to_empty(device=torch.device("cuda"), recurse=False))
             if train_config.low_cpu_fsdp and rank != 0 else None,
         )
+        #print(model)
         if fsdp_config.fsdp_activation_checkpointing:            
             model.enable_input_require_grads()
             model.gradient_checkpointing_enable()
@@ -205,6 +219,11 @@ def main(**kwargs):
     dataset_config = generate_dataset_config(train_config, kwargs)
 
      # Load and preprocess the dataset for training and validation
+    # dataset_train = get_preprocessed_dataset(
+    #     processor,
+    #     dataset_config,
+    #     split="train",
+    # )
     dataset_train = get_preprocessed_dataset(
         processor,
         dataset_config,
@@ -272,6 +291,7 @@ def main(**kwargs):
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
     # Start the training process
+   
     results = train(
         model,
         train_dataloader,

+ 13 - 10
src/llama_recipes/utils/fsdp_utils.py

@@ -3,7 +3,7 @@
 from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_name):
+def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -16,16 +16,19 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
         ):
             return True
         return False
-
+    transformer_wrap_policies = []
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
-    transformer_wrap_policy = functools.partial(
-        transformer_auto_wrap_policy,
-        transformer_layer_cls=(
-            transformer_layer_name,
-        ),
-    )
-
-    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
+    for transformer_layer_name in transformer_layer_names:
+        
+        transformer_wrap_policy = functools.partial(
+            transformer_auto_wrap_policy,
+            transformer_layer_cls=(
+                transformer_layer_name,
+            ),
+        )
+        transformer_wrap_policies.append(transformer_wrap_policy)
+    policies = transformer_wrap_policies
+    auto_wrap_policy = functools.partial(_or_policy, policies=policies)
     return auto_wrap_policy
 
 

+ 1 - 1
src/llama_recipes/utils/train_utils.py

@@ -358,7 +358,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
-                outputs = model(**batch)
+                outputs = model(**batch,use_cache=False)
                 loss = outputs.loss
                 if train_config.save_metrics:
                     val_step_loss.append(loss.detach().float().item())