瀏覽代碼

pipeline updated

khare19yash 2 月之前
父節點
當前提交
f4aca9959b
共有 2 個文件被更改,包括 101 次插入86 次删除
  1. 3 3
      src/finetune_pipeline/inference/run_inference.py
  2. 98 83
      src/finetune_pipeline/run_pipeline.py

+ 3 - 3
src/finetune_pipeline/inference/run_inference.py

@@ -215,8 +215,8 @@ def main():
     inference_data_path = inference_config.get("inference_data", None)
     if inference_data_path is None:
         raise ValueError("Inference data path must be specified in config")
-    output_dir = inference_config.get("output_dir")
-    output_path = f"{output_dir}/inference_results.json"
+    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+    results_path = f"{output_dir}/inference_results.json"
 
     # Performance parameters
     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
@@ -252,7 +252,7 @@ def main():
         column_mapping,
     )
 
-    save_inference_results(results, output_path)
+    save_inference_results(results, results_path)
 
 
 if __name__ == "__main__":

+ 98 - 83
src/finetune_pipeline/run_pipeline.py

@@ -35,7 +35,11 @@ logger = logging.getLogger(__name__)
 # Import modules from the finetune_pipeline package
 from finetune_pipeline.data.data_loader import load_and_format_data, read_config
 from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
-# from finetune_pipeline.inference.run_inference import run_inference_on_eval_data
+
+from finetune_pipeline.inference.run_inference import (
+    run_vllm_batch_inference_on_dataset,
+)
+from finetune_pipeline.inference.save_inference_results import save_inference_results
 from finetune_pipeline.inference.start_vllm_server import start_vllm_server
 
 
@@ -108,6 +112,7 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
     # Create an args object to pass to run_torch_tune
     class Args:
         pass
+
     args = Args()
     args.kwargs = kwargs
 
@@ -143,7 +148,9 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
     config = read_config(config_path)
     inference_config = config.get("inference", {})
 
-    model_path = inference_config.get("model_path","/home/ubuntu/yash-workspace/medgemma-4b-it")
+    model_path = inference_config.get(
+        "model_path", "/home/ubuntu/yash-workspace/medgemma-4b-it"
+    )
 
     # # Update the model path in the inference config
     # inference_config["model_path"] = model_path
@@ -158,19 +165,20 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
     enforce_eager = inference_config.get("enforce_eager", False)
 
-
     # Start the server in a separate process
     try:
         logger.info(f"Starting vLLM server with model {model_path}")
-        result = start_vllm_server(model_path,
-                                    port,
-                                    host,
-                                    tensor_parallel_size,
-                                    max_model_len,
-                                    max_num_seqs,
-                                    quantization,
-                                    gpu_memory_utilization,
-                                    enforce_eager)
+        result = start_vllm_server(
+            model_path,
+            port,
+            host,
+            tensor_parallel_size,
+            max_model_len,
+            max_num_seqs,
+            quantization,
+            gpu_memory_utilization,
+            enforce_eager,
+        )
         if result.returncode == 0:
             server_url = f"http://{host}:{port}/v1"
             logger.info(f"vLLM server started at {server_url}")
@@ -183,75 +191,82 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
         raise
 
 
-# def run_inference(
-#     config_path: str, server_url: str, formatted_data_paths: List[str]
-# ) -> str:
-#     """
-#     Run inference on the fine-tuned model.
-
-#     Args:
-#         config_path: Path to the configuration file
-#         server_url: URL of the vLLM server
-#         formatted_data_paths: Paths to the formatted data
-
-#     Returns:
-#         Path to the inference results
-#     """
-#     logger.info("=== Step 4: Running Inference ===")
-
-#     # Read the configuration
-#     config = read_config(config_path)
-#     inference_config = config.get("inference", {})
-#     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
-
-#     # Get the path to the formatted data for the validation or test split
-#     eval_data_path = inference_config.get("eval_data")
-#     if not eval_data_path:
-#         # Try to find a validation or test split in the formatted data
-#         for path in formatted_data_paths:
-#             if "validation_" in path or "test_" in path:
-#                 eval_data_path = path
-#                 break
-
-#         if not eval_data_path:
-#             logger.warning(
-#                 "No validation or test split found in formatted data. Using the first file."
-#             )
-#             eval_data_path = formatted_data_paths[0]
-
-#     # Extract inference parameters
-#     model_name = inference_config.get("model_name", "default")
-#     temperature = inference_config.get("temperature", 0.0)
-#     top_p = inference_config.get("top_p", 1.0)
-#     max_tokens = inference_config.get("max_tokens", 100)
-#     seed = inference_config.get("seed")
-
-#     # Run inference
-#     try:
-#         logger.info(
-#             f"Running inference on {eval_data_path} using server at {server_url}"
-#         )
-#         results = run_inference_on_eval_data(
-#             eval_data_path=eval_data_path,
-#             server_url=server_url,
-#             is_local=True,  # Assuming the formatted data is local
-#             model_name=model_name,
-#             temperature=temperature,
-#             top_p=top_p,
-#             max_tokens=max_tokens,
-#             seed=seed,
-#         )
-
-#         # Save the results
-#         results_path = os.path.join(output_dir, "inference_results.json")
-#         with open(results_path, "w") as f:
-#             json.dump(results, f, indent=2)
-
-#         logger.info(f"Inference complete. Results saved to {results_path}")
-#         return results_path
-#     except Exception as e:
-#         logger.error(f"Error during inference: {e}")
-#         raise
+def run_inference(config_path: str, formatted_data_paths: List[str]) -> str:
+    """
+    Run inference on the fine-tuned model.
+
+    Args:
+        config_path: Path to the configuration file
+        server_url: URL of the vLLM server
+        formatted_data_paths: Paths to the formatted data
+
+    Returns:
+        Path to the inference results
+    """
+    logger.info("=== Step 4: Running Inference ===")
+
+    config = read_config(config_path)
+    inference_config = config.get("inference", {})
+    formatter_config = config.get("formatter", {})
+    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+
+    # Model parameters
+    model_path = inference_config.get("model_path", None)
+    if model_path is None:
+        raise ValueError("model_path must be specified in the config")
+
+    # Get data path from parameters or config
+    inference_data_path = inference_config.get("inference_data", None)
+    if inference_data_path is None:
+        raise ValueError("Inference data path must be specified in config")
+    output_path = f"{output_dir}/inference_results.json"
+
+    # Performance parameters
+    gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
+    max_model_len = inference_config.get("max_model_len", 512)
+    tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
+    dtype = inference_config.get("dtype", "auto")
+    trust_remote_code = inference_config.get("trust_remote_code", False)
+
+    # Generation parameters
+    max_tokens = inference_config.get("max_tokens", 100)
+    temperature = inference_config.get("temperature", 0.0)
+    top_p = inference_config.get("top_p", 1.0)
+    seed = inference_config.get("seed")
+    structured = inference_config.get("structured", False)
+
+    # Data parameters
+    is_local = formatter_config.get("is_local", False)
+    dataset_kwargs = formatter_config.get("dataset_kwargs", {})
+    column_mapping = formatter_config.get("column_mapping", {})
+
+    # Run inference
+    try:
+        logger.info(f"Running inference on {inference_data_path}")
+        results = run_vllm_batch_inference_on_dataset(
+            inference_data_path,
+            model_path,
+            is_local,
+            temperature,
+            top_p,
+            max_tokens,
+            seed,
+            structured,
+            gpu_memory_utilization,
+            max_model_len,
+            dataset_kwargs,
+            column_mapping,
+        )
+
+        # Save the results
+        results_path = os.path.join(output_dir, "inference_results.json")
+        save_inference_results(results, results_path)
+
+        logger.info(f"Inference complete. Results saved to {results_path}")
+        return results_path
+    except Exception as e:
+        logger.error(f"Error during inference: {e}")
+        raise
 
 
 def run_pipeline(
@@ -362,7 +377,7 @@ def run_pipeline(
     # Step 4: Inference
     if not skip_inference:
         try:
-            results_path = run_inference(config_path, server_url, formatted_data_paths)
+            results_path = run_inference(config_path, formatted_data_paths)
             logger.info(
                 f"Pipeline completed successfully. Results saved to {results_path}"
             )
@@ -448,4 +463,4 @@ def main():
 
 
 if __name__ == "__main__":
-    main()
+    main()