瀏覽代碼

Cleanup Data module

Suraj Subramanian 1 月之前
父節點
當前提交
6364a6f2b7

+ 29 - 296
src/finetune_pipeline/data/formatter.py

@@ -1,309 +1,42 @@
-import base64
-from abc import ABC, abstractmethod
-from typing import Dict, List, Optional, TypedDict, Union
+from types import Message
+from typing import Callable, List, Union
 
 
-def image_to_base64(image_path):
-    with open(image_path, "rb") as img:
-        return base64.b64encode(img.read()).decode("utf-8")
+def format_huggingface(base_message: Message) -> Message:
+    if base_message["role"] != "user":
+        return base_message
 
+    contents = []
+    for content in base_message["content"]:
+        if content["type"] == "text":
+            contents.append(content)
+        elif content["type"] == "image_url":
+            contents.append({"type": "image", "url": content["image_url"]["url"]})
 
-class MessageContent(TypedDict, total=False):
-    """Type definition for message content in LLM requests."""
+    return {"role": "user", "content": contents}
 
-    type: str  # Required field
-    text: Optional[str]  # Optional field
-    image_url: Optional[Dict[str, str]]  # Optional field
 
-
-class Message(TypedDict):
-    """Type definition for a message in a LLM inference request."""
-
-    role: str
-    content: Union[str, List[MessageContent]]
-
-
-class Conversation:
-    """
-    Data class representing a conversation which are list of messages.
-    """
-
-    def __init__(self, messages=None):
-        self.messages = messages if messages is not None else []
-
-    def add_message(self, message):
-        """
-        Add a message to the conversation.
-
-        Args:
-            message: Message object to add
-        """
-        self.messages.append(message)
-
-    def get_message(self, index):
-        """
-        Get a message at a specific index.
-
-        Args:
-            index: Index of the message to retrieve
-
-        Returns:
-            Message: The message at the specified index
-
-        Raises:
-            IndexError: If the index is out of bounds
-        """
-        if index < 0 or index >= len(self.messages):
-            raise IndexError(
-                f"Message index {index} out of range (0-{len(self.messages)-1})"
-            )
-        return self.messages[index]
-
-
-class Formatter(ABC):
-    """
-    Abstract base class for formatters that convert messages to different formats.
-    """
-
-    def __init__(self):
-        """
-        Initialize the formatter.
-
-        Subclasses can override this method to add specific initialization parameters.
-        """
-        pass
-
-    @abstractmethod
-    def format_data(self, data) -> List:
-        """
-        Format the message. This method must be implemented by subclasses.
-
-        Args:
-            data: List of Conversation objects
-
-        Returns:
-            List of formatted data
-        """
-        pass
-
-    @abstractmethod
-    def format_conversation(self, conversation) -> Union[Dict, str]:
-        """
-        Format a sample. This method must be implemented by subclasses.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            Formatted sample in the appropriate format
-        """
-        pass
-
-    @abstractmethod
-    def format_message(self, message) -> Union[Dict, str]:
-        """
-        Format a message. This method must be implemented by subclasses.
-
-        Args:
-            sample: Message object
-
-        Returns:
-            Formatted message in the appropriate format
-        """
-        pass
-
-    # The read_data function has been moved to convert_to_conversations in data_loader.py
-
-
-class TorchtuneFormatter(Formatter):
-    """
-    Formatter for Torchtune format.
-    """
-
-    def __init__(self):
-        """
-        Initialize the formatter.
-        """
-        super().__init__()
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects.
-
-        Returns:
-            list: List of formatted data
-        """
-        if data is None:
-            raise ValueError("No data provided to format_data()")
-
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_conversation(conversation))
-        return formatted_data
-
-    def format_conversation(self, conversation):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            dict: Formatted sample in Torchtune format
-        """
-        formatted_messages = []
-        for message in conversation.messages:
-            formatted_messages.append(self.format_message(message))
-        return {"messages": formatted_messages}
-
-    def format_message(self, message):
-        """
-        Format a message in Torchtune format.
-
-        Args:
-            message: Message object to format
-
-        Returns:
-            dict: Formatted message in Torchtune format
-        """
-        # For Torchtune format, we can return the Message as is
-        # since it's already in a compatible format
-        return message
-
-
-class vLLMFormatter(Formatter):
-    """
-    Formatter for vLLM format.
+def apply_format(
+    data: Union[List[Message], List[List[Message]]], format_func: Callable
+):
     """
+    Apply the format function to the data.
 
-    def __init__(self):
-        """
-        Initialize the formatter.
-        """
-        super().__init__()
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects.
-
-        Returns:
-            list: List of formatted data in vLLM format
-        """
-        if data is None:
-            raise ValueError("No data provided to format_data()")
-
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_conversation(conversation))
-        return formatted_data
+    Args:
+        data: Either a list of message dictionaries or a list of conversations
+              (where each conversation is a list of message dictionaries)
+        format_func: Function that formats a single message dictionary
 
-    def format_conversation(self, conversation):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            str: Formatted sample in vLLM format
-        """
-        formatted_messages = []
-        for message in conversation.messages:
-            role = message["role"]
-            if role != "assistant":
-                formatted_messages.append(self.format_message(message))
-        return {"messages": formatted_messages}
-
-    def format_message(self, message):
-        """
-        Format a message in vLLM format.
-
-        Args:
-            message: Message object to format
-
-        Returns:
-            str: Formatted message in vLLM format
-        """
-        contents = []
-        vllm_message = {}
-
-        for content in message["content"]:
-            if content["type"] == "text":
-                contents.append(content)
-            elif content["type"] == "image_url" or content["type"] == "image":
-                base64_image = image_to_base64(content["image_url"]["url"])
-                img_content = {
-                    "type": "image_url",
-                    "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
-                }
-                contents.append(img_content)
-            else:
-                raise ValueError(f"Unknown content type: {content['type']}")
-        vllm_message["role"] = message["role"]
-        vllm_message["content"] = contents
-        return vllm_message
-
-
-class OpenAIFormatter(Formatter):
-    """
-    Formatter for OpenAI format.
+    Returns:
+        List of formatted dictionaries
     """
+    if not data:
+        return []
 
-    def __init__(self):
-        """
-        Initialize the formatter.
-        """
-        super().__init__()
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects.
-
-        Returns:
-            dict: Formatted data in OpenAI format
-        """
-        if data is None:
-            raise ValueError("No data provided to format_data()")
-
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_conversation(conversation))
-        return formatted_data
-
-    def format_conversation(self, conversation):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            dict: Formatted sample in OpenAI format
-        """
-        formatted_messages = []
-        for message in conversation.messages:
-            formatted_messages.append(self.format_message(message))
-        return {"messages": formatted_messages}
-
-    def format_message(self, message):
-        """
-        Format a message in OpenAI format.
+    if isinstance(data[0], Message):
+        return [format_func(message) for message in data]
 
-        Args:
-            message: Message object to format
+    if isinstance(data[0][0], Message):
+        return [apply_format(conversation, format_func) for conversation in data]
 
-        Returns:
-            dict: Formatted message in OpenAI format
-        """
-        # For OpenAI format, we can return the Message as is
-        # since it's already in a compatible format
-        return message
+    raise ValueError("Invalid data format")

+ 0 - 80
src/finetune_pipeline/data/formatter_functional.py

@@ -1,80 +0,0 @@
-from typing import Dict, List, Union
-
-
-def format_message_torchtune(message: Dict) -> Dict:
-    """Format a message in Torchtune format."""
-    return message
-
-
-def format_message_openai(message: Dict) -> Dict:
-    """Format a message in OpenAI format."""
-    contents = []
-    for content in message["content"]:
-        if content["type"] == "text":
-            contents.append({"type": "input_text", "text": content["text"]})
-        elif content["type"] == "image_url":
-            contents.append(
-                {"type": "input_image", "image_url": content["image_url"]["url"]}
-            )
-        else:
-            raise ValueError(f"Unknown content type: {content['type']}")
-    return {"role": message["role"], "content": contents}
-
-
-def format_message_vllm(message: Dict) -> Dict:
-    """Format a message in vLLM format."""
-    contents = []
-    vllm_message = {}
-
-    for content in message["content"]:
-        if content["type"] == "text":
-            contents.append(content)
-        elif content["type"] == "image_url" or content["type"] == "image":
-            img_content = {
-                "type": "image_url",
-                "image_url": {
-                    "url": f"data:image/jpg;base64,{content["image_url"]["url"]}"
-                },
-            }
-            contents.append(img_content)
-        else:
-            raise ValueError(f"Unknown content type: {content['type']}")
-    vllm_message["role"] = message["role"]
-    vllm_message["content"] = contents
-    return vllm_message
-
-
-def apply_format(data: Union[List[Dict], List[List[Dict]]], format_func) -> List[Dict]:
-    """
-    Apply the format function to the data.
-
-    Args:
-        data: Either a list of message dictionaries or a list of conversations
-              (where each conversation is a list of message dictionaries)
-        format_func: Function that formats a single message dictionary
-
-    Returns:
-        List of formatted dictionaries
-    """
-    if not data:
-        return []
-
-    # Check if data is a list of conversations (list of lists) or a list of messages
-    if isinstance(data[0], list):
-        # data is a list of conversations, each conversation is a list of messages
-        formatted_conversations = []
-        for conversation in data:
-            formatted_messages = []
-            for message in conversation:
-                formatted_message = format_func(message)
-                formatted_messages.append(formatted_message)
-            # Return the conversation as a dictionary with "messages" key
-            formatted_conversations.append({"messages": formatted_messages})
-        return formatted_conversations
-    else:
-        # data is a list of messages
-        formatted_messages = []
-        for message in data:
-            formatted_message = format_func(message)
-            formatted_messages.append(formatted_message)
-        return formatted_messages

+ 14 - 120
src/finetune_pipeline/data/data_loader.py

@@ -2,111 +2,17 @@
 Data loader module for loading and formatting data from Hugging Face.
 """
 
-import base64
 import json
 import os
 from pathlib import Path
-from typing import Dict, Optional, Union
-
-import pandas as pd
-from PIL import Image
-
-# Try to import yaml, but don't fail if it's not available
-try:
-    import yaml
-
-    HAS_YAML = True
-except ImportError:
-    HAS_YAML = False
-
-# Try to import datasets, but don't fail if it's not available
-try:
-    from datasets import Dataset, load_dataset, load_from_disk
-
-    HAS_DATASETS = True
-except ImportError:
-    HAS_DATASETS = False
-
-    # Define dummy functions to avoid "possibly unbound" errors
-    def load_dataset(*args, **kwargs):
-        raise ImportError("The 'datasets' package is required to load data.")
-
-    def load_from_disk(*args, **kwargs):
-        raise ImportError("The 'datasets' package is required to load data.")
-
-
-def is_base64_encoded(s: str) -> bool:
-    """Check if a string is already base64 encoded."""
-    try:
-        # Basic character check - base64 only contains these characters
-        if not all(
-            c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
-            for c in s
-        ):
-            return False
-
-        # Try to decode - if it fails, it's not valid base64
-        decoded = base64.b64decode(s, validate=True)
-
-        # Re-encode and compare - if they match, it was valid base64
-        re_encoded = base64.b64encode(decoded).decode("utf-8")
-        return s == re_encoded or s == re_encoded.rstrip(
-            "="
-        )  # Handle padding differences
-    except Exception:
-        return False
-
-
-def image_to_base64(image: Union[str, list, Image.Image]):
-    if isinstance(image, str):
-        # Check if the string is already base64 encoded
-        if is_base64_encoded(image):
-            return image
-        # Otherwise, treat it as a file path
-        with open(image, "rb") as img:
-            return base64.b64encode(img.read()).decode("utf-8")
-    elif isinstance(image, Image.Image):
-        return base64.b64encode(image.tobytes()).decode("utf-8")
-    elif isinstance(image, list):
-        return [image_to_base64(img) for img in image]
-
-
-def read_config(config_path: str) -> Dict:
-    """
-    Read the configuration file (supports both JSON and YAML formats).
+from typing import Dict, Optional
 
-    Args:
-        config_path: Path to the configuration file
+from datasets import load_dataset, load_from_disk
 
-    Returns:
-        dict: Configuration parameters
+from .utils import image_to_base64, read_config
 
-    Raises:
-        ValueError: If the file format is not supported
-        ImportError: If the required package for the file format is not installed
-    """
-    file_extension = Path(config_path).suffix.lower()
-
-    with open(config_path, "r") as f:
-        if file_extension in [".json"]:
-            config = json.load(f)
-        elif file_extension in [".yaml", ".yml"]:
-            if not HAS_YAML:
-                raise ImportError(
-                    "The 'pyyaml' package is required to load YAML files. "
-                    "Please install it with 'pip install pyyaml'."
-                )
-            config = yaml.safe_load(f)
-        else:
-            raise ValueError(
-                f"Unsupported config file format: {file_extension}. "
-                f"Supported formats are: .json, .yaml, .yml"
-            )
-
-    return config
 
-
-def load_data(
+def process_hf_dataset(
     data_path: str,
     is_local: bool = False,
     column_mapping: Optional[Dict] = None,
@@ -127,22 +33,15 @@ def load_data(
         ImportError: If the datasets package is not installed
         ValueError: If data_path is None or empty
     """
-    if not HAS_DATASETS:
-        raise ImportError(
-            "The 'datasets' package is required to load data. "
-            "Please install it with 'pip install datasets'."
-        )
-
     if not data_path:
         raise ValueError("data_path must be provided")
+
     dataset = None
     if is_local:
         # Load from local disk
         file_extension = Path(data_path).suffix.lower()
         if file_extension in [".csv"]:
-            data = pd.read_csv(data_path)
-            dataset = Dataset.from_pandas(data)
-        else:
+            dataset = load_dataset("csv", data_files=data_path)
             dataset = load_from_disk(data_path)
     else:
         # Load from Hugging Face Hub
@@ -208,12 +107,13 @@ def convert_to_encoded_messages(
     messages.append({"role": "user", "content": user_content})
 
     # Create assistant message with text content
-    messages.append(
-        {
-            "role": "assistant",
-            "content": [{"type": "text", "text": output_label}],
-        }
-    )
+    if output_label:
+        messages.append(
+            {
+                "role": "assistant",
+                "content": [{"type": "text", "text": output_label}],
+            }
+        )
     # Serialize to string and return. This is required because datasets.map adds extra keys to each dict in messages
     example["messages"] = json.dumps(messages)
     return example
@@ -233,12 +133,6 @@ def save_encoded_dataset(encoded_dataset, output_dir: str, split: str):
         json.dump(messages, f, indent=2)
 
 
-# TODO: Verify if this is actually needed?
-# def format_encoded_dataset(encoded_dataset, output_dir, split, format):
-#     if format == "vllm":
-#         messages = [json.loads(x) for x in encoded_dataset["messages"]]
-
-
 def get_splits(dataset):
     """
     Helper function to get splits from a dataset.
@@ -305,7 +199,7 @@ def get_hf_dataset(
         )
 
     # Load the dataset
-    dataset = load_data(
+    dataset = process_hf_dataset(
         data_path=dataset_id,
         is_local=is_local,
         column_mapping=column_mapping,

+ 74 - 0
src/finetune_pipeline/data/utils.py

@@ -0,0 +1,74 @@
+import base64
+import json
+from pathlib import Path
+from typing import Dict, Union
+
+import yaml
+
+from PIL import Image
+
+
+def is_base64_encoded(s: str) -> bool:
+    """Check if a string is already base64 encoded."""
+    try:
+        # Basic character check - base64 only contains these characters
+        if not all(
+            c in "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="
+            for c in s
+        ):
+            return False
+
+        # Try to decode - if it fails, it's not valid base64
+        decoded = base64.b64decode(s, validate=True)
+
+        # Re-encode and compare - if they match, it was valid base64
+        re_encoded = base64.b64encode(decoded).decode("utf-8")
+        return s == re_encoded or s == re_encoded.rstrip(
+            "="
+        )  # Handle padding differences
+    except Exception:
+        return False
+
+
+def image_to_base64(image: Union[str, list, Image.Image]):
+    if isinstance(image, str):
+        # Check if the string is already base64 encoded
+        if is_base64_encoded(image):
+            return image
+        # Otherwise, treat it as a file path
+        with open(image, "rb") as img:
+            return base64.b64encode(img.read()).decode("utf-8")
+    elif isinstance(image, Image.Image):
+        return base64.b64encode(image.tobytes()).decode("utf-8")
+    elif isinstance(image, list):
+        return [image_to_base64(img) for img in image]
+
+
+def read_config(config_path: str) -> Dict:
+    """
+    Read the configuration file (supports both JSON and YAML formats).
+
+    Args:
+        config_path: Path to the configuration file
+
+    Returns:
+        dict: Configuration parameters
+
+    Raises:
+        ValueError: If the file format is not supported
+        ImportError: If the required package for the file format is not installed
+    """
+    file_extension = Path(config_path).suffix.lower()
+
+    with open(config_path, "r") as f:
+        if file_extension in [".json"]:
+            config = json.load(f)
+        elif file_extension in [".yaml", ".yml"]:
+            config = yaml.safe_load(f)
+        else:
+            raise ValueError(
+                f"Unsupported config file format: {file_extension}. "
+                f"Supported formats are: .json, .yaml, .yml"
+            )
+
+    return config

+ 1 - 1
src/finetune_pipeline/types.py

@@ -11,7 +11,7 @@ from typing import Dict, List, Optional, TypedDict, Union
 class MessageContent(TypedDict, total=False):
     """Type definition for message content in LLM requests."""
 
-    type: str  # Required field
+    type: str  # "text", "image_url""
     text: Optional[str]  # Optional field
     image_url: Optional[Dict[str, str]]  # Optional field
 

+ 0 - 46
src/finetune_pipeline/types/__init__.py

@@ -1,46 +0,0 @@
-"""
-Type definitions for the finetune pipeline.
-
-This package contains domain-specific type definitions organized by functional area.
-"""
-
-# Data processing types
-from .data import (
-    DataLoaderConfig,
-    DatasetStats,
-    FormatterConfig,
-    Message,
-    MessageContent,
-)
-
-# Inference types
-from .inference import (
-    InferenceConfig,
-    InferenceRequest,
-    InferenceResponse,
-    ModelServingConfig,
-    ServingMetrics,
-)
-
-# Training types
-from .training import CheckpointInfo, LoRAConfig, TrainingConfig, TrainingMetrics
-
-__all__ = [
-    # Data types
-    "DataLoaderConfig",
-    "DatasetStats",
-    "FormatterConfig",
-    "Message",
-    "MessageContent",
-    # Training types
-    "CheckpointInfo",
-    "LoRAConfig",
-    "TrainingConfig",
-    "TrainingMetrics",
-    # Inference types
-    "InferenceConfig",
-    "InferenceRequest",
-    "InferenceResponse",
-    "ModelServingConfig",
-    "ServingMetrics",
-]

+ 0 - 48
src/finetune_pipeline/types/data.py

@@ -1,48 +0,0 @@
-"""
-Data processing and formatting type definitions.
-
-Types related to data loading, formatting, and preprocessing.
-"""
-
-from typing import Dict, List, Optional, TypedDict, Union
-
-
-class MessageContent(TypedDict, total=False):
-    """Type definition for message content in LLM requests."""
-
-    type: str  # Required field
-    text: Optional[str]  # Optional field
-    image_url: Optional[Dict[str, str]]  # Optional field
-
-
-class Message(TypedDict):
-    """Type definition for a message in a LLM inference request."""
-
-    role: str
-    content: Union[str, List[MessageContent]]
-
-
-class DataLoaderConfig(TypedDict, total=False):
-    """Configuration for data loading."""
-
-    batch_size: int
-    shuffle: bool
-    data_path: str
-    validation_split: Optional[float]
-
-
-class FormatterConfig(TypedDict, total=False):
-    """Configuration for data formatting."""
-
-    format_type: str  # "torchtune", "vllm", "openai"
-    include_system_prompt: bool
-    max_sequence_length: Optional[int]
-
-
-class DatasetStats(TypedDict):
-    """Statistics about a dataset."""
-
-    total_conversations: int
-    total_messages: int
-    avg_messages_per_conversation: float
-    data_size_mb: float

+ 0 - 57
src/finetune_pipeline/types/inference.py

@@ -1,57 +0,0 @@
-"""
-Inference and serving type definitions.
-
-Types related to model inference, serving configurations, and inference results.
-"""
-
-from typing import Dict, List, Optional, TypedDict
-
-
-class InferenceConfig(TypedDict, total=False):
-    """Configuration for inference parameters."""
-
-    model_path: str
-    max_tokens: Optional[int]
-    temperature: Optional[float]
-    top_p: Optional[float]
-    top_k: Optional[int]
-    repetition_penalty: Optional[float]
-
-
-class InferenceRequest(TypedDict):
-    """Request for model inference."""
-
-    messages: List[Dict]  # List of Message objects
-    config: InferenceConfig
-    stream: Optional[bool]
-
-
-class InferenceResponse(TypedDict):
-    """Response from model inference."""
-
-    generated_text: str
-    input_tokens: int
-    output_tokens: int
-    total_time_ms: float
-    tokens_per_second: float
-
-
-class ModelServingConfig(TypedDict, total=False):
-    """Configuration for model serving."""
-
-    host: str
-    port: int
-    model_path: str
-    max_concurrent_requests: int
-    gpu_memory_utilization: Optional[float]
-    tensor_parallel_size: Optional[int]
-
-
-class ServingMetrics(TypedDict):
-    """Metrics for model serving."""
-
-    requests_per_second: float
-    average_latency_ms: float
-    active_requests: int
-    total_requests: int
-    error_rate: float

+ 0 - 50
src/finetune_pipeline/types/training.py

@@ -1,50 +0,0 @@
-"""
-Training and finetuning type definitions.
-
-Types related to model training, finetuning configurations, and training metrics.
-"""
-
-from typing import List, Optional, TypedDict
-
-
-class TrainingConfig(TypedDict, total=False):
-    """Configuration for training parameters."""
-
-    learning_rate: float
-    batch_size: int
-    epochs: int
-    model_name: str
-    optimizer: Optional[str]
-    weight_decay: Optional[float]
-    gradient_accumulation_steps: Optional[int]
-
-
-class LoRAConfig(TypedDict, total=False):
-    """Configuration for LoRA (Low-Rank Adaptation) finetuning."""
-
-    rank: int
-    alpha: int
-    dropout: float
-    target_modules: List[str]
-
-
-class TrainingMetrics(TypedDict):
-    """Training metrics collected during training."""
-
-    epoch: int
-    step: int
-    train_loss: float
-    validation_loss: Optional[float]
-    learning_rate: float
-    throughput_tokens_per_sec: float
-
-
-class CheckpointInfo(TypedDict):
-    """Information about a model checkpoint."""
-
-    checkpoint_path: str
-    epoch: int
-    step: int
-    model_name: str
-    training_config: TrainingConfig
-    metrics: TrainingMetrics