|
@@ -486,8 +486,52 @@ def print_model_size(model, config, rank: int = 0) -> None:
|
|
|
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
|
|
|
|
|
|
+def print_frozen_model_status(model, config, rank: int = 0) -> None:
|
|
|
+ """
|
|
|
+ Print the frozen status of the model's and the number of trainable parameters after frozen.
|
|
|
|
|
|
-
|
|
|
+ Args:
|
|
|
+ model: The PyTorch model.
|
|
|
+ model_name (str): Name of the model.
|
|
|
+ rank (int, optional): Current process's rank. Defaults to 0.
|
|
|
+ """
|
|
|
+ if rank == 0:
|
|
|
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
+ print("After freezing the model:")
|
|
|
+ print(f"--> {config.model_name} has {trainable_params / 1e6} Million trainable params\n")
|
|
|
+
|
|
|
+ module_states = {}
|
|
|
+ # Iterate over all parameters
|
|
|
+ for name, param in model.named_parameters():
|
|
|
+ # Extract the top-level module name (e.g., "vision_model", "language_model")
|
|
|
+ top_module = name.split(".")[0]
|
|
|
+
|
|
|
+ # Initialize a record for the top-level module
|
|
|
+ if top_module not in module_states:
|
|
|
+ module_states[top_module] = {"frozen": [], "unfrozen": []}
|
|
|
+
|
|
|
+ # Group parameters into frozen or unfrozen
|
|
|
+ if param.requires_grad:
|
|
|
+ module_states[top_module]["unfrozen"].append(name)
|
|
|
+ else:
|
|
|
+ module_states[top_module]["frozen"].append(name)
|
|
|
+
|
|
|
+ print("--> Model state after freezing:")
|
|
|
+ # Analyze and print the results
|
|
|
+ for module, states in module_states.items():
|
|
|
+ frozen_params = states["frozen"]
|
|
|
+ unfrozen_params = states["unfrozen"]
|
|
|
+
|
|
|
+ if frozen_params and unfrozen_params:
|
|
|
+ # Mixed state: both frozen and unfrozen parameters
|
|
|
+ print(f" {module}: Mixed")
|
|
|
+ elif frozen_params:
|
|
|
+ # All parameters are frozen
|
|
|
+ print(f" {module}: Frozen")
|
|
|
+ else:
|
|
|
+ # All parameters are unfrozen
|
|
|
+ print(f" {module}: Unfrozen")
|
|
|
+ print("")
|
|
|
|
|
|
def get_policies(cfg, rank):
|
|
|
"""Get the policies for mixed precision and fsdp wrapping"""
|