浏览代码

add freeze_LLM_only option for mllama finetuning

JimChienTW 5 月之前
父节点
当前提交
21e8368c7e
共有 3 个文件被更改,包括 25 次插入3 次删除
  1. 1 0
      src/llama_recipes/configs/training.py
  2. 13 2
      src/llama_recipes/finetuning.py
  3. 11 1
      src/llama_recipes/utils/train_utils.py

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -35,6 +35,7 @@ class train_config:
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False
     num_freeze_layers: int = 1
+    freeze_LLM_only: bool = False # Freeze self-attention layers in the language_model. Vision model, multi_modal_projector, cross-attention will be fine-tuned
     quantization: str = None
     one_gpu: bool = False
     save_model: bool = True

+ 13 - 2
src/llama_recipes/finetuning.py

@@ -38,6 +38,7 @@ from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
     clear_gpu_cache,
     freeze_transformer_layers,
+    freeze_LLM_only,
     get_policies,
     print_model_size,
     setup,
@@ -193,8 +194,6 @@ def main(**kwargs):
         )
         model.resize_token_embeddings(len(tokenizer))
 
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -235,6 +234,10 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            
+        if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
+            freeze_LLM_only(model)
+            
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
@@ -255,6 +258,11 @@ def main(**kwargs):
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
+        
+        if train_config.freeze_LLM_only:
+            use_orig_params = True
+        else:
+            use_orig_params = False
         model = FSDP(
             model,
             auto_wrap_policy=(
@@ -282,6 +290,7 @@ def main(**kwargs):
                 if train_config.low_cpu_fsdp and rank != 0
                 else None
             ),
+            use_orig_params=use_orig_params,
         )
         if fsdp_config.fsdp_activation_checkpointing:
             model.enable_input_require_grads()
@@ -298,6 +307,8 @@ def main(**kwargs):
     else:
         dataset_processer = tokenizer
 
+    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+    
     # Load and preprocess the dataset for training and validation
 
     dataset_train = get_preprocessed_dataset(

+ 11 - 1
src/llama_recipes/utils/train_utils.py

@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
             if i < num_layer:
                 for param in layer.parameters():
                     param.requires_grad = False
-
+                    
+def freeze_LLM_only(model):
+    """
+    Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
+    """
+    for name, param in model.language_model.named_parameters():
+                param.requires_grad = False
+    for i, layer in enumerate(model.language_model.model.layers):
+        if i in model.language_model.model.cross_attention_layers:
+            for param in layer.parameters():
+                param.requires_grad = True
 
 def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):