|
|
@@ -1,10 +1,24 @@
|
|
|
-import base64
|
|
|
-from typing import Dict, List
|
|
|
+from typing import Dict, List, Union
|
|
|
|
|
|
|
|
|
-def image_to_base64(image_path):
|
|
|
- with open(image_path, "rb") as img:
|
|
|
- return base64.b64encode(img.read()).decode("utf-8")
|
|
|
+def format_message_torchtune(message: Dict) -> Dict:
|
|
|
+ """Format a message in Torchtune format."""
|
|
|
+ return message
|
|
|
+
|
|
|
+
|
|
|
+def format_message_openai(message: Dict) -> Dict:
|
|
|
+ """Format a message in OpenAI format."""
|
|
|
+ contents = []
|
|
|
+ for content in message["content"]:
|
|
|
+ if content["type"] == "text":
|
|
|
+ contents.append({"type": "input_text", "text": content["text"]})
|
|
|
+ elif content["type"] == "image_url":
|
|
|
+ contents.append(
|
|
|
+ {"type": "input_image", "image_url": content["image_url"]["url"]}
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Unknown content type: {content['type']}")
|
|
|
+ return {"role": message["role"], "content": contents}
|
|
|
|
|
|
|
|
|
def format_message_vllm(message: Dict) -> Dict:
|
|
|
@@ -16,10 +30,11 @@ def format_message_vllm(message: Dict) -> Dict:
|
|
|
if content["type"] == "text":
|
|
|
contents.append(content)
|
|
|
elif content["type"] == "image_url" or content["type"] == "image":
|
|
|
- base64_image = image_to_base64(content["image_url"]["url"])
|
|
|
img_content = {
|
|
|
"type": "image_url",
|
|
|
- "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
|
|
|
+ "image_url": {
|
|
|
+ "url": f"data:image/jpg;base64,{content["image_url"]["url"]}"
|
|
|
+ },
|
|
|
}
|
|
|
contents.append(img_content)
|
|
|
else:
|
|
|
@@ -29,85 +44,37 @@ def format_message_vllm(message: Dict) -> Dict:
|
|
|
return vllm_message
|
|
|
|
|
|
|
|
|
-def format_conversation_vllm(conversation) -> Dict:
|
|
|
- """Format a conversation in vLLM format."""
|
|
|
- formatted_messages = []
|
|
|
- for message in conversation.messages:
|
|
|
- role = message["role"]
|
|
|
- if role != "assistant":
|
|
|
- formatted_messages.append(format_message_vllm(message))
|
|
|
- return {"messages": formatted_messages}
|
|
|
-
|
|
|
-
|
|
|
-# TODO: Remove
|
|
|
-def format_conversation_openai(conversation) -> Dict:
|
|
|
- """Format a conversation in OpenAI format."""
|
|
|
- formatted_messages = []
|
|
|
- for message in conversation.messages:
|
|
|
- formatted_messages.append(format_message_openai(message))
|
|
|
- return {"messages": formatted_messages}
|
|
|
-
|
|
|
-
|
|
|
-# TODO: Remove
|
|
|
-def format_data_torchtune(data: List[Conversation]) -> List[Dict]:
|
|
|
- """Format data in Torchtune format."""
|
|
|
- if data is None:
|
|
|
- raise ValueError("No data provided to format_data()")
|
|
|
-
|
|
|
- return [format_conversation_torchtune(conversation) for conversation in data]
|
|
|
-
|
|
|
-
|
|
|
-def format_data_vllm(data: List[Conversation]) -> List[Dict]:
|
|
|
- """Format data in vLLM format."""
|
|
|
- if data is None:
|
|
|
- raise ValueError("No data provided to format_data()")
|
|
|
-
|
|
|
- return [format_conversation_vllm(conversation) for conversation in data]
|
|
|
-
|
|
|
-
|
|
|
-# TODO: Remove
|
|
|
-def format_data_openai(data: List[Conversation]) -> List[Dict]:
|
|
|
- """Format data in OpenAI format."""
|
|
|
- if data is None:
|
|
|
- raise ValueError("No data provided to format_data()")
|
|
|
-
|
|
|
- return [format_conversation_openai(conversation) for conversation in data]
|
|
|
-
|
|
|
-
|
|
|
-# Dictionary to map format names to functions for easy dispatch
|
|
|
-FORMATTERS = {
|
|
|
- "torchtune": {
|
|
|
- "data": format_data_torchtune,
|
|
|
- "conversation": format_conversation_torchtune,
|
|
|
- "message": format_message_torchtune,
|
|
|
- },
|
|
|
- "vllm": {
|
|
|
- "data": format_data_vllm,
|
|
|
- "conversation": format_conversation_vllm,
|
|
|
- "message": format_message_vllm,
|
|
|
- },
|
|
|
- "openai": {
|
|
|
- "data": format_data_openai,
|
|
|
- "conversation": format_conversation_openai,
|
|
|
- "message": format_message_openai,
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
-
|
|
|
-def format_data(data: List[Conversation], format_type: str) -> List[Dict]:
|
|
|
+def apply_format(data: Union[List[Dict], List[List[Dict]]], format_func) -> List[Dict]:
|
|
|
"""
|
|
|
- Generic function to format data in the specified format.
|
|
|
+ Apply the format function to the data.
|
|
|
|
|
|
Args:
|
|
|
- data: List of Conversation objects
|
|
|
- format_type: One of "torchtune", "vllm", "openai"
|
|
|
+ data: Either a list of message dictionaries or a list of conversations
|
|
|
+ (where each conversation is a list of message dictionaries)
|
|
|
+ format_func: Function that formats a single message dictionary
|
|
|
|
|
|
Returns:
|
|
|
- List of formatted data
|
|
|
+ List of formatted dictionaries
|
|
|
"""
|
|
|
- if format_type not in FORMATTERS:
|
|
|
- raise ValueError(
|
|
|
- f"Unknown format type: {format_type}. Supported: {list(FORMATTERS.keys())}"
|
|
|
- )
|
|
|
-
|
|
|
- return FORMATTERS[format_type]["data"](data)
|
|
|
+ if not data:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # Check if data is a list of conversations (list of lists) or a list of messages
|
|
|
+ if isinstance(data[0], list):
|
|
|
+ # data is a list of conversations, each conversation is a list of messages
|
|
|
+ formatted_conversations = []
|
|
|
+ for conversation in data:
|
|
|
+ formatted_messages = []
|
|
|
+ for message in conversation:
|
|
|
+ formatted_message = format_func(message)
|
|
|
+ formatted_messages.append(formatted_message)
|
|
|
+ # Return the conversation as a dictionary with "messages" key
|
|
|
+ formatted_conversations.append({"messages": formatted_messages})
|
|
|
+ return formatted_conversations
|
|
|
+ else:
|
|
|
+ # data is a list of messages
|
|
|
+ formatted_messages = []
|
|
|
+ for message in data:
|
|
|
+ formatted_message = format_func(message)
|
|
|
+ formatted_messages.append(formatted_message)
|
|
|
+ return formatted_messages
|