| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451 |
- #!/usr/bin/env python
- """
- End-to-end pipeline for data loading, fine-tuning, and inference.
- This script integrates all the modules in the finetune_pipeline package:
- 1. Data loading and formatting
- 2. Model fine-tuning
- 3. vLLM server startup
- 4. Inference on the fine-tuned model
- Example usage:
- python run_pipeline.py --config config.yaml
- python run_pipeline.py --config config.yaml --skip-finetuning --skip-server
- python run_pipeline.py --config config.yaml --only-inference
- """
- import argparse
- import json
- import logging
- import os
- import subprocess
- import sys
- import time
- from pathlib import Path
- from typing import Dict, List, Optional, Tuple
- # Configure logging
- logging.basicConfig(
- format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
- datefmt="%Y-%m-%d %H:%M:%S",
- level=logging.INFO,
- )
- logger = logging.getLogger(__name__)
- # Import modules from the finetune_pipeline package
- from finetune_pipeline.data.data_loader import load_and_format_data, read_config
- from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
- # from finetune_pipeline.inference.run_inference import run_inference_on_eval_data
- from finetune_pipeline.inference.start_vllm_server import start_vllm_server
- def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
- """
- Run the data loading and formatting step.
- Args:
- config_path: Path to the configuration file
- Returns:
- Tuple containing lists of paths to the formatted data and conversation data
- """
- logger.info("=== Step 1: Data Loading and Formatting ===")
- # Read the configuration
- config = read_config(config_path)
- formatter_config = config.get("formatter", {})
- output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
- # Create the output directory if it doesn't exist
- os.makedirs(output_dir, exist_ok=True)
- # Load and format the data
- try:
- formatted_data_paths, conversation_data_paths = load_and_format_data(
- formatter_config, output_dir
- )
- logger.info(f"Data loading and formatting complete. Saved to {output_dir}")
- logger.info(f"Formatted data paths: {formatted_data_paths}")
- logger.info(f"Conversation data paths: {conversation_data_paths}")
- return formatted_data_paths, conversation_data_paths
- except Exception as e:
- logger.error(f"Error during data loading and formatting: {e}")
- raise
- def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
- """
- Run the fine-tuning step.
- Args:
- config_path: Path to the configuration file
- formatted_data_paths: Paths to the formatted data
- Returns:
- Path to the fine-tuned model
- """
- logger.info("=== Step 2: Model Fine-tuning ===")
- # Read the configuration
- config = read_config(config_path)
- finetuning_config = config.get("finetuning", {})
- output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
- # Get the path to the formatted data for the train split
- train_data_path = None
- for path in formatted_data_paths:
- if "train_" in path:
- train_data_path = path
- break
- if not train_data_path:
- logger.warning("No train split found in formatted data. Using the first file.")
- train_data_path = formatted_data_paths[0]
- # Prepare additional kwargs for the fine-tuning
- kwargs = f"dataset=finetune_pipeline.finetuning.custom_sft_dataset dataset.train_on_input=True dataset.dataset_path={train_data_path}"
- # Create an args object to pass to run_torch_tune
- class Args:
- pass
- args = Args()
- args.kwargs = kwargs
- # Run the fine-tuning
- try:
- logger.info(f"Starting fine-tuning with data from {train_data_path}")
- run_torch_tune(finetuning_config, args=args)
- # Get the path to the fine-tuned model
- # This is a simplification; in reality, you'd need to extract the actual path from the fine-tuning output
- model_path = os.path.join(output_dir, "finetuned_model")
- logger.info(f"Fine-tuning complete. Model saved to {model_path}")
- return model_path
- except Exception as e:
- logger.error(f"Error during fine-tuning: {e}")
- raise
- def run_vllm_server(config_path: str, model_path: str) -> str:
- """
- Start the vLLM server.
- Args:
- config_path: Path to the configuration file
- model_path: Path to the fine-tuned model
- Returns:
- URL of the vLLM server
- """
- logger.info("=== Step 3: Starting vLLM Server ===")
- # Read the configuration
- config = read_config(config_path)
- inference_config = config.get("inference", {})
- model_path = inference_config.get("model_path","/home/ubuntu/yash-workspace/medgemma-4b-it")
- # # Update the model path in the inference config
- # inference_config["model_path"] = model_path
- # Extract server parameters
- port = inference_config.get("port", 8000)
- host = inference_config.get("host", "0.0.0.0")
- tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
- max_model_len = inference_config.get("max_model_len", 4096)
- max_num_seqs = inference_config.get("max_num_seqs", 256)
- quantization = inference_config.get("quantization")
- gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
- enforce_eager = inference_config.get("enforce_eager", False)
- # Start the server in a separate process
- try:
- logger.info(f"Starting vLLM server with model {model_path}")
- result = start_vllm_server(model_path,
- port,
- host,
- tensor_parallel_size,
- max_model_len,
- max_num_seqs,
- quantization,
- gpu_memory_utilization,
- enforce_eager)
- if result.returncode == 0:
- server_url = f"http://{host}:{port}/v1"
- logger.info(f"vLLM server started at {server_url}")
- return server_url
- else:
- logger.error(f"vLLM server failed to start")
- raise RuntimeError("vLLM server failed to start")
- except Exception as e:
- logger.error(f"Error starting vLLM server: {e}")
- raise
- # def run_inference(
- # config_path: str, server_url: str, formatted_data_paths: List[str]
- # ) -> str:
- # """
- # Run inference on the fine-tuned model.
- # Args:
- # config_path: Path to the configuration file
- # server_url: URL of the vLLM server
- # formatted_data_paths: Paths to the formatted data
- # Returns:
- # Path to the inference results
- # """
- # logger.info("=== Step 4: Running Inference ===")
- # # Read the configuration
- # config = read_config(config_path)
- # inference_config = config.get("inference", {})
- # output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
- # # Get the path to the formatted data for the validation or test split
- # eval_data_path = inference_config.get("eval_data")
- # if not eval_data_path:
- # # Try to find a validation or test split in the formatted data
- # for path in formatted_data_paths:
- # if "validation_" in path or "test_" in path:
- # eval_data_path = path
- # break
- # if not eval_data_path:
- # logger.warning(
- # "No validation or test split found in formatted data. Using the first file."
- # )
- # eval_data_path = formatted_data_paths[0]
- # # Extract inference parameters
- # model_name = inference_config.get("model_name", "default")
- # temperature = inference_config.get("temperature", 0.0)
- # top_p = inference_config.get("top_p", 1.0)
- # max_tokens = inference_config.get("max_tokens", 100)
- # seed = inference_config.get("seed")
- # # Run inference
- # try:
- # logger.info(
- # f"Running inference on {eval_data_path} using server at {server_url}"
- # )
- # results = run_inference_on_eval_data(
- # eval_data_path=eval_data_path,
- # server_url=server_url,
- # is_local=True, # Assuming the formatted data is local
- # model_name=model_name,
- # temperature=temperature,
- # top_p=top_p,
- # max_tokens=max_tokens,
- # seed=seed,
- # )
- # # Save the results
- # results_path = os.path.join(output_dir, "inference_results.json")
- # with open(results_path, "w") as f:
- # json.dump(results, f, indent=2)
- # logger.info(f"Inference complete. Results saved to {results_path}")
- # return results_path
- # except Exception as e:
- # logger.error(f"Error during inference: {e}")
- # raise
- def run_pipeline(
- config_path: str,
- skip_data_loading: bool = False,
- skip_finetuning: bool = False,
- skip_server: bool = False,
- skip_inference: bool = False,
- only_data_loading: bool = False,
- only_finetuning: bool = False,
- only_server: bool = False,
- only_inference: bool = False,
- ) -> None:
- """
- Run the end-to-end pipeline.
- Args:
- config_path: Path to the configuration file
- skip_data_loading: Whether to skip the data loading step
- skip_finetuning: Whether to skip the fine-tuning step
- skip_server: Whether to skip starting the vLLM server
- skip_inference: Whether to skip the inference step
- only_data_loading: Whether to run only the data loading step
- only_finetuning: Whether to run only the fine-tuning step
- only_server: Whether to run only the vLLM server step
- only_inference: Whether to run only the inference step
- """
- logger.info(f"Starting pipeline with config {config_path}")
- # Check if the config file exists
- if not os.path.exists(config_path):
- logger.error(f"Config file {config_path} does not exist")
- sys.exit(1)
- # Handle "only" flags
- if only_data_loading:
- skip_finetuning = True
- skip_server = True
- skip_inference = True
- elif only_finetuning:
- skip_data_loading = True
- skip_server = True
- skip_inference = True
- elif only_server:
- skip_data_loading = True
- skip_finetuning = True
- skip_inference = True
- elif only_inference:
- skip_data_loading = True
- skip_finetuning = True
- skip_server = True
- # Step 1: Data Loading and Formatting
- formatted_data_paths = []
- conversation_data_paths = []
- if not skip_data_loading:
- try:
- formatted_data_paths, conversation_data_paths = run_data_loading(
- config_path
- )
- except Exception as e:
- logger.error(f"Pipeline failed at data loading step: {e}")
- sys.exit(1)
- else:
- logger.info("Skipping data loading step")
- # Try to infer the paths from the config
- config = read_config(config_path)
- output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
- formatter_type = config.get("formatter", {}).get("type", "torchtune")
- # This is a simplification; in reality, you'd need to know the exact paths
- formatted_data_paths = [
- os.path.join(output_dir, f"train_{formatter_type}_formatted_data.json")
- ]
- # Step 2: Fine-tuning
- model_path = ""
- if not skip_finetuning:
- try:
- model_path = run_finetuning(config_path, formatted_data_paths)
- except Exception as e:
- logger.error(f"Pipeline failed at fine-tuning step: {e}")
- sys.exit(1)
- else:
- logger.info("Skipping fine-tuning step")
- # Try to infer the model path from the config
- config = read_config(config_path)
- output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
- model_path = os.path.join(output_dir, "finetuned_model")
- # Step 3: Start vLLM Server
- server_url = ""
- server_process = None
- if not skip_server:
- try:
- server_url = run_vllm_server(config_path, model_path)
- except Exception as e:
- logger.error(f"Pipeline failed at vLLM server step: {e}")
- sys.exit(1)
- else:
- logger.info("Skipping vLLM server step")
- # Try to infer the server URL from the config
- config = read_config(config_path)
- inference_config = config.get("inference", {})
- host = inference_config.get("host", "0.0.0.0")
- port = inference_config.get("port", 8000)
- server_url = f"http://{host}:{port}/v1"
- # Step 4: Inference
- if not skip_inference:
- try:
- results_path = run_inference(config_path, server_url, formatted_data_paths)
- logger.info(
- f"Pipeline completed successfully. Results saved to {results_path}"
- )
- except Exception as e:
- logger.error(f"Pipeline failed at inference step: {e}")
- sys.exit(1)
- else:
- logger.info("Skipping inference step")
- logger.info("Pipeline execution complete")
- def main():
- """Main function."""
- parser = argparse.ArgumentParser(description="Run the end-to-end pipeline")
- # Configuration
- parser.add_argument(
- "--config",
- type=str,
- required=True,
- help="Path to the configuration file",
- )
- # Skip flags
- parser.add_argument(
- "--skip-data-loading",
- action="store_true",
- help="Skip the data loading step",
- )
- parser.add_argument(
- "--skip-finetuning",
- action="store_true",
- help="Skip the fine-tuning step",
- )
- parser.add_argument(
- "--skip-server",
- action="store_true",
- help="Skip starting the vLLM server",
- )
- parser.add_argument(
- "--skip-inference",
- action="store_true",
- help="Skip the inference step",
- )
- # Only flags
- parser.add_argument(
- "--only-data-loading",
- action="store_true",
- help="Run only the data loading step",
- )
- parser.add_argument(
- "--only-finetuning",
- action="store_true",
- help="Run only the fine-tuning step",
- )
- parser.add_argument(
- "--only-server",
- action="store_true",
- help="Run only the vLLM server step",
- )
- parser.add_argument(
- "--only-inference",
- action="store_true",
- help="Run only the inference step",
- )
- args = parser.parse_args()
- # Run the pipeline
- run_pipeline(
- config_path=args.config,
- skip_data_loading=args.skip_data_loading,
- skip_finetuning=args.skip_finetuning,
- skip_server=args.skip_server,
- skip_inference=args.skip_inference,
- only_data_loading=args.only_data_loading,
- only_finetuning=args.only_finetuning,
- only_server=args.only_server,
- only_inference=args.only_inference,
- )
- if __name__ == "__main__":
- main()
|