toolcall.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. from pprint import pprint
  2. from typing import Any, Mapping
  3. from torchtune.data import _messages, Message
  4. from torchtune.datasets import SFTDataset
  5. from torchtune.modules.transforms import Transform
  6. from torchtune.modules.transforms.tokenizers import ModelTokenizer
  7. class ToolCallMessages(Transform):
  8. def __init__(self):
  9. self._role_map = {
  10. "system": "system",
  11. "human": "user",
  12. "gpt": "assistant",
  13. "tool": "ipython",
  14. }
  15. def __call__(self, sample):
  16. messages = [
  17. Message(
  18. role=self._role_map[msg["from"]],
  19. content=msg["value"],
  20. masked=self._role_map[msg["from"]] != "assistant",
  21. eot=True,
  22. )
  23. for msg in sample["cot_conversations"]
  24. ]
  25. return {"messages": messages}
  26. def custom_dataset(model_transform, **load_dataset_kwargs) -> SFTDataset:
  27. message_transform = ToolCallMessages()
  28. return SFTDataset(
  29. source="json",
  30. data_files="train_data.json",
  31. split="train",
  32. message_transform=message_transform,
  33. model_transform=ModelTokenizer,
  34. **load_dataset_kwargs,
  35. )