浏览代码

added custom message transform

khare19yash 2 月之前
父节点
当前提交
59cb07f6c3

+ 50 - 5
src/finetune_pipeline/finetuning/custom_sft_dataset.py

@@ -1,14 +1,58 @@
 """
 Custom SFT dataset for fine-tuning.
 """
-
+from typing import Any, List, Mapping
 from torchtune.data import OpenAIToMessages
 from torchtune.datasets import SFTDataset
 from torchtune.modules.transforms import Transform
+from torchtune.data import load_image, Message
+
+class MessageTransform(Transform):
+    def __init__(self):
+        super().__init__()
+
+    def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
+
+        user_content = []
+        assistant_content = []
+
+        for message in sample["messages"]:
+            contents = message['content']
+            role = message['role']
+            for content in contents:
+                typ = content['type']
+                val = None
+                if typ == 'text':
+                    val = content['text']
+                    if role == 'user':
+                        user_content.append({"type": "text", "content": val})
+                    else:
+                        assistant_content.append({"type": "text", "content": val})
+                elif typ == 'image' or typ == 'image_url':
+                    val = load_image(content['image'])
+                    user_content.append({"type": "image", "content": val})
+
+        messages = [
+            Message(
+                role="user",
+                content=user_content,
+                masked=True,
+                eot=True,
+            ),
+            Message(
+                role="assistant",
+                content=assistant_content,
+                masked=False,
+                eot=True,
+            ),
+        ]
+
+        return {"messages": messages}
 
 
 def custom_sft_dataset(
     model_transform: Transform,
+    *,
     dataset_path: str = "/tmp/train.json",
     train_on_input: bool = False,
     split: str = "train",
@@ -25,13 +69,14 @@ def custom_sft_dataset(
     Returns:
         SFTDataset: A dataset ready for fine-tuning with TorchTune
     """
-    openaitomessage = OpenAIToMessages(train_on_input=train_on_input)
+    # message_transform = OpenAIToMessages(train_on_input=train_on_input)
+    message_transform = MessageTransform()
 
     ds = SFTDataset(
         source="json",
-        data_files=dataset_path,
-        split=split,
-        message_transform=openaitomessage,
+        data_files="/home/ubuntu/yash-workspace/outputs/train_torchtune_formatted_data.json",
+        split="train",
+        message_transform=message_transform,
         model_transform=model_transform,
     )
     return ds

+ 9 - 6
src/finetune_pipeline/finetuning/run_finetuning.py

@@ -132,11 +132,11 @@ def run_torch_tune(training_config: Dict, args=None):
         )
 
     # Add any additional kwargs if provided
-    if args and args.kwargs:
-        # Split the kwargs string by spaces to get individual key=value pairs
-        kwargs_list = args.kwargs.split()
-        base_cmd.extend(kwargs_list)
-        logger.info(f"Added additional kwargs: {kwargs_list}")
+    # if args and args.kwargs:
+    #     # Split the kwargs string by spaces to get individual key=value pairs
+    #     kwargs_list = args.kwargs.split()
+    #     base_cmd.extend(kwargs_list)
+    #     logger.info(f"Added additional kwargs: {kwargs_list}")
 
     # Log the command
     logger.info(f"Running command: {' '.join(base_cmd)}")
@@ -169,7 +169,10 @@ def main():
     )
     args = parser.parse_args()
 
-    run_torch_tune(args.config, args=args)
+    config = read_config(args.config)
+    finetuning_config = config.get("finetuning", {})
+
+    run_torch_tune(finetuning_config, args=args)
 
 
 if __name__ == "__main__":