Przeglądaj źródła

updated data_loader to handle split and added custom dataset

Ubuntu 2 miesięcy temu
rodzic
commit
936c249933

+ 81 - 9
src/finetune_pipeline/config.yaml

@@ -1,23 +1,92 @@
+# # Configuration for data loading, formatting, and fine-tuning
+
+
+# output_dir: "/tmp/finetune_pipeline/outputs/"  # Directory to store output files
+
+# data:
+#   data_path: "dz-osamu/IU-Xray"  # Path to the dataset to format (either a Hugging Face dataset ID or a local path)
+#   is_local: false                  # Whether the data is stored locally
+#   # Maps custom column names to standard field names
+#   column_mapping:
+#     input: "query"             # Field containing the input text
+#     output: "response"              # Field containing the output text
+#     image: "image"           # Field containing the image path (optional)
+#   # Additional arguments to pass to the load_dataset function
+#   # dataset_kwargs:
+#   #   split: "train"                # Dataset split to load
+#   #   # Add any other dataset-specific arguments here
+
+
+# # Formatter configuration
+# formatter:
+#   type: "vllm"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+
+
+# # # Something like this in the torchtune config
+# # dataset:
+# #   _component_: torchtune.datasets.CustomSFTDataset
+# #   packed: False
+# #   split: train
+# # seed: null
+# # shuffle: True
+
+
+# # Training configuration
+# finetuning:
+#   strategy: "lora"               # Training strategy ('fft' or 'lora')
+#   num_epochs: 1                 # Number of training epochs
+#   batch_size: 1                 # Batch size per device for training
+#   torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
+#   num_processes_per_node: 8             # TorchTune-specific configuration
+#   distributed: true             # Whether to use distributed training
+
+
+# # vLLM Inference configuration
+# inference:
+#   # Model configuration
+#   model_path: "/home/ubuntu/yash-workspace/medgemma-4b-it" # Path to the model checkpoint
+#   quantization: null            # Quantization method (awq, gptq, squeezellm)
+
+#   # Server configuration
+#   port: 8000                    # Port to run the server on
+#   host: "0.0.0.0"               # Host to run the server on
+
+#   # Performance configuration
+#   tensor_parallel_size: 1       # Number of GPUs to use for tensor parallelism
+#   max_model_len: 32           # Maximum sequence length
+#   max_num_seqs: 1              # Maximum number of sequences
+#   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
+#   enforce_eager: false          # Enforce eager execution
+
+#   eval_data: "your/eval/dataset/path" # Path to the evaluation dataset (optional)
+
+#   # Additional vLLM parameters (optional)
+#   # swap_space: 4               # Size of CPU swap space in GiB
+#   # block_size: 16              # Size of blocks used in the KV cache
+#   # disable_log_stats: true     # Disable logging of stats
+#   # disable_log_requests: false # Disable logging of requests
+
+
+
 # Configuration for data loading, formatting, and fine-tuning
 
 
-output_dir: "/tmp/finetune_pipeline/outputs/"  # Directory to store output files
+output_dir: "/home/ubuntu/yash-workspace/outputs"  # Directory to store output files
 
 # Formatter configuration
 formatter:
-  type: "vllm"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+  type: "torchtune"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
   data_path: "dz-osamu/IU-Xray"  # Path to the dataset to format (either a Hugging Face dataset ID or a local path)
   is_local: false                  # Whether the data is stored locally
   # Maps custom column names to standard field names
   column_mapping:
     input: "query"             # Field containing the input text
     output: "response"              # Field containing the output text
-    image: null           # Field containing the image path (optional)
+    image: "images"           # Field containing the image path (optional)
 
   # Additional arguments to pass to the load_dataset function
   dataset_kwargs:
     split: "train"                # Dataset split to load
-    # Add any other dataset-specific arguments here
 
 # Training configuration
 finetuning:
@@ -25,15 +94,17 @@ finetuning:
   num_epochs: 1                 # Number of training epochs
   batch_size: 1                 # Batch size per device for training
   torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
-  num_processes_per_node: 8             # TorchTune-specific configuration
-  distributed: true             # Whether to use distributed training
+  num_processes_per_node: 1             # TorchTune-specific configuration
+  distributed: false             # Whether to use distributed training
 
 
 # vLLM Inference configuration
 inference:
   # Model configuration
-  model_path: "/home/ubuntu/yash-workspace/medgemma-4b-it" # Path to the model checkpoint
+  model_path: "your/model/path" # Path to the model checkpoint
   quantization: null            # Quantization method (awq, gptq, squeezellm)
+  dtype: "auto"                 # Data type for model weights (half, float, bfloat16, auto)
+  trust_remote_code: false      # Trust remote code when loading the model
 
   # Server configuration
   port: 8000                    # Port to run the server on
@@ -41,8 +112,8 @@ inference:
 
   # Performance configuration
   tensor_parallel_size: 1       # Number of GPUs to use for tensor parallelism
-  max_model_len: 32           # Maximum sequence length
-  max_num_seqs: 1              # Maximum number of sequences
+  max_model_len: 1024           # Maximum sequence length
+  max_num_seqs: 16              # Maximum number of sequences
   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
   enforce_eager: false          # Enforce eager execution
 
@@ -53,3 +124,4 @@ inference:
   # block_size: 16              # Size of blocks used in the KV cache
   # disable_log_stats: true     # Disable logging of stats
   # disable_log_requests: false # Disable logging of requests
+

+ 101 - 31
src/finetune_pipeline/data/data_loader.py

@@ -79,7 +79,7 @@ def load_data(data_path: str, is_local: bool = False, **kwargs):
         **kwargs: Additional arguments to pass to the load_dataset function
 
     Returns:
-        Dataset object from the datasets library
+        Dataset object from the datasets library with all splits
 
     Raises:
         ImportError: If the datasets package is not installed
@@ -174,9 +174,18 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
         user_content = [
             {"type": "text", "text": input_text},
         ]
-        # Add image to user content
+        # Add image(s) to user content
         if image is not None:
-            user_content.append({"type": "image", "image_url": {"url": image}})
+            if isinstance(image, list):
+                # Handle list of images
+                for img in image:
+                    if img:  # Check if image path is not empty
+                        user_content.append(
+                            {"type": "image", "image_url": {"url": img}}
+                        )
+            else:
+                # Handle single image
+                user_content.append({"type": "image", "image_url": {"url": image}})
 
         user_message = {"role": "user", "content": user_content}
 
@@ -197,7 +206,7 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
 
 
 def save_formatted_data(
-    formatted_data: List[Any], output_dir: str, formatter_type: str
+    formatted_data: List[Any], output_dir: str, formatter_type: str, split: str
 ) -> str:
     """
     Save formatted data to a JSON file.
@@ -215,7 +224,7 @@ def save_formatted_data(
 
     # Define the output file path
     formatted_data_path = os.path.join(
-        output_dir, f"{formatter_type}_formatted_data.json"
+        output_dir, f"{split}_{formatter_type}_formatted_data.json"
     )
 
     # Save the formatted data
@@ -237,7 +246,7 @@ def save_formatted_data(
     return formatted_data_path
 
 
-def save_conversation_data(conversation_data: List, output_dir: str) -> str:
+def save_conversation_data(conversation_data: List, output_dir: str, split: str) -> str:
     """
     Save conversation data to a JSON file.
 
@@ -252,7 +261,7 @@ def save_conversation_data(conversation_data: List, output_dir: str) -> str:
     os.makedirs(output_dir, exist_ok=True)
 
     # Define the output file path
-    conversation_data_path = os.path.join(output_dir, "conversation_data.json")
+    conversation_data_path = os.path.join(output_dir, f"{split}_conversation_data.json")
 
     # Convert Conversation objects to a serializable format
     serializable_conversations = []
@@ -267,37 +276,103 @@ def save_conversation_data(conversation_data: List, output_dir: str) -> str:
     return conversation_data_path
 
 
-def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
+def format_data(
+    data,
+    formatter_type: str,
+    output_dir: str,
+    column_mapping: Optional[Dict] = None,
+    dataset_kwargs: Optional[Dict] = None,
+):
     """
-    Format the data using the specified formatter.
+    Format the data using the specified formatter for all splits.
 
     Args:
-        data: Data to format
+        data: Dataset with multiple splits to format or a single dataset
         formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
+        output_dir: Directory to save the formatted data
         column_mapping: Optional mapping of column names
+        dataset_kwargs: Optional dataset kwargs that may contain split information
 
     Returns:
-        Tuple containing formatted data and conversation data
+        Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
     """
-    # First convert the data to conversations
-    conversations = convert_to_conversations(data, column_mapping)
+    formatted_data_paths = []
+    conversation_data_paths = []
+
+    # Check if the dataset has explicit splits
+    if (
+        hasattr(data, "keys")
+        and callable(data.keys)
+        and len(data.keys()) > 0
+        and isinstance(data, dict)
+    ):
+        # Dataset has splits (train, validation, test, etc.)
+        splits = data.keys()
+
+        for split in splits:
+            # First convert the data to conversations
+            conversations = convert_to_conversations(data[split], column_mapping)
+
+            # Then get the formatter and format the conversations
+            formatter = get_formatter(formatter_type)
+            formatted_data = formatter.format_data(conversations)
+            print(
+                f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
+            )
+
+            # Save the formatted data
+            formatted_data_path = save_formatted_data(
+                formatted_data, output_dir, formatter_type, split
+            )
+            formatted_data_paths.append(formatted_data_path)
+
+            # Save the conversation data
+            conversation_data_path = save_conversation_data(
+                conversations, output_dir, split
+            )
+            conversation_data_paths.append(conversation_data_path)
+    else:
+        # Dataset doesn't have explicit splits, treat it as a single dataset
+        # Check if a split is specified in dataset_kwargs
+        split = "default"
+        if dataset_kwargs and "split" in dataset_kwargs:
+            split = dataset_kwargs["split"]
+
+        # First convert the data to conversations
+        conversations = convert_to_conversations(data, column_mapping)
+
+        # Then get the formatter and format the conversations
+        formatter = get_formatter(formatter_type)
+        formatted_data = formatter.format_data(conversations)
+        print(
+            f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
+        )
 
-    # Then get the formatter and format the conversations
-    formatter = get_formatter(formatter_type)
-    formatted_data = formatter.format_data(conversations)
+        # Save the formatted data
+        formatted_data_path = save_formatted_data(
+            formatted_data, output_dir, formatter_type, split
+        )
+        formatted_data_paths.append(formatted_data_path)
+
+        # Save the conversation data
+        conversation_data_path = save_conversation_data(
+            conversations, output_dir, split
+        )
+        conversation_data_paths.append(conversation_data_path)
 
-    return formatted_data, conversations
+    return formatted_data_paths, conversation_data_paths
 
 
-def load_and_format_data(formatter_config: Dict):
+def load_and_format_data(formatter_config: Dict, output_dir: str):
     """
     Load and format data based on the configuration.
 
     Args:
         formatter_config: Dictionary containing formatter configuration parameters
+        output_dir: Directory to save the formatted data
 
     Returns:
-        Formatted data in the specified format
+        Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
     """
 
     # Extract parameters from config
@@ -316,11 +391,11 @@ def load_and_format_data(formatter_config: Dict):
     data = load_data(data_path, is_local, **dataset_kwargs)
 
     # Format the data
-    formatted_data, conversation_data = format_data(
-        data, formatter_type, column_mapping
+    formatted_data_paths, conversation_data_paths = format_data(
+        data, formatter_type, output_dir, column_mapping, dataset_kwargs
     )
 
-    return formatted_data, conversation_data
+    return formatted_data_paths, conversation_data_paths
 
 
 if __name__ == "__main__":
@@ -341,14 +416,9 @@ if __name__ == "__main__":
     # Read the configuration
     config = read_config(args.config)
     formatter_config = config.get("formatter", {})
-    output_dir = config.get("output_dir")
+    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
 
     # Load and format the data
-    formatted_data, conversation_data = load_and_format_data(formatter_config)
-    print(f"Loaded and formatted data: {len(formatted_data)} samples")
-
-    # Save the data if output_dir is provided
-    if output_dir:
-        formatter_type = formatter_config.get("type", "torchtune")
-        save_formatted_data(formatted_data, output_dir, formatter_type)
-        save_conversation_data(conversation_data, output_dir)
+    formatted_data_paths, conversation_data_paths = load_and_format_data(
+        formatter_config, output_dir
+    )

+ 24 - 0
src/finetune_pipeline/finetuning/dataset.py

@@ -0,0 +1,24 @@
+from torchtune.datasets import SFTDataset
+from torchtune.modules.transforms import Transform
+from torchtune.data import OpenAIToMessages
+
+
+def custom_sft_dataset(
+    model_transform: Transform,
+    *,
+    split: str = "train",
+    dataset_path: str = "files/synthetic_data/train.csv",
+    train_on_input: bool = True,
+) -> SFTDataset:
+    """Creates a custom dataset."""
+
+    openaitomessage = OpenAIToMessages(train_on_input=train_on_input)
+
+    ds = SFTDataset(
+        source="json",
+        data_files=dataset_path,
+        split="train",
+        message_transform=openaitomessage,
+        model_transform=Transform,
+    )
+    return ds

+ 35 - 1
src/finetune_pipeline/inference/run_inference.py

@@ -70,7 +70,41 @@ class VLLMClient:
             self.logger.error(f"Error sending request to vLLM server: {e}")
             raise
 
-def run_inference_on_eval_data(
+
+def vllm_call_batch(llm, image_paths: List[str], structured):
+    messages_batch = []
+    for img_path in image_paths:
+        messages = [
+            {
+                "role": "user",
+                "content": [
+                    {"type": "image_url", "image_url": {"url": f"file:///{img_path}"}},
+                    {
+                        "type": "text",
+                        "text": generate_prompt(structured),
+                    },
+                ],
+            }
+        ]
+        messages_batch.append(messages)
+
+    # Using greedy decoding
+    if structured:
+        sampling_params = SamplingParams(
+            temperature=0,
+            top_p=1,
+            max_tokens=8192,
+            guided_decoding=guided_decoding_params,
+        )
+    else:
+        sampling_params = SamplingParams(
+            temperature=0,
+            top_p=1,
+            max_tokens=8192,
+        )
+    return llm.chat(messages_batch, sampling_params, use_tqdm=True)
+
+def run_vllm_batch_inference_on_dataset(
     eval_data_path: str,
     server_url: str = "http://localhost:8000/v1",
     is_local: bool = False,