|
@@ -81,7 +81,7 @@ class Formatter(ABC):
|
|
|
pass
|
|
|
|
|
|
@abstractmethod
|
|
|
- def format_sample(self, sample) -> Union[Dict, str]:
|
|
|
+ def format_conversation(self, conversation) -> Union[Dict, str]:
|
|
|
"""
|
|
|
Format a sample. This method must be implemented by subclasses.
|
|
|
|
|
@@ -93,22 +93,53 @@ class Formatter(ABC):
|
|
|
"""
|
|
|
pass
|
|
|
|
|
|
- def read_data(self, data):
|
|
|
+ @abstractmethod
|
|
|
+ def format_message(self, message) -> Union[Dict, str]:
|
|
|
+ """
|
|
|
+ Format a message. This method must be implemented by subclasses.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ sample: Message object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Formatted message in the appropriate format
|
|
|
+ """
|
|
|
+ pass
|
|
|
+
|
|
|
+ def read_data(self, data, column_mapping=None):
|
|
|
"""
|
|
|
Read Hugging Face data and convert it to a list of Conversation objects.
|
|
|
|
|
|
Args:
|
|
|
data: Hugging Face dataset or iterable data
|
|
|
+ column_mapping: Optional dictionary mapping custom column names to standard field names.
|
|
|
+ Expected keys are 'input', 'output', and 'image'.
|
|
|
+ Example: {'input': 'question', 'output': 'answer', 'image': 'image_path'}
|
|
|
+ If None, default field names ('input', 'output', 'image') will be used.
|
|
|
|
|
|
Returns:
|
|
|
list: List of Conversation objects, each containing a list of Message objects
|
|
|
"""
|
|
|
+ # Default column mapping if none provided
|
|
|
+ if column_mapping is None:
|
|
|
+ column_mapping = {"input": "input", "output": "output", "image": "image"}
|
|
|
+
|
|
|
+ # Validate column mapping
|
|
|
+ required_fields = ["input", "output"]
|
|
|
+ for field in required_fields:
|
|
|
+ if field not in column_mapping:
|
|
|
+ raise ValueError(f"Column mapping must include '{field}' field")
|
|
|
+
|
|
|
conversations = []
|
|
|
for item in data:
|
|
|
- # Extract fields from the Hugging Face dataset item
|
|
|
- image = item.get("image", None)
|
|
|
- input_text = item.get("input", "")
|
|
|
- output_label = item.get("output", "")
|
|
|
+ # Extract fields from the Hugging Face dataset item using the column mapping
|
|
|
+ image_field = column_mapping.get("image")
|
|
|
+ input_field = column_mapping.get("input")
|
|
|
+ output_field = column_mapping.get("output")
|
|
|
+
|
|
|
+ image = item.get(image_field, None) if image_field else None
|
|
|
+ input_text = item.get(input_field, "")
|
|
|
+ output_label = item.get(output_field, "")
|
|
|
|
|
|
# Create a new conversation
|
|
|
conversation = Conversation()
|
|
@@ -156,22 +187,31 @@ class TorchtuneFormatter(Formatter):
|
|
|
if data is not None:
|
|
|
self.conversation_data = self.read_data(data)
|
|
|
|
|
|
- def format_data(self, data):
|
|
|
+ def format_data(self, data=None):
|
|
|
"""
|
|
|
Format the data.
|
|
|
|
|
|
Args:
|
|
|
- data: List of Conversation objects
|
|
|
+ data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
|
|
|
|
|
|
Returns:
|
|
|
list: List of formatted data
|
|
|
"""
|
|
|
+ # Use stored conversation_data if no data is provided
|
|
|
+ data_to_format = data if data is not None else self.conversation_data
|
|
|
+
|
|
|
+ # Check if we have data to format
|
|
|
+ if data_to_format is None:
|
|
|
+ raise ValueError(
|
|
|
+ "No data to format. Either provide data to format_data() or initialize with data."
|
|
|
+ )
|
|
|
+
|
|
|
formatted_data = []
|
|
|
- for conversation in data:
|
|
|
- formatted_data.append(self.format_sample(conversation))
|
|
|
+ for conversation in data_to_format:
|
|
|
+ formatted_data.append(self.format_conversation(conversation))
|
|
|
return formatted_data
|
|
|
|
|
|
- def format_sample(self, sample):
|
|
|
+ def format_conversation(self, conversation):
|
|
|
"""
|
|
|
Format a sample.
|
|
|
|
|
@@ -182,11 +222,11 @@ class TorchtuneFormatter(Formatter):
|
|
|
dict: Formatted sample in Torchtune format
|
|
|
"""
|
|
|
formatted_messages = []
|
|
|
- for message in sample.messages:
|
|
|
- formatted_messages.append(self.format(message))
|
|
|
+ for message in conversation.messages:
|
|
|
+ formatted_messages.append(self.format_message(message))
|
|
|
return {"messages": formatted_messages}
|
|
|
|
|
|
- def format(self, message):
|
|
|
+ def format_message(self, message):
|
|
|
"""
|
|
|
Format a message in Torchtune format.
|
|
|
|
|
@@ -206,22 +246,43 @@ class vLLMFormatter(Formatter):
|
|
|
Formatter for vLLM format.
|
|
|
"""
|
|
|
|
|
|
- def format_data(self, data):
|
|
|
+ def __init__(self, data=None):
|
|
|
+ """
|
|
|
+ Initialize the formatter.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: Optional data to initialize with
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+ self.conversation_data = None
|
|
|
+ if data is not None:
|
|
|
+ self.conversation_data = self.read_data(data)
|
|
|
+
|
|
|
+ def format_data(self, data=None):
|
|
|
"""
|
|
|
Format the data.
|
|
|
|
|
|
Args:
|
|
|
- data: List of Conversation objects
|
|
|
+ data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
|
|
|
|
|
|
Returns:
|
|
|
list: List of formatted data in vLLM format
|
|
|
"""
|
|
|
+ # Use stored conversation_data if no data is provided
|
|
|
+ data_to_format = data if data is not None else self.conversation_data
|
|
|
+
|
|
|
+ # Check if we have data to format
|
|
|
+ if data_to_format is None:
|
|
|
+ raise ValueError(
|
|
|
+ "No data to format. Either provide data to format_data() or initialize with data."
|
|
|
+ )
|
|
|
+
|
|
|
formatted_data = []
|
|
|
- for conversation in data:
|
|
|
- formatted_data.append(self.format_sample(conversation))
|
|
|
+ for conversation in data_to_format:
|
|
|
+ formatted_data.append(self.format_conversation(conversation))
|
|
|
return formatted_data
|
|
|
|
|
|
- def format_sample(self, sample):
|
|
|
+ def format_conversation(self, conversation):
|
|
|
"""
|
|
|
Format a sample.
|
|
|
|
|
@@ -232,11 +293,11 @@ class vLLMFormatter(Formatter):
|
|
|
str: Formatted sample in vLLM format
|
|
|
"""
|
|
|
formatted_messages = []
|
|
|
- for message in sample.messages:
|
|
|
- formatted_messages.append(self.format(message))
|
|
|
+ for message in conversation.messages:
|
|
|
+ formatted_messages.append(self.format_message(message))
|
|
|
return "\n".join(formatted_messages)
|
|
|
|
|
|
- def format(self, message):
|
|
|
+ def format_message(self, message):
|
|
|
"""
|
|
|
Format a message in vLLM format.
|
|
|
|
|
@@ -266,22 +327,43 @@ class OpenAIFormatter(Formatter):
|
|
|
Formatter for OpenAI format.
|
|
|
"""
|
|
|
|
|
|
- def format_data(self, data):
|
|
|
+ def __init__(self, data=None):
|
|
|
+ """
|
|
|
+ Initialize the formatter.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ data: Optional data to initialize with
|
|
|
+ """
|
|
|
+ super().__init__()
|
|
|
+ self.conversation_data = None
|
|
|
+ if data is not None:
|
|
|
+ self.conversation_data = self.read_data(data)
|
|
|
+
|
|
|
+ def format_data(self, data=None):
|
|
|
"""
|
|
|
Format the data.
|
|
|
|
|
|
Args:
|
|
|
- data: List of Conversation objects
|
|
|
+ data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
|
|
|
|
|
|
Returns:
|
|
|
dict: Formatted data in OpenAI format
|
|
|
"""
|
|
|
+ # Use stored conversation_data if no data is provided
|
|
|
+ data_to_format = data if data is not None else self.conversation_data
|
|
|
+
|
|
|
+ # Check if we have data to format
|
|
|
+ if data_to_format is None:
|
|
|
+ raise ValueError(
|
|
|
+ "No data to format. Either provide data to format_data() or initialize with data."
|
|
|
+ )
|
|
|
+
|
|
|
formatted_data = []
|
|
|
- for conversation in data:
|
|
|
- formatted_data.append(self.format_sample(conversation))
|
|
|
+ for conversation in data_to_format:
|
|
|
+ formatted_data.append(self.format_conversation(conversation))
|
|
|
return formatted_data
|
|
|
|
|
|
- def format_sample(self, sample):
|
|
|
+ def format_conversation(self, conversation):
|
|
|
"""
|
|
|
Format a sample.
|
|
|
|
|
@@ -292,11 +374,11 @@ class OpenAIFormatter(Formatter):
|
|
|
dict: Formatted sample in OpenAI format
|
|
|
"""
|
|
|
formatted_messages = []
|
|
|
- for message in sample.messages:
|
|
|
- formatted_messages.append(self.format(message))
|
|
|
+ for message in conversation.messages:
|
|
|
+ formatted_messages.append(self.format_message(message))
|
|
|
return {"messages": formatted_messages}
|
|
|
|
|
|
- def format(self, message):
|
|
|
+ def format_message(self, message):
|
|
|
"""
|
|
|
Format a message in OpenAI format.
|
|
|
|