瀏覽代碼

added data loader and config file

khare19yash 1 月之前
父節點
當前提交
7ec9941960
共有 3 個文件被更改,包括 341 次插入30 次删除
  1. 20 0
      src/finetune_pipeline/config.yaml
  2. 209 0
      src/finetune_pipeline/data/data_loader.py
  3. 112 30
      src/finetune_pipeline/data/formatter.py

+ 20 - 0
src/finetune_pipeline/config.yaml

@@ -0,0 +1,20 @@
+# Configuration for data loading and formatting
+
+# Data source configuration
+data_path: "your/dataset/path"  # Path to the dataset (either a Hugging Face dataset ID or a local path)
+is_local: true                  # Whether the data is stored locally
+
+# Formatter configuration
+formatter_type: "torchtune"     # Type of formatter to use ('torchtune', 'vllm', or 'openai')
+
+# Column mapping configuration
+# Maps custom column names to standard field names
+column_mapping:
+  input: "question"             # Field containing the input text
+  output: "answer"              # Field containing the output text
+  image: "image_path"           # Field containing the image path (optional)
+
+# Additional arguments to pass to the load_dataset function
+dataset_kwargs:
+  split: "train"                # Dataset split to load
+  # Add any other dataset-specific arguments here

+ 209 - 0
src/finetune_pipeline/data/data_loader.py

@@ -0,0 +1,209 @@
+"""
+Data loader module for loading and formatting data from Hugging Face.
+"""
+
+import json
+import os
+from pathlib import Path
+from typing import Dict, Optional
+
+# Try to import yaml, but don't fail if it's not available
+try:
+    import yaml
+
+    HAS_YAML = True
+except ImportError:
+    HAS_YAML = False
+
+# Try to import datasets, but don't fail if it's not available
+try:
+    from datasets import load_dataset, load_from_disk
+
+    HAS_DATASETS = True
+except ImportError:
+    HAS_DATASETS = False
+
+    # Define dummy functions to avoid "possibly unbound" errors
+    def load_dataset(*args, **kwargs):
+        raise ImportError("The 'datasets' package is required to load data.")
+
+    def load_from_disk(*args, **kwargs):
+        raise ImportError("The 'datasets' package is required to load data.")
+
+
+from .formatter import Formatter, OpenAIFormatter, TorchtuneFormatter, vLLMFormatter
+
+
+def read_config(config_path: str) -> Dict:
+    """
+    Read the configuration file (supports both JSON and YAML formats).
+
+    Args:
+        config_path: Path to the configuration file
+
+    Returns:
+        dict: Configuration parameters
+
+    Raises:
+        ValueError: If the file format is not supported
+        ImportError: If the required package for the file format is not installed
+    """
+    file_extension = Path(config_path).suffix.lower()
+
+    with open(config_path, "r") as f:
+        if file_extension in [".json"]:
+            config = json.load(f)
+        elif file_extension in [".yaml", ".yml"]:
+            if not HAS_YAML:
+                raise ImportError(
+                    "The 'pyyaml' package is required to load YAML files. "
+                    "Please install it with 'pip install pyyaml'."
+                )
+            # Only use yaml if it's available (HAS_YAML is True here)
+            import yaml  # This import will succeed because we've already checked HAS_YAML
+
+            config = yaml.safe_load(f)
+        else:
+            raise ValueError(
+                f"Unsupported config file format: {file_extension}. "
+                f"Supported formats are: .json, .yaml, .yml"
+            )
+
+    return config
+
+
+def load_data(data_path: str, is_local: bool = False, **kwargs):
+    """
+    Load data from Hugging Face Hub or local disk.
+
+    Args:
+        data_path: Path to the dataset (either a Hugging Face dataset ID or a local path)
+        is_local: Whether the data is stored locally
+        **kwargs: Additional arguments to pass to the load_dataset function
+
+    Returns:
+        Dataset object from the datasets library
+
+    Raises:
+        ImportError: If the datasets package is not installed
+        ValueError: If data_path is None or empty
+    """
+    if not HAS_DATASETS:
+        raise ImportError(
+            "The 'datasets' package is required to load data. "
+            "Please install it with 'pip install datasets'."
+        )
+
+    if not data_path:
+        raise ValueError("data_path must be provided")
+
+    if is_local:
+        # Load from local disk
+        dataset = load_from_disk(data_path)
+    else:
+        # Load from Hugging Face Hub
+        dataset = load_dataset(data_path, **kwargs)
+
+    return dataset
+
+
+def get_formatter(formatter_type: str) -> Formatter:
+    """
+    Get the appropriate formatter based on the formatter type.
+
+    Args:
+        formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
+
+    Returns:
+        Formatter: Formatter instance
+
+    Raises:
+        ValueError: If the formatter type is not supported
+    """
+    formatter_map = {
+        "torchtune": TorchtuneFormatter,
+        "vllm": vLLMFormatter,
+        "openai": OpenAIFormatter,
+    }
+
+    if formatter_type.lower() not in formatter_map:
+        raise ValueError(
+            f"Unsupported formatter type: {formatter_type}. "
+            f"Supported types are: {', '.join(formatter_map.keys())}"
+        )
+
+    return formatter_map[formatter_type.lower()]()
+
+
+def format_data(data, formatter_type: str, column_mapping: Optional[Dict] = None):
+    """
+    Format the data using the specified formatter.
+
+    Args:
+        data: Data to format
+        formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
+        column_mapping: Optional mapping of column names
+
+    Returns:
+        Formatted data in the specified format
+    """
+    formatter = get_formatter(formatter_type)
+
+    # Read the data and convert to conversations
+    conversations = formatter.read_data(data, column_mapping)
+
+    # Format the conversations
+    formatted_data = formatter.format_data(conversations)
+
+    return formatted_data
+
+
+def load_and_format_data(config_path: str):
+    """
+    Load and format data based on the configuration.
+
+    Args:
+        config_path: Path to the configuration file
+
+    Returns:
+        Formatted data in the specified format
+    """
+    # Read the configuration
+    config = read_config(config_path)
+
+    # Extract parameters from config
+    data_path = config.get("data_path")
+    if not data_path:
+        raise ValueError("data_path must be specified in the config file")
+
+    is_local = config.get("is_local", False)
+    formatter_type = config.get("formatter_type", "torchtune")
+    column_mapping = config.get("column_mapping")
+    dataset_kwargs = config.get("dataset_kwargs", {})
+
+    # Load the data
+    data = load_data(data_path, is_local, **dataset_kwargs)
+
+    # Format the data
+    formatted_data = format_data(data, formatter_type, column_mapping)
+
+    return formatted_data
+
+
+if __name__ == "__main__":
+    # Example usage
+    import argparse
+
+    parser = argparse.ArgumentParser(
+        description="Load and format data from Hugging Face"
+    )
+    parser.add_argument(
+        "--config",
+        type=str,
+        required=True,
+        help="Path to the configuration file (JSON or YAML)",
+    )
+    args = parser.parse_args()
+
+    formatted_data = load_and_format_data(args.config)
+    print(f"Loaded and formatted data: {len(formatted_data)} samples")

+ 112 - 30
src/finetune_pipeline/data/formatter.py

@@ -81,7 +81,7 @@ class Formatter(ABC):
         pass
 
     @abstractmethod
-    def format_sample(self, sample) -> Union[Dict, str]:
+    def format_conversation(self, conversation) -> Union[Dict, str]:
         """
         Format a sample. This method must be implemented by subclasses.
 
@@ -93,22 +93,53 @@ class Formatter(ABC):
         """
         pass
 
-    def read_data(self, data):
+    @abstractmethod
+    def format_message(self, message) -> Union[Dict, str]:
+        """
+        Format a message. This method must be implemented by subclasses.
+
+        Args:
+            sample: Message object
+
+        Returns:
+            Formatted message in the appropriate format
+        """
+        pass
+
+    def read_data(self, data, column_mapping=None):
         """
         Read Hugging Face data and convert it to a list of Conversation objects.
 
         Args:
             data: Hugging Face dataset or iterable data
+            column_mapping: Optional dictionary mapping custom column names to standard field names.
+                            Expected keys are 'input', 'output', and 'image'.
+                            Example: {'input': 'question', 'output': 'answer', 'image': 'image_path'}
+                            If None, default field names ('input', 'output', 'image') will be used.
 
         Returns:
             list: List of Conversation objects, each containing a list of Message objects
         """
+        # Default column mapping if none provided
+        if column_mapping is None:
+            column_mapping = {"input": "input", "output": "output", "image": "image"}
+
+        # Validate column mapping
+        required_fields = ["input", "output"]
+        for field in required_fields:
+            if field not in column_mapping:
+                raise ValueError(f"Column mapping must include '{field}' field")
+
         conversations = []
         for item in data:
-            # Extract fields from the Hugging Face dataset item
-            image = item.get("image", None)
-            input_text = item.get("input", "")
-            output_label = item.get("output", "")
+            # Extract fields from the Hugging Face dataset item using the column mapping
+            image_field = column_mapping.get("image")
+            input_field = column_mapping.get("input")
+            output_field = column_mapping.get("output")
+
+            image = item.get(image_field, None) if image_field else None
+            input_text = item.get(input_field, "")
+            output_label = item.get(output_field, "")
 
             # Create a new conversation
             conversation = Conversation()
@@ -156,22 +187,31 @@ class TorchtuneFormatter(Formatter):
         if data is not None:
             self.conversation_data = self.read_data(data)
 
-    def format_data(self, data):
+    def format_data(self, data=None):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects
+            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
 
         Returns:
             list: List of formatted data
         """
+        # Use stored conversation_data if no data is provided
+        data_to_format = data if data is not None else self.conversation_data
+
+        # Check if we have data to format
+        if data_to_format is None:
+            raise ValueError(
+                "No data to format. Either provide data to format_data() or initialize with data."
+            )
+
         formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
+        for conversation in data_to_format:
+            formatted_data.append(self.format_conversation(conversation))
         return formatted_data
 
-    def format_sample(self, sample):
+    def format_conversation(self, conversation):
         """
         Format a sample.
 
@@ -182,11 +222,11 @@ class TorchtuneFormatter(Formatter):
             dict: Formatted sample in Torchtune format
         """
         formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
+        for message in conversation.messages:
+            formatted_messages.append(self.format_message(message))
         return {"messages": formatted_messages}
 
-    def format(self, message):
+    def format_message(self, message):
         """
         Format a message in Torchtune format.
 
@@ -206,22 +246,43 @@ class vLLMFormatter(Formatter):
     Formatter for vLLM format.
     """
 
-    def format_data(self, data):
+    def __init__(self, data=None):
+        """
+        Initialize the formatter.
+
+        Args:
+            data: Optional data to initialize with
+        """
+        super().__init__()
+        self.conversation_data = None
+        if data is not None:
+            self.conversation_data = self.read_data(data)
+
+    def format_data(self, data=None):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects
+            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
 
         Returns:
             list: List of formatted data in vLLM format
         """
+        # Use stored conversation_data if no data is provided
+        data_to_format = data if data is not None else self.conversation_data
+
+        # Check if we have data to format
+        if data_to_format is None:
+            raise ValueError(
+                "No data to format. Either provide data to format_data() or initialize with data."
+            )
+
         formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
+        for conversation in data_to_format:
+            formatted_data.append(self.format_conversation(conversation))
         return formatted_data
 
-    def format_sample(self, sample):
+    def format_conversation(self, conversation):
         """
         Format a sample.
 
@@ -232,11 +293,11 @@ class vLLMFormatter(Formatter):
             str: Formatted sample in vLLM format
         """
         formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
+        for message in conversation.messages:
+            formatted_messages.append(self.format_message(message))
         return "\n".join(formatted_messages)
 
-    def format(self, message):
+    def format_message(self, message):
         """
         Format a message in vLLM format.
 
@@ -266,22 +327,43 @@ class OpenAIFormatter(Formatter):
     Formatter for OpenAI format.
     """
 
-    def format_data(self, data):
+    def __init__(self, data=None):
+        """
+        Initialize the formatter.
+
+        Args:
+            data: Optional data to initialize with
+        """
+        super().__init__()
+        self.conversation_data = None
+        if data is not None:
+            self.conversation_data = self.read_data(data)
+
+    def format_data(self, data=None):
         """
         Format the data.
 
         Args:
-            data: List of Conversation objects
+            data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
 
         Returns:
             dict: Formatted data in OpenAI format
         """
+        # Use stored conversation_data if no data is provided
+        data_to_format = data if data is not None else self.conversation_data
+
+        # Check if we have data to format
+        if data_to_format is None:
+            raise ValueError(
+                "No data to format. Either provide data to format_data() or initialize with data."
+            )
+
         formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
+        for conversation in data_to_format:
+            formatted_data.append(self.format_conversation(conversation))
         return formatted_data
 
-    def format_sample(self, sample):
+    def format_conversation(self, conversation):
         """
         Format a sample.
 
@@ -292,11 +374,11 @@ class OpenAIFormatter(Formatter):
             dict: Formatted sample in OpenAI format
         """
         formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
+        for message in conversation.messages:
+            formatted_messages.append(self.format_message(message))
         return {"messages": formatted_messages}
 
-    def format(self, message):
+    def format_message(self, message):
         """
         Format a message in OpenAI format.