浏览代码

Create toolcall.py

Sanyam Bhutani 2 月之前
父节点
当前提交
2d19724207
共有 1 个文件被更改,包括 53 次插入0 次删除
  1. 53 0
      end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

+ 53 - 0
end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

@@ -0,0 +1,53 @@
+from pprint import pprint
+from typing import Any, Mapping
+
+from torchtune.data import _messages, Message
+from torchtune.datasets import SFTDataset
+from torchtune.modules.transforms import Transform
+from torchtune.modules.transforms.tokenizers import ModelTokenizer
+
+# Store original validate_messages for reference if needed
+original_validate = _messages.validate_messages
+
+
+# Replace with no-op function
+def no_validate(messages):
+    pass
+
+
+# Monkey patch
+_messages.validate_messages = no_validate
+
+
+class ToolCallMessages(Transform):
+    def __init__(self):
+        self._role_map = {
+            "system": "system",
+            "human": "user",
+            "gpt": "assistant",
+            "tool": "ipython",
+        }
+
+    def __call__(self, sample):
+        messages = [
+            Message(
+                role=self._role_map[msg["from"]],
+                content=msg["value"],
+                masked=self._role_map[msg["from"]] != "assistant",
+                eot=True,
+            )
+            for msg in sample["cot_conversations"]
+        ]
+        return {"messages": messages}
+
+
+def custom_dataset(model_transform, **load_dataset_kwargs) -> SFTDataset:
+    message_transform = ToolCallMessages()
+    return SFTDataset(
+        source="json",
+        data_files="train_data.json",
+        split="train",
+        message_transform=message_transform,
+        model_transform=ModelTokenizer,
+        **load_dataset_kwargs,
+    )