Kaynağa Gözat

Fix model parameter mismatch by printing parameters before FSDP

JimChienTW 5 ay önce
ebeveyn
işleme
d1195a6fd8
1 değiştirilmiş dosya ile 2 ekleme ve 3 silme
  1. 2 3
      src/llama_recipes/finetuning.py

+ 2 - 3
src/llama_recipes/finetuning.py

@@ -237,7 +237,8 @@ def main(**kwargs):
             
         if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
             freeze_LLM_only(model)
-            
+        
+        print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
@@ -306,8 +307,6 @@ def main(**kwargs):
         dataset_processer = processor
     else:
         dataset_processer = tokenizer
-
-    print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     
     # Load and preprocess the dataset for training and validation