toolcall.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. from pprint import pprint
  2. from typing import Any, Mapping
  3. from torchtune.data import _messages, Message
  4. from torchtune.datasets import SFTDataset
  5. from torchtune.modules.transforms import Transform
  6. from torchtune.modules.transforms.tokenizers import ModelTokenizer
  7. # Store original validate_messages for reference if needed
  8. original_validate = _messages.validate_messages
  9. # Replace with no-op function
  10. def no_validate(messages):
  11. pass
  12. # Monkey patch
  13. _messages.validate_messages = no_validate
  14. class ToolCallMessages(Transform):
  15. def __init__(self):
  16. self._role_map = {
  17. "system": "system",
  18. "human": "user",
  19. "gpt": "assistant",
  20. "tool": "ipython",
  21. }
  22. def __call__(self, sample):
  23. messages = [
  24. Message(
  25. role=self._role_map[msg["from"]],
  26. content=msg["value"],
  27. masked=self._role_map[msg["from"]] != "assistant",
  28. eot=True,
  29. )
  30. for msg in sample["cot_conversations"]
  31. ]
  32. return {"messages": messages}
  33. def custom_dataset(model_transform, **load_dataset_kwargs) -> SFTDataset:
  34. message_transform = ToolCallMessages()
  35. return SFTDataset(
  36. source="json",
  37. data_files="train_data.json",
  38. split="train",
  39. message_transform=message_transform,
  40. model_transform=ModelTokenizer,
  41. **load_dataset_kwargs,
  42. )