瀏覽代碼

epoch instead of epochs

khare19yash 1 月之前
父節點
當前提交
1952c8e1cd
共有 1 個文件被更改,包括 2 次插入1 次删除
  1. 2 1
      src/finetune_pipeline/run_pipeline.py

+ 2 - 1
src/finetune_pipeline/run_pipeline.py

@@ -123,7 +123,7 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
         # Get the path to the latest chekpoint of the fine-tuned model
         model_output_dir = finetuning_config.get("output_dir", config.get("output_dir"))
         epochs = finetuning_config.get("epochs", 1)
-        checkpoint_path = os.path.join(model_output_dir, f"epochs_{epochs-1}")
+        checkpoint_path = os.path.join(model_output_dir, f"epoch_{epochs-1}")
         logger.info(
             f"Fine-tuning complete. Latest checkpoint saved to {checkpoint_path}"
         )
@@ -365,6 +365,7 @@ def run_pipeline(
         output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
         model_path = os.path.join(output_dir, "finetuned_model")
 
+    time.sleep(5)
     # Step 3: Inference
     if not skip_inference:
         try: