|
@@ -35,7 +35,11 @@ logger = logging.getLogger(__name__)
|
|
|
# Import modules from the finetune_pipeline package
|
|
# Import modules from the finetune_pipeline package
|
|
|
from finetune_pipeline.data.data_loader import load_and_format_data, read_config
|
|
from finetune_pipeline.data.data_loader import load_and_format_data, read_config
|
|
|
from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
|
|
from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
|
|
|
-# from finetune_pipeline.inference.run_inference import run_inference_on_eval_data
|
|
|
|
|
|
|
+
|
|
|
|
|
+from finetune_pipeline.inference.run_inference import (
|
|
|
|
|
+ run_vllm_batch_inference_on_dataset,
|
|
|
|
|
+)
|
|
|
|
|
+from finetune_pipeline.inference.save_inference_results import save_inference_results
|
|
|
from finetune_pipeline.inference.start_vllm_server import start_vllm_server
|
|
from finetune_pipeline.inference.start_vllm_server import start_vllm_server
|
|
|
|
|
|
|
|
|
|
|
|
@@ -108,6 +112,7 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
|
|
|
# Create an args object to pass to run_torch_tune
|
|
# Create an args object to pass to run_torch_tune
|
|
|
class Args:
|
|
class Args:
|
|
|
pass
|
|
pass
|
|
|
|
|
+
|
|
|
args = Args()
|
|
args = Args()
|
|
|
args.kwargs = kwargs
|
|
args.kwargs = kwargs
|
|
|
|
|
|
|
@@ -143,7 +148,9 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
|
|
|
config = read_config(config_path)
|
|
config = read_config(config_path)
|
|
|
inference_config = config.get("inference", {})
|
|
inference_config = config.get("inference", {})
|
|
|
|
|
|
|
|
- model_path = inference_config.get("model_path","/home/ubuntu/yash-workspace/medgemma-4b-it")
|
|
|
|
|
|
|
+ model_path = inference_config.get(
|
|
|
|
|
+ "model_path", "/home/ubuntu/yash-workspace/medgemma-4b-it"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# # Update the model path in the inference config
|
|
# # Update the model path in the inference config
|
|
|
# inference_config["model_path"] = model_path
|
|
# inference_config["model_path"] = model_path
|
|
@@ -158,19 +165,20 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
|
|
|
gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
|
|
gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
|
|
|
enforce_eager = inference_config.get("enforce_eager", False)
|
|
enforce_eager = inference_config.get("enforce_eager", False)
|
|
|
|
|
|
|
|
-
|
|
|
|
|
# Start the server in a separate process
|
|
# Start the server in a separate process
|
|
|
try:
|
|
try:
|
|
|
logger.info(f"Starting vLLM server with model {model_path}")
|
|
logger.info(f"Starting vLLM server with model {model_path}")
|
|
|
- result = start_vllm_server(model_path,
|
|
|
|
|
- port,
|
|
|
|
|
- host,
|
|
|
|
|
- tensor_parallel_size,
|
|
|
|
|
- max_model_len,
|
|
|
|
|
- max_num_seqs,
|
|
|
|
|
- quantization,
|
|
|
|
|
- gpu_memory_utilization,
|
|
|
|
|
- enforce_eager)
|
|
|
|
|
|
|
+ result = start_vllm_server(
|
|
|
|
|
+ model_path,
|
|
|
|
|
+ port,
|
|
|
|
|
+ host,
|
|
|
|
|
+ tensor_parallel_size,
|
|
|
|
|
+ max_model_len,
|
|
|
|
|
+ max_num_seqs,
|
|
|
|
|
+ quantization,
|
|
|
|
|
+ gpu_memory_utilization,
|
|
|
|
|
+ enforce_eager,
|
|
|
|
|
+ )
|
|
|
if result.returncode == 0:
|
|
if result.returncode == 0:
|
|
|
server_url = f"http://{host}:{port}/v1"
|
|
server_url = f"http://{host}:{port}/v1"
|
|
|
logger.info(f"vLLM server started at {server_url}")
|
|
logger.info(f"vLLM server started at {server_url}")
|
|
@@ -183,75 +191,82 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
|
|
|
raise
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
|
-# def run_inference(
|
|
|
|
|
-# config_path: str, server_url: str, formatted_data_paths: List[str]
|
|
|
|
|
-# ) -> str:
|
|
|
|
|
-# """
|
|
|
|
|
-# Run inference on the fine-tuned model.
|
|
|
|
|
-
|
|
|
|
|
-# Args:
|
|
|
|
|
-# config_path: Path to the configuration file
|
|
|
|
|
-# server_url: URL of the vLLM server
|
|
|
|
|
-# formatted_data_paths: Paths to the formatted data
|
|
|
|
|
-
|
|
|
|
|
-# Returns:
|
|
|
|
|
-# Path to the inference results
|
|
|
|
|
-# """
|
|
|
|
|
-# logger.info("=== Step 4: Running Inference ===")
|
|
|
|
|
-
|
|
|
|
|
-# # Read the configuration
|
|
|
|
|
-# config = read_config(config_path)
|
|
|
|
|
-# inference_config = config.get("inference", {})
|
|
|
|
|
-# output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
|
|
|
|
|
-
|
|
|
|
|
-# # Get the path to the formatted data for the validation or test split
|
|
|
|
|
-# eval_data_path = inference_config.get("eval_data")
|
|
|
|
|
-# if not eval_data_path:
|
|
|
|
|
-# # Try to find a validation or test split in the formatted data
|
|
|
|
|
-# for path in formatted_data_paths:
|
|
|
|
|
-# if "validation_" in path or "test_" in path:
|
|
|
|
|
-# eval_data_path = path
|
|
|
|
|
-# break
|
|
|
|
|
-
|
|
|
|
|
-# if not eval_data_path:
|
|
|
|
|
-# logger.warning(
|
|
|
|
|
-# "No validation or test split found in formatted data. Using the first file."
|
|
|
|
|
-# )
|
|
|
|
|
-# eval_data_path = formatted_data_paths[0]
|
|
|
|
|
-
|
|
|
|
|
-# # Extract inference parameters
|
|
|
|
|
-# model_name = inference_config.get("model_name", "default")
|
|
|
|
|
-# temperature = inference_config.get("temperature", 0.0)
|
|
|
|
|
-# top_p = inference_config.get("top_p", 1.0)
|
|
|
|
|
-# max_tokens = inference_config.get("max_tokens", 100)
|
|
|
|
|
-# seed = inference_config.get("seed")
|
|
|
|
|
-
|
|
|
|
|
-# # Run inference
|
|
|
|
|
-# try:
|
|
|
|
|
-# logger.info(
|
|
|
|
|
-# f"Running inference on {eval_data_path} using server at {server_url}"
|
|
|
|
|
-# )
|
|
|
|
|
-# results = run_inference_on_eval_data(
|
|
|
|
|
-# eval_data_path=eval_data_path,
|
|
|
|
|
-# server_url=server_url,
|
|
|
|
|
-# is_local=True, # Assuming the formatted data is local
|
|
|
|
|
-# model_name=model_name,
|
|
|
|
|
-# temperature=temperature,
|
|
|
|
|
-# top_p=top_p,
|
|
|
|
|
-# max_tokens=max_tokens,
|
|
|
|
|
-# seed=seed,
|
|
|
|
|
-# )
|
|
|
|
|
-
|
|
|
|
|
-# # Save the results
|
|
|
|
|
-# results_path = os.path.join(output_dir, "inference_results.json")
|
|
|
|
|
-# with open(results_path, "w") as f:
|
|
|
|
|
-# json.dump(results, f, indent=2)
|
|
|
|
|
-
|
|
|
|
|
-# logger.info(f"Inference complete. Results saved to {results_path}")
|
|
|
|
|
-# return results_path
|
|
|
|
|
-# except Exception as e:
|
|
|
|
|
-# logger.error(f"Error during inference: {e}")
|
|
|
|
|
-# raise
|
|
|
|
|
|
|
+def run_inference(config_path: str, formatted_data_paths: List[str]) -> str:
|
|
|
|
|
+ """
|
|
|
|
|
+ Run inference on the fine-tuned model.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ config_path: Path to the configuration file
|
|
|
|
|
+ server_url: URL of the vLLM server
|
|
|
|
|
+ formatted_data_paths: Paths to the formatted data
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Path to the inference results
|
|
|
|
|
+ """
|
|
|
|
|
+ logger.info("=== Step 4: Running Inference ===")
|
|
|
|
|
+
|
|
|
|
|
+ config = read_config(config_path)
|
|
|
|
|
+ inference_config = config.get("inference", {})
|
|
|
|
|
+ formatter_config = config.get("formatter", {})
|
|
|
|
|
+ output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
|
|
|
|
|
+
|
|
|
|
|
+ # Model parameters
|
|
|
|
|
+ model_path = inference_config.get("model_path", None)
|
|
|
|
|
+ if model_path is None:
|
|
|
|
|
+ raise ValueError("model_path must be specified in the config")
|
|
|
|
|
+
|
|
|
|
|
+ # Get data path from parameters or config
|
|
|
|
|
+ inference_data_path = inference_config.get("inference_data", None)
|
|
|
|
|
+ if inference_data_path is None:
|
|
|
|
|
+ raise ValueError("Inference data path must be specified in config")
|
|
|
|
|
+ output_path = f"{output_dir}/inference_results.json"
|
|
|
|
|
+
|
|
|
|
|
+ # Performance parameters
|
|
|
|
|
+ gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
|
|
|
|
|
+ max_model_len = inference_config.get("max_model_len", 512)
|
|
|
|
|
+ tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
|
|
|
|
|
+ dtype = inference_config.get("dtype", "auto")
|
|
|
|
|
+ trust_remote_code = inference_config.get("trust_remote_code", False)
|
|
|
|
|
+
|
|
|
|
|
+ # Generation parameters
|
|
|
|
|
+ max_tokens = inference_config.get("max_tokens", 100)
|
|
|
|
|
+ temperature = inference_config.get("temperature", 0.0)
|
|
|
|
|
+ top_p = inference_config.get("top_p", 1.0)
|
|
|
|
|
+ seed = inference_config.get("seed")
|
|
|
|
|
+ structured = inference_config.get("structured", False)
|
|
|
|
|
+
|
|
|
|
|
+ # Data parameters
|
|
|
|
|
+ is_local = formatter_config.get("is_local", False)
|
|
|
|
|
+ dataset_kwargs = formatter_config.get("dataset_kwargs", {})
|
|
|
|
|
+ column_mapping = formatter_config.get("column_mapping", {})
|
|
|
|
|
+
|
|
|
|
|
+ # Run inference
|
|
|
|
|
+ try:
|
|
|
|
|
+ logger.info(f"Running inference on {inference_data_path}")
|
|
|
|
|
+ results = run_vllm_batch_inference_on_dataset(
|
|
|
|
|
+ inference_data_path,
|
|
|
|
|
+ model_path,
|
|
|
|
|
+ is_local,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ top_p,
|
|
|
|
|
+ max_tokens,
|
|
|
|
|
+ seed,
|
|
|
|
|
+ structured,
|
|
|
|
|
+ gpu_memory_utilization,
|
|
|
|
|
+ max_model_len,
|
|
|
|
|
+ dataset_kwargs,
|
|
|
|
|
+ column_mapping,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Save the results
|
|
|
|
|
+ results_path = os.path.join(output_dir, "inference_results.json")
|
|
|
|
|
+ save_inference_results(results, results_path)
|
|
|
|
|
+
|
|
|
|
|
+ logger.info(f"Inference complete. Results saved to {results_path}")
|
|
|
|
|
+ return results_path
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"Error during inference: {e}")
|
|
|
|
|
+ raise
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_pipeline(
|
|
def run_pipeline(
|
|
@@ -362,7 +377,7 @@ def run_pipeline(
|
|
|
# Step 4: Inference
|
|
# Step 4: Inference
|
|
|
if not skip_inference:
|
|
if not skip_inference:
|
|
|
try:
|
|
try:
|
|
|
- results_path = run_inference(config_path, server_url, formatted_data_paths)
|
|
|
|
|
|
|
+ results_path = run_inference(config_path, formatted_data_paths)
|
|
|
logger.info(
|
|
logger.info(
|
|
|
f"Pipeline completed successfully. Results saved to {results_path}"
|
|
f"Pipeline completed successfully. Results saved to {results_path}"
|
|
|
)
|
|
)
|
|
@@ -448,4 +463,4 @@ def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- main()
|
|
|
|
|
|
|
+ main()
|