Browse Source

custom_sft_dataset updated

khare19yash 1 month ago
parent
commit
6dd9dea944

+ 2 - 2
src/finetune_pipeline/config.yaml

@@ -93,9 +93,9 @@ finetuning:
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   strategy: "lora"               # Training strategy ('fft' or 'lora')
   num_epochs: 1                 # Number of training epochs
   num_epochs: 1                 # Number of training epochs
   batch_size: 1                 # Batch size per device for training
   batch_size: 1                 # Batch size per device for training
-  torchtune_config: "llama3_2_vision/11B_lora"             # TorchTune-specific configuration
+  torchtune_config: "llama3_1/8B_lora"             # TorchTune-specific configuration
   num_processes_per_node: 1             # TorchTune-specific configuration
   num_processes_per_node: 1             # TorchTune-specific configuration
-  distributed: false             # Whether to use distributed training
+  distributed: true             # Whether to use distributed training
 
 
 
 
 # vLLM Inference configuration
 # vLLM Inference configuration

+ 3 - 49
src/finetune_pipeline/finetuning/custom_sft_dataset.py

@@ -1,54 +1,9 @@
 """
 """
 Custom SFT dataset for fine-tuning.
 Custom SFT dataset for fine-tuning.
 """
 """
-from typing import Any, List, Mapping
 from torchtune.data import OpenAIToMessages
 from torchtune.data import OpenAIToMessages
 from torchtune.datasets import SFTDataset
 from torchtune.datasets import SFTDataset
 from torchtune.modules.transforms import Transform
 from torchtune.modules.transforms import Transform
-from torchtune.data import load_image, Message
-
-class MessageTransform(Transform):
-    def __init__(self):
-        super().__init__()
-
-    def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
-
-        user_content = []
-        assistant_content = []
-
-        for message in sample["messages"]:
-            contents = message['content']
-            role = message['role']
-            for content in contents:
-                typ = content['type']
-                val = None
-                if typ == 'text':
-                    val = content['text']
-                    if role == 'user':
-                        user_content.append({"type": "text", "content": val})
-                    else:
-                        assistant_content.append({"type": "text", "content": val})
-                elif typ == 'image' or typ == 'image_url':
-                    val = load_image(content['image'])
-                    user_content.append({"type": "image", "content": val})
-
-        messages = [
-            Message(
-                role="user",
-                content=user_content,
-                masked=True,
-                eot=True,
-            ),
-            Message(
-                role="assistant",
-                content=assistant_content,
-                masked=False,
-                eot=True,
-            ),
-        ]
-
-        return {"messages": messages}
-
 
 
 def custom_sft_dataset(
 def custom_sft_dataset(
     model_transform: Transform,
     model_transform: Transform,
@@ -69,13 +24,12 @@ def custom_sft_dataset(
     Returns:
     Returns:
         SFTDataset: A dataset ready for fine-tuning with TorchTune
         SFTDataset: A dataset ready for fine-tuning with TorchTune
     """
     """
-    # message_transform = OpenAIToMessages(train_on_input=train_on_input)
-    message_transform = MessageTransform()
+    message_transform = OpenAIToMessages(train_on_input=train_on_input)
 
 
     ds = SFTDataset(
     ds = SFTDataset(
         source="json",
         source="json",
-        data_files="/home/ubuntu/yash-workspace/outputs/train_torchtune_formatted_data.json",
-        split="train",
+        data_files=dataset_path,
+        split=split,
         message_transform=message_transform,
         message_transform=message_transform,
         model_transform=model_transform,
         model_transform=model_transform,
     )
     )