formatter_functional.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. import base64
  2. from typing import Dict, List
  3. def image_to_base64(image_path):
  4. with open(image_path, "rb") as img:
  5. return base64.b64encode(img.read()).decode("utf-8")
  6. def format_message_vllm(message: Dict) -> Dict:
  7. """Format a message in vLLM format."""
  8. contents = []
  9. vllm_message = {}
  10. for content in message["content"]:
  11. if content["type"] == "text":
  12. contents.append(content)
  13. elif content["type"] == "image_url" or content["type"] == "image":
  14. base64_image = image_to_base64(content["image_url"]["url"])
  15. img_content = {
  16. "type": "image_url",
  17. "image_url": {"url": f"data:image/jpg;base64,{base64_image}"},
  18. }
  19. contents.append(img_content)
  20. else:
  21. raise ValueError(f"Unknown content type: {content['type']}")
  22. vllm_message["role"] = message["role"]
  23. vllm_message["content"] = contents
  24. return vllm_message
  25. def format_conversation_vllm(conversation) -> Dict:
  26. """Format a conversation in vLLM format."""
  27. formatted_messages = []
  28. for message in conversation.messages:
  29. role = message["role"]
  30. if role != "assistant":
  31. formatted_messages.append(format_message_vllm(message))
  32. return {"messages": formatted_messages}
  33. # TODO: Remove
  34. def format_conversation_openai(conversation) -> Dict:
  35. """Format a conversation in OpenAI format."""
  36. formatted_messages = []
  37. for message in conversation.messages:
  38. formatted_messages.append(format_message_openai(message))
  39. return {"messages": formatted_messages}
  40. # TODO: Remove
  41. def format_data_torchtune(data: List[Conversation]) -> List[Dict]:
  42. """Format data in Torchtune format."""
  43. if data is None:
  44. raise ValueError("No data provided to format_data()")
  45. return [format_conversation_torchtune(conversation) for conversation in data]
  46. def format_data_vllm(data: List[Conversation]) -> List[Dict]:
  47. """Format data in vLLM format."""
  48. if data is None:
  49. raise ValueError("No data provided to format_data()")
  50. return [format_conversation_vllm(conversation) for conversation in data]
  51. # TODO: Remove
  52. def format_data_openai(data: List[Conversation]) -> List[Dict]:
  53. """Format data in OpenAI format."""
  54. if data is None:
  55. raise ValueError("No data provided to format_data()")
  56. return [format_conversation_openai(conversation) for conversation in data]
  57. # Dictionary to map format names to functions for easy dispatch
  58. FORMATTERS = {
  59. "torchtune": {
  60. "data": format_data_torchtune,
  61. "conversation": format_conversation_torchtune,
  62. "message": format_message_torchtune,
  63. },
  64. "vllm": {
  65. "data": format_data_vllm,
  66. "conversation": format_conversation_vllm,
  67. "message": format_message_vllm,
  68. },
  69. "openai": {
  70. "data": format_data_openai,
  71. "conversation": format_conversation_openai,
  72. "message": format_message_openai,
  73. },
  74. }
  75. def format_data(data: List[Conversation], format_type: str) -> List[Dict]:
  76. """
  77. Generic function to format data in the specified format.
  78. Args:
  79. data: List of Conversation objects
  80. format_type: One of "torchtune", "vllm", "openai"
  81. Returns:
  82. List of formatted data
  83. """
  84. if format_type not in FORMATTERS:
  85. raise ValueError(
  86. f"Unknown format type: {format_type}. Supported: {list(FORMATTERS.keys())}"
  87. )
  88. return FORMATTERS[format_type]["data"](data)