浏览代码

Fix save metric FileNotFoundError when finetuning (#499)

Hamid Shojanazeri 11 月之前
父节点
当前提交
88324236e7
共有 1 个文件被更改,包括 3 次插入4 次删除
  1. 3 4
      src/llama_recipes/utils/train_utils.py

+ 3 - 4
src/llama_recipes/utils/train_utils.py

@@ -103,10 +103,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     val_loss =[]
 
     if train_config.save_metrics:
-        output_dir = train_config.output_dir
-        if not os.path.exists(output_dir):
-            os.makedirs(output_dir)
-        metrics_filename = f"{output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+        if not os.path.exists(train_config.output_dir):
+            os.makedirs(train_config.output_dir, exist_ok=True)
+        metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
         train_step_perplexity = []
         train_step_loss = []
         val_step_loss = []