Forráskód Böngészése

moved formatter to data folder

khare19yash 3 hónapja
szülő
commit
15e63acd13
1 módosított fájl, 0 hozzáadás és 311 törlés
  1. 0 311
      src/finetune_pipeline/formatter.py

+ 0 - 311
src/finetune_pipeline/formatter.py

@@ -1,311 +0,0 @@
-from abc import ABC, abstractmethod
-from typing import Dict, List, Optional, TypedDict, Union
-
-
-class MessageContent(TypedDict, total=False):
-    """Type definition for message content in LLM requests."""
-
-    type: str  # Required field
-    text: Optional[str]  # Optional field
-    image_url: Optional[Dict[str, str]]  # Optional field
-
-
-class Message(TypedDict):
-    """Type definition for a message in a LLM inference request."""
-
-    role: str
-    content: Union[str, List[MessageContent]]
-
-
-class Conversation:
-    """
-    Data class representing a conversation which are list of messages.
-    """
-
-    def __init__(self, messages=None):
-        self.messages = messages if messages is not None else []
-
-    def add_message(self, message):
-        """
-        Add a message to the conversation.
-
-        Args:
-            message: Message object to add
-        """
-        self.messages.append(message)
-
-    def get_message(self, index):
-        """
-        Get a message at a specific index.
-
-        Args:
-            index: Index of the message to retrieve
-
-        Returns:
-            Message: The message at the specified index
-
-        Raises:
-            IndexError: If the index is out of bounds
-        """
-        if index < 0 or index >= len(self.messages):
-            raise IndexError(
-                f"Message index {index} out of range (0-{len(self.messages)-1})"
-            )
-        return self.messages[index]
-
-
-class Formatter(ABC):
-    """
-    Abstract base class for formatters that convert messages to different formats.
-    """
-
-    def __init__(self):
-        """
-        Initialize the formatter.
-
-        Subclasses can override this method to add specific initialization parameters.
-        """
-        pass
-
-    @abstractmethod
-    def format_data(self, data) -> List:
-        """
-        Format the message. This method must be implemented by subclasses.
-
-        Args:
-            data: List of Conversation objects
-
-        Returns:
-            List of formatted data
-        """
-        pass
-
-    @abstractmethod
-    def format_sample(self, sample) -> Union[Dict, str]:
-        """
-        Format a sample. This method must be implemented by subclasses.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            Formatted sample in the appropriate format
-        """
-        pass
-
-    def read_data(self, data):
-        """
-        Read Hugging Face data and convert it to a list of Conversation objects.
-
-        Args:
-            data: Hugging Face dataset or iterable data
-
-        Returns:
-            list: List of Conversation objects, each containing a list of Message objects
-        """
-        conversations = []
-        for item in data:
-            # Extract fields from the Hugging Face dataset item
-            image = item.get("image", None)
-            input_text = item.get("input", "")
-            output_label = item.get("output", "")
-
-            # Create a new conversation
-            conversation = Conversation()
-
-            # Create user content and user message
-            user_content = [
-                {"type": "text", "text": input_text},
-            ]
-            # Add image to user content
-            if image is not None:
-                user_content.append({"type": "image", "image_url": {"url": image}})
-
-            user_message = {"role": "user", "content": user_content}
-
-            # Create assistant message with text content
-            assistant_content = [
-                {"type": "text", "text": output_label},
-            ]
-            assistant_message = {"role": "assistant", "content": assistant_content}
-
-            # Add messages to the conversation
-            conversation.add_message(user_message)
-            conversation.add_message(assistant_message)
-
-            # Add the conversation to the list
-            conversations.append(conversation)
-
-        return conversations
-
-
-class TorchtuneFormatter(Formatter):
-    """
-    Formatter for Torchtune format.
-    """
-
-    def __init__(self, data=None):
-        """
-        Initialize the formatter.
-
-        Args:
-            data: Optional data to initialize with
-        """
-        super().__init__()
-        self.conversation_data = None
-        if data is not None:
-            self.conversation_data = self.read_data(data)
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects
-
-        Returns:
-            list: List of formatted data
-        """
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
-        return formatted_data
-
-    def format_sample(self, sample):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            dict: Formatted sample in Torchtune format
-        """
-        formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
-        return {"messages": formatted_messages}
-
-    def format(self, message):
-        """
-        Format a message in Torchtune format.
-
-        Args:
-            message: Message object to format
-
-        Returns:
-            dict: Formatted message in Torchtune format
-        """
-        # For Torchtune format, we can return the Message as is
-        # since it's already in a compatible format
-        return message
-
-
-class vLLMFormatter(Formatter):
-    """
-    Formatter for vLLM format.
-    """
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects
-
-        Returns:
-            list: List of formatted data in vLLM format
-        """
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
-        return formatted_data
-
-    def format_sample(self, sample):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            str: Formatted sample in vLLM format
-        """
-        formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
-        return "\n".join(formatted_messages)
-
-    def format(self, message):
-        """
-        Format a message in vLLM format.
-
-        Args:
-            message: Message object to format
-
-        Returns:
-            str: Formatted message in vLLM format
-        """
-        role = message["role"]
-        content = message["content"]
-
-        # Handle different content types
-        if isinstance(content, str):
-            return f"{role}: {content}"
-        else:
-            # For multimodal content, extract text parts
-            text_parts = []
-            for item in content:
-                if item["type"] == "text" and "text" in item:
-                    text_parts.append(item["text"])
-            return f"{role}: {' '.join(text_parts)}"
-
-
-class OpenAIFormatter(Formatter):
-    """
-    Formatter for OpenAI format.
-    """
-
-    def format_data(self, data):
-        """
-        Format the data.
-
-        Args:
-            data: List of Conversation objects
-
-        Returns:
-            dict: Formatted data in OpenAI format
-        """
-        formatted_data = []
-        for conversation in data:
-            formatted_data.append(self.format_sample(conversation))
-        return formatted_data
-
-    def format_sample(self, sample):
-        """
-        Format a sample.
-
-        Args:
-            sample: Conversation object
-
-        Returns:
-            dict: Formatted sample in OpenAI format
-        """
-        formatted_messages = []
-        for message in sample.messages:
-            formatted_messages.append(self.format(message))
-        return {"messages": formatted_messages}
-
-    def format(self, message):
-        """
-        Format a message in OpenAI format.
-
-        Args:
-            message: Message object to format
-
-        Returns:
-            dict: Formatted message in OpenAI format
-        """
-        # For OpenAI format, we can return the Message as is
-        # since it's already in a compatible format
-        return message