|
@@ -1,39 +1,142 @@
|
|
|
-class Message:
|
|
|
|
|
|
|
+from abc import ABC, abstractmethod
|
|
|
|
|
+from typing import Dict, List, Optional, TypedDict, Union
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class MessageContent(TypedDict, total=False):
|
|
|
|
|
+ """Type definition for message content in LLM requests."""
|
|
|
|
|
+
|
|
|
|
|
+ type: str # Required field
|
|
|
|
|
+ text: Optional[str] # Optional field
|
|
|
|
|
+ image_url: Optional[Dict[str, str]] # Optional field
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Message(TypedDict):
|
|
|
|
|
+ """Type definition for a message in a LLM inference request."""
|
|
|
|
|
+
|
|
|
|
|
+ role: str
|
|
|
|
|
+ content: Union[str, List[MessageContent]]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Conversation:
|
|
|
"""
|
|
"""
|
|
|
- Data class representing a message with a role and content.
|
|
|
|
|
|
|
+ Data class representing a conversation which are list of messages.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def __init__(self, role, content,content_type):
|
|
|
|
|
- self.role = role
|
|
|
|
|
- self.content = content
|
|
|
|
|
- self.type = content_type
|
|
|
|
|
|
|
+ def __init__(self, messages=None):
|
|
|
|
|
+ self.messages = messages if messages is not None else []
|
|
|
|
|
+
|
|
|
|
|
+ def add_message(self, message):
|
|
|
|
|
+ """
|
|
|
|
|
+ Add a message to the conversation.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ message: Message object to add
|
|
|
|
|
+ """
|
|
|
|
|
+ self.messages.append(message)
|
|
|
|
|
+
|
|
|
|
|
+ def get_message(self, index):
|
|
|
|
|
+ """
|
|
|
|
|
+ Get a message at a specific index.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ index: Index of the message to retrieve
|
|
|
|
|
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Message: The message at the specified index
|
|
|
|
|
|
|
|
-class Formatter:
|
|
|
|
|
|
|
+ Raises:
|
|
|
|
|
+ IndexError: If the index is out of bounds
|
|
|
|
|
+ """
|
|
|
|
|
+ if index < 0 or index >= len(self.messages):
|
|
|
|
|
+ raise IndexError(
|
|
|
|
|
+ f"Message index {index} out of range (0-{len(self.messages)-1})"
|
|
|
|
|
+ )
|
|
|
|
|
+ return self.messages[index]
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class Formatter(ABC):
|
|
|
"""
|
|
"""
|
|
|
- Base class for formatters that convert messages to different formats.
|
|
|
|
|
|
|
+ Abstract base class for formatters that convert messages to different formats.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def __init__(self, message):
|
|
|
|
|
- self.message = message
|
|
|
|
|
|
|
+ def __init__(self):
|
|
|
|
|
+ """
|
|
|
|
|
+ Initialize the formatter.
|
|
|
|
|
|
|
|
- def format_data(self, data):
|
|
|
|
|
|
|
+ Subclasses can override this method to add specific initialization parameters.
|
|
|
"""
|
|
"""
|
|
|
- Format the message. This method should be overridden by subclasses.
|
|
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ def format_data(self, data) -> List:
|
|
|
"""
|
|
"""
|
|
|
- raise NotImplementedError("Subclasses must implement format()")
|
|
|
|
|
|
|
+ Format the message. This method must be implemented by subclasses.
|
|
|
|
|
|
|
|
- def format_sample(self, sample):
|
|
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: List of Conversation objects
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ List of formatted data
|
|
|
|
|
+ """
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+ @abstractmethod
|
|
|
|
|
+ def format_sample(self, sample) -> Union[Dict, str]:
|
|
|
"""
|
|
"""
|
|
|
- Format a sample. This method should be overridden by subclasses.
|
|
|
|
|
|
|
+ Format a sample. This method must be implemented by subclasses.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ sample: Conversation object
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ Formatted sample in the appropriate format
|
|
|
"""
|
|
"""
|
|
|
- raise NotImplementedError("Subclasses must implement format_sample()")
|
|
|
|
|
|
|
+ pass
|
|
|
|
|
|
|
|
def read_data(self, data):
|
|
def read_data(self, data):
|
|
|
"""
|
|
"""
|
|
|
- Format a sample. This method should be overridden by subclasses.
|
|
|
|
|
|
|
+ Read Hugging Face data and convert it to a list of Conversation objects.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: Hugging Face dataset or iterable data
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list: List of Conversation objects, each containing a list of Message objects
|
|
|
"""
|
|
"""
|
|
|
- raise NotImplementedError("Subclasses must implement format_sample()")
|
|
|
|
|
|
|
+ conversations = []
|
|
|
|
|
+ for item in data:
|
|
|
|
|
+ # Extract fields from the Hugging Face dataset item
|
|
|
|
|
+ image = item.get("image", None)
|
|
|
|
|
+ input_text = item.get("input", "")
|
|
|
|
|
+ output_label = item.get("output", "")
|
|
|
|
|
+
|
|
|
|
|
+ # Create a new conversation
|
|
|
|
|
+ conversation = Conversation()
|
|
|
|
|
+
|
|
|
|
|
+ # Create user content and user message
|
|
|
|
|
+ user_content = [
|
|
|
|
|
+ {"type": "text", "text": input_text},
|
|
|
|
|
+ ]
|
|
|
|
|
+ # Add image to user content
|
|
|
|
|
+ if image is not None:
|
|
|
|
|
+ user_content.append({"type": "image", "image_url": {"url": image}})
|
|
|
|
|
+
|
|
|
|
|
+ user_message = {"role": "user", "content": user_content}
|
|
|
|
|
+
|
|
|
|
|
+ # Create assistant message with text content
|
|
|
|
|
+ assistant_content = [
|
|
|
|
|
+ {"type": "text", "text": output_label},
|
|
|
|
|
+ ]
|
|
|
|
|
+ assistant_message = {"role": "assistant", "content": assistant_content}
|
|
|
|
|
+
|
|
|
|
|
+ # Add messages to the conversation
|
|
|
|
|
+ conversation.add_message(user_message)
|
|
|
|
|
+ conversation.add_message(assistant_message)
|
|
|
|
|
+
|
|
|
|
|
+ # Add the conversation to the list
|
|
|
|
|
+ conversations.append(conversation)
|
|
|
|
|
+
|
|
|
|
|
+ return conversations
|
|
|
|
|
|
|
|
|
|
|
|
|
class TorchtuneFormatter(Formatter):
|
|
class TorchtuneFormatter(Formatter):
|
|
@@ -41,20 +144,61 @@ class TorchtuneFormatter(Formatter):
|
|
|
Formatter for Torchtune format.
|
|
Formatter for Torchtune format.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- data = None
|
|
|
|
|
|
|
+ def __init__(self, data=None):
|
|
|
|
|
+ """
|
|
|
|
|
+ Initialize the formatter.
|
|
|
|
|
|
|
|
- def read_data(self, data):
|
|
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: Optional data to initialize with
|
|
|
|
|
+ """
|
|
|
|
|
+ super().__init__()
|
|
|
|
|
+ self.conversation_data = None
|
|
|
|
|
+ if data is not None:
|
|
|
|
|
+ self.conversation_data = self.read_data(data)
|
|
|
|
|
+
|
|
|
|
|
+ def format_data(self, data):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format the data.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: List of Conversation objects
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list: List of formatted data
|
|
|
|
|
+ """
|
|
|
|
|
+ formatted_data = []
|
|
|
|
|
+ for conversation in data:
|
|
|
|
|
+ formatted_data.append(self.format_sample(conversation))
|
|
|
|
|
+ return formatted_data
|
|
|
|
|
+
|
|
|
|
|
+ def format_sample(self, sample):
|
|
|
"""
|
|
"""
|
|
|
- Format a sample. This method should be overridden by subclasses.
|
|
|
|
|
|
|
+ Format a sample.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ sample: Conversation object
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: Formatted sample in Torchtune format
|
|
|
"""
|
|
"""
|
|
|
- raise NotImplementedError("Subclasses must implement format_sample()")
|
|
|
|
|
|
|
+ formatted_messages = []
|
|
|
|
|
+ for message in sample.messages:
|
|
|
|
|
+ formatted_messages.append(self.format(message))
|
|
|
|
|
+ return {"messages": formatted_messages}
|
|
|
|
|
|
|
|
- def format(self):
|
|
|
|
|
|
|
+ def format(self, message):
|
|
|
"""
|
|
"""
|
|
|
- Format the message in Torchtune format.
|
|
|
|
|
|
|
+ Format a message in Torchtune format.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ message: Message object to format
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: Formatted message in Torchtune format
|
|
|
"""
|
|
"""
|
|
|
- # Implementation for Torchtune format
|
|
|
|
|
- return {"role": self.message.role, "content": self.message.content}
|
|
|
|
|
|
|
+ # For Torchtune format, we can return the Message as is
|
|
|
|
|
+ # since it's already in a compatible format
|
|
|
|
|
+ return message
|
|
|
|
|
|
|
|
|
|
|
|
|
class vLLMFormatter(Formatter):
|
|
class vLLMFormatter(Formatter):
|
|
@@ -62,22 +206,106 @@ class vLLMFormatter(Formatter):
|
|
|
Formatter for vLLM format.
|
|
Formatter for vLLM format.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def format(self):
|
|
|
|
|
|
|
+ def format_data(self, data):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format the data.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: List of Conversation objects
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list: List of formatted data in vLLM format
|
|
|
|
|
+ """
|
|
|
|
|
+ formatted_data = []
|
|
|
|
|
+ for conversation in data:
|
|
|
|
|
+ formatted_data.append(self.format_sample(conversation))
|
|
|
|
|
+ return formatted_data
|
|
|
|
|
+
|
|
|
|
|
+ def format_sample(self, sample):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format a sample.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ sample: Conversation object
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ str: Formatted sample in vLLM format
|
|
|
|
|
+ """
|
|
|
|
|
+ formatted_messages = []
|
|
|
|
|
+ for message in sample.messages:
|
|
|
|
|
+ formatted_messages.append(self.format(message))
|
|
|
|
|
+ return "\n".join(formatted_messages)
|
|
|
|
|
+
|
|
|
|
|
+ def format(self, message):
|
|
|
"""
|
|
"""
|
|
|
- Format the message in vLLM format.
|
|
|
|
|
|
|
+ Format a message in vLLM format.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ message: Message object to format
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ str: Formatted message in vLLM format
|
|
|
"""
|
|
"""
|
|
|
- # Implementation for vLLM format
|
|
|
|
|
- return f"{self.message.role}: {self.message.content}"
|
|
|
|
|
|
|
+ role = message["role"]
|
|
|
|
|
+ content = message["content"]
|
|
|
|
|
+
|
|
|
|
|
+ # Handle different content types
|
|
|
|
|
+ if isinstance(content, str):
|
|
|
|
|
+ return f"{role}: {content}"
|
|
|
|
|
+ else:
|
|
|
|
|
+ # For multimodal content, extract text parts
|
|
|
|
|
+ text_parts = []
|
|
|
|
|
+ for item in content:
|
|
|
|
|
+ if item["type"] == "text" and "text" in item:
|
|
|
|
|
+ text_parts.append(item["text"])
|
|
|
|
|
+ return f"{role}: {' '.join(text_parts)}"
|
|
|
|
|
|
|
|
|
|
|
|
|
class OpenAIFormatter(Formatter):
|
|
class OpenAIFormatter(Formatter):
|
|
|
"""
|
|
"""
|
|
|
- Formatter for Hugging Face format.
|
|
|
|
|
|
|
+ Formatter for OpenAI format.
|
|
|
"""
|
|
"""
|
|
|
|
|
|
|
|
- def format(self):
|
|
|
|
|
|
|
+ def format_data(self, data):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format the data.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ data: List of Conversation objects
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: Formatted data in OpenAI format
|
|
|
"""
|
|
"""
|
|
|
- Format the message in Hugging Face format.
|
|
|
|
|
|
|
+ formatted_data = []
|
|
|
|
|
+ for conversation in data:
|
|
|
|
|
+ formatted_data.append(self.format_sample(conversation))
|
|
|
|
|
+ return formatted_data
|
|
|
|
|
+
|
|
|
|
|
+ def format_sample(self, sample):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format a sample.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ sample: Conversation object
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: Formatted sample in OpenAI format
|
|
|
|
|
+ """
|
|
|
|
|
+ formatted_messages = []
|
|
|
|
|
+ for message in sample.messages:
|
|
|
|
|
+ formatted_messages.append(self.format(message))
|
|
|
|
|
+ return {"messages": formatted_messages}
|
|
|
|
|
+
|
|
|
|
|
+ def format(self, message):
|
|
|
|
|
+ """
|
|
|
|
|
+ Format a message in OpenAI format.
|
|
|
|
|
+
|
|
|
|
|
+ Args:
|
|
|
|
|
+ message: Message object to format
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ dict: Formatted message in OpenAI format
|
|
|
"""
|
|
"""
|
|
|
- # Implementation for OpenAI format
|
|
|
|
|
- raise NotImplementedError("Subclasses must implement format_sample()")
|
|
|
|
|
|
|
+ # For OpenAI format, we can return the Message as is
|
|
|
|
|
+ # since it's already in a compatible format
|
|
|
|
|
+ return message
|