toolcall.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import os
  2. from pprint import pprint
  3. from typing import Any, Mapping
  4. from torchtune.data import Message
  5. from torchtune.datasets import SFTDataset
  6. from torchtune.modules.transforms import Transform
  7. from torchtune.modules.transforms.tokenizers import ModelTokenizer
  8. class ToolCallMessages(Transform):
  9. def __init__(self, train_on_input=False):
  10. self._role_map = {
  11. "system": "system",
  12. "human": "user",
  13. "gpt": "assistant",
  14. "tool": "ipython",
  15. "user": "user", # to avoid key errors
  16. "assistant": "assistant", # avoid key errors again, lol
  17. }
  18. self.train_on_input = train_on_input
  19. def __call__(self, sample):
  20. conversations = sample["cot_conversations"]
  21. messages = []
  22. # EOT Logic we agreed on after debating a lot, more notes here: https://github.com/pytorch/torchtune/issues/2405#issuecomment-2670392887
  23. for i, msg in enumerate(conversations):
  24. next_is_tool = (
  25. i < len(conversations) - 1 and conversations[i + 1]["from"] == "tool"
  26. )
  27. messages.append(
  28. Message(
  29. role=self._role_map[msg["from"]],
  30. content=msg["value"],
  31. masked=(
  32. False
  33. if self.train_on_input
  34. else self._role_map[msg["from"]] != "assistant"
  35. ),
  36. eot=not (
  37. msg["from"] == "tool" or (msg["from"] == "gpt" and next_is_tool)
  38. ),
  39. )
  40. )
  41. return {"messages": messages}
  42. return {"messages": messages}
  43. # def custom_dataset(
  44. # model_transform, train_on_input=False, **load_dataset_kwargs
  45. # ) -> SFTDataset:
  46. # message_transform = ToolCallMessages(train_on_input=train_on_input)
  47. # return SFTDataset(
  48. # source="json",
  49. # data_files="train_data_tool_tag.json", # yes its hardcoded, yes I will say I will fix it and hopefully not forget
  50. # split="train",
  51. # message_transform=message_transform,
  52. # model_transform=model_transform,
  53. # **load_dataset_kwargs,
  54. # )
  55. def custom_dataset(
  56. model_transform, train_on_input=False, **load_dataset_kwargs
  57. ) -> SFTDataset:
  58. message_transform = ToolCallMessages(train_on_input=train_on_input)
  59. dataset_path = "/home/sanyambhutani/llama-cookbook/end-to-end-use-cases/data-tool/scripts/finetuning/saturday-tool-dict-try-5/train/"
  60. arrow_files = [
  61. os.path.join(dataset_path, x)
  62. for x in os.listdir(dataset_path)
  63. if x.endswith(".arrow")
  64. ]
  65. return SFTDataset(
  66. source="arrow",
  67. data_files=arrow_files,
  68. split="train",
  69. message_transform=message_transform,
  70. model_transform=model_transform,
  71. **load_dataset_kwargs,
  72. )