浏览代码

add rank to save_train_params

Hamid Shojanazeri 1 年之前
父节点
当前提交
668c364f6b
共有 1 个文件被更改,包括 5 次插入4 次删除
  1. 5 4
      utils/train_utils.py

+ 5 - 4
utils/train_utils.py

@@ -79,7 +79,6 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
             data_set_len = 0
             data_set_len = 0
-            
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
             for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
@@ -177,7 +176,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         
         
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and fsdp_config:
     if train_config.enable_fsdp and fsdp_config:
-        save_train_params(train_config, fsdp_config)
+        save_train_params(train_config, fsdp_config, rank)
         
         
     return results
     return results
 
 
@@ -328,7 +327,7 @@ def get_policies(cfg, rank):
     wrapping_policy = get_llama_wrapper()
     wrapping_policy = get_llama_wrapper()
     return mixed_precision_policy, wrapping_policy
     return mixed_precision_policy, wrapping_policy
 
 
-def save_train_params(train_config, fsdp_config):
+def save_train_params(train_config, fsdp_config, rank):
     """
     """
     This function saves the train_config and FSDP config into a train_params.yaml.
     This function saves the train_config and FSDP config into a train_params.yaml.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
@@ -363,4 +362,6 @@ def save_train_params(train_config, fsdp_config):
     else:
     else:
         # Write the YAML string to the file
         # Write the YAML string to the file
         with open(file_name, 'w') as f:
         with open(file_name, 'w') as f:
-            f.write(config_yaml)
+            f.write(config_yaml)
+        if rank==0:
+            print(f"training params are saved in {file_name}")