123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- from pprint import pprint
- from typing import Any, Mapping
- from torchtune.data import _messages, Message
- from torchtune.datasets import SFTDataset
- from torchtune.modules.transforms import Transform
- from torchtune.modules.transforms.tokenizers import ModelTokenizer
- # Store original validate_messages for reference if needed
- original_validate = _messages.validate_messages
- # Replace with no-op function
- def no_validate(messages):
- pass
- # Monkey patch
- _messages.validate_messages = no_validate
- class ToolCallMessages(Transform):
- def __init__(self):
- self._role_map = {
- "system": "system",
- "human": "user",
- "gpt": "assistant",
- "tool": "ipython",
- }
- def __call__(self, sample):
- messages = [
- Message(
- role=self._role_map[msg["from"]],
- content=msg["value"],
- masked=self._role_map[msg["from"]] != "assistant",
- eot=True,
- )
- for msg in sample["cot_conversations"]
- ]
- return {"messages": messages}
- def custom_dataset(model_transform, **load_dataset_kwargs) -> SFTDataset:
- message_transform = ToolCallMessages()
- return SFTDataset(
- source="json",
- data_files="train_data.json",
- split="train",
- message_transform=message_transform,
- model_transform=ModelTokenizer,
- **load_dataset_kwargs,
- )
|