|
@@ -409,7 +409,17 @@ def freeze_transformer_layers(model, num_layer):
|
|
|
if i < num_layer:
|
|
|
for param in layer.parameters():
|
|
|
param.requires_grad = False
|
|
|
-
|
|
|
+
|
|
|
+def freeze_LLM_only(model):
|
|
|
+ """
|
|
|
+ Freeze self-attention layers in the language_model. vision_model, multi_modal_projector, and cross-attention layers will be fine-tuned
|
|
|
+ """
|
|
|
+ for name, param in model.language_model.named_parameters():
|
|
|
+ param.requires_grad = False
|
|
|
+ for i, layer in enumerate(model.language_model.model.layers):
|
|
|
+ if i in model.language_model.model.cross_attention_layers:
|
|
|
+ for param in layer.parameters():
|
|
|
+ param.requires_grad = True
|
|
|
|
|
|
def check_frozen_layers_peft_model(model):
|
|
|
for i, layer in enumerate(model.base_model.model.model.layers):
|
|
@@ -476,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
|
|
|
|
|
|
+def print_frozen_model_status(model, config, rank: int = 0) -> None:
|
|
|
+ """
|
|
|
+ Print the frozen status of the model's and the number of trainable parameters after frozen.
|
|
|
|
|
|
-
|
|
|
+ Args:
|
|
|
+ model: The PyTorch model.
|
|
|
+ model_name (str): Name of the model.
|
|
|
+ rank (int, optional): Current process's rank. Defaults to 0.
|
|
|
+ """
|
|
|
+ if rank == 0:
|
|
|
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+ print("After freezing the model:")
|
|
|
+ print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")
|
|
|
+
|
|
|
+ module_states = {}
|
|
|
+ # Iterate over all parameters
|
|
|
+ for name, param in model.named_parameters():
|
|
|
+ # Extract the top-level module name (e.g., "vision_model", "language_model")
|
|
|
+ top_module = name.split(".")[0]
|
|
|
+
|
|
|
+ # Initialize a record for the top-level module
|
|
|
+ if top_module not in module_states:
|
|
|
+ module_states[top_module] = {"frozen": [], "unfrozen": []}
|
|
|
+
|
|
|
+ # Group parameters into frozen or unfrozen
|
|
|
+ if param.requires_grad:
|
|
|
+ module_states[top_module]["unfrozen"].append(name)
|
|
|
+ else:
|
|
|
+ module_states[top_module]["frozen"].append(name)
|
|
|
+
|
|
|
+ print("--> Model state after freezing:")
|
|
|
+ # Analyze and print the results
|
|
|
+ for module, states in module_states.items():
|
|
|
+ frozen_params = states["frozen"]
|
|
|
+ unfrozen_params = states["unfrozen"]
|
|
|
+
|
|
|
+ if frozen_params and unfrozen_params:
|
|
|
+ # Mixed state: both frozen and unfrozen parameters
|
|
|
+ print(f" {module}: Mixed")
|
|
|
+ elif frozen_params:
|
|
|
+ # All parameters are frozen
|
|
|
+ print(f" {module}: Frozen")
|
|
|
+ else:
|
|
|
+ # All parameters are unfrozen
|
|
|
+ print(f" {module}: Unfrozen")
|
|
|
+ print("")
|
|
|
|
|
|
def get_policies(cfg, rank):
|
|
|
"""Get the policies for mixed precision and fsdp wrapping"""
|