formatter_functional.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. from typing import Dict, List, Union
  2. def format_message_torchtune(message: Dict) -> Dict:
  3. """Format a message in Torchtune format."""
  4. return message
  5. def format_message_openai(message: Dict) -> Dict:
  6. """Format a message in OpenAI format."""
  7. contents = []
  8. for content in message["content"]:
  9. if content["type"] == "text":
  10. contents.append({"type": "input_text", "text": content["text"]})
  11. elif content["type"] == "image_url":
  12. contents.append(
  13. {"type": "input_image", "image_url": content["image_url"]["url"]}
  14. )
  15. else:
  16. raise ValueError(f"Unknown content type: {content['type']}")
  17. return {"role": message["role"], "content": contents}
  18. def format_message_vllm(message: Dict) -> Dict:
  19. """Format a message in vLLM format."""
  20. contents = []
  21. vllm_message = {}
  22. for content in message["content"]:
  23. if content["type"] == "text":
  24. contents.append(content)
  25. elif content["type"] == "image_url" or content["type"] == "image":
  26. img_content = {
  27. "type": "image_url",
  28. "image_url": {
  29. "url": f"data:image/jpg;base64,{content["image_url"]["url"]}"
  30. },
  31. }
  32. contents.append(img_content)
  33. else:
  34. raise ValueError(f"Unknown content type: {content['type']}")
  35. vllm_message["role"] = message["role"]
  36. vllm_message["content"] = contents
  37. return vllm_message
  38. def apply_format(data: Union[List[Dict], List[List[Dict]]], format_func) -> List[Dict]:
  39. """
  40. Apply the format function to the data.
  41. Args:
  42. data: Either a list of message dictionaries or a list of conversations
  43. (where each conversation is a list of message dictionaries)
  44. format_func: Function that formats a single message dictionary
  45. Returns:
  46. List of formatted dictionaries
  47. """
  48. if not data:
  49. return []
  50. # Check if data is a list of conversations (list of lists) or a list of messages
  51. if isinstance(data[0], list):
  52. # data is a list of conversations, each conversation is a list of messages
  53. formatted_conversations = []
  54. for conversation in data:
  55. formatted_messages = []
  56. for message in conversation:
  57. formatted_message = format_func(message)
  58. formatted_messages.append(formatted_message)
  59. # Return the conversation as a dictionary with "messages" key
  60. formatted_conversations.append({"messages": formatted_messages})
  61. return formatted_conversations
  62. else:
  63. # data is a list of messages
  64. formatted_messages = []
  65. for message in data:
  66. formatted_message = format_func(message)
  67. formatted_messages.append(formatted_message)
  68. return formatted_messages