formatter.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394
  1. from abc import ABC, abstractmethod
  2. from typing import Dict, List, Optional, TypedDict, Union
  3. class MessageContent(TypedDict, total=False):
  4. """Type definition for message content in LLM requests."""
  5. type: str # Required field
  6. text: Optional[str] # Optional field
  7. image_url: Optional[Dict[str, str]] # Optional field
  8. class Message(TypedDict):
  9. """Type definition for a message in a LLM inference request."""
  10. role: str
  11. content: Union[str, List[MessageContent]]
  12. class Conversation:
  13. """
  14. Data class representing a conversation which are list of messages.
  15. """
  16. def __init__(self, messages=None):
  17. self.messages = messages if messages is not None else []
  18. def add_message(self, message):
  19. """
  20. Add a message to the conversation.
  21. Args:
  22. message: Message object to add
  23. """
  24. self.messages.append(message)
  25. def get_message(self, index):
  26. """
  27. Get a message at a specific index.
  28. Args:
  29. index: Index of the message to retrieve
  30. Returns:
  31. Message: The message at the specified index
  32. Raises:
  33. IndexError: If the index is out of bounds
  34. """
  35. if index < 0 or index >= len(self.messages):
  36. raise IndexError(
  37. f"Message index {index} out of range (0-{len(self.messages)-1})"
  38. )
  39. return self.messages[index]
  40. class Formatter(ABC):
  41. """
  42. Abstract base class for formatters that convert messages to different formats.
  43. """
  44. def __init__(self):
  45. """
  46. Initialize the formatter.
  47. Subclasses can override this method to add specific initialization parameters.
  48. """
  49. pass
  50. @abstractmethod
  51. def format_data(self, data) -> List:
  52. """
  53. Format the message. This method must be implemented by subclasses.
  54. Args:
  55. data: List of Conversation objects
  56. Returns:
  57. List of formatted data
  58. """
  59. pass
  60. @abstractmethod
  61. def format_conversation(self, conversation) -> Union[Dict, str]:
  62. """
  63. Format a sample. This method must be implemented by subclasses.
  64. Args:
  65. sample: Conversation object
  66. Returns:
  67. Formatted sample in the appropriate format
  68. """
  69. pass
  70. @abstractmethod
  71. def format_message(self, message) -> Union[Dict, str]:
  72. """
  73. Format a message. This method must be implemented by subclasses.
  74. Args:
  75. sample: Message object
  76. Returns:
  77. Formatted message in the appropriate format
  78. """
  79. pass
  80. def read_data(self, data, column_mapping=None):
  81. """
  82. Read Hugging Face data and convert it to a list of Conversation objects.
  83. Args:
  84. data: Hugging Face dataset or iterable data
  85. column_mapping: Optional dictionary mapping custom column names to standard field names.
  86. Expected keys are 'input', 'output', and 'image'.
  87. Example: {'input': 'question', 'output': 'answer', 'image': 'image_path'}
  88. If None, default field names ('input', 'output', 'image') will be used.
  89. Returns:
  90. list: List of Conversation objects, each containing a list of Message objects
  91. """
  92. # Default column mapping if none provided
  93. if column_mapping is None:
  94. column_mapping = {"input": "input", "output": "output", "image": "image"}
  95. # Validate column mapping
  96. required_fields = ["input", "output"]
  97. for field in required_fields:
  98. if field not in column_mapping:
  99. raise ValueError(f"Column mapping must include '{field}' field")
  100. conversations = []
  101. for item in data:
  102. # Extract fields from the Hugging Face dataset item using the column mapping
  103. image_field = column_mapping.get("image")
  104. input_field = column_mapping.get("input")
  105. output_field = column_mapping.get("output")
  106. image = item.get(image_field, None) if image_field else None
  107. input_text = item.get(input_field, "")
  108. output_label = item.get(output_field, "")
  109. # Create a new conversation
  110. conversation = Conversation()
  111. # Create user content and user message
  112. user_content = [
  113. {"type": "text", "text": input_text},
  114. ]
  115. # Add image to user content
  116. if image is not None:
  117. user_content.append({"type": "image", "image_url": {"url": image}})
  118. user_message = {"role": "user", "content": user_content}
  119. # Create assistant message with text content
  120. assistant_content = [
  121. {"type": "text", "text": output_label},
  122. ]
  123. assistant_message = {"role": "assistant", "content": assistant_content}
  124. # Add messages to the conversation
  125. conversation.add_message(user_message)
  126. conversation.add_message(assistant_message)
  127. # Add the conversation to the list
  128. conversations.append(conversation)
  129. return conversations
  130. class TorchtuneFormatter(Formatter):
  131. """
  132. Formatter for Torchtune format.
  133. """
  134. def __init__(self, data=None):
  135. """
  136. Initialize the formatter.
  137. Args:
  138. data: Optional data to initialize with
  139. """
  140. super().__init__()
  141. self.conversation_data = None
  142. if data is not None:
  143. self.conversation_data = self.read_data(data)
  144. def format_data(self, data=None):
  145. """
  146. Format the data.
  147. Args:
  148. data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
  149. Returns:
  150. list: List of formatted data
  151. """
  152. # Use stored conversation_data if no data is provided
  153. data_to_format = data if data is not None else self.conversation_data
  154. # Check if we have data to format
  155. if data_to_format is None:
  156. raise ValueError(
  157. "No data to format. Either provide data to format_data() or initialize with data."
  158. )
  159. formatted_data = []
  160. for conversation in data_to_format:
  161. formatted_data.append(self.format_conversation(conversation))
  162. return formatted_data
  163. def format_conversation(self, conversation):
  164. """
  165. Format a sample.
  166. Args:
  167. sample: Conversation object
  168. Returns:
  169. dict: Formatted sample in Torchtune format
  170. """
  171. formatted_messages = []
  172. for message in conversation.messages:
  173. formatted_messages.append(self.format_message(message))
  174. return {"messages": formatted_messages}
  175. def format_message(self, message):
  176. """
  177. Format a message in Torchtune format.
  178. Args:
  179. message: Message object to format
  180. Returns:
  181. dict: Formatted message in Torchtune format
  182. """
  183. # For Torchtune format, we can return the Message as is
  184. # since it's already in a compatible format
  185. return message
  186. class vLLMFormatter(Formatter):
  187. """
  188. Formatter for vLLM format.
  189. """
  190. def __init__(self, data=None):
  191. """
  192. Initialize the formatter.
  193. Args:
  194. data: Optional data to initialize with
  195. """
  196. super().__init__()
  197. self.conversation_data = None
  198. if data is not None:
  199. self.conversation_data = self.read_data(data)
  200. def format_data(self, data=None):
  201. """
  202. Format the data.
  203. Args:
  204. data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
  205. Returns:
  206. list: List of formatted data in vLLM format
  207. """
  208. # Use stored conversation_data if no data is provided
  209. data_to_format = data if data is not None else self.conversation_data
  210. # Check if we have data to format
  211. if data_to_format is None:
  212. raise ValueError(
  213. "No data to format. Either provide data to format_data() or initialize with data."
  214. )
  215. formatted_data = []
  216. for conversation in data_to_format:
  217. formatted_data.append(self.format_conversation(conversation))
  218. return formatted_data
  219. def format_conversation(self, conversation):
  220. """
  221. Format a sample.
  222. Args:
  223. sample: Conversation object
  224. Returns:
  225. str: Formatted sample in vLLM format
  226. """
  227. formatted_messages = []
  228. for message in conversation.messages:
  229. formatted_messages.append(self.format_message(message))
  230. return "\n".join(formatted_messages)
  231. def format_message(self, message):
  232. """
  233. Format a message in vLLM format.
  234. Args:
  235. message: Message object to format
  236. Returns:
  237. str: Formatted message in vLLM format
  238. """
  239. role = message["role"]
  240. content = message["content"]
  241. # Handle different content types
  242. if isinstance(content, str):
  243. return f"{role}: {content}"
  244. else:
  245. # For multimodal content, extract text parts
  246. text_parts = []
  247. for item in content:
  248. if item["type"] == "text" and "text" in item:
  249. text_parts.append(item["text"])
  250. return f"{role}: {' '.join(text_parts)}"
  251. class OpenAIFormatter(Formatter):
  252. """
  253. Formatter for OpenAI format.
  254. """
  255. def __init__(self, data=None):
  256. """
  257. Initialize the formatter.
  258. Args:
  259. data: Optional data to initialize with
  260. """
  261. super().__init__()
  262. self.conversation_data = None
  263. if data is not None:
  264. self.conversation_data = self.read_data(data)
  265. def format_data(self, data=None):
  266. """
  267. Format the data.
  268. Args:
  269. data: List of Conversation objects. If None, uses the conversation_data stored during initialization.
  270. Returns:
  271. dict: Formatted data in OpenAI format
  272. """
  273. # Use stored conversation_data if no data is provided
  274. data_to_format = data if data is not None else self.conversation_data
  275. # Check if we have data to format
  276. if data_to_format is None:
  277. raise ValueError(
  278. "No data to format. Either provide data to format_data() or initialize with data."
  279. )
  280. formatted_data = []
  281. for conversation in data_to_format:
  282. formatted_data.append(self.format_conversation(conversation))
  283. return formatted_data
  284. def format_conversation(self, conversation):
  285. """
  286. Format a sample.
  287. Args:
  288. sample: Conversation object
  289. Returns:
  290. dict: Formatted sample in OpenAI format
  291. """
  292. formatted_messages = []
  293. for message in conversation.messages:
  294. formatted_messages.append(self.format_message(message))
  295. return {"messages": formatted_messages}
  296. def format_message(self, message):
  297. """
  298. Format a message in OpenAI format.
  299. Args:
  300. message: Message object to format
  301. Returns:
  302. dict: Formatted message in OpenAI format
  303. """
  304. # For OpenAI format, we can return the Message as is
  305. # since it's already in a compatible format
  306. return message