Browse Source

added finetuning module and formatter unit test

Ubuntu 1 month ago
parent
commit
c04bddcec2

+ 10 - 1
src/finetune_pipeline/config.yaml

@@ -1,4 +1,4 @@
-# Configuration for data loading and formatting
+# Configuration for data loading, formatting, and fine-tuning
 
 
 # Data source configuration
 # Data source configuration
 data_path: "your/dataset/path"  # Path to the dataset (either a Hugging Face dataset ID or a local path)
 data_path: "your/dataset/path"  # Path to the dataset (either a Hugging Face dataset ID or a local path)
@@ -18,3 +18,12 @@ column_mapping:
 dataset_kwargs:
 dataset_kwargs:
   split: "train"                # Dataset split to load
   split: "train"                # Dataset split to load
   # Add any other dataset-specific arguments here
   # Add any other dataset-specific arguments here
+
+# Training configuration
+finetuning:
+  strategy: "lora"               # Training strategy ('fft' or 'lora')
+  num_epochs: 1                 # Number of training epochs
+  batch_size: 1                 # Batch size per device for training
+  torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
+  num_processes_per_node: 8             # TorchTune-specific configuration
+  distributed: true             # Whether to use distributed training

+ 0 - 3
src/finetune_pipeline/data/data_loader.py

@@ -59,9 +59,6 @@ def read_config(config_path: str) -> Dict:
                     "The 'pyyaml' package is required to load YAML files. "
                     "The 'pyyaml' package is required to load YAML files. "
                     "Please install it with 'pip install pyyaml'."
                     "Please install it with 'pip install pyyaml'."
                 )
                 )
-            # Only use yaml if it's available (HAS_YAML is True here)
-            import yaml  # This import will succeed because we've already checked HAS_YAML
-
             config = yaml.safe_load(f)
             config = yaml.safe_load(f)
         else:
         else:
             raise ValueError(
             raise ValueError(

+ 168 - 0
src/finetune_pipeline/finetuning/finetuning.py

@@ -0,0 +1,168 @@
+#!/usr/bin/env python
+"""
+Fine-tuning script for language models using torch tune.
+Reads parameters from a config file and runs the torch tune command.
+"""
+
+import argparse
+import logging
+import subprocess
+import sys
+from pathlib import Path
+from typing import Dict
+
+try:
+    import yaml
+    HAS_YAML = True
+except ImportError:
+    HAS_YAML = False
+
+# Configure logging
+logging.basicConfig(
+    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+    datefmt="%Y-%m-%d %H:%M:%S",
+    level=logging.INFO,
+)
+logger = logging.getLogger(__name__)
+
+
+## Will import from dataloader eventually
+def read_config(config_path: str) -> Dict:
+    """
+    Read the configuration file (supports both JSON and YAML formats).
+
+    Args:
+        config_path: Path to the configuration file
+
+    Returns:
+        dict: Configuration parameters
+
+    Raises:
+        ValueError: If the file format is not supported
+        ImportError: If the required package for the file format is not installed
+    """
+    file_extension = Path(config_path).suffix.lower()
+
+    with open(config_path, "r") as f:
+        if file_extension in [".json"]:
+            config = json.load(f)
+        elif file_extension in [".yaml", ".yml"]:
+            if not HAS_YAML:
+                raise ImportError(
+                    "The 'pyyaml' package is required to load YAML files. "
+                    "Please install it with 'pip install pyyaml'."
+                )
+            config = yaml.safe_load(f)
+        else:
+            raise ValueError(
+                f"Unsupported config file format: {file_extension}. "
+                f"Supported formats are: .json, .yaml, .yml"
+            )
+
+    return config
+
+
+def run_torch_tune(config_path: str, args=None):
+    """
+    Run torch tune command with parameters from config file.
+
+    Args:
+        config_path: Path to the configuration file
+        args: Command line arguments
+    """
+    # Read the configuration
+    config = read_config(config_path)
+
+    # Extract parameters from config
+    training_config = config.get("finetuning", {})
+
+    # Initialize base_cmd to avoid "possibly unbound" error
+    base_cmd = []
+
+    # Determine the command based on configuration
+    if training_config.get("distributed"):
+        if training_config.get("strategy") == "lora":
+            base_cmd = [
+                "tune",
+                "run",
+                "--nproc_per_node",
+                str(training_config.get("num_processes_per_node", 1)),
+                "lora_finetune_distributed",
+                "--config",
+                training_config.get("torchtune_config"),
+            ]
+        elif training_config.get("strategy") == "fft":
+            base_cmd = [
+                "tune",
+                "run",
+                "--nproc_per_node",
+                str(training_config.get("num_processes_per_node", 1)),
+                "full_finetune_distributed",
+                "--config",
+                training_config.get("torchtune_config"),
+            ]
+        else:
+            raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+
+    else:
+        if training_config.get("strategy") == "lora":
+            base_cmd = [
+                "tune",
+                "run",
+                "lora_finetune_single_device",
+                "--config",
+                training_config.get("torchtune_config"),
+            ]
+        elif training_config.get("strategy") == "fft":
+            base_cmd = [
+                "tune",
+                "run",
+                "full_finetune_single_device",
+                "--config",
+                training_config.get("torchtune_config"),
+            ]
+        else:
+            raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+
+    # Check if we have a valid command
+    if not base_cmd:
+        raise ValueError(
+            "Could not determine the appropriate command based on the configuration"
+        )
+
+    # Log the command
+    logger.info(f"Running command: {' '.join(base_cmd)}")
+
+    # Run the command
+    try:
+        subprocess.run(base_cmd, check=True)
+        logger.info("Training complete!")
+    except subprocess.CalledProcessError as e:
+        logger.error(f"Training failed with error: {e}")
+        sys.exit(1)
+
+
+def main():
+    """Main function."""
+    parser = argparse.ArgumentParser(
+        description="Fine-tune a language model using torch tune"
+    )
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="Path to the configuration file (JSON or YAML)",
+    )
+    parser.add_argument(
+        "--kwargs",
+        type=str,
+        default=None,
+        help="Additional key-value pairs to pass to the command (comma-separated)",
+    )
+    args = parser.parse_args()
+
+    run_torch_tune(args.config, args=args)
+
+
+if __name__ == "__main__":
+    main()

+ 235 - 0
src/finetune_pipeline/tests/test_formatter.py

@@ -0,0 +1,235 @@
+import sys
+import unittest
+from pathlib import Path
+from unittest.mock import MagicMock
+
+# Add the parent directory to the path so we can import the modules
+sys.path.append(str(Path(__file__).parent.parent))
+
+from data.data_loader import convert_to_conversations, format_data, load_data
+from data.formatter import (
+    Conversation,
+    OpenAIFormatter,
+    TorchtuneFormatter,
+    vLLMFormatter,
+)
+
+
+class TestFormatter(unittest.TestCase):
+    """Test cases for the formatter module."""
+
+    @classmethod
+    def setUpClass(cls):
+        """Set up test fixtures, called before any tests are run."""
+        # Define a small dataset to use for testing
+        cls.dataset_name = "dz-osamu/IU-Xray"
+        cls.split = "train[:10]"  # Use only 10 samples for testing
+
+        try:
+            # Load the dataset
+            cls.dataset = load_data(cls.dataset_name, split=cls.split)
+
+            # Create a column mapping for the squad_v2 dataset
+            cls.column_mapping = {
+                "input": "query",
+                "output": "response",
+                "image": "images"
+            }
+
+            # Convert to list for easier processing
+            cls.data = list(cls.dataset)
+
+            # Convert to conversations
+            cls.conversations = convert_to_conversations(cls.data, cls.column_mapping)
+
+        except Exception as e:
+            print(f"Error setting up test fixtures: {e}")
+            raise
+
+    def test_conversation_creation(self):
+        """Test that conversations are created correctly."""
+        self.assertIsNotNone(self.conversations)
+        self.assertGreater(len(self.conversations), 0)
+
+        # Check that each conversation has at least two messages (user and assistant)
+        for conversation in self.conversations:
+            self.assertGreaterEqual(len(conversation.messages), 2)
+            self.assertEqual(conversation.messages[0]["role"], "user")
+            self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+    def test_torchtune_formatter(self):
+        """Test the TorchtuneFormatter."""
+        formatter = TorchtuneFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, dict)
+        self.assertIn("messages", formatted_conversation)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, dict)
+        self.assertIn("role", formatted_message)
+        self.assertIn("content", formatted_message)
+
+    def test_vllm_formatter(self):
+        """Test the vLLMFormatter."""
+        formatter = vLLMFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, str)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, str)
+        self.assertIn(message["role"], formatted_message)
+
+    def test_openai_formatter(self):
+        """Test the OpenAIFormatter."""
+        formatter = OpenAIFormatter()
+
+        # Test format_data
+        formatted_data = formatter.format_data(self.conversations)
+        self.assertIsNotNone(formatted_data)
+        self.assertEqual(len(formatted_data), len(self.conversations))
+
+        # Test format_conversation
+        formatted_conversation = formatter.format_conversation(self.conversations[0])
+        self.assertIsInstance(formatted_conversation, dict)
+        self.assertIn("messages", formatted_conversation)
+
+        # Test format_message
+        message = self.conversations[0].messages[0]
+        formatted_message = formatter.format_message(message)
+        self.assertIsInstance(formatted_message, dict)
+        self.assertIn("role", formatted_message)
+        self.assertIn("content", formatted_message)
+
+    def test_format_data_function(self):
+        """Test the format_data function from data_loader."""
+        # Test with TorchtuneFormatter
+        torchtune_data = format_data(self.data, "torchtune", self.column_mapping)
+        self.assertIsNotNone(torchtune_data)
+        self.assertEqual(len(torchtune_data), len(self.data))
+
+        # Test with vLLMFormatter
+        vllm_data = format_data(self.data, "vllm", self.column_mapping)
+        self.assertIsNotNone(vllm_data)
+        self.assertEqual(len(vllm_data), len(self.data))
+
+        # Test with OpenAIFormatter
+        openai_data = format_data(self.data, "openai", self.column_mapping)
+        self.assertIsNotNone(openai_data)
+        self.assertEqual(len(openai_data), len(self.data))
+
+    def test_with_mock_data(self):
+        """Test the formatter pipeline with mock data."""
+        # Create mock data that mimics a dataset
+        mock_data = [
+            {
+                "question": "What is the capital of France?",
+                "context": "France is a country in Western Europe. Its capital is Paris.",
+                "answer": "Paris",
+            },
+            {
+                "question": "Who wrote Hamlet?",
+                "context": "Hamlet is a tragedy written by William Shakespeare.",
+                "answer": "William Shakespeare",
+            },
+            {
+                "question": "What is the largest planet in our solar system?",
+                "context": "Jupiter is the largest planet in our solar system.",
+                "answer": "Jupiter",
+            },
+        ]
+
+        # Create a column mapping for the mock data
+        column_mapping = {"input": "context", "output": "answer"}
+
+        # Convert to conversations
+        conversations = convert_to_conversations(mock_data, column_mapping)
+
+        # Test that conversations are created correctly
+        self.assertEqual(len(conversations), len(mock_data))
+        for i, conversation in enumerate(conversations):
+            self.assertEqual(len(conversation.messages), 2)
+            self.assertEqual(conversation.messages[0]["role"], "user")
+            self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+            # Check content of user message
+            user_content = conversation.messages[0]["content"]
+            self.assertTrue(isinstance(user_content, list))
+            self.assertEqual(user_content[0]["type"], "text")
+            self.assertEqual(user_content[0]["text"], mock_data[i]["context"])
+
+            # Check content of assistant message
+            assistant_content = conversation.messages[1]["content"]
+            self.assertTrue(isinstance(assistant_content, list))
+            self.assertEqual(assistant_content[0]["type"], "text")
+            self.assertEqual(assistant_content[0]["text"], mock_data[i]["answer"])
+
+        # Test each formatter with the mock data
+        formatters = {
+            "torchtune": TorchtuneFormatter(),
+            "vllm": vLLMFormatter(),
+            "openai": OpenAIFormatter(),
+        }
+
+        for name, formatter in formatters.items():
+            formatted_data = formatter.format_data(conversations)
+            self.assertEqual(len(formatted_data), len(mock_data))
+
+            # Test the first formatted item
+            if name == "vllm":
+                # vLLM formatter returns strings
+                self.assertTrue(isinstance(formatted_data[0], str))
+                self.assertIn("user:", formatted_data[0])
+                self.assertIn("assistant:", formatted_data[0])
+            else:
+                # Torchtune and OpenAI formatters return dicts
+                self.assertTrue(isinstance(formatted_data[0], dict))
+                self.assertIn("messages", formatted_data[0])
+                self.assertEqual(len(formatted_data[0]["messages"]), 2)
+
+
+if __name__ == "__main__":
+    # If run as a script, this allows passing a dataset name as an argument
+    import argparse
+
+    parser = argparse.ArgumentParser(
+        description="Test the formatter module with a specific dataset"
+    )
+    parser.add_argument(
+        "--dataset",
+        type=str,
+        default="dz-osamu/IU-Xray",
+        help="Name of the Hugging Face dataset to use for testing",
+    )
+    parser.add_argument(
+        "--split",
+        type=str,
+        default="train[:10]",
+        help="Dataset split to use (e.g., 'train[:10]', 'validation[:10]')",
+    )
+
+    args = parser.parse_args()
+
+    # Override the default dataset in the test class
+    TestFormatter.dataset_name = args.dataset
+    TestFormatter.split = args.split
+
+    # Run the tests
+    unittest.main(argv=["first-arg-is-ignored"])