瀏覽代碼

unittest for finetuning

Ubuntu 1 月之前
父節點
當前提交
db0b302252

+ 0 - 0
src/finetune_pipeline/config.yaml


+ 0 - 0
src/finetune_pipeline/data/augmentation.py


+ 0 - 0
src/finetune_pipeline/data/data_loader.py


+ 0 - 0
src/finetune_pipeline/data/formatter.py


+ 0 - 0
src/finetune_pipeline/finetuning/finetuning.py


+ 158 - 0
src/finetune_pipeline/tests/test_finetuning.py

@@ -0,0 +1,158 @@
+import json
+import os
+import subprocess
+import sys
+import tempfile
+import unittest
+from pathlib import Path
+from unittest.mock import MagicMock, patch
+
+# Add the parent directory to the path so we can import the modules
+sys.path.append(str(Path(__file__).parent.parent))
+
+# Import the module to test
+import finetuning
+
+from finetuning import run_torch_tune
+
+
+
+class TestFinetuning(unittest.TestCase):
+    """Test cases for the finetuning module."""
+
+    def setUp(self):
+        """Set up test fixtures, called before each test."""
+        # Create temporary config files for testing
+        self.temp_dir = tempfile.TemporaryDirectory()
+
+        # Create a YAML config file
+        self.yaml_config_path = os.path.join(self.temp_dir.name, "config.yaml")
+        self.create_yaml_config()
+
+        # Create a JSON config file
+        self.json_config_path = os.path.join(self.temp_dir.name, "config.json")
+        self.create_json_config()
+
+    def tearDown(self):
+        """Tear down test fixtures, called after each test."""
+        self.temp_dir.cleanup()
+
+    def create_yaml_config(self):
+        """Create a YAML config file for testing."""
+        yaml_content = """
+finetuning:
+  strategy: "lora"
+  distributed: true
+  num_processes_per_node: 4
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+        with open(self.yaml_config_path, "w") as f:
+            f.write(yaml_content)
+
+    def create_json_config(self):
+        """Create a JSON config file for testing."""
+        json_content = {
+            "finetuning": {
+                "strategy": "lora",
+                "distributed": True,
+                "torchtune_config": "llama3_2_vision/11B_lora",
+            }
+        }
+        with open(self.json_config_path, "w") as f:
+            json.dump(json_content, f)
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_lora_distributed(self, mock_run):
+        """Test running torch tune with LoRA distributed strategy."""
+        # Set up the mock
+        mock_run.return_value = MagicMock()
+
+        # Call the function
+        run_torch_tune(self.yaml_config_path)
+
+        # Check that subprocess.run was called with the correct command
+        expected_cmd = [
+            "tune",
+            "run",
+            "--nproc_per_node",
+            "4",
+            "lora_finetune_distributed",
+            "--config",
+            "llama3_2_vision/11B_lora",
+        ]
+        mock_run.assert_called_once()
+        args, kwargs = mock_run.call_args
+        self.assertEqual(args[0], expected_cmd)
+        self.assertTrue(kwargs.get("check", False))
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_lora_single_device(self, mock_run):
+        """Test running torch tune with LoRA single device strategy."""
+        # Create a config with single device
+        single_device_config_path = os.path.join(
+            self.temp_dir.name, "single_device_config.yaml"
+        )
+        with open(single_device_config_path, "w") as f:
+            f.write(
+                """
+finetuning:
+  strategy: "lora"
+  distributed: false
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+            )
+
+        # Set up the mock
+        mock_run.return_value = MagicMock()
+
+        # Call the function
+        finetuning.run_torch_tune(single_device_config_path)
+
+        # Check that subprocess.run was called with the correct command
+        expected_cmd = [
+            "tune",
+            "run",
+            "lora_finetune_single_device",
+            "--config",
+            "llama3_2_vision/11B_lora",
+        ]
+        mock_run.assert_called_once()
+        args, kwargs = mock_run.call_args
+        self.assertEqual(args[0], expected_cmd)
+        self.assertTrue(kwargs.get("check", False))
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_invalid_strategy(self, mock_run):
+        """Test running torch tune with an invalid strategy."""
+        # Create a config with an invalid strategy
+        invalid_config_path = os.path.join(self.temp_dir.name, "invalid_config.yaml")
+        with open(invalid_config_path, "w") as f:
+            f.write(
+                """
+finetuning:
+  strategy: "pretraining"
+  distributed: true
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+            )
+
+        # Call the function and check that it raises a ValueError
+        with self.assertRaises(ValueError):
+            finetuning.run_torch_tune(invalid_config_path)
+
+        # Check that subprocess.run was not called
+        mock_run.assert_not_called()
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_subprocess_error(self, mock_run):
+        """Test handling of subprocess errors."""
+        # Set up the mock to raise an error
+        mock_run.side_effect = subprocess.CalledProcessError(1, ["tune", "run"])
+
+        # Call the function and check that it exits with an error
+        with self.assertRaises(SystemExit):
+            finetuning.run_torch_tune(self.yaml_config_path)
+
+
+if __name__ == "__main__":
+    unittest.main()

+ 528 - 223
src/finetune_pipeline/tests/test_formatter.py

@@ -1,235 +1,540 @@
+# import sys
+# import unittest
+# from pathlib import Path
+# from unittest.mock import MagicMock
+
+# # Add the parent directory to the path so we can import the modules
+# sys.path.append(str(Path(__file__).parent.parent))
+
+# from data.data_loader import convert_to_conversations, format_data, load_data
+# from data.formatter import (
+#     Conversation,
+#     OpenAIFormatter,
+#     TorchtuneFormatter,
+#     vLLMFormatter,
+# )
+
+
+# class TestFormatter(unittest.TestCase):
+#     """Test cases for the formatter module."""
+
+#     @classmethod
+#     def setUpClass(cls):
+#         """Set up test fixtures, called before any tests are run."""
+#         # Define a small dataset to use for testing
+#         cls.dataset_name = "dz-osamu/IU-Xray"
+#         cls.split = "train[:10]"  # Use only 10 samples for testing
+
+#         try:
+#             # Load the dataset
+#             cls.dataset = load_data(cls.dataset_name, split=cls.split)
+
+#             # Create a column mapping for the dz-osamu/IU-Xray dataset
+#             cls.column_mapping = {
+#                 "input": "query",
+#                 "output": "response",
+#                 "image": "images"
+#             }
+
+#             # Convert to list for easier processing
+#             cls.data = list(cls.dataset)
+
+#             # Convert to conversations
+#             cls.conversations = convert_to_conversations(cls.data, cls.column_mapping)
+
+#         except Exception as e:
+#             print(f"Error setting up test fixtures: {e}")
+#             raise
+
+#     def test_conversation_creation(self):
+#         """Test that conversations are created correctly."""
+#         self.assertIsNotNone(self.conversations)
+#         self.assertGreater(len(self.conversations), 0)
+
+#         # Check that each conversation has at least two messages (user and assistant)
+#         for conversation in self.conversations:
+#             self.assertGreaterEqual(len(conversation.messages), 2)
+#             self.assertEqual(conversation.messages[0]["role"], "user")
+#             self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+#     def test_torchtune_formatter(self):
+#         """Test the TorchtuneFormatter."""
+#         formatter = TorchtuneFormatter()
+
+#         # Test format_data
+#         formatted_data = formatter.format_data(self.conversations)
+#         self.assertIsNotNone(formatted_data)
+#         self.assertEqual(len(formatted_data), len(self.conversations))
+
+#         # Test format_conversation
+#         formatted_conversation = formatter.format_conversation(self.conversations[0])
+#         self.assertIsInstance(formatted_conversation, dict)
+#         self.assertIn("messages", formatted_conversation)
+
+#         # Test format_message
+#         message = self.conversations[0].messages[0]
+#         formatted_message = formatter.format_message(message)
+#         self.assertIsInstance(formatted_message, dict)
+#         self.assertIn("role", formatted_message)
+#         self.assertIn("content", formatted_message)
+
+#     def test_vllm_formatter(self):
+#         """Test the vLLMFormatter."""
+#         formatter = vLLMFormatter()
+
+#         # Test format_data
+#         formatted_data = formatter.format_data(self.conversations)
+#         self.assertIsNotNone(formatted_data)
+#         self.assertEqual(len(formatted_data), len(self.conversations))
+
+#         # Test format_conversation
+#         formatted_conversation = formatter.format_conversation(self.conversations[0])
+#         self.assertIsInstance(formatted_conversation, str)
+
+#         # Test format_message
+#         message = self.conversations[0].messages[0]
+#         formatted_message = formatter.format_message(message)
+#         self.assertIsInstance(formatted_message, str)
+#         self.assertIn(message["role"], formatted_message)
+
+#     def test_openai_formatter(self):
+#         """Test the OpenAIFormatter."""
+#         formatter = OpenAIFormatter()
+
+#         # Test format_data
+#         formatted_data = formatter.format_data(self.conversations)
+#         self.assertIsNotNone(formatted_data)
+#         self.assertEqual(len(formatted_data), len(self.conversations))
+
+#         # Test format_conversation
+#         formatted_conversation = formatter.format_conversation(self.conversations[0])
+#         self.assertIsInstance(formatted_conversation, dict)
+#         self.assertIn("messages", formatted_conversation)
+
+#         # Test format_message
+#         message = self.conversations[0].messages[0]
+#         formatted_message = formatter.format_message(message)
+#         self.assertIsInstance(formatted_message, dict)
+#         self.assertIn("role", formatted_message)
+#         self.assertIn("content", formatted_message)
+
+#     def test_format_data_function(self):
+#         """Test the format_data function from data_loader."""
+#         # Test with TorchtuneFormatter
+#         torchtune_data = format_data(self.data, "torchtune", self.column_mapping)
+#         self.assertIsNotNone(torchtune_data)
+#         self.assertEqual(len(torchtune_data), len(self.data))
+
+#         # Test with vLLMFormatter
+#         vllm_data = format_data(self.data, "vllm", self.column_mapping)
+#         self.assertIsNotNone(vllm_data)
+#         self.assertEqual(len(vllm_data), len(self.data))
+
+#         # Test with OpenAIFormatter
+#         openai_data = format_data(self.data, "openai", self.column_mapping)
+#         self.assertIsNotNone(openai_data)
+#         self.assertEqual(len(openai_data), len(self.data))
+
+#     def test_with_mock_data(self):
+#         """Test the formatter pipeline with mock data."""
+#         # Create mock data that mimics a dataset
+#         mock_data = [
+#             {
+#                 "question": "What is the capital of France?",
+#                 "context": "France is a country in Western Europe. Its capital is Paris.",
+#                 "answer": "Paris",
+#             },
+#             {
+#                 "question": "Who wrote Hamlet?",
+#                 "context": "Hamlet is a tragedy written by William Shakespeare.",
+#                 "answer": "William Shakespeare",
+#             },
+#             {
+#                 "question": "What is the largest planet in our solar system?",
+#                 "context": "Jupiter is the largest planet in our solar system.",
+#                 "answer": "Jupiter",
+#             },
+#         ]
+
+#         # Create a column mapping for the mock data
+#         column_mapping = {"input": "context", "output": "answer"}
+
+#         # Convert to conversations
+#         conversations = convert_to_conversations(mock_data, column_mapping)
+
+#         # Test that conversations are created correctly
+#         self.assertEqual(len(conversations), len(mock_data))
+#         for i, conversation in enumerate(conversations):
+#             self.assertEqual(len(conversation.messages), 2)
+#             self.assertEqual(conversation.messages[0]["role"], "user")
+#             self.assertEqual(conversation.messages[1]["role"], "assistant")
+
+#             # Check content of user message
+#             user_content = conversation.messages[0]["content"]
+#             self.assertTrue(isinstance(user_content, list))
+#             self.assertEqual(user_content[0]["type"], "text")
+#             self.assertEqual(user_content[0]["text"], mock_data[i]["context"])
+
+#             # Check content of assistant message
+#             assistant_content = conversation.messages[1]["content"]
+#             self.assertTrue(isinstance(assistant_content, list))
+#             self.assertEqual(assistant_content[0]["type"], "text")
+#             self.assertEqual(assistant_content[0]["text"], mock_data[i]["answer"])
+
+#         # Test each formatter with the mock data
+#         formatters = {
+#             "torchtune": TorchtuneFormatter(),
+#             "vllm": vLLMFormatter(),
+#             "openai": OpenAIFormatter(),
+#         }
+
+#         for name, formatter in formatters.items():
+#             formatted_data = formatter.format_data(conversations)
+#             self.assertEqual(len(formatted_data), len(mock_data))
+
+#             # Test the first formatted item
+#             if name == "vllm":
+#                 # vLLM formatter returns strings
+#                 self.assertTrue(isinstance(formatted_data[0], str))
+#                 self.assertIn("user:", formatted_data[0])
+#                 self.assertIn("assistant:", formatted_data[0])
+#             else:
+#                 # Torchtune and OpenAI formatters return dicts
+#                 self.assertTrue(isinstance(formatted_data[0], dict))
+#                 self.assertIn("messages", formatted_data[0])
+#                 self.assertEqual(len(formatted_data[0]["messages"]), 2)
+
+
+# if __name__ == "__main__":
+#     # If run as a script, this allows passing a dataset name as an argument
+#     import argparse
+
+#     parser = argparse.ArgumentParser(
+#         description="Test the formatter module with a specific dataset"
+#     )
+#     parser.add_argument(
+#         "--dataset",
+#         type=str,
+#         default="dz-osamu/IU-Xray",
+#         help="Name of the Hugging Face dataset to use for testing",
+#     )
+#     parser.add_argument(
+#         "--split",
+#         type=str,
+#         default="train[:10]",
+#         help="Dataset split to use (e.g., 'train[:10]', 'validation[:10]')",
+#     )
+
+#     args = parser.parse_args()
+
+#     # Override the default dataset in the test class
+#     TestFormatter.dataset_name = args.dataset
+#     TestFormatter.split = args.split
+
+#     # Run the tests
+#     unittest.main(argv=["first-arg-is-ignored"])
+
+
+
+import json
+import os
+import subprocess
 import sys
+import tempfile
 import unittest
 from pathlib import Path
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
 
 # Add the parent directory to the path so we can import the modules
 sys.path.append(str(Path(__file__).parent.parent))
 
-from data.data_loader import convert_to_conversations, format_data, load_data
-from data.formatter import (
-    Conversation,
-    OpenAIFormatter,
-    TorchtuneFormatter,
-    vLLMFormatter,
-)
-
-
-class TestFormatter(unittest.TestCase):
-    """Test cases for the formatter module."""
-
-    @classmethod
-    def setUpClass(cls):
-        """Set up test fixtures, called before any tests are run."""
-        # Define a small dataset to use for testing
-        cls.dataset_name = "dz-osamu/IU-Xray"
-        cls.split = "train[:10]"  # Use only 10 samples for testing
-
-        try:
-            # Load the dataset
-            cls.dataset = load_data(cls.dataset_name, split=cls.split)
-
-            # Create a column mapping for the squad_v2 dataset
-            cls.column_mapping = {
-                "input": "query",
-                "output": "response",
-                "image": "images"
+# Import the module to test
+import finetuning
+
+
+class TestFinetuning(unittest.TestCase):
+    """Test cases for the finetuning module."""
+
+    def setUp(self):
+        """Set up test fixtures, called before each test."""
+        # Create temporary config files for testing
+        self.temp_dir = tempfile.TemporaryDirectory()
+
+        # Create a YAML config file
+        self.yaml_config_path = os.path.join(self.temp_dir.name, "config.yaml")
+        self.create_yaml_config()
+
+        # Create a JSON config file
+        self.json_config_path = os.path.join(self.temp_dir.name, "config.json")
+        self.create_json_config()
+
+    def tearDown(self):
+        """Tear down test fixtures, called after each test."""
+        self.temp_dir.cleanup()
+
+    def create_yaml_config(self):
+        """Create a YAML config file for testing."""
+        yaml_content = """
+finetuning:
+  strategy: "lora"
+  distributed: true
+  num_processes_per_node: 4
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+        with open(self.yaml_config_path, "w") as f:
+            f.write(yaml_content)
+
+    def create_json_config(self):
+        """Create a JSON config file for testing."""
+        json_content = {
+            "finetuning": {
+                "strategy": "lora",
+                "distributed": True,
+                "torchtune_config": "llama3_2_vision/11B_lora",
             }
-
-            # Convert to list for easier processing
-            cls.data = list(cls.dataset)
-
-            # Convert to conversations
-            cls.conversations = convert_to_conversations(cls.data, cls.column_mapping)
-
-        except Exception as e:
-            print(f"Error setting up test fixtures: {e}")
-            raise
-
-    def test_conversation_creation(self):
-        """Test that conversations are created correctly."""
-        self.assertIsNotNone(self.conversations)
-        self.assertGreater(len(self.conversations), 0)
-
-        # Check that each conversation has at least two messages (user and assistant)
-        for conversation in self.conversations:
-            self.assertGreaterEqual(len(conversation.messages), 2)
-            self.assertEqual(conversation.messages[0]["role"], "user")
-            self.assertEqual(conversation.messages[1]["role"], "assistant")
-
-    def test_torchtune_formatter(self):
-        """Test the TorchtuneFormatter."""
-        formatter = TorchtuneFormatter()
-
-        # Test format_data
-        formatted_data = formatter.format_data(self.conversations)
-        self.assertIsNotNone(formatted_data)
-        self.assertEqual(len(formatted_data), len(self.conversations))
-
-        # Test format_conversation
-        formatted_conversation = formatter.format_conversation(self.conversations[0])
-        self.assertIsInstance(formatted_conversation, dict)
-        self.assertIn("messages", formatted_conversation)
-
-        # Test format_message
-        message = self.conversations[0].messages[0]
-        formatted_message = formatter.format_message(message)
-        self.assertIsInstance(formatted_message, dict)
-        self.assertIn("role", formatted_message)
-        self.assertIn("content", formatted_message)
-
-    def test_vllm_formatter(self):
-        """Test the vLLMFormatter."""
-        formatter = vLLMFormatter()
-
-        # Test format_data
-        formatted_data = formatter.format_data(self.conversations)
-        self.assertIsNotNone(formatted_data)
-        self.assertEqual(len(formatted_data), len(self.conversations))
-
-        # Test format_conversation
-        formatted_conversation = formatter.format_conversation(self.conversations[0])
-        self.assertIsInstance(formatted_conversation, str)
-
-        # Test format_message
-        message = self.conversations[0].messages[0]
-        formatted_message = formatter.format_message(message)
-        self.assertIsInstance(formatted_message, str)
-        self.assertIn(message["role"], formatted_message)
-
-    def test_openai_formatter(self):
-        """Test the OpenAIFormatter."""
-        formatter = OpenAIFormatter()
-
-        # Test format_data
-        formatted_data = formatter.format_data(self.conversations)
-        self.assertIsNotNone(formatted_data)
-        self.assertEqual(len(formatted_data), len(self.conversations))
-
-        # Test format_conversation
-        formatted_conversation = formatter.format_conversation(self.conversations[0])
-        self.assertIsInstance(formatted_conversation, dict)
-        self.assertIn("messages", formatted_conversation)
-
-        # Test format_message
-        message = self.conversations[0].messages[0]
-        formatted_message = formatter.format_message(message)
-        self.assertIsInstance(formatted_message, dict)
-        self.assertIn("role", formatted_message)
-        self.assertIn("content", formatted_message)
-
-    def test_format_data_function(self):
-        """Test the format_data function from data_loader."""
-        # Test with TorchtuneFormatter
-        torchtune_data = format_data(self.data, "torchtune", self.column_mapping)
-        self.assertIsNotNone(torchtune_data)
-        self.assertEqual(len(torchtune_data), len(self.data))
-
-        # Test with vLLMFormatter
-        vllm_data = format_data(self.data, "vllm", self.column_mapping)
-        self.assertIsNotNone(vllm_data)
-        self.assertEqual(len(vllm_data), len(self.data))
-
-        # Test with OpenAIFormatter
-        openai_data = format_data(self.data, "openai", self.column_mapping)
-        self.assertIsNotNone(openai_data)
-        self.assertEqual(len(openai_data), len(self.data))
-
-    def test_with_mock_data(self):
-        """Test the formatter pipeline with mock data."""
-        # Create mock data that mimics a dataset
-        mock_data = [
-            {
-                "question": "What is the capital of France?",
-                "context": "France is a country in Western Europe. Its capital is Paris.",
-                "answer": "Paris",
-            },
-            {
-                "question": "Who wrote Hamlet?",
-                "context": "Hamlet is a tragedy written by William Shakespeare.",
-                "answer": "William Shakespeare",
-            },
-            {
-                "question": "What is the largest planet in our solar system?",
-                "context": "Jupiter is the largest planet in our solar system.",
-                "answer": "Jupiter",
-            },
-        ]
-
-        # Create a column mapping for the mock data
-        column_mapping = {"input": "context", "output": "answer"}
-
-        # Convert to conversations
-        conversations = convert_to_conversations(mock_data, column_mapping)
-
-        # Test that conversations are created correctly
-        self.assertEqual(len(conversations), len(mock_data))
-        for i, conversation in enumerate(conversations):
-            self.assertEqual(len(conversation.messages), 2)
-            self.assertEqual(conversation.messages[0]["role"], "user")
-            self.assertEqual(conversation.messages[1]["role"], "assistant")
-
-            # Check content of user message
-            user_content = conversation.messages[0]["content"]
-            self.assertTrue(isinstance(user_content, list))
-            self.assertEqual(user_content[0]["type"], "text")
-            self.assertEqual(user_content[0]["text"], mock_data[i]["context"])
-
-            # Check content of assistant message
-            assistant_content = conversation.messages[1]["content"]
-            self.assertTrue(isinstance(assistant_content, list))
-            self.assertEqual(assistant_content[0]["type"], "text")
-            self.assertEqual(assistant_content[0]["text"], mock_data[i]["answer"])
-
-        # Test each formatter with the mock data
-        formatters = {
-            "torchtune": TorchtuneFormatter(),
-            "vllm": vLLMFormatter(),
-            "openai": OpenAIFormatter(),
         }
-
-        for name, formatter in formatters.items():
-            formatted_data = formatter.format_data(conversations)
-            self.assertEqual(len(formatted_data), len(mock_data))
-
-            # Test the first formatted item
-            if name == "vllm":
-                # vLLM formatter returns strings
-                self.assertTrue(isinstance(formatted_data[0], str))
-                self.assertIn("user:", formatted_data[0])
-                self.assertIn("assistant:", formatted_data[0])
-            else:
-                # Torchtune and OpenAI formatters return dicts
-                self.assertTrue(isinstance(formatted_data[0], dict))
-                self.assertIn("messages", formatted_data[0])
-                self.assertEqual(len(formatted_data[0]["messages"]), 2)
+        with open(self.json_config_path, "w") as f:
+            json.dump(json_content, f)
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_lora_distributed(self, mock_run):
+        """Test running torch tune with LoRA distributed strategy."""
+        # Set up the mock
+        mock_run.return_value = MagicMock()
+
+        # Get the function from the module
+        run_torch_tune = getattr(finetuning, "run_torch_tune")
+
+        # Call the function
+        run_torch_tune(self.yaml_config_path)
+
+        # Check that subprocess.run was called with the correct command
+        expected_cmd = [
+            "tune",
+            "run",
+            "--nproc_per_node",
+            "4",
+            "lora_finetune_distributed",
+            "--config",
+            "llama3_2_vision/11B_lora",
+        ]
+        mock_run.assert_called_once()
+        args, kwargs = mock_run.call_args
+        self.assertEqual(args[0], expected_cmd)
+        self.assertTrue(kwargs.get("check", False))
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_lora_single_device(self, mock_run):
+        """Test running torch tune with LoRA single device strategy."""
+        # Create a config with single device
+        single_device_config_path = os.path.join(
+            self.temp_dir.name, "single_device_config.yaml"
+        )
+        with open(single_device_config_path, "w") as f:
+            f.write(
+                """
+finetuning:
+  strategy: "lora"
+  distributed: false
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+            )
+
+        # Set up the mock
+        mock_run.return_value = MagicMock()
+
+        # Get the function from the module
+        run_torch_tune = getattr(finetuning, "run_torch_tune")
+
+        # Call the function
+        run_torch_tune(single_device_config_path)
+
+        # Check that subprocess.run was called with the correct command
+        expected_cmd = [
+            "tune",
+            "run",
+            "lora_finetune_single_device",
+            "--config",
+            "llama3_2_vision/11B_lora",
+        ]
+        mock_run.assert_called_once()
+        args, kwargs = mock_run.call_args
+        self.assertEqual(args[0], expected_cmd)
+        self.assertTrue(kwargs.get("check", False))
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_invalid_strategy(self, mock_run):
+        """Test running torch tune with an invalid strategy."""
+        # Create a config with an invalid strategy
+        invalid_config_path = os.path.join(self.temp_dir.name, "invalid_config.yaml")
+        with open(invalid_config_path, "w") as f:
+            f.write(
+                """
+finetuning:
+  strategy: "pretraining"
+  distributed: true
+  torchtune_config: "llama3_2_vision/11B_lora"
+"""
+            )
+
+        # Get the function from the module
+        run_torch_tune = getattr(finetuning, "run_torch_tune")
+
+        # Call the function and check that it raises a ValueError
+        with self.assertRaises(ValueError):
+            run_torch_tune(invalid_config_path)
+
+        # Check that subprocess.run was not called
+        mock_run.assert_not_called()
+
+    @patch("subprocess.run")
+    def test_run_torch_tune_subprocess_error(self, mock_run):
+        """Test handling of subprocess errors."""
+        # Set up the mock to raise an error
+        mock_run.side_effect = subprocess.CalledProcessError(1, ["tune", "run"])
+
+        # Get the function from the module
+        run_torch_tune = getattr(finetuning, "run_torch_tune")
+
+        # Call the function and check that it exits with an error
+        with self.assertRaises(SystemExit):
+            run_torch_tune(self.yaml_config_path)
+
+
+#     @patch("subprocess.run")
+#     def test_run_torch_tune_with_args(self, mock_run):
+#         """Test running torch tune with command line arguments."""
+#         # Set up the mock
+#         mock_run.return_value = MagicMock()
+
+#         # Create mock args
+#         args = MagicMock()
+#         args.kwargs = "learning_rate=1e-5,batch_size=16"
+
+#         # Modify the finetuning.py file to handle kwargs
+#         original_finetuning_py = None
+#         with open(finetuning.__file__, "r") as f:
+#             original_finetuning_py = f.read()
+
+#         try:
+#             # Add code to handle kwargs in run_torch_tune function
+#             with open(finetuning.__file__, "a") as f:
+#                 f.write(
+#                     """
+# # Add kwargs to base_cmd if provided
+# def add_kwargs_to_cmd(base_cmd, args):
+#     if args and hasattr(args, 'kwargs') and args.kwargs:
+#         kwargs = args.kwargs.split(',')
+#         base_cmd.extend(kwargs)
+#     return base_cmd
+
+# # Monkey patch the run_torch_tune function
+# original_run_torch_tune = run_torch_tune
+# def patched_run_torch_tune(config_path, args=None):
+#     # Read the configuration
+#     config = read_config(config_path)
+
+#     # Extract parameters from config
+#     training_config = config.get("finetuning", {})
+
+#     # Initialize base_cmd to avoid "possibly unbound" error
+#     base_cmd = []
+
+#     # Determine the command based on configuration
+#     if training_config.get("distributed"):
+#         if training_config.get("strategy") == "lora":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "--nproc_per_node",
+#                 str(training_config.get("num_processes_per_node", 1)),
+#                 "lora_finetune_distributed",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         elif training_config.get("strategy") == "fft":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "--nproc_per_node",
+#                 str(training_config.get("num_processes_per_node", 1)),
+#                 "full_finetune_distributed",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         else:
+#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+#     else:
+#         if training_config.get("strategy") == "lora":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "lora_finetune_single_device",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         elif training_config.get("strategy") == "fft":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "full_finetune_single_device",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         else:
+#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+
+#     # Check if we have a valid command
+#     if not base_cmd:
+#         raise ValueError("Could not determine the appropriate command based on the configuration")
+
+#     # Add kwargs to base_cmd if provided
+#     if args and hasattr(args, 'kwargs') and args.kwargs:
+#         kwargs = args.kwargs.split(',')
+#         base_cmd.extend(kwargs)
+
+#     # Log the command
+#     logger.info(f"Running command: {' '.join(base_cmd)}")
+
+#     # Run the command
+#     try:
+#         subprocess.run(base_cmd, check=True)
+#         logger.info("Training complete!")
+#     except subprocess.CalledProcessError as e:
+#         logger.error(f"Training failed with error: {e}")
+#         sys.exit(1)
+
+# # Replace the original function with our patched version
+# run_torch_tune = patched_run_torch_tune
+# """
+#                 )
+
+#             # Call the function with args
+#             finetuning.run_torch_tune(self.yaml_config_path, args=args)
+
+#             # Check that subprocess.run was called with the correct command including kwargs
+#             expected_cmd = [
+#                 "tune",
+#                 "run",
+#                 "--nproc_per_node",
+#                 "4",
+#                 "lora_finetune_distributed",
+#                 "--config",
+#                 "llama3_2_vision/11B_lora",
+#                 "learning_rate=1e-5",
+#                 "batch_size=16",
+#             ]
+#             mock_run.assert_called_once()
+#             call_args, call_kwargs = mock_run.call_args
+#             self.assertEqual(call_args[0], expected_cmd)
+#             self.assertTrue(call_kwargs.get("check", False))
+
+#         finally:
+#             # Restore the original finetuning.py file
+#             if original_finetuning_py:
+#                 with open(finetuning.__file__, "w") as f:
+#                     f.write(original_finetuning_py)
 
 
 if __name__ == "__main__":
-    # If run as a script, this allows passing a dataset name as an argument
-    import argparse
-
-    parser = argparse.ArgumentParser(
-        description="Test the formatter module with a specific dataset"
-    )
-    parser.add_argument(
-        "--dataset",
-        type=str,
-        default="dz-osamu/IU-Xray",
-        help="Name of the Hugging Face dataset to use for testing",
-    )
-    parser.add_argument(
-        "--split",
-        type=str,
-        default="train[:10]",
-        help="Dataset split to use (e.g., 'train[:10]', 'validation[:10]')",
-    )
-
-    args = parser.parse_args()
-
-    # Override the default dataset in the test class
-    TestFormatter.dataset_name = args.dataset
-    TestFormatter.split = args.split
-
-    # Run the tests
-    unittest.main(argv=["first-arg-is-ignored"])
+    unittest.main()