Browse Source

functional refactor

Suraj Subramanian 1 month ago
parent
commit
d12bf4f64b

+ 3 - 2
src/finetune_pipeline/config.yaml

@@ -2,10 +2,11 @@
 output_dir: "/tmp/finetuning-pipeline/llama3_2_vision/"  # Directory to store output files
 
 data:
-  data_path: "data/path"  # Path to the dataset to load
+  dataset_id: "data/path"  # Path to the dataset to load
   is_local: true           # Whether the data is stored locally
-  formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or
+  # formatter_type: "vllm"            # Type of formatter to use ('torchtune', 'vllm', or
   system_prompt: "You are a helpful assisstant"  # System prompt to use for the dataset
+  # TODO: Key should be old name, value should be new name. Ref: https://huggingface.co/docs/datasets/v4.0.0/en/package_reference/main_classes#datasets.Dataset.rename_columns
   column_mapping:
     input: "instruction"             # Field containing the input text
     output: "output"              # Field containing the output text

+ 183 - 278
src/finetune_pipeline/data/data_loader.py

@@ -2,11 +2,14 @@
 Data loader module for loading and formatting data from Hugging Face.
 """
 
+import base64
 import json
 import os
 from pathlib import Path
-from typing import Any, Dict, List, Optional, Union
+from typing import Dict, Optional, Union
+
 import pandas as pd
+from PIL import Image
 
 # Try to import yaml, but don't fail if it's not available
 try:
@@ -18,7 +21,7 @@ except ImportError:
 
 # Try to import datasets, but don't fail if it's not available
 try:
-    from datasets import load_dataset, load_from_disk, Dataset
+    from datasets import Dataset, load_dataset, load_from_disk
 
     HAS_DATASETS = True
 except ImportError:
@@ -32,7 +35,12 @@ except ImportError:
         raise ImportError("The 'datasets' package is required to load data.")
 
 
-from .formatter import Formatter, OpenAIFormatter, TorchtuneFormatter, vLLMFormatter
+def image_to_base64(image: Union[str, Image.Image]):
+    if isinstance(image, str):
+        with open(image, "rb") as img:
+            return base64.b64encode(img.read()).decode("utf-8")
+    elif isinstance(image, Image.Image):
+        return base64.b64encode(image.tobytes()).decode("utf-8")
 
 
 def read_config(config_path: str) -> Dict:
@@ -70,7 +78,12 @@ def read_config(config_path: str) -> Dict:
     return config
 
 
-def load_data(data_path: str, is_local: bool = False, **kwargs):
+def load_dataset(
+    data_path: str,
+    is_local: bool = False,
+    column_mapping: Optional[Dict] = None,
+    **kwargs,
+):
     """
     Load data from Hugging Face Hub or local disk.
 
@@ -108,341 +121,233 @@ def load_data(data_path: str, is_local: bool = False, **kwargs):
         # Load from Hugging Face Hub
         dataset = load_dataset(data_path, **kwargs)
 
-    return dataset
-
-
-def get_formatter(formatter_type: str) -> Formatter:
-    """
-    Get the appropriate formatter based on the formatter type.
-
-    Args:
-        formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
-
-    Returns:
-        Formatter: Formatter instance
-
-    Raises:
-        ValueError: If the formatter type is not supported
-    """
-    formatter_map = {
-        "torchtune": TorchtuneFormatter,
-        "vllm": vLLMFormatter,
-        "openai": OpenAIFormatter,
-    }
-
-    if formatter_type.lower() not in formatter_map:
-        raise ValueError(
-            f"Unsupported formatter type: {formatter_type}. "
-            f"Supported types are: {', '.join(formatter_map.keys())}"
-        )
-
-    return formatter_map[formatter_type.lower()]()
-
-
-
-def convert_to_conversations(
-    data, column_mapping: Optional[Dict] = None, system_prompt: Optional[str] = None
-) -> List[Any]:
-    """
-    Convert data to a list of Conversation objects.
-
-    Args:
-        data: Data to convert
-        column_mapping: Optional mapping of column names
-
-    Returns:
-        list: List of Conversation objects
-    """
-    # Import here to avoid circular imports
-    from .formatter import Conversation
-
-    # Default column mapping if none provided
+    # Rename columns if column_mapping is provided
     if column_mapping is None:
         column_mapping = {"input": "input", "output": "output", "image": "image"}
 
-    # Validate column mapping
     required_fields = ["input", "output"]
     for field in required_fields:
         if field not in column_mapping:
             raise ValueError(f"Column mapping must include '{field}' field")
+    dataset = dataset.rename_columns(column_mapping)
 
-    conversations = []
-    for item in data:
-        # Extract fields from the dataset item using the column mapping
-        image_field = column_mapping.get("image")
-        input_field = column_mapping.get("input")
-        output_field = column_mapping.get("output")
+    return dataset
 
-        image = item.get(image_field, None) if image_field else None
-        input_text = item.get(input_field, "")
-        output_label = item.get(output_field, "")
 
-        # Create a new conversation
-        conversation = Conversation()
+def convert_to_encoded_messages(
+    example: Dict, system_prompt: Optional[str] = None
+) -> Dict:
+    image_field = "image"
+    input_field = "input"
+    output_field = "output"
 
-        system_message = None
-        if system_prompt:
-            system_message = {
+    image = example.get(image_field, None)  # if image_field in example else None
+    input_text = example.get(input_field, "")
+    output_label = example.get(output_field, "")
+
+    messages = []
+
+    # Create system message if system_prompt is provided
+    if system_prompt:
+        messages.append(
+            {
                 "role": "system",
                 "content": [{"type": "text", "text": system_prompt}],
             }
+        )
 
-        # Create user content and user message
-        user_content = [
-            {"type": "text", "text": input_text},
-        ]
-        # Add image(s) to user content
-        if image is not None:
-            if isinstance(image, list):
-                # Handle list of images
-                for img in image:
-                    if img:  # Check if image path is not empty
-                        user_content.append(
-                            {"type": "image_url", "image_url": {"url": img}}
-                        )
-                        break
-            else:
-                # Handle single image
-                user_content.append(
-                    {"type": "image_url", "image_url": {"url": image}}
-                )
-
-        user_message = {"role": "user", "content": user_content}
-
-        # Create assistant message with text content
-        assistant_content = [
-            {"type": "text", "text": output_label},
-        ]
-        assistant_message = {"role": "assistant", "content": assistant_content}
-
-        # Add messages to the conversation
-        if system_message:
-            conversation.add_message(system_message)
-        conversation.add_message(user_message)
-        conversation.add_message(assistant_message)
-
-        # Add the conversation to the list
-        conversations.append(conversation)
-
-    return conversations
-
-
-def save_formatted_data(
-    formatted_data: List[Any], output_dir: str, formatter_type: str, split: str
-) -> str:
-    """
-    Save formatted data to a JSON file.
+    # Create user content and user message
+    user_content = [
+        {"type": "text", "text": input_text},
+    ]
+    # Add image(s) to user content
+    if image is not None:
+        if not isinstance(image, list):
+            image = [image]
+        for img in image:
+            b64_img = image_to_base64(img)
+            user_content.append({"type": "image_url", "image_url": {"url": b64_img}})
+
+    messages.append({"role": "user", "content": user_content})
+
+    # Create assistant message with text content
+    messages.append(
+        {
+            "role": "assistant",
+            "content": [{"type": "text", "text": output_label}],
+        }
+    )
+    # Serialize to string and return. This is required because datasets.map adds extra keys to each dict in messages
+    example["messages"] = json.dumps(messages)
+    return example
 
-    Args:
-        formatted_data: The formatted data to save
-        output_dir: Directory to save the data
-        formatter_type: Type of formatter used ('torchtune', 'vllm', or 'openai')
 
-    Returns:
-        Path to the saved file
-    """
+def save_encoded_dataset(encoded_dataset, output_dir: str, split: str):
     # Create the output directory if it doesn't exist
     os.makedirs(output_dir, exist_ok=True)
 
     # Define the output file path
-    formatted_data_path = os.path.join(
-        output_dir, f"{split}_{formatter_type}_formatted_data.json"
-    )
+    conversation_data_path = os.path.join(output_dir, f"{split}_conversation_data.json")
+
+    if not "messages" in encoded_dataset.column_names:
+        raise RuntimeError
+    messages = [json.loads(x) for x in encoded_dataset["messages"]]
+    with open(conversation_data_path, "w") as f:
+        json.dump(messages, f, indent=2)
 
-    # Save the formatted data
-    with open(formatted_data_path, "w") as f:
-        # Handle different data types
-        if isinstance(formatted_data, list) and all(
-            isinstance(item, dict) for item in formatted_data
-        ):
-            json.dump(formatted_data, f, indent=2)
-        elif isinstance(formatted_data, list) and all(
-            isinstance(item, str) for item in formatted_data
-        ):
-            json.dump(formatted_data, f, indent=2)
-        else:
-            # For other types, convert to a simple list of strings
-            json.dump([str(item) for item in formatted_data], f, indent=2)
 
-    print(f"Saved formatted data to {formatted_data_path}")
-    return formatted_data_path
+# TODO: Verify if this is actually needed?
+# def format_encoded_dataset(encoded_dataset, output_dir, split, format):
+#     if format == "vllm":
+#         messages = [json.loads(x) for x in encoded_dataset["messages"]]
 
 
-def save_conversation_data(conversation_data: List, output_dir: str, split: str) -> str:
+def get_splits(dataset):
     """
-    Save conversation data to a JSON file.
+    Helper function to get splits from a dataset.
 
     Args:
-        conversation_data: List of Conversation objects
-        output_dir: Directory to save the data
+        dataset: HuggingFace dataset object
 
     Returns:
-        Path to the saved file
+        List of split names
     """
-    # Create the output directory if it doesn't exist
-    os.makedirs(output_dir, exist_ok=True)
+    if hasattr(dataset, "keys"):
+        return {k: dataset[k] for k in dataset.keys()}
+    return {"default": dataset}
 
-    # Define the output file path
-    conversation_data_path = os.path.join(output_dir, f"{split}_conversation_data.json")
-
-    # Convert Conversation objects to a serializable format
-    serializable_conversations = []
-    for conv in conversation_data:
-        serializable_conversations.append({"messages": conv.messages})
-
-    # Save the conversation data
-    with open(conversation_data_path, "w") as f:
-        json.dump(serializable_conversations, f, indent=2)
 
-    print(f"Saved conversation data to {conversation_data_path}")
-    return conversation_data_path
-
-
-def format_data(
-    data,
-    formatter_type: str,
-    output_dir: str,
+def get_hf_dataset(
+    config_path: Optional[str] = None,
+    output_dir: Optional[str] = None,
+    dataset_id: Optional[str] = None,
+    is_local: bool = False,
     column_mapping: Optional[Dict] = None,
-    system_prompt: Optional[str] = None,
     dataset_kwargs: Optional[Dict] = None,
+    system_prompt: Optional[str] = None,
 ):
     """
-    Format the data using the specified formatter for all splits.
+    Load and format data based on either a configuration file or individual parameters.
 
     Args:
-        data: Dataset with multiple splits to format or a single dataset
-        formatter_type: Type of formatter to use ('torchtune', 'vllm', or 'openai')
         output_dir: Directory to save the formatted data
-        column_mapping: Optional mapping of column names
-        dataset_kwargs: Optional dataset kwargs that may contain split information
+        config_path: Path to configuration file (YAML/JSON). If provided, other parameters are ignored.
+        dataset_id: Path/ID to the dataset to load
+        is_local: Whether the data is stored locally
+        column_mapping: Dictionary mapping column names
+        dataset_kwargs: Additional arguments to pass to load_dataset
+        system_prompt: System prompt to use for the dataset
 
     Returns:
-        Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
+        str: Path to the output directory containing the formatted data
     """
-    formatted_data_paths = []
-    conversation_data_paths = []
-
-    # Check if the dataset has explicit splits
-    if (
-        hasattr(data, "keys")
-        and callable(data.keys)
-        and len(data.keys()) > 0
-        and isinstance(data, dict)
-    ):
-        # Dataset has splits (train, validation, test, etc.)
-        splits = data.keys()
-
-        for split in splits:
-            # First convert the data to conversations
-            conversations = convert_to_conversations(data[split], column_mapping)
-
-            # Then get the formatter and format the conversations
-            formatter = get_formatter(formatter_type)
-            formatted_data = formatter.format_data(conversations)
-            print(
-                f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
-            )
 
-            # Save the formatted data
-            formatted_data_path = save_formatted_data(
-                formatted_data, output_dir, formatter_type, split
-            )
-            formatted_data_paths.append(formatted_data_path)
-
-            # Save the conversation data
-            conversation_data_path = save_conversation_data(
-                conversations, output_dir, split
-            )
-            conversation_data_paths.append(conversation_data_path)
+    # If config_path is provided, load from config file
+    if config_path:
+        config = read_config(config_path)
+        output_dir = config.get("output_dir", "/tmp/finetuning-pipeline/outputs")
+        data_config = config.get("data", {})
+
+        # Extract parameters from data config
+        dataset_id = data_config.get("dataset_id")
+        is_local = data_config.get("is_local", False)
+        column_mapping = data_config.get("column_mapping")
+        dataset_kwargs = data_config.get("dataset_kwargs", {})
+        system_prompt = data_config.get("system_prompt", None)
     else:
-        # Dataset doesn't have explicit splits, treat it as a single dataset
-        # Check if a split is specified in dataset_kwargs
-        split = "default"
-        if dataset_kwargs and "split" in dataset_kwargs:
-            split = dataset_kwargs["split"]
-
-        # First convert the data to conversations
-        conversations = convert_to_conversations(data, column_mapping)
-
-        # Then get the formatter and format the conversations
-        formatter = get_formatter(formatter_type)
-        formatted_data = formatter.format_data(conversations)
-        print(
-            f"Loaded and formatted data for split '{split}': {len(formatted_data)} samples"
-        )
+        # Use individual parameters passed to the function
+        if dataset_kwargs is None:
+            dataset_kwargs = {}
 
-        # Save the formatted data
-        formatted_data_path = save_formatted_data(
-            formatted_data, output_dir, formatter_type, split
+    # Validate required parameters
+    if not dataset_id:
+        raise ValueError(
+            "dataset_id must be specified either in config file or as parameter"
         )
-        formatted_data_paths.append(formatted_data_path)
 
-        # Save the conversation data
-        conversation_data_path = save_conversation_data(
-            conversations, output_dir, split
-        )
-        conversation_data_paths.append(conversation_data_path)
+    # Load the dataset
+    dataset = load_dataset(
+        data_path=dataset_id,
+        is_local=is_local,
+        column_mapping=column_mapping,
+        **dataset_kwargs,
+    )
 
-    return formatted_data_paths, conversation_data_paths
+    # Get available splits
+    dataset_splits = get_splits(dataset)
 
+    # Process each split
+    for split_name, split_dataset in dataset_splits.items():
+        # Apply the conversion function
+        encoded_dataset = split_dataset.map(
+            lambda example: convert_to_encoded_messages(example, system_prompt)
+        )
 
-def load_and_format_data(data_config: Dict, output_dir: str):
-    """
-    Load and format data based on the configuration.
+        # Save the encoded dataset
+        save_encoded_dataset(encoded_dataset, output_dir, split_name)
 
-    Args:
-        formatter_config: Dictionary containing formatter configuration parameters
-        output_dir: Directory to save the formatted data
+        # TODO: Evaluate if formatting is needed here
 
-    Returns:
-        Tuple containing (formatted_data_paths, conversation_data_paths) where each is a list of paths to saved files
-    """
-
-    # Extract parameters from config
-    data_path = data_config.get("data_path")
-    if not data_path:
-        raise ValueError(
-            "data_path must be specified in the formatter section of the config file"
-        )
+    return output_dir
 
-    is_local = data_config.get("is_local", False)
-    formatter_type = data_config.get("formatter_type", "torchtune")
-    column_mapping = data_config.get("column_mapping")
-    dataset_kwargs = data_config.get("dataset_kwargs", {})
-    system_prompt = data_config.get("system_prompt", None)
 
-    # Load the data
-    data = load_data(data_path, is_local, **dataset_kwargs)
+def main():
+    """
+    Example command-line interface for get_hf_dataset function.
+    Shows how to use the function with either config file or individual arguments.
+    """
+    import argparse
 
-    # Format the data
-    formatted_data_paths, conversation_data_paths = format_data(
-        data, formatter_type, output_dir, column_mapping, system_prompt, dataset_kwargs
+    parser = argparse.ArgumentParser(description="Load and format HuggingFace dataset")
+    parser.add_argument(
+        "--output_dir",
+        required=True,
+        help="Directory to save formatted data",
+        default="/tmp/finetuning-pipeline/outputs",
     )
 
-    return formatted_data_paths, conversation_data_paths
-
-
-if __name__ == "__main__":
-    # Example usage
-    import argparse
+    # Config file option
+    parser.add_argument("--config", help="Path to config file (YAML/JSON)")
 
-    parser = argparse.ArgumentParser(
-        description="Load and format data from Hugging Face"
+    # Individual parameter options
+    parser.add_argument("--dataset_id", help="HF Dataset ID or path")
+    parser.add_argument(
+        "--is_local", action="store_true", help="Dataset is stored locally"
     )
+    parser.add_argument("--system_prompt", help="System prompt for the dataset")
     parser.add_argument(
-        "--config",
-        type=str,
-        required=True,
-        help="Path to the configuration file (JSON or YAML)",
+        "--split", help="Dataset split to load (e.g., train, validation)"
     )
+    parser.add_argument("--shuffle", action="store_true", help="Shuffle the dataset")
+
     args = parser.parse_args()
 
-    # Read the configuration
-    config = read_config(args.config)
-    data_config = config.get("data", {})
-    output_dir = config.get("output_dir", "/tmp/finetune-pipeline/data/")
-    output_data_dir = os.path.join(output_dir, "data")
-    # Load and format the data
-    formatted_data_paths, conversation_data_paths = load_and_format_data(
-        data_config, output_data_dir
-    )
+    if args.config:
+        # Use config file
+        print(f"Loading dataset using config file: {args.config}")
+        result = get_hf_dataset(output_dir=args.output_dir, config_path=args.config)
+    else:
+        # Use individual arguments
+        if not args.dataset_id:
+            raise ValueError("--dataset_id is required when not using --config")
+
+        dataset_kwargs = {}
+        if args.split:
+            dataset_kwargs["split"] = args.split
+        if args.shuffle:
+            dataset_kwargs["shuffle"] = args.shuffle
+
+        print(f"Loading dataset using individual arguments: {args.dataset_id}")
+        result = get_hf_dataset(
+            output_dir=args.output_dir,
+            dataset_id=args.dataset_id,
+            is_local=args.is_local,
+            dataset_kwargs=dataset_kwargs,
+            system_prompt=args.system_prompt,
+        )
+
+    print(f"Dataset processed and saved to: {result}")
+
+
+if __name__ == "__main__":
+    main()

+ 113 - 0
src/finetune_pipeline/data/formatter_functional.py

@@ -0,0 +1,113 @@
+import base64
+from typing import Dict, List
+
+
+def image_to_base64(image_path):
+    with open(image_path, "rb") as img:
+        return base64.b64encode(img.read()).decode("utf-8")
+
+
+def format_message_vllm(message: Dict) -> Dict:
+    """Format a message in vLLM format."""
+    contents = []
+    vllm_message = {}
+
+    for content in message["content"]:
+        if content["type"] == "text":
+            contents.append(content)
+        elif content["type"] == "image_url" or content["type"] == "image":
+            base64_image = image_to_base64(content["image_url"]["url"])
+            img_content = {
+                "type": "image_url",
+                "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
+            }
+            contents.append(img_content)
+        else:
+            raise ValueError(f"Unknown content type: {content['type']}")
+    vllm_message["role"] = message["role"]
+    vllm_message["content"] = contents
+    return vllm_message
+
+
+def format_conversation_vllm(conversation) -> Dict:
+    """Format a conversation in vLLM format."""
+    formatted_messages = []
+    for message in conversation.messages:
+        role = message["role"]
+        if role != "assistant":
+            formatted_messages.append(format_message_vllm(message))
+    return {"messages": formatted_messages}
+
+
+# TODO: Remove
+def format_conversation_openai(conversation) -> Dict:
+    """Format a conversation in OpenAI format."""
+    formatted_messages = []
+    for message in conversation.messages:
+        formatted_messages.append(format_message_openai(message))
+    return {"messages": formatted_messages}
+
+
+# TODO: Remove
+def format_data_torchtune(data: List[Conversation]) -> List[Dict]:
+    """Format data in Torchtune format."""
+    if data is None:
+        raise ValueError("No data provided to format_data()")
+
+    return [format_conversation_torchtune(conversation) for conversation in data]
+
+
+def format_data_vllm(data: List[Conversation]) -> List[Dict]:
+    """Format data in vLLM format."""
+    if data is None:
+        raise ValueError("No data provided to format_data()")
+
+    return [format_conversation_vllm(conversation) for conversation in data]
+
+
+# TODO: Remove
+def format_data_openai(data: List[Conversation]) -> List[Dict]:
+    """Format data in OpenAI format."""
+    if data is None:
+        raise ValueError("No data provided to format_data()")
+
+    return [format_conversation_openai(conversation) for conversation in data]
+
+
+# Dictionary to map format names to functions for easy dispatch
+FORMATTERS = {
+    "torchtune": {
+        "data": format_data_torchtune,
+        "conversation": format_conversation_torchtune,
+        "message": format_message_torchtune,
+    },
+    "vllm": {
+        "data": format_data_vllm,
+        "conversation": format_conversation_vllm,
+        "message": format_message_vllm,
+    },
+    "openai": {
+        "data": format_data_openai,
+        "conversation": format_conversation_openai,
+        "message": format_message_openai,
+    },
+}
+
+
+def format_data(data: List[Conversation], format_type: str) -> List[Dict]:
+    """
+    Generic function to format data in the specified format.
+
+    Args:
+        data: List of Conversation objects
+        format_type: One of "torchtune", "vllm", "openai"
+
+    Returns:
+        List of formatted data
+    """
+    if format_type not in FORMATTERS:
+        raise ValueError(
+            f"Unknown format type: {format_type}. Supported: {list(FORMATTERS.keys())}"
+        )
+
+    return FORMATTERS[format_type]["data"](data)

+ 42 - 0
src/finetune_pipeline/types.py

@@ -0,0 +1,42 @@
+"""
+Core type definitions for the finetune pipeline.
+
+This module contains TypedDicts and type definitions that are used across
+multiple modules in the pipeline.
+"""
+
+from typing import Dict, List, Optional, TypedDict, Union
+
+
+class MessageContent(TypedDict, total=False):
+    """Type definition for message content in LLM requests."""
+
+    type: str  # Required field
+    text: Optional[str]  # Optional field
+    image_url: Optional[Dict[str, str]]  # Optional field
+
+
+class Message(TypedDict):
+    """Type definition for a message in a LLM inference request."""
+
+    role: str
+    content: Union[str, List[MessageContent]]
+
+
+class TrainingConfig(TypedDict, total=False):
+    """Configuration for training parameters."""
+
+    learning_rate: float
+    batch_size: int
+    epochs: int
+    model_name: str
+    optimizer: Optional[str]
+
+
+class InferenceConfig(TypedDict, total=False):
+    """Configuration for inference parameters."""
+
+    model_path: str
+    max_tokens: Optional[int]
+    temperature: Optional[float]
+    top_p: Optional[float]

+ 46 - 0
src/finetune_pipeline/types/__init__.py

@@ -0,0 +1,46 @@
+"""
+Type definitions for the finetune pipeline.
+
+This package contains domain-specific type definitions organized by functional area.
+"""
+
+# Data processing types
+from .data import (
+    DataLoaderConfig,
+    DatasetStats,
+    FormatterConfig,
+    Message,
+    MessageContent,
+)
+
+# Inference types
+from .inference import (
+    InferenceConfig,
+    InferenceRequest,
+    InferenceResponse,
+    ModelServingConfig,
+    ServingMetrics,
+)
+
+# Training types
+from .training import CheckpointInfo, LoRAConfig, TrainingConfig, TrainingMetrics
+
+__all__ = [
+    # Data types
+    "DataLoaderConfig",
+    "DatasetStats",
+    "FormatterConfig",
+    "Message",
+    "MessageContent",
+    # Training types
+    "CheckpointInfo",
+    "LoRAConfig",
+    "TrainingConfig",
+    "TrainingMetrics",
+    # Inference types
+    "InferenceConfig",
+    "InferenceRequest",
+    "InferenceResponse",
+    "ModelServingConfig",
+    "ServingMetrics",
+]

+ 48 - 0
src/finetune_pipeline/types/data.py

@@ -0,0 +1,48 @@
+"""
+Data processing and formatting type definitions.
+
+Types related to data loading, formatting, and preprocessing.
+"""
+
+from typing import Dict, List, Optional, TypedDict, Union
+
+
+class MessageContent(TypedDict, total=False):
+    """Type definition for message content in LLM requests."""
+
+    type: str  # Required field
+    text: Optional[str]  # Optional field
+    image_url: Optional[Dict[str, str]]  # Optional field
+
+
+class Message(TypedDict):
+    """Type definition for a message in a LLM inference request."""
+
+    role: str
+    content: Union[str, List[MessageContent]]
+
+
+class DataLoaderConfig(TypedDict, total=False):
+    """Configuration for data loading."""
+
+    batch_size: int
+    shuffle: bool
+    data_path: str
+    validation_split: Optional[float]
+
+
+class FormatterConfig(TypedDict, total=False):
+    """Configuration for data formatting."""
+
+    format_type: str  # "torchtune", "vllm", "openai"
+    include_system_prompt: bool
+    max_sequence_length: Optional[int]
+
+
+class DatasetStats(TypedDict):
+    """Statistics about a dataset."""
+
+    total_conversations: int
+    total_messages: int
+    avg_messages_per_conversation: float
+    data_size_mb: float

+ 57 - 0
src/finetune_pipeline/types/inference.py

@@ -0,0 +1,57 @@
+"""
+Inference and serving type definitions.
+
+Types related to model inference, serving configurations, and inference results.
+"""
+
+from typing import Dict, List, Optional, TypedDict
+
+
+class InferenceConfig(TypedDict, total=False):
+    """Configuration for inference parameters."""
+
+    model_path: str
+    max_tokens: Optional[int]
+    temperature: Optional[float]
+    top_p: Optional[float]
+    top_k: Optional[int]
+    repetition_penalty: Optional[float]
+
+
+class InferenceRequest(TypedDict):
+    """Request for model inference."""
+
+    messages: List[Dict]  # List of Message objects
+    config: InferenceConfig
+    stream: Optional[bool]
+
+
+class InferenceResponse(TypedDict):
+    """Response from model inference."""
+
+    generated_text: str
+    input_tokens: int
+    output_tokens: int
+    total_time_ms: float
+    tokens_per_second: float
+
+
+class ModelServingConfig(TypedDict, total=False):
+    """Configuration for model serving."""
+
+    host: str
+    port: int
+    model_path: str
+    max_concurrent_requests: int
+    gpu_memory_utilization: Optional[float]
+    tensor_parallel_size: Optional[int]
+
+
+class ServingMetrics(TypedDict):
+    """Metrics for model serving."""
+
+    requests_per_second: float
+    average_latency_ms: float
+    active_requests: int
+    total_requests: int
+    error_rate: float

+ 50 - 0
src/finetune_pipeline/types/training.py

@@ -0,0 +1,50 @@
+"""
+Training and finetuning type definitions.
+
+Types related to model training, finetuning configurations, and training metrics.
+"""
+
+from typing import List, Optional, TypedDict
+
+
+class TrainingConfig(TypedDict, total=False):
+    """Configuration for training parameters."""
+
+    learning_rate: float
+    batch_size: int
+    epochs: int
+    model_name: str
+    optimizer: Optional[str]
+    weight_decay: Optional[float]
+    gradient_accumulation_steps: Optional[int]
+
+
+class LoRAConfig(TypedDict, total=False):
+    """Configuration for LoRA (Low-Rank Adaptation) finetuning."""
+
+    rank: int
+    alpha: int
+    dropout: float
+    target_modules: List[str]
+
+
+class TrainingMetrics(TypedDict):
+    """Training metrics collected during training."""
+
+    epoch: int
+    step: int
+    train_loss: float
+    validation_loss: Optional[float]
+    learning_rate: float
+    throughput_tokens_per_sec: float
+
+
+class CheckpointInfo(TypedDict):
+    """Information about a model checkpoint."""
+
+    checkpoint_path: str
+    epoch: int
+    step: int
+    model_name: str
+    training_config: TrainingConfig
+    metrics: TrainingMetrics