Browse Source

updated pipeline and added requirements.txt

khare19yash 1 month ago
parent
commit
2ac2ef26fb

+ 11 - 5
src/finetune_pipeline/config.yaml

@@ -86,14 +86,14 @@ formatter:
 
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
-    split: "validation"                # Dataset split to load
+    split: "train"                # Dataset split to load
 
 # Training configuration
 finetuning:
   model_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct" # Path to the model checkpoint
   tokenizer_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct/original/tokenizer.model" # Path to the tokenizer
-  output_dir: ${output_dir}/model_outputs  # Directory to store checkpoints
-  log_dir: ${output_dir}/logs  # Directory to store logs
+  output_dir: /home/yashkhare/workspace/finetuning-pipeline/model_outputs  # Directory to store checkpoints
+  log_dir: /home/yashkhare/workspace/finetuning-pipeline/logs  # Directory to store logs
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
   batch_size: 4                 # Batch size per device for training
@@ -105,7 +105,7 @@ finetuning:
 # vLLM Inference configuration
 inference:
   # Model configuration
-  model_path: "/home/yashkhare/workspace/medgemma-4b-it" # Path to the model checkpoint
+  model_path: "/home/yashkhare/workspace/finetuning-pipeline/model_outputs/epoch_0" # Path to the model checkpoint
   quantization: null            # Quantization method (awq, gptq, squeezellm)
   dtype: "auto"                 # Data type for model weights (half, float, bfloat16, auto)
   trust_remote_code: false      # Trust remote code when loading the model
@@ -121,7 +121,13 @@ inference:
   gpu_memory_utilization: 0.95   # Fraction of GPU memory to use
   enforce_eager: false          # Enforce eager execution
 
-  inference_data: "dz-osamu/IU-Xray" # Path to the inference dataset (optional)
+  inference_data_kwargs:
+    data_path: "dz-osamu/IU-Xray"     # Path to the inference dataset
+    split: "validation"               # Dataset split to load
+    formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+    format_data: true                # Whether to format the inference dataset
+    max_samples: null                 # Maximum number of samples to load (null for all)
+    is_local: false                   # Whether the data is stored locally
 
   # Additional vLLM parameters (optional)
   # swap_space: 4               # Size of CPU swap space in GiB

+ 2 - 2
src/finetune_pipeline/data/data_loader.py

@@ -417,8 +417,8 @@ if __name__ == "__main__":
     config = read_config(args.config)
     formatter_config = config.get("formatter", {})
     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
-
+    output_data_dir = os.path.join(output_dir, "data")
     # Load and format the data
     formatted_data_paths, conversation_data_paths = load_and_format_data(
-        formatter_config, output_dir
+        formatter_config, output_data_dir
     )

+ 0 - 20
src/finetune_pipeline/inference/__init__.py

@@ -3,23 +3,3 @@ Inference utilities for LLMs.
 
 This module provides tools for running inference with fine-tuned models.
 """
-
-# from .inference import (
-#     run_inference_from_config,
-#     run_inference_on_eval_data,
-#     VLLMClient,
-#     VLLMInferenceRequest,
-# )
-# from .start_vllm_server import check_vllm_installed, read_config, start_vllm_server
-
-# __all__ = [
-#     # From inference
-#     "VLLMClient",
-#     "VLLMInferenceRequest",
-#     "run_inference_on_eval_data",
-#     "run_inference_from_config",
-#     # From start_vllm_server
-#     "start_vllm_server",
-#     "read_config",
-#     "check_vllm_installed",
-# ]

+ 233 - 44
src/finetune_pipeline/inference/run_inference.py

@@ -13,12 +13,231 @@ from ..data.data_loader import (
     get_formatter,
     load_data,
     read_config,
+    save_formatted_data,
 )
 
 # Set up logging
 logger = logging.getLogger(__name__)
 
 
+def load_inference_data(
+    inference_data_kwargs: Dict, formatter_config: Optional[Dict] = None
+) -> List[Dict]:
+    """
+    Load and format inference data using inference_data_kwargs configuration.
+
+    Args:
+        inference_data_kwargs: Dictionary containing all inference data parameters
+        formatter_config: Fallback formatter configuration for compatibility
+
+    Returns:
+        List of formatted data dictionaries ready for inference
+
+    Raises:
+        ValueError: If required parameters are missing or invalid
+        FileNotFoundError: If local file path doesn't exist
+        Exception: For data loading or formatting errors
+    """
+    # Extract parameters from inference_data_kwargs
+    path = inference_data_kwargs.get("data_path")
+    if not path:
+        raise ValueError("data_path is required in inference_data_kwargs")
+
+    format_data = inference_data_kwargs.get("format_data", True)
+    formatter_type = inference_data_kwargs.get("formatter_type", "vllm")
+    max_samples = inference_data_kwargs.get("max_samples")
+    is_local = inference_data_kwargs.get("is_local", False)
+    split = inference_data_kwargs.get("split", "validation")
+
+    # Build dataset_kwargs
+    dataset_kwargs = {"split": split}
+    if "dataset_kwargs" in inference_data_kwargs:
+        dataset_kwargs.update(inference_data_kwargs["dataset_kwargs"])
+
+    # Use formatter_config for column mapping if provided
+    column_mapping = {}
+    if formatter_config:
+        column_mapping = formatter_config.get("column_mapping", {})
+
+    # Validate formatter type
+    valid_formatters = ["vllm", "torchtune", "openai"]
+    if formatter_type not in valid_formatters:
+        raise ValueError(
+            f"Invalid formatter_type '{formatter_type}'. Must be one of {valid_formatters}"
+        )
+
+    logger.info(f"Loading inference data from: {path}")
+    logger.info(f"Format data: {format_data}, Formatter type: {formatter_type}")
+    if max_samples:
+        logger.info(f"Max samples: {max_samples}")
+
+    formatted_data = []
+
+    if format_data:
+        # Validate column mapping
+        if not column_mapping:
+            logger.warning("No column mapping provided. Using default column names.")
+            column_mapping = {"input": "input", "output": "output"}
+
+        # Check if local file exists
+        if is_local:
+            import os
+
+            if not os.path.exists(path):
+                raise FileNotFoundError(f"Local file not found: {path}")
+            logger.info(f"Loading data from local file: {path}")
+        else:
+            logger.info(f"Loading data from Hugging Face dataset: {path}")
+
+        # Load the data with progress tracking
+        try:
+            logger.info("Loading raw data...")
+            data = load_data(path, is_local, **dataset_kwargs)
+
+            # Apply sample limit if specified
+            if max_samples and hasattr(data, "__len__") and len(data) > max_samples:
+                logger.info(
+                    f"Limiting dataset to {max_samples} samples (original: {len(data)})"
+                )
+                if hasattr(data, "select"):
+                    # For HuggingFace datasets
+                    data = data.select(range(max_samples))
+                else:
+                    # For other iterable data
+                    data = list(data)[:max_samples]
+
+            data_size = len(data) if hasattr(data, "__len__") else "unknown"
+            logger.info(f"Successfully loaded {data_size} samples")
+
+        except Exception as e:
+            logger.error(f"Failed to load data from {path}: {e}")
+            logger.error(f"Dataset kwargs: {dataset_kwargs}")
+            raise RuntimeError(f"Data loading failed: {str(e)}") from e
+
+        # Convert to conversations with progress tracking
+        try:
+            logger.info("Converting data to conversation format...")
+            conversations = convert_to_conversations(data, column_mapping)
+            logger.info(f"Created {len(conversations)} conversations")
+
+            # Validate conversations
+            if not conversations:
+                raise ValueError("No conversations were created from the data")
+
+            # Log sample conversation for debugging
+            if conversations and logger.isEnabledFor(logging.DEBUG):
+                sample_conv = conversations[0]
+                logger.debug(
+                    f"Sample conversation: {sample_conv.messages[:2] if hasattr(sample_conv, 'messages') else sample_conv}"
+                )
+
+        except Exception as e:
+            logger.error(f"Failed to convert data to conversations: {e}")
+            logger.error(f"Column mapping: {column_mapping}")
+            raise RuntimeError(f"Conversation conversion failed: {str(e)}") from e
+
+        # Format conversations using specified formatter
+        try:
+            logger.info(f"Formatting conversations using {formatter_type} formatter...")
+            formatter = get_formatter(formatter_type)
+
+            # Add progress bar for large datasets
+            if len(conversations) > 1000:
+                logger.info("Processing large dataset with progress tracking...")
+                from tqdm import tqdm
+
+                formatted_data = []
+                for conv in tqdm(conversations, desc="Formatting conversations"):
+                    formatted_data.append(formatter.format_conversation(conv))
+            else:
+                formatted_data = formatter.format_data(conversations)
+
+            logger.info(f"Successfully formatted {len(formatted_data)} samples")
+
+            # Validate formatted data
+            if not formatted_data:
+                raise ValueError("No formatted data was produced")
+
+            # Log sample formatted data for debugging
+            if formatted_data and logger.isEnabledFor(logging.DEBUG):
+                logger.debug(f"Sample formatted data: {formatted_data[0]}")
+
+        except Exception as e:
+            logger.error(f"Failed to format conversations: {e}")
+            logger.error(f"Formatter type: {formatter_type}")
+            raise RuntimeError(f"Data formatting failed: {str(e)}") from e
+
+    else:
+        # Load pre-formatted data
+        logger.info("Loading pre-formatted data...")
+        try:
+            import os
+            from pathlib import Path
+
+            file_path = Path(path)
+            if not file_path.exists():
+                raise FileNotFoundError(f"Pre-formatted file not found: {path}")
+
+            # Support different file formats
+            if file_path.suffix.lower() == ".json":
+                with open(path, "r", encoding="utf-8") as f:
+                    formatted_data = json.load(f)
+            elif file_path.suffix.lower() in [".jsonl", ".ndjson"]:
+                # Support JSONL format
+                formatted_data = []
+                with open(path, "r", encoding="utf-8") as f:
+                    for line_num, line in enumerate(f, 1):
+                        try:
+                            if line.strip():  # Skip empty lines
+                                formatted_data.append(json.loads(line))
+                        except json.JSONDecodeError as e:
+                            logger.warning(
+                                f"Skipping invalid JSON on line {line_num}: {e}"
+                            )
+            else:
+                raise ValueError(
+                    f"Unsupported file format: {file_path.suffix}. Supported: .json, .jsonl, .ndjson"
+                )
+
+            # Apply sample limit if specified
+            if max_samples and len(formatted_data) > max_samples:
+                logger.info(
+                    f"Limiting pre-formatted data to {max_samples} samples (original: {len(formatted_data)})"
+                )
+                formatted_data = formatted_data[:max_samples]
+
+            logger.info(
+                f"Successfully loaded {len(formatted_data)} pre-formatted samples"
+            )
+
+        except FileNotFoundError:
+            raise
+        except Exception as e:
+            logger.error(f"Failed to load pre-formatted data from {path}: {e}")
+            raise RuntimeError(f"Pre-formatted data loading failed: {str(e)}") from e
+
+    # Final validation
+    if not formatted_data:
+        raise ValueError("No data was loaded. Check your path and configuration.")
+
+    # Validate data structure
+    if not isinstance(formatted_data, list):
+        raise ValueError("Formatted data must be a list")
+
+    # Basic structure validation for first sample
+    if formatted_data:
+        sample = formatted_data[0]
+        if not isinstance(sample, dict):
+            logger.warning(
+                "First sample is not a dictionary. This might cause issues during inference."
+            )
+
+    logger.info(
+        f"Data loading completed successfully. Total samples: {len(formatted_data)}"
+    )
+    return formatted_data
+
+
 def vllm_call_batch(
     llm: LLM, data: List[Dict], sampling_params: SamplingParams
 ) -> List[str]:
@@ -54,9 +273,8 @@ def vllm_call_batch(
 
 
 def run_vllm_batch_inference_on_dataset(
-    inference_data_path: str,
+    inference_data: List[Dict],
     model_path: str,
-    is_local: bool = False,
     temperature: float = 0.0,
     top_p: float = 1.0,
     max_tokens: int = 100,
@@ -64,47 +282,25 @@ def run_vllm_batch_inference_on_dataset(
     structured: bool = False,
     gpu_memory_utilization: float = 0.95,
     max_model_len: int = 4096,
-    dataset_kwargs: Optional[Dict[str, Any]] = None,
-    column_mapping: Optional[Dict[str, str]] = None,
 ) -> Dict[str, Any]:
     """
     Run inference on evaluation data using a vLLM server.
 
     Args:
-        eval_data_path: Path to the evaluation data
-        server_url: URL of the vLLM server
-        is_local: Whether the data is stored locally
+        inference_data: Inference data to run inference on
+        model_path: Path to the vLLM model
         temperature: Temperature for sampling
         top_p: Top-p for sampling
         max_tokens: Maximum number of tokens to generate
         seed: Random seed for reproducibility
-        dataset_kwargs: Additional arguments to pass to the load_dataset function
-        column_mapping: Mapping of column names
+        structured: Whether to use structured output
+        gpu_memory_utilization: GPU memory utilization for the vLLM server
+        max_model_len: Maximum model length for the vLLM server
 
     Returns:
         List of responses from the vLLM server
     """
 
-    logger.info(f"Loading Inference data from {inference_data_path}")
-
-    # Load the evaluation data
-    if dataset_kwargs is None:
-        dataset_kwargs = {}
-
-    try:
-        data = load_data(inference_data_path, is_local, **dataset_kwargs)
-    except Exception as e:
-        logger.error(f"Failed to load data from {inference_data_path}: {e}")
-        raise
-
-    # Convert the data to conversations
-    logger.info("Converting data to conversation format")
-    conversations = convert_to_conversations(data, column_mapping)
-
-    # Convert the conversations to vLLM format
-    vllm_formatter = get_formatter("vllm")
-    formatted_data = vllm_formatter.format_data(conversations)
-
     # Create an LLM
     logger.info(f"Initializing vLLM with model: {model_path}")
     try:
@@ -139,8 +335,8 @@ def run_vllm_batch_inference_on_dataset(
         )
 
     # Run inference on the formatted data
-    logger.info(f"Running inference on {len(formatted_data)} examples")
-    outputs = vllm_call_batch(llm, formatted_data, sampling_params)
+    logger.info(f"Running inference on {len(inference_data)} examples")
+    outputs = vllm_call_batch(llm, inference_data, sampling_params)
 
     # Return a dictionary containing the outputs and metadata
     return {
@@ -211,12 +407,8 @@ def main():
     if model_path is None:
         raise ValueError("model_path must be specified in the config")
 
-    # Get data path from parameters or config
-    inference_data_path = inference_config.get("inference_data", None)
-    if inference_data_path is None:
-        raise ValueError("Inference data path must be specified in config")
     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
-    results_path = f"{output_dir}/inference_results.json"
+    results_path = f"{output_dir}/data/inference_results.json"
 
     # Performance parameters
     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
@@ -232,15 +424,14 @@ def main():
     seed = inference_config.get("seed")
     structured = inference_config.get("structured", False)
 
-    # Data parameters
-    is_local = formatter_config.get("is_local", False)
-    dataset_kwargs = formatter_config.get("dataset_kwargs", {})
-    column_mapping = formatter_config.get("column_mapping", {})
+    # Inference Data parameters
+    inference_data_kwargs = inference_config.get("inference_data_kwargs", {})
+
+    inference_data = load_inference_data(inference_data_kwargs, formatter_config)
 
     results = run_vllm_batch_inference_on_dataset(
-        inference_data_path,
+        inference_data,
         model_path,
-        is_local,
         temperature,
         top_p,
         max_tokens,
@@ -248,8 +439,6 @@ def main():
         structured,
         gpu_memory_utilization,
         max_model_len,
-        dataset_kwargs,
-        column_mapping,
     )
 
     save_inference_results(results, results_path)

+ 7 - 0
src/finetune_pipeline/requirements.txt

@@ -0,0 +1,7 @@
+datasets==4.0.0
+PyYAML==6.0.2
+PyYAML==6.0.2
+Requests==2.32.4
+torchtune==0.0.0
+tqdm==4.67.1
+vllm==0.10.0

+ 52 - 48
src/finetune_pipeline/run_pipeline.py

@@ -38,8 +38,8 @@ from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
 
 from finetune_pipeline.inference.run_inference import (
     run_vllm_batch_inference_on_dataset,
+    save_inference_results,
 )
-from finetune_pipeline.inference.save_inference_results import save_inference_results
 from finetune_pipeline.inference.start_vllm_server import start_vllm_server
 
 
@@ -118,7 +118,7 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
     # Run the fine-tuning
     try:
         logger.info(f"Starting fine-tuning with data from {train_data_path}")
-        run_torch_tune(finetuning_config, args=args)
+        run_torch_tune(config, args=args)
 
         # Get the path to the latest chekpoint of the fine-tuned model
         model_output_dir = finetuning_config.get("output_dir", config.get("output_dir"))
@@ -201,8 +201,7 @@ def run_inference(
 
     Args:
         config_path: Path to the configuration file
-        server_url: URL of the vLLM server
-        formatted_data_paths: Paths to the formatted data
+        formatted_data_paths: Paths to the formatted data (for compatibility)
 
     Returns:
         Path to the inference results
@@ -217,21 +216,19 @@ def run_inference(
     # Model parameters
     if model_path == "":
         model_path = inference_config.get("model_path", None)
-        if model_path is None:
-            raise ValueError("model_path must be specified in the config")
-
-    # Get data path from parameters or config
-    inference_data_path = inference_config.get("inference_data", None)
-    if inference_data_path is None:
-        raise ValueError("Inference data path must be specified in config")
-    output_path = f"{output_dir}/inference_results.json"
+    if model_path is None:
+        raise ValueError("model_path must be specified in the config")
+
+    # Get inference data configuration
+    inference_data_kwargs = inference_config.get("inference_data_kwargs", {})
+    if not inference_data_kwargs or not inference_data_kwargs.get("data_path"):
+        raise ValueError(
+            "inference_data_kwargs with data_path must be specified in config"
+        )
 
     # Performance parameters
     gpu_memory_utilization = inference_config.get("gpu_memory_utilization", 0.95)
     max_model_len = inference_config.get("max_model_len", 512)
-    tensor_parallel_size = inference_config.get("tensor_parallel_size", 1)
-    dtype = inference_config.get("dtype", "auto")
-    trust_remote_code = inference_config.get("trust_remote_code", False)
 
     # Generation parameters
     max_tokens = inference_config.get("max_tokens", 100)
@@ -240,27 +237,34 @@ def run_inference(
     seed = inference_config.get("seed")
     structured = inference_config.get("structured", False)
 
-    # Data parameters
-    is_local = formatter_config.get("is_local", False)
-    dataset_kwargs = formatter_config.get("dataset_kwargs", {})
-    column_mapping = formatter_config.get("column_mapping", {})
+    # Load inference data using the new function
+    try:
+        logger.info("Loading inference data...")
+        from finetune_pipeline.inference.run_inference import load_inference_data
+
+        inference_data = load_inference_data(
+            inference_data_kwargs=inference_data_kwargs,
+            formatter_config=formatter_config,
+        )
+        logger.info(f"Loaded {len(inference_data)} samples for inference")
+
+    except Exception as e:
+        logger.error(f"Failed to load inference data: {e}")
+        raise
 
     # Run inference
     try:
-        logger.info(f"Running inference on {inference_data_path}")
+        logger.info(f"Running inference with model: {model_path}")
         results = run_vllm_batch_inference_on_dataset(
-            inference_data_path,
-            model_path,
-            is_local,
-            temperature,
-            top_p,
-            max_tokens,
-            seed,
-            structured,
-            gpu_memory_utilization,
-            max_model_len,
-            dataset_kwargs,
-            column_mapping,
+            inference_data=inference_data,
+            model_path=model_path,
+            temperature=temperature,
+            top_p=top_p,
+            max_tokens=max_tokens,
+            seed=seed,
+            structured=structured,
+            gpu_memory_utilization=gpu_memory_utilization,
+            max_model_len=max_model_len,
         )
 
         # Save the results
@@ -361,7 +365,22 @@ def run_pipeline(
         output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
         model_path = os.path.join(output_dir, "finetuned_model")
 
-    # # Step 3: Start vLLM Server
+    # Step 3: Inference
+    if not skip_inference:
+        try:
+            results_path = run_inference(config_path, formatted_data_paths, model_path)
+            logger.info(
+                f"Pipeline completed successfully. Results saved to {results_path}"
+            )
+        except Exception as e:
+            logger.error(f"Pipeline failed at inference step: {e}")
+            sys.exit(1)
+    else:
+        logger.info("Skipping inference step")
+
+    logger.info("Pipeline execution complete")
+
+    # # Step 4: Start vLLM Server
     # server_url = ""
     # server_process = None
     # if not skip_server:
@@ -379,21 +398,6 @@ def run_pipeline(
     #     port = inference_config.get("port", 8000)
     #     server_url = f"http://{host}:{port}/v1"
 
-    # Step 3: Inference
-    if not skip_inference:
-        try:
-            results_path = run_inference(config_path, formatted_data_paths, model_path)
-            logger.info(
-                f"Pipeline completed successfully. Results saved to {results_path}"
-            )
-        except Exception as e:
-            logger.error(f"Pipeline failed at inference step: {e}")
-            sys.exit(1)
-    else:
-        logger.info("Skipping inference step")
-
-    logger.info("Pipeline execution complete")
-
 
 def main():
     """Main function."""