浏览代码

added run_pipeline to run end to end

Ubuntu 2 月之前
父节点
当前提交
67aa33cbee

+ 5 - 5
src/finetune_pipeline/config.yaml

@@ -82,7 +82,7 @@ formatter:
   column_mapping:
     input: "query"             # Field containing the input text
     output: "response"              # Field containing the output text
-    image: "images"           # Field containing the image path (optional)
+    image: null           # Field containing the image path (optional)
 
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
@@ -101,7 +101,7 @@ finetuning:
 # vLLM Inference configuration
 inference:
   # Model configuration
-  model_path: "your/model/path" # Path to the model checkpoint
+  model_path: "/home/ubuntu/yash-workspace/medgemma-4b-it" # Path to the model checkpoint
   quantization: null            # Quantization method (awq, gptq, squeezellm)
   dtype: "auto"                 # Data type for model weights (half, float, bfloat16, auto)
   trust_remote_code: false      # Trust remote code when loading the model
@@ -112,12 +112,12 @@ inference:
 
   # Performance configuration
   tensor_parallel_size: 1       # Number of GPUs to use for tensor parallelism
-  max_model_len: 1024           # Maximum sequence length
-  max_num_seqs: 16              # Maximum number of sequences
+  max_model_len: 32           # Maximum sequence length
+  max_num_seqs: 1              # Maximum number of sequences
   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
   enforce_eager: false          # Enforce eager execution
 
-  eval_data: "your/eval/dataset/path" # Path to the evaluation dataset (optional)
+  data: "your/eval/dataset/path" # Path to the inference dataset (optional)
 
   # Additional vLLM parameters (optional)
   # swap_space: 4               # Size of CPU swap space in GiB

+ 4 - 2
src/finetune_pipeline/finetuning/custom_sft_dataset.py

@@ -11,6 +11,8 @@ def custom_sft_dataset(
     model_transform: Transform,
     dataset_path: str = "/tmp/train.json",
     train_on_input: bool = False,
+    split: str = "train",
+
 ) -> SFTDataset:
     """
     Creates a custom SFT dataset for fine-tuning.
@@ -18,6 +20,7 @@ def custom_sft_dataset(
     Args:
         dataset_path: Path to the formatted data JSON file
         train_on_input: Whether to train on input tokens
+        split: Dataset split to use
 
     Returns:
         SFTDataset: A dataset ready for fine-tuning with TorchTune
@@ -27,9 +30,8 @@ def custom_sft_dataset(
     ds = SFTDataset(
         source="json",
         data_files=dataset_path,
-        split="train",
+        split=split,
         message_transform=openaitomessage,
         model_transform=model_transform,
     )
     return ds
-

+ 4 - 4
src/finetune_pipeline/finetuning/run_finetuning.py

@@ -63,7 +63,7 @@ def read_config(config_path: str) -> Dict:
     return config
 
 
-def run_torch_tune(config_path: str, args=None):
+def run_torch_tune(training_config: Dict, args=None):
     """
     Run torch tune command with parameters from config file.
 
@@ -71,11 +71,11 @@ def run_torch_tune(config_path: str, args=None):
         config_path: Path to the configuration file
         args: Command line arguments that may include additional kwargs to pass to the command
     """
-    # Read the configuration
-    config = read_config(config_path)
+    # # Read the configuration
+    # config = read_config(config_path)
 
     # Extract parameters from config
-    training_config = config.get("finetuning", {})
+    # training_config = config.get("finetuning", {})
 
     # Initialize base_cmd to avoid "possibly unbound" error
     base_cmd = []

+ 2 - 1
src/finetune_pipeline/inference/start_vllm_server.py

@@ -156,7 +156,8 @@ def start_vllm_server(
 
     # Run the command
     try:
-        subprocess.run(cmd, check=True)
+        result = subprocess.run(cmd, check=True)
+        return result
     except subprocess.CalledProcessError as e:
         logger.error(f"Failed to start vLLM server: {e}")
         sys.exit(1)

+ 451 - 0
src/finetune_pipeline/run_pipeline.py

@@ -0,0 +1,451 @@
+#!/usr/bin/env python
+"""
+End-to-end pipeline for data loading, fine-tuning, and inference.
+
+This script integrates all the modules in the finetune_pipeline package:
+1. Data loading and formatting
+2. Model fine-tuning
+3. vLLM server startup
+4. Inference on the fine-tuned model
+
+Example usage:
+    python run_pipeline.py --config config.yaml
+    python run_pipeline.py --config config.yaml --skip-finetuning --skip-server
+    python run_pipeline.py --config config.yaml --only-inference
+"""
+
+import argparse
+import json
+import logging
+import os
+import subprocess
+import sys
+import time
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+# Configure logging
+logging.basicConfig(
+    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+    datefmt="%Y-%m-%d %H:%M:%S",
+    level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+# Import modules from the finetune_pipeline package
+from finetune_pipeline.data.data_loader import load_and_format_data, read_config
+from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
+# from finetune_pipeline.inference.run_inference import run_inference_on_eval_data
+from finetune_pipeline.inference.start_vllm_server import start_vllm_server
+
+
+def run_data_loading(config_path: str) -> Tuple[List[str], List[str]]:
+    """
+    Run the data loading and formatting step.
+
+    Args:
+        config_path: Path to the configuration file
+
+    Returns:
+        Tuple containing lists of paths to the formatted data and conversation data
+    """
+    logger.info("=== Step 1: Data Loading and Formatting ===")
+
+    # Read the configuration
+    config = read_config(config_path)
+    formatter_config = config.get("formatter", {})
+    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
+
+    # Create the output directory if it doesn't exist
+    os.makedirs(output_dir, exist_ok=True)
+
+    # Load and format the data
+    try:
+        formatted_data_paths, conversation_data_paths = load_and_format_data(
+            formatter_config, output_dir
+        )
+        logger.info(f"Data loading and formatting complete. Saved to {output_dir}")
+        logger.info(f"Formatted data paths: {formatted_data_paths}")
+        logger.info(f"Conversation data paths: {conversation_data_paths}")
+        return formatted_data_paths, conversation_data_paths
+    except Exception as e:
+        logger.error(f"Error during data loading and formatting: {e}")
+        raise
+
+
+def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
+    """
+    Run the fine-tuning step.
+
+    Args:
+        config_path: Path to the configuration file
+        formatted_data_paths: Paths to the formatted data
+
+    Returns:
+        Path to the fine-tuned model
+    """
+    logger.info("=== Step 2: Model Fine-tuning ===")
+
+    # Read the configuration
+    config = read_config(config_path)
+    finetuning_config = config.get("finetuning", {})
+    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+
+    # Get the path to the formatted data for the train split
+    train_data_path = None
+    for path in formatted_data_paths:
+        if "train_" in path:
+            train_data_path = path
+            break
+
+    if not train_data_path:
+        logger.warning("No train split found in formatted data. Using the first file.")
+        train_data_path = formatted_data_paths[0]
+
+    # Prepare additional kwargs for the fine-tuning
+    kwargs = f"dataset=finetune_pipeline.finetuning.custom_sft_dataset dataset.train_on_input=True dataset.dataset_path={train_data_path}"
+
+    # Create an args object to pass to run_torch_tune
+    class Args:
+        pass
+    args = Args()
+    args.kwargs = kwargs
+
+    # Run the fine-tuning
+    try:
+        logger.info(f"Starting fine-tuning with data from {train_data_path}")
+        run_torch_tune(finetuning_config, args=args)
+
+        # Get the path to the fine-tuned model
+        # This is a simplification; in reality, you'd need to extract the actual path from the fine-tuning output
+        model_path = os.path.join(output_dir, "finetuned_model")
+        logger.info(f"Fine-tuning complete. Model saved to {model_path}")
+        return model_path
+    except Exception as e:
+        logger.error(f"Error during fine-tuning: {e}")
+        raise
+
+
+def run_vllm_server(config_path: str, model_path: str) -> str:
+    """
+    Start the vLLM server.
+
+    Args:
+        config_path: Path to the configuration file
+        model_path: Path to the fine-tuned model
+
+    Returns:
+        URL of the vLLM server
+    """
+    logger.info("=== Step 3: Starting vLLM Server ===")
+
+    # Read the configuration
+    config = read_config(config_path)
+    inference_config = config.get("inference", {})
+
+    model_path = inference_config.get("model_path","/home/ubuntu/yash-workspace/medgemma-4b-it")
+
+    # # Update the model path in the inference config
+    # inference_config["model_path"] = model_path
+
+    # Extract server parameters
+    port = inference_config.get("port", 8000)
+    host = inference_config.get("host", "0.0.0.0")
+    tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
+    max_model_len = inference_config.get("max_model_len", 4096)
+    max_num_seqs = inference_config.get("max_num_seqs", 256)
+    quantization = inference_config.get("quantization")
+    gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.9)
+    enforce_eager = inference_config.get("enforce_eager", False)
+
+
+    # Start the server in a separate process
+    try:
+        logger.info(f"Starting vLLM server with model {model_path}")
+        result = start_vllm_server(model_path,
+                                    port,
+                                    host,
+                                    tensor_parallel_size,
+                                    max_model_len,
+                                    max_num_seqs,
+                                    quantization,
+                                    gpu_memory_utilization,
+                                    enforce_eager)
+        if result.returncode == 0:
+            server_url = f"http://{host}:{port}/v1"
+            logger.info(f"vLLM server started at {server_url}")
+            return server_url
+        else:
+            logger.error(f"vLLM server failed to start")
+            raise RuntimeError("vLLM server failed to start")
+    except Exception as e:
+        logger.error(f"Error starting vLLM server: {e}")
+        raise
+
+
+# def run_inference(
+#     config_path: str, server_url: str, formatted_data_paths: List[str]
+# ) -> str:
+#     """
+#     Run inference on the fine-tuned model.
+
+#     Args:
+#         config_path: Path to the configuration file
+#         server_url: URL of the vLLM server
+#         formatted_data_paths: Paths to the formatted data
+
+#     Returns:
+#         Path to the inference results
+#     """
+#     logger.info("=== Step 4: Running Inference ===")
+
+#     # Read the configuration
+#     config = read_config(config_path)
+#     inference_config = config.get("inference", {})
+#     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+
+#     # Get the path to the formatted data for the validation or test split
+#     eval_data_path = inference_config.get("eval_data")
+#     if not eval_data_path:
+#         # Try to find a validation or test split in the formatted data
+#         for path in formatted_data_paths:
+#             if "validation_" in path or "test_" in path:
+#                 eval_data_path = path
+#                 break
+
+#         if not eval_data_path:
+#             logger.warning(
+#                 "No validation or test split found in formatted data. Using the first file."
+#             )
+#             eval_data_path = formatted_data_paths[0]
+
+#     # Extract inference parameters
+#     model_name = inference_config.get("model_name", "default")
+#     temperature = inference_config.get("temperature", 0.0)
+#     top_p = inference_config.get("top_p", 1.0)
+#     max_tokens = inference_config.get("max_tokens", 100)
+#     seed = inference_config.get("seed")
+
+#     # Run inference
+#     try:
+#         logger.info(
+#             f"Running inference on {eval_data_path} using server at {server_url}"
+#         )
+#         results = run_inference_on_eval_data(
+#             eval_data_path=eval_data_path,
+#             server_url=server_url,
+#             is_local=True,  # Assuming the formatted data is local
+#             model_name=model_name,
+#             temperature=temperature,
+#             top_p=top_p,
+#             max_tokens=max_tokens,
+#             seed=seed,
+#         )
+
+#         # Save the results
+#         results_path = os.path.join(output_dir, "inference_results.json")
+#         with open(results_path, "w") as f:
+#             json.dump(results, f, indent=2)
+
+#         logger.info(f"Inference complete. Results saved to {results_path}")
+#         return results_path
+#     except Exception as e:
+#         logger.error(f"Error during inference: {e}")
+#         raise
+
+
+def run_pipeline(
+    config_path: str,
+    skip_data_loading: bool = False,
+    skip_finetuning: bool = False,
+    skip_server: bool = False,
+    skip_inference: bool = False,
+    only_data_loading: bool = False,
+    only_finetuning: bool = False,
+    only_server: bool = False,
+    only_inference: bool = False,
+) -> None:
+    """
+    Run the end-to-end pipeline.
+
+    Args:
+        config_path: Path to the configuration file
+        skip_data_loading: Whether to skip the data loading step
+        skip_finetuning: Whether to skip the fine-tuning step
+        skip_server: Whether to skip starting the vLLM server
+        skip_inference: Whether to skip the inference step
+        only_data_loading: Whether to run only the data loading step
+        only_finetuning: Whether to run only the fine-tuning step
+        only_server: Whether to run only the vLLM server step
+        only_inference: Whether to run only the inference step
+    """
+    logger.info(f"Starting pipeline with config {config_path}")
+
+    # Check if the config file exists
+    if not os.path.exists(config_path):
+        logger.error(f"Config file {config_path} does not exist")
+        sys.exit(1)
+
+    # Handle "only" flags
+    if only_data_loading:
+        skip_finetuning = True
+        skip_server = True
+        skip_inference = True
+    elif only_finetuning:
+        skip_data_loading = True
+        skip_server = True
+        skip_inference = True
+    elif only_server:
+        skip_data_loading = True
+        skip_finetuning = True
+        skip_inference = True
+    elif only_inference:
+        skip_data_loading = True
+        skip_finetuning = True
+        skip_server = True
+
+    # Step 1: Data Loading and Formatting
+    formatted_data_paths = []
+    conversation_data_paths = []
+    if not skip_data_loading:
+        try:
+            formatted_data_paths, conversation_data_paths = run_data_loading(
+                config_path
+            )
+        except Exception as e:
+            logger.error(f"Pipeline failed at data loading step: {e}")
+            sys.exit(1)
+    else:
+        logger.info("Skipping data loading step")
+        # Try to infer the paths from the config
+        config = read_config(config_path)
+        output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
+        formatter_type = config.get("formatter", {}).get("type", "torchtune")
+        # This is a simplification; in reality, you'd need to know the exact paths
+        formatted_data_paths = [
+            os.path.join(output_dir, f"train_{formatter_type}_formatted_data.json")
+        ]
+
+    # Step 2: Fine-tuning
+    model_path = ""
+    if not skip_finetuning:
+        try:
+            model_path = run_finetuning(config_path, formatted_data_paths)
+        except Exception as e:
+            logger.error(f"Pipeline failed at fine-tuning step: {e}")
+            sys.exit(1)
+    else:
+        logger.info("Skipping fine-tuning step")
+        # Try to infer the model path from the config
+        config = read_config(config_path)
+        output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
+        model_path = os.path.join(output_dir, "finetuned_model")
+
+    # Step 3: Start vLLM Server
+    server_url = ""
+    server_process = None
+    if not skip_server:
+        try:
+            server_url = run_vllm_server(config_path, model_path)
+        except Exception as e:
+            logger.error(f"Pipeline failed at vLLM server step: {e}")
+            sys.exit(1)
+    else:
+        logger.info("Skipping vLLM server step")
+        # Try to infer the server URL from the config
+        config = read_config(config_path)
+        inference_config = config.get("inference", {})
+        host = inference_config.get("host", "0.0.0.0")
+        port = inference_config.get("port", 8000)
+        server_url = f"http://{host}:{port}/v1"
+
+    # Step 4: Inference
+    if not skip_inference:
+        try:
+            results_path = run_inference(config_path, server_url, formatted_data_paths)
+            logger.info(
+                f"Pipeline completed successfully. Results saved to {results_path}"
+            )
+        except Exception as e:
+            logger.error(f"Pipeline failed at inference step: {e}")
+            sys.exit(1)
+    else:
+        logger.info("Skipping inference step")
+
+    logger.info("Pipeline execution complete")
+
+
+def main():
+    """Main function."""
+    parser = argparse.ArgumentParser(description="Run the end-to-end pipeline")
+
+    # Configuration
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="Path to the configuration file",
+    )
+
+    # Skip flags
+    parser.add_argument(
+        "--skip-data-loading",
+        action="store_true",
+        help="Skip the data loading step",
+    )
+    parser.add_argument(
+        "--skip-finetuning",
+        action="store_true",
+        help="Skip the fine-tuning step",
+    )
+    parser.add_argument(
+        "--skip-server",
+        action="store_true",
+        help="Skip starting the vLLM server",
+    )
+    parser.add_argument(
+        "--skip-inference",
+        action="store_true",
+        help="Skip the inference step",
+    )
+
+    # Only flags
+    parser.add_argument(
+        "--only-data-loading",
+        action="store_true",
+        help="Run only the data loading step",
+    )
+    parser.add_argument(
+        "--only-finetuning",
+        action="store_true",
+        help="Run only the fine-tuning step",
+    )
+    parser.add_argument(
+        "--only-server",
+        action="store_true",
+        help="Run only the vLLM server step",
+    )
+    parser.add_argument(
+        "--only-inference",
+        action="store_true",
+        help="Run only the inference step",
+    )
+
+    args = parser.parse_args()
+
+    # Run the pipeline
+    run_pipeline(
+        config_path=args.config,
+        skip_data_loading=args.skip_data_loading,
+        skip_finetuning=args.skip_finetuning,
+        skip_server=args.skip_server,
+        skip_inference=args.skip_inference,
+        only_data_loading=args.only_data_loading,
+        only_finetuning=args.only_finetuning,
+        only_server=args.only_server,
+        only_inference=args.only_inference,
+    )
+
+
+if __name__ == "__main__":
+    main()