瀏覽代碼

updated custom_sft_dataset

Ubuntu 1 月之前
父節點
當前提交
8c6f486d4a

+ 9 - 1
src/finetune_pipeline/finetuning/__init__.py

@@ -1,4 +1,12 @@
 """
-Test suite for the finetune_pipeline package.
+Fine-tuning utilities for LLMs.
 
+This module provides tools for fine-tuning language models using various strategies.
 """
+
+# Import the custom_sft_dataset function from the custom_sft_dataset module
+from .custom_sft_dataset import custom_sft_dataset
+
+__all__ = [
+    "custom_sft_dataset",
+]

+ 18 - 7
src/finetune_pipeline/finetuning/dataset.py

@@ -1,17 +1,27 @@
+"""
+Custom SFT dataset for fine-tuning.
+"""
+
+from torchtune.data import OpenAIToMessages
 from torchtune.datasets import SFTDataset
 from torchtune.modules.transforms import Transform
-from torchtune.data import OpenAIToMessages
 
 
 def custom_sft_dataset(
     model_transform: Transform,
-    *,
-    split: str = "train",
-    dataset_path: str = "files/synthetic_data/train.csv",
-    train_on_input: bool = True,
+    dataset_path: str = "/tmp/train.json",
+    train_on_input: bool = False,
 ) -> SFTDataset:
-    """Creates a custom dataset."""
+    """
+    Creates a custom SFT dataset for fine-tuning.
 
+    Args:
+        dataset_path: Path to the formatted data JSON file
+        train_on_input: Whether to train on input tokens
+
+    Returns:
+        SFTDataset: A dataset ready for fine-tuning with TorchTune
+    """
     openaitomessage = OpenAIToMessages(train_on_input=train_on_input)
 
     ds = SFTDataset(
@@ -19,6 +29,7 @@ def custom_sft_dataset(
         data_files=dataset_path,
         split="train",
         message_transform=openaitomessage,
-        model_transform=Transform,
+        model_transform=model_transform,
     )
     return ds
+

+ 179 - 1
src/finetune_pipeline/finetuning/run_finetuning.py

@@ -165,4 +165,182 @@ def main():
 
 
 if __name__ == "__main__":
-    main()
+    main()
+
+
+# #!/usr/bin/env python
+# """
+# Fine-tuning script for language models using torch tune.
+# Reads parameters from a config file and runs the torch tune command.
+# """
+
+# import argparse
+# import logging
+# import subprocess
+# import sys
+# from pathlib import Path
+# from typing import Dict
+
+# try:
+#     import yaml
+
+#     HAS_YAML = True
+# except ImportError:
+#     HAS_YAML = False
+
+# # Configure logging
+# logging.basicConfig(
+#     format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
+#     datefmt="%Y-%m-%d %H:%M:%S",
+#     level=logging.INFO,
+# )
+# logger = logging.getLogger(__name__)
+
+
+# ## Will import from dataloader eventually
+# def read_config(config_path: str) -> Dict:
+#     """
+#     Read the configuration file (supports both JSON and YAML formats).
+
+#     Args:
+#         config_path: Path to the configuration file
+
+#     Returns:
+#         dict: Configuration parameters
+
+#     Raises:
+#         ValueError: If the file format is not supported
+#         ImportError: If the required package for the file format is not installed
+#     """
+#     file_extension = Path(config_path).suffix.lower()
+
+#     with open(config_path, "r") as f:
+#         if file_extension in [".json"]:
+#             config = json.load(f)
+#         elif file_extension in [".yaml", ".yml"]:
+#             if not HAS_YAML:
+#                 raise ImportError(
+#                     "The 'pyyaml' package is required to load YAML files. "
+#                     "Please install it with 'pip install pyyaml'."
+#                 )
+#             config = yaml.safe_load(f)
+#         else:
+#             raise ValueError(
+#                 f"Unsupported config file format: {file_extension}. "
+#                 f"Supported formats are: .json, .yaml, .yml"
+#             )
+
+#     return config
+
+
+# def run_torch_tune(config_path: str, args=None):
+#     """
+#     Run torch tune command with parameters from config file.
+
+#     Args:
+#         config_path: Path to the configuration file
+#         args: Command line arguments that may include additional kwargs to pass to the command
+#     """
+#     # Read the configuration
+#     config = read_config(config_path)
+
+#     # Extract parameters from config
+#     training_config = config.get("finetuning", {})
+
+#     # Initialize base_cmd to avoid "possibly unbound" error
+#     base_cmd = []
+
+#     # Determine the command based on configuration
+#     if training_config.get("distributed"):
+#         if training_config.get("strategy") == "lora":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "--nproc_per_node",
+#                 str(training_config.get("num_processes_per_node", 1)),
+#                 "lora_finetune_distributed",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         elif training_config.get("strategy") == "fft":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "--nproc_per_node",
+#                 str(training_config.get("num_processes_per_node", 1)),
+#                 "full_finetune_distributed",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         else:
+#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+
+#     else:
+#         if training_config.get("strategy") == "lora":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "lora_finetune_single_device",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         elif training_config.get("strategy") == "fft":
+#             base_cmd = [
+#                 "tune",
+#                 "run",
+#                 "full_finetune_single_device",
+#                 "--config",
+#                 training_config.get("torchtune_config"),
+#             ]
+#         else:
+#             raise ValueError(f"Invalid strategy: {training_config.get('strategy')}")
+
+#     # Check if we have a valid command
+#     if not base_cmd:
+#         raise ValueError(
+#             "Could not determine the appropriate command based on the configuration"
+#         )
+
+#     # Add any additional kwargs if provided
+#     if args and args.kwargs:
+#         # Split the kwargs string by spaces to get individual key=value pairs
+#         kwargs_list = args.kwargs.split()
+#         base_cmd.extend(kwargs_list)
+#         logger.info(f"Added additional kwargs: {kwargs_list}")
+
+#     # Log the command
+#     logger.info(f"Running command: {' '.join(base_cmd)}")
+
+#     # Run the command
+#     try:
+#         subprocess.run(base_cmd, check=True)
+#         logger.info("Training complete!")
+#     except subprocess.CalledProcessError as e:
+#         logger.error(f"Training failed with error: {e}")
+#         sys.exit(1)
+
+
+# def main():
+#     """Main function."""
+#     parser = argparse.ArgumentParser(
+#         description="Fine-tune a language model using torch tune"
+#     )
+#     parser.add_argument(
+#         "--config",
+#         type=str,
+#         required=True,
+#         help="Path to the configuration file (JSON or YAML)",
+#     )
+#     parser.add_argument(
+#         "--kwargs",
+#         type=str,
+#         default=None,
+#         help="Additional key-value pairs to pass to the command (space-separated, e.g., 'dataset=module.function dataset.param=value')",
+#     )
+#     args = parser.parse_args()
+
+#     run_torch_tune(args.config, args=args)
+
+
+# if __name__ == "__main__":
+#     main()