Bläddra i källkod

add freeze_LLM_only option for mllama finetuning (#791)

Kai Wu 5 månader sedan
förälder
incheckning
e5662e5804

+ 1 - 0
recipes/quickstart/finetuning/README.md

@@ -54,6 +54,7 @@ It lets us specify the training settings for everything from `model_name` to `da
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True

Filskillnaden har hållts tillbaka eftersom den är för stor
+ 6 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 1 - 0
src/llama_recipes/configs/training.py

@@ -35,6 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True

+ 18 - 3
src/llama_recipes/finetuning.py

@@ -38,8 +38,10 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
+    freeze_LLM_only,
     get_policies,
     print_model_size,
+    print_frozen_model_status,
     setup,
     setup_environ_flags,
     train,
@@ -194,7 +196,7 @@ def main(**kwargs):
         model.resize_token_embeddings(len(tokenizer))
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
+    
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -235,7 +237,14 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
-
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
+            
+        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
+            freeze_LLM_only(model)
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
+        
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:
@@ -255,6 +264,11 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        
+        if train_config.freeze_LLM_only:
+            use_orig_params = True
+        else:
+            use_orig_params = False
         model = FSDP(
             model,
             auto_wrap_policy=(
@@ -282,6 +296,7 @@ def main(**kwargs):
                 if train_config.low_cpu_fsdp and rank != 0
                 else None
             ),
+            use_orig_params=use_orig_params,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
@@ -297,7 +312,7 @@ def main(**kwargs):
         dataset_processer = processor
     else:
         dataset_processer = tokenizer
-
+    
     # Load and preprocess the dataset for training and validation
 
     dataset_train = get_preprocessed_dataset(

+ 56 - 2
src/llama_recipes/utils/train_utils.py

@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
             if i < num_layer:
                 for param in layer.parameters():
                     param.requires_grad = False
-
+                    
+def freeze_LLM_only(model):
+    """
+    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
+    """
+    for name, param in model.language_model.named_parameters():
+                param.requires_grad = False
+    for i, layer in enumerate(model.language_model.model.layers):
+        if i in model.language_model.model.cross_attention_layers:
+            for param in layer.parameters():
+                param.requires_grad = True
 
 def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
@@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
         total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
         print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
 
+def print_frozen_model_status(model, config, rank: int = 0) -> None:
+    """
+    Print the frozen status of the model's and the number of trainable parameters after frozen.
 
-
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print("After freezing the model:")
+        print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")
+
+        module_states = {}
+        # Iterate over all parameters
+        for name, param in model.named_parameters():
+            # Extract the top-level module name (e.g., "vision_model", "language_model")
+            top_module = name.split(".")[0]
+
+            # Initialize a record for the top-level module
+            if top_module not in module_states:
+                module_states[top_module] = {"frozen": [], "unfrozen": []}
+
+            # Group parameters into frozen or unfrozen
+            if param.requires_grad:
+                module_states[top_module]["unfrozen"].append(name)
+            else:
+                module_states[top_module]["frozen"].append(name)
+
+        print("--> Model state after freezing:")
+        # Analyze and print the results
+        for module, states in module_states.items():
+            frozen_params = states["frozen"]
+            unfrozen_params = states["unfrozen"]
+
+            if frozen_params and unfrozen_params:
+                # Mixed state: both frozen and unfrozen parameters
+                print(f"    {module}: Mixed")
+            elif frozen_params:
+                # All parameters are frozen
+                print(f"    {module}: Frozen")
+            else:
+                # All parameters are unfrozen
+                print(f"    {module}: Unfrozen")
+        print("")
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""