Browse Source

updated data loader

Ubuntu 1 month ago
parent
commit
694e2354c8

+ 7 - 0
src/finetune_pipeline/__init__.py

@@ -0,0 +1,7 @@
+"""
+Fine-tuning pipeline for LLMs.
+
+This package provides tools for data loading, formatting, fine-tuning, and inference.
+"""
+
+__version__ = "0.1.0"

+ 21 - 19
src/finetune_pipeline/config.yaml

@@ -1,32 +1,32 @@
 # Configuration for data loading, formatting, and fine-tuning
 # Configuration for data loading, formatting, and fine-tuning
 
 
-# Data source configuration
-data_path: "your/dataset/path"  # Path to the dataset (either a Hugging Face dataset ID or a local path)
-is_local: true                  # Whether the data is stored locally
 
 
-# Formatter configuration
-formatter_type: "torchtune"     # Type of formatter to use ('torchtune', 'vllm', or 'openai')
-
-# Column mapping configuration
-# Maps custom column names to standard field names
-column_mapping:
-  input: "question"             # Field containing the input text
-  output: "answer"              # Field containing the output text
-  image: "image_path"           # Field containing the image path (optional)
+output_dir: "/tmp/finetune_pipeline/outputs/"  # Directory to store output files
 
 
-# Additional arguments to pass to the load_dataset function
-dataset_kwargs:
-  split: "train"                # Dataset split to load
-  # Add any other dataset-specific arguments here
+# Formatter configuration
+formatter:
+  type: "vllm"  # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+  data_path: "dz-osamu/IU-Xray"  # Path to the dataset to format (either a Hugging Face dataset ID or a local path)
+  is_local: false                  # Whether the data is stored locally
+  # Maps custom column names to standard field names
+  column_mapping:
+    input: "query"             # Field containing the input text
+    output: "response"              # Field containing the output text
+    image: null           # Field containing the image path (optional)
+
+  # Additional arguments to pass to the load_dataset function
+  dataset_kwargs:
+    split: "train"                # Dataset split to load
+    # Add any other dataset-specific arguments here
 
 
 # Training configuration
 # Training configuration
 finetuning:
 finetuning:
-  strategy: "lora"               # Training strategy ('fft' or 'lora')
+  strategy: "fft"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
   num_epochs: 1                 # Number of training epochs
   batch_size: 1                 # Batch size per device for training
   batch_size: 1                 # Batch size per device for training
   torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
   torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
-  num_processes_per_node: 8             # TorchTune-specific configuration
-  distributed: true             # Whether to use distributed training
+  num_processes_per_node: 1             # TorchTune-specific configuration
+  distributed: false             # Whether to use distributed training
 
 
 
 
 # vLLM Inference configuration
 # vLLM Inference configuration
@@ -48,6 +48,8 @@ inference:
   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
   gpu_memory_utilization: 0.9   # Fraction of GPU memory to use
   enforce_eager: false          # Enforce eager execution
   enforce_eager: false          # Enforce eager execution
 
 
+  eval_data: "your/eval/dataset/path" # Path to the evaluation dataset (optional)
+
   # Additional vLLM parameters (optional)
   # Additional vLLM parameters (optional)
   # swap_space: 4               # Size of CPU swap space in GiB
   # swap_space: 4               # Size of CPU swap space in GiB
   # block_size: 16              # Size of blocks used in the KV cache
   # block_size: 16              # Size of blocks used in the KV cache

+ 46 - 0
src/finetune_pipeline/data/__init__.py

@@ -0,0 +1,46 @@
+"""
+Data loading and formatting utilities.
+
+This module provides tools for loading data from various sources and formatting it
+for fine-tuning and inference.
+"""
+
+from .data_loader import (
+    convert_to_conversations,
+    format_data,
+    get_formatter,
+    load_and_format_data,
+    load_data,
+    read_config,
+    save_conversation_data,
+    save_formatted_data,
+)
+from .formatter import (
+    Conversation,
+    Formatter,
+    Message,
+    MessageContent,
+    OpenAIFormatter,
+    TorchtuneFormatter,
+    vLLMFormatter,
+)
+
+__all__ = [
+    # From data_loader
+    "load_data",
+    "convert_to_conversations",
+    "format_data",
+    "get_formatter",
+    "load_and_format_data",
+    "read_config",
+    "save_formatted_data",
+    "save_conversation_data",
+    # From formatter
+    "Conversation",
+    "Formatter",
+    "Message",
+    "MessageContent",
+    "OpenAIFormatter",
+    "TorchtuneFormatter",
+    "vLLMFormatter",
+]

+ 101 - 16
src/finetune_pipeline/data/data_loader.py

@@ -5,7 +5,7 @@ Data loader module for loading and formatting data from Hugging Face.
 import json
 import json
 import os
 import os
 from pathlib import Path
 from pathlib import Path
-from typing import Dict, Optional
+from typing import Any, Dict, List, Optional, Union
 
 
 # Try to import yaml, but don't fail if it's not available
 # Try to import yaml, but don't fail if it's not available
 try:
 try:
@@ -196,6 +196,77 @@ def convert_to_conversations(data, column_mapping: Optional[Dict] = None):
     return conversations
     return conversations
 
 
 
 
+def save_formatted_data(
+    formatted_data: List[Any], output_dir: str, formatter_type: str
+) -> str:
+    """
+    Save formatted data to a JSON file.
+
+    Args:
+        formatted_data: The formatted data to save
+        output_dir: Directory to save the data
+        formatter_type: Type of formatter used ('torchtune', 'vllm', or 'openai')
+
+    Returns:
+        Path to the saved file
+    """
+    # Create the output directory if it doesn't exist
+    os.makedirs(output_dir, exist_ok=True)
+
+    # Define the output file path
+    formatted_data_path = os.path.join(
+        output_dir, f"{formatter_type}_formatted_data.json"
+    )
+
+    # Save the formatted data
+    with open(formatted_data_path, "w") as f:
+        # Handle different data types
+        if isinstance(formatted_data, list) and all(
+            isinstance(item, dict) for item in formatted_data
+        ):
+            json.dump(formatted_data, f, indent=2)
+        elif isinstance(formatted_data, list) and all(
+            isinstance(item, str) for item in formatted_data
+        ):
+            json.dump(formatted_data, f, indent=2)
+        else:
+            # For other types, convert to a simple list of strings
+            json.dump([str(item) for item in formatted_data], f, indent=2)
+
+    print(f"Saved formatted data to {formatted_data_path}")
+    return formatted_data_path
+
+
+def save_conversation_data(conversation_data: List, output_dir: str) -> str:
+    """
+    Save conversation data to a JSON file.
+
+    Args:
+        conversation_data: List of Conversation objects
+        output_dir: Directory to save the data
+
+    Returns:
+        Path to the saved file
+    """
+    # Create the output directory if it doesn't exist
+    os.makedirs(output_dir, exist_ok=True)
+
+    # Define the output file path
+    conversation_data_path = os.path.join(output_dir, "conversation_data.json")
+
+    # Convert Conversation objects to a serializable format
+    serializable_conversations = []
+    for conv in conversation_data:
+        serializable_conversations.append({"messages": conv.messages})
+
+    # Save the conversation data
+    with open(conversation_data_path, "w") as f:
+        json.dump(serializable_conversations, f, indent=2)
+
+    print(f"Saved conversation data to {conversation_data_path}")
+    return conversation_data_path
+
+
 def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
 def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
     """
     """
     Format the data using the specified formatter.
     Format the data using the specified formatter.
@@ -206,7 +277,7 @@ def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None
         column_mapping: Optional mapping of column names
         column_mapping: Optional mapping of column names
 
 
     Returns:
     Returns:
-        Formatted data in the specified format
+        Tuple containing formatted data and conversation data
     """
     """
     # First convert the data to conversations
     # First convert the data to conversations
     conversations = convert_to_conversations(data, column_mapping)
     conversations = convert_to_conversations(data, column_mapping)
@@ -215,39 +286,41 @@ def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None
     formatter = get_formatter(formatter_type)
     formatter = get_formatter(formatter_type)
     formatted_data = formatter.format_data(conversations)
     formatted_data = formatter.format_data(conversations)
 
 
-    return formatted_data
+    return formatted_data, conversations
 
 
 
 
-def load_and_format_data(config_path: str):
+def load_and_format_data(formatter_config: Dict):
     """
     """
     Load and format data based on the configuration.
     Load and format data based on the configuration.
 
 
     Args:
     Args:
-        config_path: Path to the configuration file
+        formatter_config: Dictionary containing formatter configuration parameters
 
 
     Returns:
     Returns:
         Formatted data in the specified format
         Formatted data in the specified format
     """
     """
-    # Read the configuration
-    config = read_config(config_path)
 
 
     # Extract parameters from config
     # Extract parameters from config
-    data_path = config.get("data_path")
+    data_path = formatter_config.get("data_path")
     if not data_path:
     if not data_path:
-        raise ValueError("data_path must be specified in the config file")
+        raise ValueError(
+            "data_path must be specified in the formatter section of the config file"
+        )
 
 
-    is_local = config.get("is_local", False)
-    formatter_type = config.get("formatter_type", "torchtune")
-    column_mapping = config.get("column_mapping")
-    dataset_kwargs = config.get("dataset_kwargs", {})
+    is_local = formatter_config.get("is_local", False)
+    formatter_type = formatter_config.get("type", "torchtune")
+    column_mapping = formatter_config.get("column_mapping")
+    dataset_kwargs = formatter_config.get("dataset_kwargs", {})
 
 
     # Load the data
     # Load the data
     data = load_data(data_path, is_local, **dataset_kwargs)
     data = load_data(data_path, is_local, **dataset_kwargs)
 
 
     # Format the data
     # Format the data
-    formatted_data = format_data(data, formatter_type, column_mapping)
+    formatted_data, conversation_data = format_data(
+        data, formatter_type, column_mapping
+    )
 
 
-    return formatted_data
+    return formatted_data, conversation_data
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -265,5 +338,17 @@ if __name__ == "__main__":
     )
     )
     args = parser.parse_args()
     args = parser.parse_args()
 
 
-    formatted_data = load_and_format_data(args.config)
+    # Read the configuration
+    config = read_config(args.config)
+    formatter_config = config.get("formatter", {})
+    output_dir = config.get("output_dir")
+
+    # Load and format the data
+    formatted_data, conversation_data = load_and_format_data(formatter_config)
     print(f"Loaded and formatted data: {len(formatted_data)} samples")
     print(f"Loaded and formatted data: {len(formatted_data)} samples")
+
+    # Save the data if output_dir is provided
+    if output_dir:
+        formatter_type = formatter_config.get("type", "torchtune")
+        save_formatted_data(formatted_data, output_dir, formatter_type)
+        save_conversation_data(conversation_data, output_dir)

+ 4 - 0
src/finetune_pipeline/finetuning/__init__.py

@@ -0,0 +1,4 @@
+"""
+Test suite for the finetune_pipeline package.
+
+"""

+ 25 - 0
src/finetune_pipeline/inference/__init__.py

@@ -0,0 +1,25 @@
+"""
+Inference utilities for LLMs.
+
+This module provides tools for running inference with fine-tuned models.
+"""
+
+from .inference import (
+    run_inference_from_config,
+    run_inference_on_eval_data,
+    VLLMClient,
+    VLLMInferenceRequest,
+)
+from .start_vllm_server import check_vllm_installed, read_config, start_vllm_server
+
+__all__ = [
+    # From inference
+    "VLLMClient",
+    "VLLMInferenceRequest",
+    "run_inference_on_eval_data",
+    "run_inference_from_config",
+    # From start_vllm_server
+    "start_vllm_server",
+    "read_config",
+    "check_vllm_installed",
+]

+ 5 - 0
src/finetune_pipeline/tests/__init__.py

@@ -0,0 +1,5 @@
+"""
+Test suite for the finetune_pipeline package.
+
+This package contains tests for the data loading, formatting, fine-tuning, and inference modules.
+"""

+ 231 - 527
src/finetune_pipeline/tests/test_formatter.py

@@ -1,540 +1,244 @@
-# import sys
-# import unittest
-# from pathlib import Path
-# from unittest.mock import MagicMock
-
-# # Add the parent directory to the path so we can import the modules
-# sys.path.append(str(Path(__file__).parent.parent))
-
-# from data.data_loader import convert_to_conversations, format_data, load_data
-# from data.formatter import (
-#     Conversation,
-#     OpenAIFormatter,
-#     TorchtuneFormatter,
-#     vLLMFormatter,
-# )
-
-
-# class TestFormatter(unittest.TestCase):
-#     """Test cases for the formatter module."""
-
-#     @classmethod
-#     def setUpClass(cls):
-#         """Set up test fixtures, called before any tests are run."""
-#         # Define a small dataset to use for testing
-#         cls.dataset_name = "dz-osamu/IU-Xray"
-#         cls.split = "train[:10]"  # Use only 10 samples for testing
-
-#         try:
-#             # Load the dataset
-#             cls.dataset = load_data(cls.dataset_name, split=cls.split)
-
-#             # Create a column mapping for the dz-osamu/IU-Xray dataset
-#             cls.column_mapping = {
-#                 "input": "query",
-#                 "output": "response",
-#                 "image": "images"
-#             }
-
-#             # Convert to list for easier processing
-#             cls.data = list(cls.dataset)
-
-#             # Convert to conversations
-#             cls.conversations = convert_to_conversations(cls.data, cls.column_mapping)
-
-#         except Exception as e:
-#             print(f"Error setting up test fixtures: {e}")
-#             raise
-
-#     def test_conversation_creation(self):
-#         """Test that conversations are created correctly."""
-#         self.assertIsNotNone(self.conversations)
-#         self.assertGreater(len(self.conversations), 0)
-
-#         # Check that each conversation has at least two messages (user and assistant)
-#         for conversation in self.conversations:
-#             self.assertGreaterEqual(len(conversation.messages), 2)
-#             self.assertEqual(conversation.messages[0]["role"], "user")
-#             self.assertEqual(conversation.messages[1]["role"], "assistant")
-
-#     def test_torchtune_formatter(self):
-#         """Test the TorchtuneFormatter."""
-#         formatter = TorchtuneFormatter()
-
-#         # Test format_data
-#         formatted_data = formatter.format_data(self.conversations)
-#         self.assertIsNotNone(formatted_data)
-#         self.assertEqual(len(formatted_data), len(self.conversations))
-
-#         # Test format_conversation
-#         formatted_conversation = formatter.format_conversation(self.conversations[0])
-#         self.assertIsInstance(formatted_conversation, dict)
-#         self.assertIn("messages", formatted_conversation)
-
-#         # Test format_message
-#         message = self.conversations[0].messages[0]
-#         formatted_message = formatter.format_message(message)
-#         self.assertIsInstance(formatted_message, dict)
-#         self.assertIn("role", formatted_message)
-#         self.assertIn("content", formatted_message)
-
-#     def test_vllm_formatter(self):
-#         """Test the vLLMFormatter."""
-#         formatter = vLLMFormatter()
-
-#         # Test format_data
-#         formatted_data = formatter.format_data(self.conversations)
-#         self.assertIsNotNone(formatted_data)
-#         self.assertEqual(len(formatted_data), len(self.conversations))
-
-#         # Test format_conversation
-#         formatted_conversation = formatter.format_conversation(self.conversations[0])
-#         self.assertIsInstance(formatted_conversation, str)
-
-#         # Test format_message
-#         message = self.conversations[0].messages[0]
-#         formatted_message = formatter.format_message(message)
-#         self.assertIsInstance(formatted_message, str)
-#         self.assertIn(message["role"], formatted_message)
-
-#     def test_openai_formatter(self):
-#         """Test the OpenAIFormatter."""
-#         formatter = OpenAIFormatter()
-
-#         # Test format_data
-#         formatted_data = formatter.format_data(self.conversations)
-#         self.assertIsNotNone(formatted_data)
-#         self.assertEqual(len(formatted_data), len(self.conversations))
-
-#         # Test format_conversation
-#         formatted_conversation = formatter.format_conversation(self.conversations[0])
-#         self.assertIsInstance(formatted_conversation, dict)
-#         self.assertIn("messages", formatted_conversation)
-
-#         # Test format_message
-#         message = self.conversations[0].messages[0]
-#         formatted_message = formatter.format_message(message)
-#         self.assertIsInstance(formatted_message, dict)
-#         self.assertIn("role", formatted_message)
-#         self.assertIn("content", formatted_message)
-
-#     def test_format_data_function(self):
-#         """Test the format_data function from data_loader."""
-#         # Test with TorchtuneFormatter
-#         torchtune_data = format_data(self.data, "torchtune", self.column_mapping)
-#         self.assertIsNotNone(torchtune_data)
-#         self.assertEqual(len(torchtune_data), len(self.data))
-
-#         # Test with vLLMFormatter
-#         vllm_data = format_data(self.data, "vllm", self.column_mapping)
-#         self.assertIsNotNone(vllm_data)
-#         self.assertEqual(len(vllm_data), len(self.data))
-
-#         # Test with OpenAIFormatter
-#         openai_data = format_data(self.data, "openai", self.column_mapping)
-#         self.assertIsNotNone(openai_data)
-#         self.assertEqual(len(openai_data), len(self.data))
-
-#     def test_with_mock_data(self):
-#         """Test the formatter pipeline with mock data."""
-#         # Create mock data that mimics a dataset
-#         mock_data = [
-#             {
-#                 "question": "What is the capital of France?",
-#                 "context": "France is a country in Western Europe. Its capital is Paris.",
-#                 "answer": "Paris",
-#             },
-#             {
-#                 "question": "Who wrote Hamlet?",
-#                 "context": "Hamlet is a tragedy written by William Shakespeare.",
-#                 "answer": "William Shakespeare",
-#             },
-#             {
-#                 "question": "What is the largest planet in our solar system?",
-#                 "context": "Jupiter is the largest planet in our solar system.",
-#                 "answer": "Jupiter",
-#             },
-#         ]
-
-#         # Create a column mapping for the mock data
-#         column_mapping = {"input": "context", "output": "answer"}
-
-#         # Convert to conversations
-#         conversations = convert_to_conversations(mock_data, column_mapping)
-
-#         # Test that conversations are created correctly
-#         self.assertEqual(len(conversations), len(mock_data))
-#         for i, conversation in enumerate(conversations):
-#             self.assertEqual(len(conversation.messages), 2)
-#             self.assertEqual(conversation.messages[0]["role"], "user")
-#             self.assertEqual(conversation.messages[1]["role"], "assistant")
-
-#             # Check content of user message
-#             user_content = conversation.messages[0]["content"]
-#             self.assertTrue(isinstance(user_content, list))
-#             self.assertEqual(user_content[0]["type"], "text")
-#             self.assertEqual(user_content[0]["text"], mock_data[i]["context"])
-
-#             # Check content of assistant message
-#             assistant_content = conversation.messages[1]["content"]
-#             self.assertTrue(isinstance(assistant_content, list))
-#             self.assertEqual(assistant_content[0]["type"], "text")
-#             self.assertEqual(assistant_content[0]["text"], mock_data[i]["answer"])
-
-#         # Test each formatter with the mock data
-#         formatters = {
-#             "torchtune": TorchtuneFormatter(),
-#             "vllm": vLLMFormatter(),
-#             "openai": OpenAIFormatter(),
-#         }
-
-#         for name, formatter in formatters.items():
-#             formatted_data = formatter.format_data(conversations)
-#             self.assertEqual(len(formatted_data), len(mock_data))
-
-#             # Test the first formatted item
-#             if name == "vllm":
-#                 # vLLM formatter returns strings
-#                 self.assertTrue(isinstance(formatted_data[0], str))
-#                 self.assertIn("user:", formatted_data[0])
-#                 self.assertIn("assistant:", formatted_data[0])
-#             else:
-#                 # Torchtune and OpenAI formatters return dicts
-#                 self.assertTrue(isinstance(formatted_data[0], dict))
-#                 self.assertIn("messages", formatted_data[0])
-#                 self.assertEqual(len(formatted_data[0]["messages"]), 2)
-
-
-# if __name__ == "__main__":
-#     # If run as a script, this allows passing a dataset name as an argument
-#     import argparse
-
-#     parser = argparse.ArgumentParser(
-#         description="Test the formatter module with a specific dataset"
-#     )
-#     parser.add_argument(
-#         "--dataset",
-#         type=str,
-#         default="dz-osamu/IU-Xray",
-#         help="Name of the Hugging Face dataset to use for testing",
-#     )
-#     parser.add_argument(
-#         "--split",
-#         type=str,
-#         default="train[:10]",
-#         help="Dataset split to use (e.g., 'train[:10]', 'validation[:10]')",
-#     )
-
-#     args = parser.parse_args()
-
-#     # Override the default dataset in the test class
-#     TestFormatter.dataset_name = args.dataset
-#     TestFormatter.split = args.split
-
-#     # Run the tests
-#     unittest.main(argv=["first-arg-is-ignored"])
-
-
-
-import json
-import os
-import subprocess
 import sys
 import sys
-import tempfile
 import unittest
 import unittest
 from pathlib import Path
 from pathlib import Path
-from unittest.mock import MagicMock, patch
+from unittest.mock import MagicMock
 
 
 # Add the parent directory to the path so we can import the modules
 # Add the parent directory to the path so we can import the modules
 sys.path.append(str(Path(__file__).parent.parent))
 sys.path.append(str(Path(__file__).parent.parent))
 
 
-# Import the module to test
-import finetuning
-
-
-class TestFinetuning(unittest.TestCase):
-    """Test cases for the finetuning module."""
-
-    def setUp(self):
-        """Set up test fixtures, called before each test."""
-        # Create temporary config files for testing
-        self.temp_dir = tempfile.TemporaryDirectory()
-
-        # Create a YAML config file
-        self.yaml_config_path = os.path.join(self.temp_dir.name, "config.yaml")
-        self.create_yaml_config()
-
-        # Create a JSON config file
-        self.json_config_path = os.path.join(self.temp_dir.name, "config.json")
-        self.create_json_config()
-
-    def tearDown(self):
-        """Tear down test fixtures, called after each test."""
-        self.temp_dir.cleanup()
-
-    def create_yaml_config(self):
-        """Create a YAML config file for testing."""
-        yaml_content = """
-finetuning:
-  strategy: "lora"
-  distributed: true
-  num_processes_per_node: 4
-  torchtune_config: "llama3_2_vision/11B_lora"
-"""
-        with open(self.yaml_config_path, "w") as f:
-            f.write(yaml_content)
-
-    def create_json_config(self):
-        """Create a JSON config file for testing."""
-        json_content = {
-            "finetuning": {
-                "strategy": "lora",
-                "distributed": True,
-                "torchtune_config": "llama3_2_vision/11B_lora",
+from data.data_loader import convert_to_conversations, format_data, load_data
+from data.formatter import (
+    Conversation,
+    OpenAIFormatter,
+    TorchtuneFormatter,
+    vLLMFormatter,
+)
+
+
+class TestFormatter(unittest.TestCase):
+    """Test cases for the formatter module."""
+
+    @classmethod
+    def setUpClass(cls):
+        """Set up test fixtures, called before any tests are run."""
+        # Define a small dataset to use for testing
+        cls.dataset_name = "dz-osamu/IU-Xray"
+        cls.split = "train[:10]"  # Use only 10 samples for testing
+
+        try:
+            # Load the dataset
+            cls.dataset = load_data(cls.dataset_name, split=cls.split)
+
+            # Create a column mapping for the dz-osamu/IU-Xray dataset
+            cls.column_mapping = {
+                "input": "query",
+                "output": "response",
+                "image": "images",
             }
             }
-        }
-        with open(self.json_config_path, "w") as f:
-            json.dump(json_content, f)
-
-    @patch("subprocess.run")
-    def test_run_torch_tune_lora_distributed(self, mock_run):
-        """Test running torch tune with LoRA distributed strategy."""
-        # Set up the mock
-        mock_run.return_value = MagicMock()
-
-        # Get the function from the module
-        run_torch_tune = getattr(finetuning, "run_torch_tune")
-
-        # Call the function
-        run_torch_tune(self.yaml_config_path)
-
-        # Check that subprocess.run was called with the correct command
-        expected_cmd = [
-            "tune",
-            "run",
-            "--nproc_per_node",
-            "4",
-            "lora_finetune_distributed",
-            "--config",
-            "llama3_2_vision/11B_lora",
-        ]
-        mock_run.assert_called_once()
-        args, kwargs = mock_run.call_args
-        self.assertEqual(args[0], expected_cmd)
-        self.assertTrue(kwargs.get("check", False))
-
-    @patch("subprocess.run")
-    def test_run_torch_tune_lora_single_device(self, mock_run):
-        """Test running torch tune with LoRA single device strategy."""
-        # Create a config with single device
-        single_device_config_path = os.path.join(
-            self.temp_dir.name, "single_device_config.yaml"
+
+            # Convert to list for easier processing
+            cls.data = list(cls.dataset)
+
+            # Convert to conversations
+            cls.conversations = convert_to_conversations(cls.data, cls.column_mapping)
+
+        except Exception as e:
+            print(f"Error setting up test fixtures: {e}")
+            raise
+
+    def test_conversation_creation(self):
+        """Test that conversations are created correctly."""
+        self.assertIsNotNone(self.conversations)
+        self.assertGreater(len(self.conversations), 0)
+
+        # Check that each conversation has at least two messages (user and assistant)
+        for conversation in self.conversations:
+            self.assertGreaterEqual(len(conversation.messages), 2)
+            self.assertEqual(conversation.messages[0]["role"], "user")
+            self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+    def test_torchtune_formatter(self):
+        """Test the TorchtuneFormatter."""
+        formatter = TorchtuneFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, dict)
+        self.assertIn("messages", formatted_conversation)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, dict)
+        self.assertIn("role", formatted_message)
+        self.assertIn("content", formatted_message)
+
+    def test_vllm_formatter(self):
+        """Test the vLLMFormatter."""
+        formatter = vLLMFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, str)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, str)
+        self.assertIn(message["role"], formatted_message)
+
+    def test_openai_formatter(self):
+        """Test the OpenAIFormatter."""
+        formatter = OpenAIFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, dict)
+        self.assertIn("messages", formatted_conversation)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, dict)
+        self.assertIn("role", formatted_message)
+        self.assertIn("content", formatted_message)
+
+    def test_format_data_function(self):
+        """Test the format_data function from data_loader."""
+        # Test with TorchtuneFormatter
+        torchtune_data, torchtune_conversations = format_data(
+            self.data, "torchtune", self.column_mapping
         )
         )
-        with open(single_device_config_path, "w") as f:
-            f.write(
-                """
-finetuning:
-  strategy: "lora"
-  distributed: false
-  torchtune_config: "llama3_2_vision/11B_lora"
-"""
-            )
-
-        # Set up the mock
-        mock_run.return_value = MagicMock()
-
-        # Get the function from the module
-        run_torch_tune = getattr(finetuning, "run_torch_tune")
-
-        # Call the function
-        run_torch_tune(single_device_config_path)
-
-        # Check that subprocess.run was called with the correct command
-        expected_cmd = [
-            "tune",
-            "run",
-            "lora_finetune_single_device",
-            "--config",
-            "llama3_2_vision/11B_lora",
+        self.assertIsNotNone(torchtune_data)
+        self.assertEqual(len(torchtune_data), len(self.data))
+        self.assertEqual(len(torchtune_conversations), len(self.data))
+
+        # Test with vLLMFormatter
+        vllm_data, vllm_conversations = format_data(
+            self.data, "vllm", self.column_mapping
+        )
+        self.assertIsNotNone(vllm_data)
+        self.assertEqual(len(vllm_data), len(self.data))
+        self.assertEqual(len(vllm_conversations), len(self.data))
+
+        # Test with OpenAIFormatter
+        openai_data, openai_conversations = format_data(
+            self.data, "openai", self.column_mapping
+        )
+        self.assertIsNotNone(openai_data)
+        self.assertEqual(len(openai_data), len(self.data))
+        self.assertEqual(len(openai_conversations), len(self.data))
+
+    def test_with_mock_data(self):
+        """Test the formatter pipeline with mock data."""
+        # Create mock data that mimics a dataset
+        mock_data = [
+            {
+                "question": "What is the capital of France?",
+                "context": "France is a country in Western Europe. Its capital is Paris.",
+                "answer": "Paris",
+            },
+            {
+                "question": "Who wrote Hamlet?",
+                "context": "Hamlet is a tragedy written by William Shakespeare.",
+                "answer": "William Shakespeare",
+            },
+            {
+                "question": "What is the largest planet in our solar system?",
+                "context": "Jupiter is the largest planet in our solar system.",
+                "answer": "Jupiter",
+            },
         ]
         ]
-        mock_run.assert_called_once()
-        args, kwargs = mock_run.call_args
-        self.assertEqual(args[0], expected_cmd)
-        self.assertTrue(kwargs.get("check", False))
-
-    @patch("subprocess.run")
-    def test_run_torch_tune_invalid_strategy(self, mock_run):
-        """Test running torch tune with an invalid strategy."""
-        # Create a config with an invalid strategy
-        invalid_config_path = os.path.join(self.temp_dir.name, "invalid_config.yaml")
-        with open(invalid_config_path, "w") as f:
-            f.write(
-                """
-finetuning:
-  strategy: "pretraining"
-  distributed: true
-  torchtune_config: "llama3_2_vision/11B_lora"
-"""
-            )
-
-        # Get the function from the module
-        run_torch_tune = getattr(finetuning, "run_torch_tune")
-
-        # Call the function and check that it raises a ValueError
-        with self.assertRaises(ValueError):
-            run_torch_tune(invalid_config_path)
-
-        # Check that subprocess.run was not called
-        mock_run.assert_not_called()
-
-    @patch("subprocess.run")
-    def test_run_torch_tune_subprocess_error(self, mock_run):
-        """Test handling of subprocess errors."""
-        # Set up the mock to raise an error
-        mock_run.side_effect = subprocess.CalledProcessError(1, ["tune", "run"])
-
-        # Get the function from the module
-        run_torch_tune = getattr(finetuning, "run_torch_tune")
-
-        # Call the function and check that it exits with an error
-        with self.assertRaises(SystemExit):
-            run_torch_tune(self.yaml_config_path)
-
-
-#     @patch("subprocess.run")
-#     def test_run_torch_tune_with_args(self, mock_run):
-#         """Test running torch tune with command line arguments."""
-#         # Set up the mock
-#         mock_run.return_value = MagicMock()
-
-#         # Create mock args
-#         args = MagicMock()
-#         args.kwargs = "learning_rate=1e-5,batch_size=16"
-
-#         # Modify the finetuning.py file to handle kwargs
-#         original_finetuning_py = None
-#         with open(finetuning.__file__, "r") as f:
-#             original_finetuning_py = f.read()
-
-#         try:
-#             # Add code to handle kwargs in run_torch_tune function
-#             with open(finetuning.__file__, "a") as f:
-#                 f.write(
-#                     """
-# # Add kwargs to base_cmd if provided
-# def add_kwargs_to_cmd(base_cmd, args):
-#     if args and hasattr(args, 'kwargs') and args.kwargs:
-#         kwargs = args.kwargs.split(',')
-#         base_cmd.extend(kwargs)
-#     return base_cmd
-
-# # Monkey patch the run_torch_tune function
-# original_run_torch_tune = run_torch_tune
-# def patched_run_torch_tune(config_path, args=None):
-#     # Read the configuration
-#     config = read_config(config_path)
-
-#     # Extract parameters from config
-#     training_config = config.get("finetuning", {})
-
-#     # Initialize base_cmd to avoid "possibly unbound" error
-#     base_cmd = []
-
-#     # Determine the command based on configuration
-#     if training_config.get("distributed"):
-#         if training_config.get("strategy") == "lora":
-#             base_cmd = [
-#                 "tune",
-#                 "run",
-#                 "--nproc_per_node",
-#                 str(training_config.get("num_processes_per_node", 1)),
-#                 "lora_finetune_distributed",
-#                 "--config",
-#                 training_config.get("torchtune_config"),
-#             ]
-#         elif training_config.get("strategy") == "fft":
-#             base_cmd = [
-#                 "tune",
-#                 "run",
-#                 "--nproc_per_node",
-#                 str(training_config.get("num_processes_per_node", 1)),
-#                 "full_finetune_distributed",
-#                 "--config",
-#                 training_config.get("torchtune_config"),
-#             ]
-#         else:
-#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
-#     else:
-#         if training_config.get("strategy") == "lora":
-#             base_cmd = [
-#                 "tune",
-#                 "run",
-#                 "lora_finetune_single_device",
-#                 "--config",
-#                 training_config.get("torchtune_config"),
-#             ]
-#         elif training_config.get("strategy") == "fft":
-#             base_cmd = [
-#                 "tune",
-#                 "run",
-#                 "full_finetune_single_device",
-#                 "--config",
-#                 training_config.get("torchtune_config"),
-#             ]
-#         else:
-#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
-
-#     # Check if we have a valid command
-#     if not base_cmd:
-#         raise ValueError("Could not determine the appropriate command based on the configuration")
-
-#     # Add kwargs to base_cmd if provided
-#     if args and hasattr(args, 'kwargs') and args.kwargs:
-#         kwargs = args.kwargs.split(',')
-#         base_cmd.extend(kwargs)
-
-#     # Log the command
-#     logger.info(f"Running command: {' '.join(base_cmd)}")
-
-#     # Run the command
-#     try:
-#         subprocess.run(base_cmd, check=True)
-#         logger.info("Training complete!")
-#     except subprocess.CalledProcessError as e:
-#         logger.error(f"Training failed with error: {e}")
-#         sys.exit(1)
-
-# # Replace the original function with our patched version
-# run_torch_tune = patched_run_torch_tune
-# """
-#                 )
-
-#             # Call the function with args
-#             finetuning.run_torch_tune(self.yaml_config_path, args=args)
-
-#             # Check that subprocess.run was called with the correct command including kwargs
-#             expected_cmd = [
-#                 "tune",
-#                 "run",
-#                 "--nproc_per_node",
-#                 "4",
-#                 "lora_finetune_distributed",
-#                 "--config",
-#                 "llama3_2_vision/11B_lora",
-#                 "learning_rate=1e-5",
-#                 "batch_size=16",
-#             ]
-#             mock_run.assert_called_once()
-#             call_args, call_kwargs = mock_run.call_args
-#             self.assertEqual(call_args[0], expected_cmd)
-#             self.assertTrue(call_kwargs.get("check", False))
-
-#         finally:
-#             # Restore the original finetuning.py file
-#             if original_finetuning_py:
-#                 with open(finetuning.__file__, "w") as f:
-#                     f.write(original_finetuning_py)
+
+        # Create a column mapping for the mock data
+        column_mapping = {"input": "context", "output": "answer"}
+
+        # Convert to conversations
+        conversations = convert_to_conversations(mock_data, column_mapping)
+
+        # Test that conversations are created correctly
+        self.assertEqual(len(conversations), len(mock_data))
+        for i, conversation in enumerate(conversations):
+            self.assertEqual(len(conversation.messages), 2)
+            self.assertEqual(conversation.messages[0]["role"], "user")
+            self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+            # Check content of user message
+            user_content = conversation.messages[0]["content"]
+            self.assertTrue(isinstance(user_content, list))
+            self.assertEqual(user_content[0]["type"], "text")
+            self.assertEqual(user_content[0]["text"], mock_data[i]["context"])
+
+            # Check content of assistant message
+            assistant_content = conversation.messages[1]["content"]
+            self.assertTrue(isinstance(assistant_content, list))
+            self.assertEqual(assistant_content[0]["type"], "text")
+            self.assertEqual(assistant_content[0]["text"], mock_data[i]["answer"])
+
+        # Test each formatter with the mock data
+        formatters = {
+            "torchtune": TorchtuneFormatter(),
+            "vllm": vLLMFormatter(),
+            "openai": OpenAIFormatter(),
+        }
+
+        for name, formatter in formatters.items():
+            formatted_data = formatter.format_data(conversations)
+            self.assertEqual(len(formatted_data), len(mock_data))
+
+            # Test the first formatted item
+            if name == "vllm":
+                # vLLM formatter returns strings
+                self.assertTrue(isinstance(formatted_data[0], str))
+                self.assertIn("user:", formatted_data[0])
+                self.assertIn("assistant:", formatted_data[0])
+            else:
+                # Torchtune and OpenAI formatters return dicts
+                self.assertTrue(isinstance(formatted_data[0], dict))
+                self.assertIn("messages", formatted_data[0])
+                self.assertEqual(len(formatted_data[0]["messages"]), 2)
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    unittest.main()
+    # If run as a script, this allows passing a dataset name as an argument
+    import argparse
+
+    parser = argparse.ArgumentParser(
+        description="Test the formatter module with a specific dataset"
+    )
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        default="dz-osamu/IU-Xray",
+        help="Name of the Hugging Face dataset to use for testing",
+    )
+    parser.add_argument(
+        "--split",
+        type=str,
+        default="train[:10]",
+        help="Dataset split to use (e.g., 'train[:10]', 'validation[:10]')",
+    )
+
+    args = parser.parse_args()
+
+    # Override the default dataset in the test class
+    TestFormatter.dataset_name = args.dataset
+    TestFormatter.split = args.split
+
+    # Run the tests
+    unittest.main(argv=["first-arg-is-ignored"])