toolcall.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. from pprint import pprint
  2. from typing import Any, Mapping
  3. from torchtune.data import Message
  4. from torchtune.datasets import SFTDataset
  5. from torchtune.modules.transforms import Transform
  6. from torchtune.modules.transforms.tokenizers import ModelTokenizer
  7. class ToolCallMessages(Transform):
  8. def __init__(self, train_on_input=False):
  9. self._role_map = {
  10. "system": "system",
  11. "human": "user",
  12. "gpt": "assistant",
  13. "tool": "ipython",
  14. "user": "user",
  15. "assistant": "assistant",
  16. }
  17. self.train_on_input = train_on_input
  18. def __call__(self, sample):
  19. conversations = sample["cot_conversations"]
  20. messages = []
  21. # Keep the original list comprehension structure but add the EOT logic
  22. for i, msg in enumerate(conversations):
  23. next_is_tool = (
  24. i < len(conversations) - 1 and conversations[i + 1]["from"] == "tool"
  25. )
  26. messages.append(
  27. Message(
  28. role=self._role_map[msg["from"]],
  29. content=msg["value"],
  30. masked=(
  31. False
  32. if self.train_on_input
  33. else self._role_map[msg["from"]] != "assistant"
  34. ),
  35. eot=not (
  36. msg["from"] == "tool" or (msg["from"] == "gpt" and next_is_tool)
  37. ),
  38. )
  39. )
  40. return {"messages": messages}
  41. def custom_dataset(
  42. model_transform, train_on_input=False, **load_dataset_kwargs
  43. ) -> SFTDataset:
  44. message_transform = ToolCallMessages(train_on_input=train_on_input)
  45. return SFTDataset(
  46. source="json",
  47. data_files="train_final_mix.json",
  48. split="train",
  49. message_transform=message_transform,
  50. model_transform=model_transform,
  51. **load_dataset_kwargs,
  52. )