|
@@ -1,54 +1,9 @@
|
|
"""
|
|
"""
|
|
Custom SFT dataset for fine-tuning.
|
|
Custom SFT dataset for fine-tuning.
|
|
"""
|
|
"""
|
|
-from typing import Any, List, Mapping
|
|
|
|
from torchtune.data import OpenAIToMessages
|
|
from torchtune.data import OpenAIToMessages
|
|
from torchtune.datasets import SFTDataset
|
|
from torchtune.datasets import SFTDataset
|
|
from torchtune.modules.transforms import Transform
|
|
from torchtune.modules.transforms import Transform
|
|
-from torchtune.data import load_image, Message
|
|
|
|
-
|
|
|
|
-class MessageTransform(Transform):
|
|
|
|
- def __init__(self):
|
|
|
|
- super().__init__()
|
|
|
|
-
|
|
|
|
- def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
|
|
|
|
-
|
|
|
|
- user_content = []
|
|
|
|
- assistant_content = []
|
|
|
|
-
|
|
|
|
- for message in sample["messages"]:
|
|
|
|
- contents = message['content']
|
|
|
|
- role = message['role']
|
|
|
|
- for content in contents:
|
|
|
|
- typ = content['type']
|
|
|
|
- val = None
|
|
|
|
- if typ == 'text':
|
|
|
|
- val = content['text']
|
|
|
|
- if role == 'user':
|
|
|
|
- user_content.append({"type": "text", "content": val})
|
|
|
|
- else:
|
|
|
|
- assistant_content.append({"type": "text", "content": val})
|
|
|
|
- elif typ == 'image' or typ == 'image_url':
|
|
|
|
- val = load_image(content['image'])
|
|
|
|
- user_content.append({"type": "image", "content": val})
|
|
|
|
-
|
|
|
|
- messages = [
|
|
|
|
- Message(
|
|
|
|
- role="user",
|
|
|
|
- content=user_content,
|
|
|
|
- masked=True,
|
|
|
|
- eot=True,
|
|
|
|
- ),
|
|
|
|
- Message(
|
|
|
|
- role="assistant",
|
|
|
|
- content=assistant_content,
|
|
|
|
- masked=False,
|
|
|
|
- eot=True,
|
|
|
|
- ),
|
|
|
|
- ]
|
|
|
|
-
|
|
|
|
- return {"messages": messages}
|
|
|
|
-
|
|
|
|
|
|
|
|
def custom_sft_dataset(
|
|
def custom_sft_dataset(
|
|
model_transform: Transform,
|
|
model_transform: Transform,
|
|
@@ -69,13 +24,12 @@ def custom_sft_dataset(
|
|
Returns:
|
|
Returns:
|
|
SFTDataset: A dataset ready for fine-tuning with TorchTune
|
|
SFTDataset: A dataset ready for fine-tuning with TorchTune
|
|
"""
|
|
"""
|
|
- # message_transform = OpenAIToMessages(train_on_input=train_on_input)
|
|
|
|
- message_transform = MessageTransform()
|
|
|
|
|
|
+ message_transform = OpenAIToMessages(train_on_input=train_on_input)
|
|
|
|
|
|
ds = SFTDataset(
|
|
ds = SFTDataset(
|
|
source="json",
|
|
source="json",
|
|
- data_files="/home/ubuntu/yash-workspace/outputs/train_torchtune_formatted_data.json",
|
|
|
|
- split="train",
|
|
|
|
|
|
+ data_files=dataset_path,
|
|
|
|
+ split=split,
|
|
message_transform=message_transform,
|
|
message_transform=message_transform,
|
|
model_transform=model_transform,
|
|
model_transform=model_transform,
|
|
)
|
|
)
|