123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- from typing import Dict, List, Union
- def format_message_torchtune(message: Dict) -> Dict:
- """Format a message in Torchtune format."""
- return message
- def format_message_openai(message: Dict) -> Dict:
- """Format a message in OpenAI format."""
- contents = []
- for content in message["content"]:
- if content["type"] == "text":
- contents.append({"type": "input_text", "text": content["text"]})
- elif content["type"] == "image_url":
- contents.append(
- {"type": "input_image", "image_url": content["image_url"]["url"]}
- )
- else:
- raise ValueError(f"Unknown content type: {content['type']}")
- return {"role": message["role"], "content": contents}
- def format_message_vllm(message: Dict) -> Dict:
- """Format a message in vLLM format."""
- contents = []
- vllm_message = {}
- for content in message["content"]:
- if content["type"] == "text":
- contents.append(content)
- elif content["type"] == "image_url" or content["type"] == "image":
- img_content = {
- "type": "image_url",
- "image_url": {
- "url": f"data:image/jpg;base64,{content["image_url"]["url"]}"
- },
- }
- contents.append(img_content)
- else:
- raise ValueError(f"Unknown content type: {content['type']}")
- vllm_message["role"] = message["role"]
- vllm_message["content"] = contents
- return vllm_message
- def apply_format(data: Union[List[Dict], List[List[Dict]]], format_func) -> List[Dict]:
- """
- Apply the format function to the data.
- Args:
- data: Either a list of message dictionaries or a list of conversations
- (where each conversation is a list of message dictionaries)
- format_func: Function that formats a single message dictionary
- Returns:
- List of formatted dictionaries
- """
- if not data:
- return []
- # Check if data is a list of conversations (list of lists) or a list of messages
- if isinstance(data[0], list):
- # data is a list of conversations, each conversation is a list of messages
- formatted_conversations = []
- for conversation in data:
- formatted_messages = []
- for message in conversation:
- formatted_message = format_func(message)
- formatted_messages.append(formatted_message)
- # Return the conversation as a dictionary with "messages" key
- formatted_conversations.append({"messages": formatted_messages})
- return formatted_conversations
- else:
- # data is a list of messages
- formatted_messages = []
- for message in data:
- formatted_message = format_func(message)
- formatted_messages.append(formatted_message)
- return formatted_messages
|