Browse Source

config updated

khare19yash 1 month ago
parent
commit
f4c99d00f4

+ 20 - 93
src/finetune_pipeline/config.yaml

@@ -1,100 +1,27 @@
-# # Configuration for data loading, formatting, and fine-tuning
-
-
-# output_dir: "/tmp/finetune_pipeline/outputs/"  # Directory to store output files
-
-# data:
-#   data_path: "dz-osamu/IU-Xray"  # Path to the dataset to format (either a Hugging Face dataset ID or a local path)
-#   is_local: false                  # Whether the data is stored locally
-#   # Maps custom column names to standard field names
-#   column_mapping:
-#     input: "query"             # Field containing the input text
-#     output: "response"              # Field containing the output text
-#     image: "image"           # Field containing the image path (optional)
-#   # Additional arguments to pass to the load_dataset function
-#   # dataset_kwargs:
-#   #   split: "train"                # Dataset split to load
-#   #   # Add any other dataset-specific arguments here
-
-
-# # Formatter configuration
-# formatter:
-#   type: "vllm"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
-
-
-# # # Something like this in the torchtune config
-# # dataset:
-# #   _component_: torchtune.datasets.CustomSFTDataset
-# #   packed: False
-# #   split: train
-# # seed: null
-# # shuffle: True
-
-
-# # Training configuration
-# finetuning:
-#   strategy: "lora"               # Training strategy ('fft' or 'lora')
-#   num_epochs: 1                 # Number of training epochs
-#   batch_size: 1                 # Batch size per device for training
-#   torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
-#   num_processes_per_node: 8             # TorchTune-specific configuration
-#   distributed: true             # Whether to use distributed training
-
-
-# # vLLM Inference configuration
-# inference:
-#   # Model configuration
-#   model_path: "/home/ubuntu/yash-workspace/medgemma-4b-it" # Path to the model checkpoint
-#   quantization: null            # Quantization method (awq, gptq, squeezellm)
-
-#   # Server configuration
-#   port: 8000                    # Port to run the server on
-#   host: "0.0.0.0"               # Host to run the server on
-
-#   # Performance configuration
-#   tensor_parallel_size: 1       # Number of GPUs to use for tensor parallelism
-#   max_model_len: 32           # Maximum sequence length
-#   max_num_seqs: 1              # Maximum number of sequences
-#   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
-#   enforce_eager: false          # Enforce eager execution
-
-#   eval_data: "your/eval/dataset/path" # Path to the evaluation dataset (optional)
-
-#   # Additional vLLM parameters (optional)
-#   # swap_space: 4               # Size of CPU swap space in GiB
-#   # block_size: 16              # Size of blocks used in the KV cache
-#   # disable_log_stats: true     # Disable logging of stats
-#   # disable_log_requests: false # Disable logging of requests
-
-
-
 # Configuration for data loading, formatting, and fine-tuning
+output_dir: "/tmp/finetuning-pipeline/llama3_2_vision/"  # Directory to store output files
 
-
-output_dir: "/home/yashkhare/workspace/finetuning-pipeline/"  # Directory to store output files
-
-# Formatter configuration
-formatter:
-  type: "torchtune"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
-  data_path: "dz-osamu/IU-Xray"  # Path to the dataset to format (either a Hugging Face dataset ID or a local path)
-  is_local: false                  # Whether the data is stored locally
-  # Maps custom column names to standard field names
+data:
+  data_path: "data/path"  # Path to the dataset to load
+  is_local: true           # Whether the data is stored locally
+  formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or
+  system_prompt: "You are a helpful assisstant"  # System prompt to use for the dataset
   column_mapping:
-    input: "query"             # Field containing the input text
-    output: "response"              # Field containing the output text
-    image: "images"           # Field containing the image path (optional)
-
+    input: "instruction"             # Field containing the input text
+    output: "output"              # Field containing the output text
+    image: "image"           # Field containing the image path (optional)
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
-    split: "train"                # Dataset split to load
+    split: "validation"                # Dataset split to load
+    shuffle: false                 # Whether to shuffle the dataset
 
 # Training configuration
 finetuning:
   #formatter_type: "torchtune"            # Type of formatter to use ('torchtune', 'vllm', or 'openai')
-  model_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct" # Path to the model checkpoint
-  tokenizer_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct/original/tokenizer.model" # Path to the tokenizer
-  output_dir: /home/yashkhare/workspace/finetuning-pipeline/model_outputs  # Directory to store checkpoints
-  log_dir: /home/yashkhare/workspace/finetuning-pipeline/logs  # Directory to store logs
+  model_path: "path/to/model" # Path to the model checkpoint
+  tokenizer_path: "path/to/tokenizer" # Path to the tokenizer
+  output_dir: /tmp/finetuning-pipeline/model_outputs  # Directory to store checkpoints
+  log_dir: /tmp/finetuning-pipeline/logs  # Directory to store logs
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
   max_steps_per_epoch: null
@@ -107,7 +34,7 @@ finetuning:
 # vLLM Inference configuration
 inference:
   # Model configuration
-  model_path: "/home/yashkhare/workspace/Llama-3.2-11B-Vision-Instruct" # Path to the model checkpoint
+  model_path: "path/to/model/checkpoint" # Path to the model checkpoint
   quantization: null            # Quantization method (awq, gptq, squeezellm)
   dtype: "auto"                 # Data type for model weights (half, float, bfloat16, auto)
   trust_remote_code: false      # Trust remote code when loading the model
@@ -124,12 +51,12 @@ inference:
   enforce_eager: false          # Enforce eager execution
 
   inference_data_kwargs:
-    data_path: "dz-osamu/IU-Xray"     # Path to the inference dataset
+    data_path: "inference/data/path"     # Path to the inference dataset
     split: "validation"               # Dataset split to load
     formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or 'openai')
-    format_data: true                # Whether to format the inference dataset
-    max_samples: 10                 # Maximum number of samples to load (null for all)
-    is_local: false                   # Whether the data is stored locally
+    format_data: false                # Whether to format the inference dataset
+    max_samples: null                 # Maximum number of samples to load (null for all)
+    is_local: true                   # Whether the data is stored locally
 
   # Additional vLLM parameters (optional)
   # swap_space: 4               # Size of CPU swap space in GiB

+ 34 - 32
src/finetune_pipeline/data/data_loader.py

@@ -6,6 +6,7 @@ import json
 import os
 from pathlib import Path
 from typing import Any, Dict, List, Optional, Union
+import pandas as pd
 
 # Try to import yaml, but don't fail if it's not available
 try:
@@ -17,7 +18,7 @@ except ImportError:
 
 # Try to import datasets, but don't fail if it's not available
 try:
-    from datasets import load_dataset, load_from_disk
+    from datasets import load_dataset, load_from_disk, Dataset
 
     HAS_DATASETS = True
 except ImportError:
@@ -94,9 +95,15 @@ def load_data(data_path: str, is_local: bool = False, **kwargs):
     if not data_path:
         raise ValueError("data_path must be provided")
 
+    dataset = None
     if is_local:
         # Load from local disk
-        dataset = load_from_disk(data_path)
+        file_extension = Path(data_path).suffix.lower()
+        if file_extension in [".csv"]:
+            data = pd.read_csv(data_path)
+            dataset = Dataset.from_pandas(data)
+        else:
+            dataset = load_from_disk(data_path)
     else:
         # Load from Hugging Face Hub
         dataset = load_dataset(data_path, **kwargs)
@@ -132,24 +139,10 @@ def get_formatter(formatter_type: str) -> Formatter:
     return formatter_map[formatter_type.lower()]()
 
 
-def get_image_path(img: str) -> str:
-    """
-    Get the image path from the image URL.
-
-    Args:
-        img: Image URL
-
-    Returns:
-        str: Image path
-    """
-
-    img_name = img.split("/")[-2]
-    img_id = img.split("/")[-1]
-    img_dir = "/home/yashkhare/workspace/IU-Xray/mnt/bn/haiyang-dataset-lq/medical/home/yisiyang/ysy/medical_dataset/iu_xray/iu_xray/images"
-    return f"{img_dir}/{img_name}/{img_id}"
 
-
-def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
+def convert_to_conversations(
+    data, column_mapping: Optional[Dict] = None, system_prompt: Optional[str] = None
+) -> List[Any]:
     """
     Convert data to a list of Conversation objects.
 
@@ -187,6 +180,13 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
         # Create a new conversation
         conversation = Conversation()
 
+        system_message = None
+        if system_prompt:
+            system_message = {
+                "role": "system",
+                "content": [{"type": "text", "text": system_prompt}],
+            }
+
         # Create user content and user message
         user_content = [
             {"type": "text", "text": input_text},
@@ -197,16 +197,14 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
                 # Handle list of images
                 for img in image:
                     if img:  # Check if image path is not empty
-                        img_path = get_image_path(img)
                         user_content.append(
-                            {"type": "image_url", "image_url": {"url": img_path}}
+                            {"type": "image_url", "image_url": {"url": img}}
                         )
                         break
             else:
                 # Handle single image
-                img_path = get_image_path(image)
                 user_content.append(
-                    {"type": "image_url", "image_url": {"url": img_path}}
+                    {"type": "image_url", "image_url": {"url": image}}
                 )
 
         user_message = {"role": "user", "content": user_content}
@@ -218,6 +216,8 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
         assistant_message = {"role": "assistant", "content": assistant_content}
 
         # Add messages to the conversation
+        if system_message:
+            conversation.add_message(system_message)
         conversation.add_message(user_message)
         conversation.add_message(assistant_message)
 
@@ -303,6 +303,7 @@ def format_data(
     formatter_type: str,
     output_dir: str,
     column_mapping: Optional[Dict] = None,
+    system_prompt: Optional[str] = None,
     dataset_kwargs: Optional[Dict] = None,
 ):
     """
@@ -385,7 +386,7 @@ def format_data(
     return formatted_data_paths, conversation_data_paths
 
 
-def load_and_format_data(formatter_config: Dict, output_dir: str):
+def load_and_format_data(data_config: Dict, output_dir: str):
     """
     Load and format data based on the configuration.
 
@@ -398,23 +399,24 @@ def load_and_format_data(formatter_config: Dict, output_dir: str):
     """
 
     # Extract parameters from config
-    data_path = formatter_config.get("data_path")
+    data_path = data_config.get("data_path")
     if not data_path:
         raise ValueError(
             "data_path must be specified in the formatter section of the config file"
         )
 
-    is_local = formatter_config.get("is_local", False)
-    formatter_type = formatter_config.get("type", "torchtune")
-    column_mapping = formatter_config.get("column_mapping")
-    dataset_kwargs = formatter_config.get("dataset_kwargs", {})
+    is_local = data_config.get("is_local", False)
+    formatter_type = data_config.get("formatter_type", "torchtune")
+    column_mapping = data_config.get("column_mapping")
+    dataset_kwargs = data_config.get("dataset_kwargs", {})
+    system_prompt = data_config.get("system_prompt", None)
 
     # Load the data
     data = load_data(data_path, is_local, **dataset_kwargs)
 
     # Format the data
     formatted_data_paths, conversation_data_paths = format_data(
-        data, formatter_type, output_dir, column_mapping, dataset_kwargs
+        data, formatter_type, output_dir, column_mapping, system_prompt, dataset_kwargs
     )
 
     return formatted_data_paths, conversation_data_paths
@@ -437,10 +439,10 @@ if __name__ == "__main__":
 
     # Read the configuration
     config = read_config(args.config)
-    formatter_config = config.get("formatter", {})
+    data_config = config.get("data", {})
     output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
     output_data_dir = os.path.join(output_dir, "data")
     # Load and format the data
     formatted_data_paths, conversation_data_paths = load_and_format_data(
-        formatter_config, output_data_dir
+        data_config, output_data_dir
     )

+ 3 - 3
src/finetune_pipeline/data/formatter.py

@@ -216,7 +216,7 @@ class vLLMFormatter(Formatter):
         formatted_messages = []
         for message in conversation.messages:
             role = message["role"]
-            if role == "user":
+            if role != "assistant":
                 formatted_messages.append(self.format_message(message))
         return {"messages": formatted_messages}
 
@@ -235,8 +235,8 @@ class vLLMFormatter(Formatter):
 
         for content in message["content"]:
             if content["type"] == "text":
-                contents.append(content["text"])
-            elif content["type"] == "image_url":
+                contents.append(content)
+            elif content["type"] == "image_url" or content["type"] == "image":
                 base64_image = image_to_base64(content["image_url"]["url"])
                 img_content = {
                     "type": "image_url",

+ 1 - 0
src/finetune_pipeline/inference/run_inference.py

@@ -2,6 +2,7 @@ import argparse
 import json
 import logging
 from typing import Any, Dict, List, Optional, TypedDict, Union
+import os
 
 import requests
 from tqdm import tqdm