Bladeren bron

fix toolcall

Sanyam Bhutani 1 maand geleden
bovenliggende
commit
80a6fce069
1 gewijzigde bestanden met toevoegingen van 35 en 14 verwijderingen
  1. 35 14
      end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

+ 35 - 14
end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

@@ -1,39 +1,60 @@
 from pprint import pprint
 from typing import Any, Mapping
 
-from torchtune.data import _messages, Message
+from torchtune.data import Message
 from torchtune.datasets import SFTDataset
 from torchtune.modules.transforms import Transform
 from torchtune.modules.transforms.tokenizers import ModelTokenizer
+
+
 class ToolCallMessages(Transform):
-    def __init__(self):
+    def __init__(self, train_on_input=False):
         self._role_map = {
             "system": "system",
             "human": "user",
             "gpt": "assistant",
             "tool": "ipython",
+            "user": "user",
+            "assistant": "assistant",
         }
+        self.train_on_input = train_on_input
 
     def __call__(self, sample):
-        messages = [
-            Message(
-                role=self._role_map[msg["from"]],
-                content=msg["value"],
-                masked=self._role_map[msg["from"]] != "assistant",
-                eot=True,
+        conversations = sample["cot_conversations"]
+        messages = []
+
+        # Keep the original list comprehension structure but add the EOT logic
+        for i, msg in enumerate(conversations):
+            next_is_tool = (
+                i < len(conversations) - 1 and conversations[i + 1]["from"] == "tool"
+            )
+
+            messages.append(
+                Message(
+                    role=self._role_map[msg["from"]],
+                    content=msg["value"],
+                    masked=(
+                        False
+                        if self.train_on_input
+                        else self._role_map[msg["from"]] != "assistant"
+                    ),
+                    eot=not (
+                        msg["from"] == "tool" or (msg["from"] == "gpt" and next_is_tool)
+                    ),
+                )
             )
-            for msg in sample["cot_conversations"]
-        ]
         return {"messages": messages}
 
 
-def custom_dataset(model_transform, **load_dataset_kwargs) -> SFTDataset:
-    message_transform = ToolCallMessages()
+def custom_dataset(
+    model_transform, train_on_input=False, **load_dataset_kwargs
+) -> SFTDataset:
+    message_transform = ToolCallMessages(train_on_input=train_on_input)
     return SFTDataset(
         source="json",
-        data_files="train_data.json",
+        data_files="train_final_mix.json",
         split="train",
         message_transform=message_transform,
-        model_transform=ModelTokenizer,
+        model_transform=model_transform,
         **load_dataset_kwargs,
     )