Sanyam Bhutani 1 месяц назад
Родитель
Сommit
dd7a3a5bbc
1 измененных файлов с 30 добавлено и 7 удалено
  1. 30 7
      end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

+ 30 - 7
end-to-end-use-cases/data-tool/scripts/finetuning/toolcall.py

@@ -1,3 +1,4 @@
+import os
 from pprint import pprint
 from typing import Any, Mapping
 
@@ -14,21 +15,19 @@ class ToolCallMessages(Transform):
             "human": "user",
             "gpt": "assistant",
             "tool": "ipython",
-            "user": "user",
-            "assistant": "assistant",
+            "user": "user",  # to avoid key errors
+            "assistant": "assistant",  # avoid key errors again, lol
         }
         self.train_on_input = train_on_input
 
     def __call__(self, sample):
         conversations = sample["cot_conversations"]
         messages = []
-
-        # Keep the original list comprehension structure but add the EOT logic
+        # EOT Logic we agreed on after debating a lot, more notes here: https://github.com/pytorch/torchtune/issues/2405#issuecomment-2670392887
         for i, msg in enumerate(conversations):
             next_is_tool = (
                 i < len(conversations) - 1 and conversations[i + 1]["from"] == "tool"
             )
-
             messages.append(
                 Message(
                     role=self._role_map[msg["from"]],
@@ -45,14 +44,38 @@ class ToolCallMessages(Transform):
             )
         return {"messages": messages}
 
+        return {"messages": messages}
+
+
+# def custom_dataset(
+#     model_transform, train_on_input=False, **load_dataset_kwargs
+# ) -> SFTDataset:
+#     message_transform = ToolCallMessages(train_on_input=train_on_input)
+#     return SFTDataset(
+#         source="json",
+#         data_files="train_data_tool_tag.json",  # yes its hardcoded, yes I will say I will fix it and hopefully not forget
+#         split="train",
+#         message_transform=message_transform,
+#         model_transform=model_transform,
+#         **load_dataset_kwargs,
+#     )
+
 
 def custom_dataset(
     model_transform, train_on_input=False, **load_dataset_kwargs
 ) -> SFTDataset:
     message_transform = ToolCallMessages(train_on_input=train_on_input)
+
+    dataset_path = "/home/sanyambhutani/llama-cookbook/end-to-end-use-cases/data-tool/scripts/finetuning/saturday-tool-dict-try-5/train/"
+    arrow_files = [
+        os.path.join(dataset_path, x)
+        for x in os.listdir(dataset_path)
+        if x.endswith(".arrow")
+    ]
+
     return SFTDataset(
-        source="json",
-        data_files="train_final_mix.json",
+        source="arrow",
+        data_files=arrow_files,
         split="train",
         message_transform=message_transform,
         model_transform=model_transform,