| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394 |
- from abc import ABC, abstractmethod
- from typing import Dict, List, Optional, TypedDict, Union
- class MessageContent(TypedDict, total=False):
- """Type definition for message content in LLM requests."""
- type: str # Required field
- text: Optional[str] # Optional field
- image_url: Optional[Dict[str, str]] # Optional field
- class Message(TypedDict):
- """Type definition for a message in a LLM inference request."""
- role: str
- content: Union[str, List[MessageContent]]
- class Conversation:
- """
- Data class representing a conversation which are list of messages.
- """
- def __init__(self, messages=None):
- self.messages = messages if messages is not None else []
- def add_message(self, message):
- """
- Add a message to the conversation.
- Args:
- message: Message object to add
- """
- self.messages.append(message)
- def get_message(self, index):
- """
- Get a message at a specific index.
- Args:
- index: Index of the message to retrieve
- Returns:
- Message: The message at the specified index
- Raises:
- IndexError: If the index is out of bounds
- """
- if index < 0 or index >= len(self.messages):
- raise IndexError(
- f"Message index {index} out of range (0-{len(self.messages)-1})"
- )
- return self.messages[index]
- class Formatter(ABC):
- """
- Abstract base class for formatters that convert messages to different formats.
- """
- def __init__(self):
- """
- Initialize the formatter.
- Subclasses can override this method to add specific initialization parameters.
- """
- pass
- @abstractmethod
- def format_data(self, data) -> List:
- """
- Format the message. This method must be implemented by subclasses.
- Args:
- data: List of Conversation objects
- Returns:
- List of formatted data
- """
- pass
- @abstractmethod
- def format_conversation(self, conversation) -> Union[Dict, str]:
- """
- Format a sample. This method must be implemented by subclasses.
- Args:
- sample: Conversation object
- Returns:
- Formatted sample in the appropriate format
- """
- pass
- @abstractmethod
- def format_message(self, message) -> Union[Dict, str]:
- """
- Format a message. This method must be implemented by subclasses.
- Args:
- sample: Message object
- Returns:
- Formatted message in the appropriate format
- """
- pass
- def read_data(self, data, column_mapping=None):
- """
- Read Hugging Face data and convert it to a list of Conversation objects.
- Args:
- data: Hugging Face dataset or iterable data
- column_mapping: Optional dictionary mapping custom column names to standard field names.
- Expected keys are 'input', 'output', and 'image'.
- Example: {'input': 'question', 'output': 'answer', 'image': 'image_path'}
- If None, default field names ('input', 'output', 'image') will be used.
- Returns:
- list: List of Conversation objects, each containing a list of Message objects
- """
- # Default column mapping if none provided
- if column_mapping is None:
- column_mapping = {"input": "input", "output": "output", "image": "image"}
- # Validate column mapping
- required_fields = ["input", "output"]
- for field in required_fields:
- if field not in column_mapping:
- raise ValueError(f"Column mapping must include '{field}' field")
- conversations = []
- for item in data:
- # Extract fields from the Hugging Face dataset item using the column mapping
- image_field = column_mapping.get("image")
- input_field = column_mapping.get("input")
- output_field = column_mapping.get("output")
- image = item.get(image_field, None) if image_field else None
- input_text = item.get(input_field, "")
- output_label = item.get(output_field, "")
- # Create a new conversation
- conversation = Conversation()
- # Create user content and user message
- user_content = [
- {"type": "text", "text": input_text},
- ]
- # Add image to user content
- if image is not None:
- user_content.append({"type": "image", "image_url": {"url": image}})
- user_message = {"role": "user", "content": user_content}
- # Create assistant message with text content
- assistant_content = [
- {"type": "text", "text": output_label},
- ]
- assistant_message = {"role": "assistant", "content": assistant_content}
- # Add messages to the conversation
- conversation.add_message(user_message)
- conversation.add_message(assistant_message)
- # Add the conversation to the list
- conversations.append(conversation)
- return conversations
- class TorchtuneFormatter(Formatter):
- """
- Formatter for Torchtune format.
- """
- def __init__(self, data=None):
- """
- Initialize the formatter.
- Args:
- data: Optional data to initialize with
- """
- super().__init__()
- self.conversation_data = None
- if data is not None:
- self.conversation_data = self.read_data(data)
- def format_data(self, data=None):
- """
- Format the data.
- Args:
- data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
- Returns:
- list: List of formatted data
- """
- # Use stored conversation_data if no data is provided
- data_to_format = data if data is not None else self.conversation_data
- # Check if we have data to format
- if data_to_format is None:
- raise ValueError(
- "No data to format. Either provide data to format_data() or initialize with data."
- )
- formatted_data = []
- for conversation in data_to_format:
- formatted_data.append(self.format_conversation(conversation))
- return formatted_data
- def format_conversation(self, conversation):
- """
- Format a sample.
- Args:
- sample: Conversation object
- Returns:
- dict: Formatted sample in Torchtune format
- """
- formatted_messages = []
- for message in conversation.messages:
- formatted_messages.append(self.format_message(message))
- return {"messages": formatted_messages}
- def format_message(self, message):
- """
- Format a message in Torchtune format.
- Args:
- message: Message object to format
- Returns:
- dict: Formatted message in Torchtune format
- """
- # For Torchtune format, we can return the Message as is
- # since it's already in a compatible format
- return message
- class vLLMFormatter(Formatter):
- """
- Formatter for vLLM format.
- """
- def __init__(self, data=None):
- """
- Initialize the formatter.
- Args:
- data: Optional data to initialize with
- """
- super().__init__()
- self.conversation_data = None
- if data is not None:
- self.conversation_data = self.read_data(data)
- def format_data(self, data=None):
- """
- Format the data.
- Args:
- data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
- Returns:
- list: List of formatted data in vLLM format
- """
- # Use stored conversation_data if no data is provided
- data_to_format = data if data is not None else self.conversation_data
- # Check if we have data to format
- if data_to_format is None:
- raise ValueError(
- "No data to format. Either provide data to format_data() or initialize with data."
- )
- formatted_data = []
- for conversation in data_to_format:
- formatted_data.append(self.format_conversation(conversation))
- return formatted_data
- def format_conversation(self, conversation):
- """
- Format a sample.
- Args:
- sample: Conversation object
- Returns:
- str: Formatted sample in vLLM format
- """
- formatted_messages = []
- for message in conversation.messages:
- formatted_messages.append(self.format_message(message))
- return "\n".join(formatted_messages)
- def format_message(self, message):
- """
- Format a message in vLLM format.
- Args:
- message: Message object to format
- Returns:
- str: Formatted message in vLLM format
- """
- role = message["role"]
- content = message["content"]
- # Handle different content types
- if isinstance(content, str):
- return f"{role}: {content}"
- else:
- # For multimodal content, extract text parts
- text_parts = []
- for item in content:
- if item["type"] == "text" and "text" in item:
- text_parts.append(item["text"])
- return f"{role}: {' '.join(text_parts)}"
- class OpenAIFormatter(Formatter):
- """
- Formatter for OpenAI format.
- """
- def __init__(self, data=None):
- """
- Initialize the formatter.
- Args:
- data: Optional data to initialize with
- """
- super().__init__()
- self.conversation_data = None
- if data is not None:
- self.conversation_data = self.read_data(data)
- def format_data(self, data=None):
- """
- Format the data.
- Args:
- data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
- Returns:
- dict: Formatted data in OpenAI format
- """
- # Use stored conversation_data if no data is provided
- data_to_format = data if data is not None else self.conversation_data
- # Check if we have data to format
- if data_to_format is None:
- raise ValueError(
- "No data to format. Either provide data to format_data() or initialize with data."
- )
- formatted_data = []
- for conversation in data_to_format:
- formatted_data.append(self.format_conversation(conversation))
- return formatted_data
- def format_conversation(self, conversation):
- """
- Format a sample.
- Args:
- sample: Conversation object
- Returns:
- dict: Formatted sample in OpenAI format
- """
- formatted_messages = []
- for message in conversation.messages:
- formatted_messages.append(self.format_message(message))
- return {"messages": formatted_messages}
- def format_message(self, message):
- """
- Format a message in OpenAI format.
- Args:
- message: Message object to format
- Returns:
- dict: Formatted message in OpenAI format
- """
- # For OpenAI format, we can return the Message as is
- # since it's already in a compatible format
- return message
|