|
@@ -1,3 +1,4 @@
|
|
|
+import os
|
|
|
from pprint import pprint
|
|
|
from typing import Any, Mapping
|
|
|
|
|
@@ -14,21 +15,19 @@ class ToolCallMessages(Transform):
|
|
|
"human": "user",
|
|
|
"gpt": "assistant",
|
|
|
"tool": "ipython",
|
|
|
- "user": "user",
|
|
|
- "assistant": "assistant",
|
|
|
+ "user": "user", # to avoid key errors
|
|
|
+ "assistant": "assistant", # avoid key errors again, lol
|
|
|
}
|
|
|
self.train_on_input = train_on_input
|
|
|
|
|
|
def __call__(self, sample):
|
|
|
conversations = sample["cot_conversations"]
|
|
|
messages = []
|
|
|
-
|
|
|
- # Keep the original list comprehension structure but add the EOT logic
|
|
|
+ # EOT Logic we agreed on after debating a lot, more notes here: https://github.com/pytorch/torchtune/issues/2405#issuecomment-2670392887
|
|
|
for i, msg in enumerate(conversations):
|
|
|
next_is_tool = (
|
|
|
i < len(conversations) - 1 and conversations[i + 1]["from"] == "tool"
|
|
|
)
|
|
|
-
|
|
|
messages.append(
|
|
|
Message(
|
|
|
role=self._role_map[msg["from"]],
|
|
@@ -45,14 +44,38 @@ class ToolCallMessages(Transform):
|
|
|
)
|
|
|
return {"messages": messages}
|
|
|
|
|
|
+ return {"messages": messages}
|
|
|
+
|
|
|
+
|
|
|
+# def custom_dataset(
|
|
|
+# model_transform, train_on_input=False, **load_dataset_kwargs
|
|
|
+# ) -> SFTDataset:
|
|
|
+# message_transform = ToolCallMessages(train_on_input=train_on_input)
|
|
|
+# return SFTDataset(
|
|
|
+# source="json",
|
|
|
+# data_files="train_data_tool_tag.json", # yes its hardcoded, yes I will say I will fix it and hopefully not forget
|
|
|
+# split="train",
|
|
|
+# message_transform=message_transform,
|
|
|
+# model_transform=model_transform,
|
|
|
+# **load_dataset_kwargs,
|
|
|
+# )
|
|
|
+
|
|
|
|
|
|
def custom_dataset(
|
|
|
model_transform, train_on_input=False, **load_dataset_kwargs
|
|
|
) -> SFTDataset:
|
|
|
message_transform = ToolCallMessages(train_on_input=train_on_input)
|
|
|
+
|
|
|
+ dataset_path = "/home/sanyambhutani/llama-cookbook/end-to-end-use-cases/data-tool/scripts/finetuning/saturday-tool-dict-try-5/train/"
|
|
|
+ arrow_files = [
|
|
|
+ os.path.join(dataset_path, x)
|
|
|
+ for x in os.listdir(dataset_path)
|
|
|
+ if x.endswith(".arrow")
|
|
|
+ ]
|
|
|
+
|
|
|
return SFTDataset(
|
|
|
- source="json",
|
|
|
- data_files="train_final_mix.json",
|
|
|
+ source="arrow",
|
|
|
+ data_files=arrow_files,
|
|
|
split="train",
|
|
|
message_transform=message_transform,
|
|
|
model_transform=model_transform,
|