浏览代码

Print model's frozen status after freezing

JimChienTW 5 月之前
父节点
当前提交
d31ee18e4f
共有 2 个文件被更改,包括 52 次插入3 次删除
  1. 7 2
      src/llama_recipes/finetuning.py
  2. 45 1
      src/llama_recipes/utils/train_utils.py

+ 7 - 2
src/llama_recipes/finetuning.py

@@ -41,6 +41,7 @@ from llama_recipes.utils.train_utils import (
     freeze_LLM_only,
     get_policies,
     print_model_size,
+    print_frozen_model_status,
     setup,
     setup_environ_flags,
     train,
@@ -194,6 +195,8 @@ def main(**kwargs):
         )
         model.resize_token_embeddings(len(tokenizer))
 
+    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
+    
     # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
     if (
         train_config.enable_fsdp
@@ -234,12 +237,14 @@ def main(**kwargs):
 
         if not train_config.use_peft and train_config.freeze_layers:
             freeze_transformer_layers(model, train_config.num_freeze_layers)
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
             
         if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
             freeze_LLM_only(model)
+            # print model size and frozen layers after freezing layers
+            print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
         
-        print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
-
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
         if is_vision:

+ 45 - 1
src/llama_recipes/utils/train_utils.py

@@ -486,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
         total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
         print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
 
+def print_frozen_model_status(model, config, rank: int = 0) -> None:
+    """
+    Print the frozen status of the model's and the number of trainable parameters after frozen.
 
-
+    Args:
+        model: The PyTorch model.
+        model_name (str): Name of the model.
+        rank (int, optional): Current process's rank. Defaults to 0.
+    """
+    if rank == 0:
+        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
+        print("After freezing the model:")
+        print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")
+
+        module_states = {}
+        # Iterate over all parameters
+        for name, param in model.named_parameters():
+            # Extract the top-level module name (e.g., "vision_model", "language_model")
+            top_module = name.split(".")[0]
+
+            # Initialize a record for the top-level module
+            if top_module not in module_states:
+                module_states[top_module] = {"frozen": [], "unfrozen": []}
+
+            # Group parameters into frozen or unfrozen
+            if param.requires_grad:
+                module_states[top_module]["unfrozen"].append(name)
+            else:
+                module_states[top_module]["frozen"].append(name)
+
+        print("--> Model state after freezing:")
+        # Analyze and print the results
+        for module, states in module_states.items():
+            frozen_params = states["frozen"]
+            unfrozen_params = states["unfrozen"]
+
+            if frozen_params and unfrozen_params:
+                # Mixed state: both frozen and unfrozen parameters
+                print(f"    {module}: Mixed")
+            elif frozen_params:
+                # All parameters are frozen
+                print(f"    {module}: Frozen")
+            else:
+                # All parameters are unfrozen
+                print(f"    {module}: Unfrozen")
+        print("")
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""