Преглед изворни кода

add rank to save_train_params

Hamid Shojanazeri пре 1 година
родитељ
комит
668c364f6b
1 измењених фајлова са 5 додато и 4 уклоњено
  1. 5 4
      utils/train_utils.py

+ 5 - 4
utils/train_utils.py

@@ -79,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
             data_set_len = 0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
@@ -177,7 +176,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and fsdp_config:
     if train_config.enable_fsdp and fsdp_config:
-        save_train_params(train_config, fsdp_config)
+        save_train_params(train_config, fsdp_config, rank)
         
         
     return results
     return results
 
 
@@ -328,7 +327,7 @@ def get_policies(cfg, rank):
     wrapping_policy = get_llama_wrapper()
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
     return mixed_precision_policy, wrapping_policy
 
 
-def save_train_params(train_config, fsdp_config):
+def save_train_params(train_config, fsdp_config, rank):
     """
     """
     This function saves the train_config and FSDP config into a train_params.yaml.
     This function saves the train_config and FSDP config into a train_params.yaml.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
@@ -363,4 +362,6 @@ def save_train_params(train_config, fsdp_config):
     else:
     else:
         # Write the YAML string to the file
         # Write the YAML string to the file
         with open(file_name, 'w') as f:
         with open(file_name, 'w') as f:
-            f.write(config_yaml)
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")