瀏覽代碼

updated finetuning and added Readme

khare19yash 1 月之前
父節點
當前提交
8e1174932d

+ 414 - 0
src/finetune_pipeline/README.md

@@ -0,0 +1,414 @@
+# Finetune Pipeline
+
+A comprehensive end-to-end pipeline for fine-tuning large language models using TorchTune and running inference with vLLM. This pipeline provides a unified interface for data loading, model fine-tuning, and inference with support for various data formats and distributed training.
+
+## Features
+
+- **Data Loading & Formatting**: Support for multiple formats (TorchTune, vLLM, OpenAI)
+- **Flexible Fine-tuning**: LoRA and full fine-tuning with single-device and distributed training
+- **Inference**: High-performance inference using vLLM
+- **Configuration-driven**: YAML/JSON configuration files for reproducible experiments
+- **Modular Design**: Use individual components or run the full pipeline
+- **Multi-format Support**: Works with Hugging Face datasets and local data
+
+## Table of Contents
+
+- [Installation](#installation)
+- [Quick Start](#quick-start)
+- [Configuration](#configuration)
+- [Usage](#usage)
+- [Module Structure](#module-structure)
+- [API Reference](#api-reference)
+- [Examples](#examples)
+- [Contributing](#contributing)
+
+## Installation
+
+### Prerequisites
+
+```bash
+# Core dependencies
+pip install torch torchtune vllm datasets pyyaml tqdm
+
+# Optional dependencies for development
+pip install pytest black flake8
+```
+
+### Environment Setup
+
+```bash
+# Clone the repository
+git clone <repository-url>
+cd finetune_pipeline
+
+# Install in development mode
+pip install -e .
+```
+
+## Quick Start
+
+1. **Create a configuration file** (`config.yaml`):
+
+```yaml
+output_dir: "/path/to/output/"
+
+formatter:
+  type: "torchtune"
+  data_path: "dz-osamu/IU-Xray"
+  is_local: false
+  column_mapping:
+    input: "query"
+    output: "response"
+    image: null
+  dataset_kwargs:
+    split: "validation"
+
+finetuning:
+  model_path: "/path/to/Llama-3.1-8B-Instruct"
+  tokenizer_path: "/path/to/tokenizer.model"
+  strategy: "lora"
+  num_epochs: 1
+  batch_size: 1
+  torchtune_config: "llama3_1/8B_lora"
+  num_processes_per_node: 8
+  distributed: true
+
+inference:
+  model_path: "/path/to/model"
+  port: 8000
+  host: "0.0.0.0"
+  tensor_parallel_size: 1
+  max_model_len: 512
+  gpu_memory_utilization: 0.95
+  inference_data: "dz-osamu/IU-Xray"
+```
+
+2. **Run the full pipeline**:
+
+```bash
+python run_pipeline.py --config config.yaml
+```
+
+3. **Run individual steps**:
+
+```bash
+# Data loading only
+python run_pipeline.py --config config.yaml --only-data-loading
+
+# Fine-tuning only
+python run_pipeline.py --config config.yaml --only-finetuning
+
+# Inference only
+python run_pipeline.py --config config.yaml --only-inference
+```
+
+## Configuration
+
+The pipeline uses YAML or JSON configuration files with the following main sections:
+
+### Global Configuration
+
+- `output_dir`: Base directory for all outputs
+
+### Data Formatting (`formatter`)
+
+- `type`: Formatter type (`"torchtune"`, `"vllm"`, `"openai"`)
+- `data_path`: Path to dataset (Hugging Face ID or local path)
+- `is_local`: Whether data is stored locally
+- `column_mapping`: Map dataset columns to standard fields
+- `dataset_kwargs`: Additional arguments for data loading
+
+### Fine-tuning (`finetuning`)
+
+- `model_path`: Path to base model
+- `tokenizer_path`: Path to tokenizer
+- `strategy`: Training strategy (`"lora"` or `"fft"`)
+- `num_epochs`: Number of training epochs
+- `batch_size`: Batch size per device
+- `torchtune_config`: TorchTune configuration name
+- `distributed`: Enable distributed training
+- `num_processes_per_node`: Number of processes for distributed training
+
+### Inference (`inference`)
+
+- `model_path`: Path to model for inference
+- `port`: Server port
+- `host`: Server host
+- `tensor_parallel_size`: Number of GPUs for tensor parallelism
+- `max_model_len`: Maximum sequence length
+- `gpu_memory_utilization`: GPU memory utilization fraction
+- `inference_data`: Dataset for inference
+
+## Usage
+
+### Command Line Interface
+
+The main pipeline script provides several options:
+
+```bash
+# Full pipeline
+python run_pipeline.py --config config.yaml
+
+# Skip specific steps
+python run_pipeline.py --config config.yaml --skip-finetuning --skip-inference
+
+# Run only specific steps
+python run_pipeline.py --config config.yaml --only-data-loading
+python run_pipeline.py --config config.yaml --only-finetuning
+python run_pipeline.py --config config.yaml --only-inference
+```
+
+### Individual Components
+
+#### Data Loading and Formatting
+
+```python
+from finetune_pipeline.data.data_loader import load_and_format_data, read_config
+
+config = read_config("config.yaml")
+formatter_config = config.get("formatter", {})
+output_dir = config.get("output_dir", "/tmp/")
+
+formatted_data_paths, conversation_data_paths = load_and_format_data(
+    formatter_config, output_dir
+)
+```
+
+#### Fine-tuning
+
+```python
+from finetune_pipeline.finetuning.run_finetuning import run_torch_tune
+
+config = read_config("config.yaml")
+finetuning_config = config.get("finetuning", {})
+
+run_torch_tune(finetuning_config, config)
+```
+
+#### Inference
+
+```python
+from finetune_pipeline.inference.run_inference import run_vllm_batch_inference_on_dataset
+
+results = run_vllm_batch_inference_on_dataset(
+    data_path="dataset_name",
+    model_path="/path/to/model",
+    is_local=False,
+    temperature=0.0,
+    max_tokens=100,
+    # ... other parameters
+)
+```
+
+## Module Structure
+
+```
+finetune_pipeline/
+├── __init__.py
+├── config.yaml                 # Example configuration
+├── run_pipeline.py             # Main pipeline orchestrator
+│
+├── data/                       # Data loading and formatting
+│   ├── __init__.py
+│   ├── data_loader.py          # Dataset loading utilities
+│   ├── formatter.py            # Data format converters
+│   └── augmentation.py         # Data augmentation utilities
+│
+├── finetuning/                 # Fine-tuning components
+│   ├── __init__.py
+│   ├── run_finetuning.py       # TorchTune fine-tuning script
+│   └── custom_sft_dataset.py   # Custom dataset for supervised fine-tuning
+│
+├── inference/                  # Inference components
+│   ├── __init__.py
+│   ├── run_inference.py        # Batch inference utilities
+│   ├── start_vllm_server.py    # vLLM server management
+│   └── save_inference_results.py # Result saving utilities
+│
+└── tests/                      # Test suite
+    ├── __init__.py
+    ├── test_formatter.py
+    └── test_finetuning.py
+```
+
+## API Reference
+
+### Data Components
+
+#### `Formatter` (Abstract Base Class)
+- `format_data(data)`: Format list of conversations
+- `format_conversation(conversation)`: Format single conversation
+- `format_message(message)`: Format single message
+
+#### `TorchtuneFormatter`
+Formats data for TorchTune training with message-based structure.
+
+#### `vLLMFormatter`
+Formats data for vLLM inference with optimized structure.
+
+#### `OpenAIFormatter`
+Formats data compatible with OpenAI API format.
+
+### Fine-tuning Components
+
+#### `run_torch_tune(training_config, config, args=None)`
+Execute TorchTune training with configuration-based parameters.
+
+**Parameters:**
+- `training_config`: Training configuration section
+- `config`: Full configuration dictionary
+- `args`: Additional command-line arguments
+
+### Inference Components
+
+#### `run_vllm_batch_inference_on_dataset(...)`
+Run batch inference on a dataset using vLLM.
+
+**Key Parameters:**
+- `data_path`: Path to dataset
+- `model_path`: Path to model
+- `temperature`: Sampling temperature
+- `max_tokens`: Maximum tokens to generate
+- `gpu_memory_utilization`: GPU memory usage fraction
+
+## Examples
+
+### Example 1: Medical QA Fine-tuning
+
+```yaml
+# config_medical_qa.yaml
+output_dir: "/workspace/medical_qa_output/"
+
+formatter:
+  type: "torchtune"
+  data_path: "medical-qa-dataset"
+  column_mapping:
+    input: "question"
+    output: "answer"
+
+finetuning:
+  model_path: "/models/Llama-3.1-8B-Instruct"
+  strategy: "lora"
+  num_epochs: 3
+  distributed: true
+  num_processes_per_node: 4
+
+inference:
+  model_path: "/workspace/medical_qa_output/"
+  max_model_len: 1024
+  temperature: 0.1
+```
+
+```bash
+python run_pipeline.py --config config_medical_qa.yaml
+```
+
+### Example 2: Multi-modal Fine-tuning
+
+```yaml
+# config_multimodal.yaml
+formatter:
+  type: "torchtune"
+  data_path: "multimodal-dataset"
+  column_mapping:
+    input: "query"
+    output: "response"
+    image: "image_path"
+
+finetuning:
+  strategy: "lora"
+  torchtune_config: "llama3_2_vision/11B_lora"
+  # ... other config
+```
+
+### Example 3: Distributed Training
+
+```yaml
+# config_distributed.yaml
+finetuning:
+  distributed: true
+  num_processes_per_node: 8
+  strategy: "lora"
+  # ... other config
+```
+
+```bash
+# Run with distributed training
+python run_pipeline.py --config config_distributed.yaml
+```
+
+### Example 4: Custom Dataset Format
+
+```python
+# Custom data loading
+from finetune_pipeline.data.formatter import TorchtuneFormatter, Conversation, Message
+
+# Create conversations
+conversations = []
+conversation = Conversation()
+conversation.add_message(Message(role="user", content="What is AI?"))
+conversation.add_message(Message(role="assistant", content="AI is..."))
+conversations.append(conversation)
+
+# Format for training
+formatter = TorchtuneFormatter()
+formatted_data = formatter.format_data(conversations)
+```
+
+## Advanced Usage
+
+### Custom Arguments
+
+Pass additional arguments to TorchTune:
+
+```bash
+python finetuning/run_finetuning.py \
+  --config config.yaml \
+  --kwargs "dataset.train_on_input=True optimizer.lr=1e-5"
+```
+
+### Pipeline Control
+
+Fine-grained control over pipeline execution:
+
+```bash
+# Skip certain steps
+python run_pipeline.py --config config.yaml --skip-finetuning --skip-server
+
+# Run only data loading and fine-tuning
+python run_pipeline.py --config config.yaml --skip-inference --skip-server
+```
+
+### Configuration Validation
+
+The pipeline automatically validates configuration parameters and provides helpful error messages for missing or invalid settings.
+
+## Troubleshooting
+
+### Common Issues
+
+1. **CUDA Out of Memory**: Reduce `batch_size`, `max_model_len`, or `gpu_memory_utilization`
+2. **Import Errors**: Ensure all dependencies are installed (`torch`, `torchtune`, `vllm`)
+3. **Configuration Errors**: Check YAML syntax and required fields
+4. **Distributed Training Issues**: Verify `num_processes_per_node` matches available GPUs
+
+### Debug Mode
+
+Enable verbose logging:
+
+```python
+import logging
+logging.basicConfig(level=logging.DEBUG)
+```
+
+## Contributing
+
+1. Fork the repository
+2. Create a feature branch
+3. Add tests for new functionality
+4. Run the test suite: `pytest tests/`
+5. Submit a pull request
+
+## License
+
+This project is licensed under the MIT License - see the LICENSE file for details.

+ 6 - 2
src/finetune_pipeline/config.yaml

@@ -90,11 +90,15 @@ formatter:
 
 # Training configuration
 finetuning:
+  model_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct" # Path to the model checkpoint
+  tokenizer_path: "/home/yashkhare/workspace/Llama-3.1-8B-Instruct/original/tokenizer.model" # Path to the tokenizer
+  output_dir: ${output_dir}/model_outputs  # Directory to store checkpoints
+  log_dir: ${output_dir}/logs  # Directory to store logs
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
-  batch_size: 1                 # Batch size per device for training
+  batch_size: 4                 # Batch size per device for training
   torchtune_config: "llama3_1/8B_lora"             # TorchTune-specific configuration
-  num_processes_per_node: 1             # TorchTune-specific configuration
+  num_processes_per_node: 8             # TorchTune-specific configuration
   distributed: true             # Whether to use distributed training
 
 

+ 67 - 27
src/finetune_pipeline/finetuning/run_finetuning.py

@@ -5,6 +5,7 @@ Reads parameters from a config file and runs the torch tune command.
 """
 
 import argparse
+import json
 import logging
 import subprocess
 import sys
@@ -63,67 +64,64 @@ def read_config(config_path: str) -> Dict:
     return config
 
 
-def run_torch_tune(training_config: Dict, args=None):
+def run_torch_tune(config: Dict, args=None):
     """
     Run torch tune command with parameters from config file.
 
     Args:
-        config_path: Path to the configuration file
+        config: Full configuration dictionary
         args: Command line arguments that may include additional kwargs to pass to the command
     """
-    # # Read the configuration
-    # config = read_config(config_path)
 
-    # Extract parameters from config
-    # training_config = config.get("finetuning", {})
+    finetuning_config = config.get("finetuning", {})
 
     # Initialize base_cmd to avoid "possibly unbound" error
     base_cmd = []
 
     # Determine the command based on configuration
-    if training_config.get("distributed"):
-        if training_config.get("strategy") == "lora":
+    if finetuning_config.get("distributed"):
+        if finetuning_config.get("strategy") == "lora":
             base_cmd = [
                 "tune",
                 "run",
                 "--nproc_per_node",
-                str(training_config.get("num_processes_per_node", 1)),
+                str(finetuning_config.get("num_processes_per_node", 1)),
                 "lora_finetune_distributed",
                 "--config",
-                training_config.get("torchtune_config"),
+                finetuning_config.get("torchtune_config"),
             ]
-        elif training_config.get("strategy") == "fft":
+        elif finetuning_config.get("strategy") == "fft":
             base_cmd = [
                 "tune",
                 "run",
                 "--nproc_per_node",
-                str(training_config.get("num_processes_per_node", 1)),
+                str(finetuning_config.get("num_processes_per_node", 1)),
                 "full_finetune_distributed",
                 "--config",
-                training_config.get("torchtune_config"),
+                finetuning_config.get("torchtune_config"),
             ]
         else:
-            raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+            raise ValueError(f"Invalid strategy: {finetuning_config.get('strategy')}")
 
     else:
-        if training_config.get("strategy") == "lora":
+        if finetuning_config.get("strategy") == "lora":
             base_cmd = [
                 "tune",
                 "run",
                 "lora_finetune_single_device",
                 "--config",
-                training_config.get("torchtune_config"),
+                finetuning_config.get("torchtune_config"),
             ]
-        elif training_config.get("strategy") == "fft":
+        elif finetuning_config.get("strategy") == "fft":
             base_cmd = [
                 "tune",
                 "run",
                 "full_finetune_single_device",
                 "--config",
-                training_config.get("torchtune_config"),
+                finetuning_config.get("torchtune_config"),
             ]
         else:
-            raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+            raise ValueError(f"Invalid strategy: {finetuning_config.get('strategy')}")
 
     # Check if we have a valid command
     if not base_cmd:
@@ -131,12 +129,54 @@ def run_torch_tune(training_config: Dict, args=None):
             "Could not determine the appropriate command based on the configuration"
         )
 
+    # Add configuration-based arguments
+    config_args = []
+
+    # Add output_dir
+    output_dir = config.get("output_dir")
+    if output_dir:
+        config_args.extend(["output_dir=" + output_dir])
+
+    # Add epochs
+    num_epochs = finetuning_config.get("num_epochs", 1)
+    if num_epochs:
+        config_args.extend(["epochs=" + str(num_epochs)])
+
+    # Add batch_size
+    batch_size = finetuning_config.get("batch_size", 1)
+    if batch_size:
+        config_args.extend(["batch_size=" + str(batch_size)])
+
+    # Add checkpointer.checkpoint_dir (use output_dir if checkpoint_dir not specified) (Model Path)
+    model_path = finetuning_config.get("model_path")
+    if model_path:
+        config_args.extend(["checkpointer.checkpoint_dir=" + model_path])
+
+    # Add checkpointer.output_dir (use config output_dir if output_dir not specified) (Model Output Path)
+    model_output_dir = finetuning_config.get("output_dir", config.get("output_dir"))
+    if model_output_dir:
+        config_args.extend(["checkpointer.output_dir=" + model_output_dir])
+
+    # Add tokenizer.path from training config
+    if finetuning_config.get("tokenizer_path"):
+        config_args.extend(["tokenizer.path=" + finetuning_config["tokenizer_path"]])
+
+    # Add log_dir (use config output_dir if log_dir not specified)
+    log_dir = finetuning_config.get("log_dir", config.get("output_dir"))
+    if log_dir:
+        config_args.extend(["metric_logger.log_dir=" + log_dir])
+
+    # Add the config arguments to base_cmd
+    if config_args:
+        base_cmd.extend(config_args)
+        logger.info(f"Added config arguments: {config_args}")
+
     # Add any additional kwargs if provided
-    # if args and args.kwargs:
-    #     # Split the kwargs string by spaces to get individual key=value pairs
-    #     kwargs_list = args.kwargs.split()
-    #     base_cmd.extend(kwargs_list)
-    #     logger.info(f"Added additional kwargs: {kwargs_list}")
+    if args and args.kwargs:
+        # Split the kwargs string by spaces to get individual key=value pairs
+        kwargs_list = args.kwargs.split()
+        base_cmd.extend(kwargs_list)
+        logger.info(f"Added additional kwargs: {kwargs_list}")
 
     # Log the command
     logger.info(f"Running command: {' '.join(base_cmd)}")
@@ -170,10 +210,10 @@ def main():
     args = parser.parse_args()
 
     config = read_config(args.config)
-    finetuning_config = config.get("finetuning", {})
+    # finetuning_config = config.get("finetuning", {})
 
-    run_torch_tune(finetuning_config, args=args)
+    run_torch_tune(config, args=args)
 
 
 if __name__ == "__main__":
-    main()
+    main()

+ 35 - 30
src/finetune_pipeline/run_pipeline.py

@@ -93,7 +93,6 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
     # Read the configuration
     config = read_config(config_path)
     finetuning_config = config.get("finetuning", {})
-    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
 
     # Get the path to the formatted data for the train split
     train_data_path = None
@@ -121,11 +120,14 @@ def run_finetuning(config_path: str, formatted_data_paths: List[str]) -> str:
         logger.info(f"Starting fine-tuning with data from {train_data_path}")
         run_torch_tune(finetuning_config, args=args)
 
-        # Get the path to the fine-tuned model
-        # This is a simplification; in reality, you'd need to extract the actual path from the fine-tuning output
-        model_path = os.path.join(output_dir, "finetuned_model")
-        logger.info(f"Fine-tuning complete. Model saved to {model_path}")
-        return model_path
+        # Get the path to the latest chekpoint of the fine-tuned model
+        model_output_dir = finetuning_config.get("output_dir", config.get("output_dir"))
+        epochs = finetuning_config.get("epochs", 1)
+        checkpoint_path = os.path.join(model_output_dir, f"epochs_{epochs-1}")
+        logger.info(
+            f"Fine-tuning complete. Latest checkpoint saved to {checkpoint_path}"
+        )
+        return checkpoint_path
     except Exception as e:
         logger.error(f"Error during fine-tuning: {e}")
         raise
@@ -191,7 +193,9 @@ def run_vllm_server(config_path: str, model_path: str) -> str:
         raise
 
 
-def run_inference(config_path: str, formatted_data_paths: List[str]) -> str:
+def run_inference(
+    config_path: str, formatted_data_paths: List[str], model_path: str = ""
+) -> str:
     """
     Run inference on the fine-tuned model.
 
@@ -211,9 +215,10 @@ def run_inference(config_path: str, formatted_data_paths: List[str]) -> str:
     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
 
     # Model parameters
-    model_path = inference_config.get("model_path", None)
-    if model_path is None:
-        raise ValueError("model_path must be specified in the config")
+    if model_path == "":
+        model_path = inference_config.get("model_path", None)
+        if model_path is None:
+            raise ValueError("model_path must be specified in the config")
 
     # Get data path from parameters or config
     inference_data_path = inference_config.get("inference_data", None)
@@ -356,28 +361,28 @@ def run_pipeline(
         output_dir = config.get("output_dir", "/tmp/finetune-pipeline/")
         model_path = os.path.join(output_dir, "finetuned_model")
 
-    # Step 3: Start vLLM Server
-    server_url = ""
-    server_process = None
-    if not skip_server:
-        try:
-            server_url = run_vllm_server(config_path, model_path)
-        except Exception as e:
-            logger.error(f"Pipeline failed at vLLM server step: {e}")
-            sys.exit(1)
-    else:
-        logger.info("Skipping vLLM server step")
-        # Try to infer the server URL from the config
-        config = read_config(config_path)
-        inference_config = config.get("inference", {})
-        host = inference_config.get("host", "0.0.0.0")
-        port = inference_config.get("port", 8000)
-        server_url = f"http://{host}:{port}/v1"
-
-    # Step 4: Inference
+    # # Step 3: Start vLLM Server
+    # server_url = ""
+    # server_process = None
+    # if not skip_server:
+    #     try:
+    #         server_url = run_vllm_server(config_path, model_path)
+    #     except Exception as e:
+    #         logger.error(f"Pipeline failed at vLLM server step: {e}")
+    #         sys.exit(1)
+    # else:
+    #     logger.info("Skipping vLLM server step")
+    #     # Try to infer the server URL from the config
+    #     config = read_config(config_path)
+    #     inference_config = config.get("inference", {})
+    #     host = inference_config.get("host", "0.0.0.0")
+    #     port = inference_config.get("port", 8000)
+    #     server_url = f"http://{host}:{port}/v1"
+
+    # Step 3: Inference
     if not skip_inference:
         try:
-            results_path = run_inference(config_path, formatted_data_paths)
+            results_path = run_inference(config_path, formatted_data_paths, model_path)
             logger.info(
                 f"Pipeline completed successfully. Results saved to {results_path}"
             )