Explorar o código

add get_custom_data_collator feature

Kai Wu hai 7 meses
pai
achega
ce299b3439

+ 50 - 66
recipes/quickstart/finetuning/datasets/vqa_dataset.py

@@ -6,6 +6,7 @@ import copy
 from datasets import load_dataset
 import itertools
 import torch
+
 # check system prompt token seq or user prompt token seq is in the current token list
 def check_header(targets,seq):
     for i in range(len(seq)-3):
@@ -17,78 +18,61 @@ def replace_target(target,seq):
         if seq[i:i+3] == target:
             seq[i],seq[i+1],seq[i+2] = -100,-100,-100
     return seq
-def tokenize_dialog(dialog, images, processor):
+def tokenize_dialogs(dialogs, images, processor):
     # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
-    text_prompt = processor.apply_chat_template(dialog)
+    text_prompt = processor.apply_chat_template(dialogs)
     #print("text_prompt",text_prompt)
-    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")    
-    labels = copy.copy(batch["input_ids"].tolist()[0])
-    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
-    last_idx = 0
-    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
-    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
-    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
-    for n, idx in enumerate(eot_indices):
-        current_seq = labels[last_idx:idx+1]
-        if check_header(prompt_header_seqs,current_seq):
-            # found prompt header, indicating that this seq should be masked
-            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
-        else:
-            last_idx = idx+1
-        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
-    assistant_header_seq = [128006, 78191, 128007]
-    labels = replace_target(assistant_header_seq,labels)
-    #print("labels",labels)
-    # print("pixel_values .shape",batch["pixel_values"].shape)
-    # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
-
-    batch["labels"] = torch.tensor(labels)
-    #pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
-    batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
-    # pixel_values .shape torch.Size([1, 4, 3, 560, 560])
-    print("pixel_values .shape",batch["pixel_values"].shape)
-    # exit()
-    # combined_tokens = {
-    #     # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
-    #     # "labels": list(itertools.chain(*(t for t in labels_tokens))),
-    #     "input_ids": dialog_tokens,
-    #     "labels": labels,
-    #     "attention_mask": [1]*len(dialog_tokens),
-    #     "pixel_values": batch["pixel_values"],
-    #     "aspect_ratio_ids": batch["aspect_ratio_ids"],
-    #     "aspect_ratio_mask": batch["aspect_ratio_mask"],
-    #     "cross_attention_mask": batch["cross_attention_mask"]
-    # }
-    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
-    # labels = list(itertools.chain(*(t for t in labels_tokens))),
-    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
-    # pixel_values =  batch["pixel_values"],
-    # image_sizes = batch["image_sizes"]
-#    print("combined_tokens",combined_tokens[image_sizes])
-    
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
+    batch["labels"] = copy.copy(batch["input_ids"])
+    for i in range(len(batch["input_ids"])):
+        dialog_tokens = batch["input_ids"][i].tolist()
+        labels = copy.copy(dialog_tokens)
+        eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx+1
+            # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        batch["labels"][i] = torch.tensor(labels)
     return batch
-def image_tokenize(sample, processor):
-    processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
-    images,sample_text = sample["images"],sample["messages"]
-    dialog = []
-    for line in sample_text:
-        content = []
-        messages = line["content"]
-        role = line["role"]
-        for message in messages:
-            if message["type"] == "image":
-                content.append({"type": "image"})
-            elif message["type"] == "text":
-                content.append({"type": "text", "text": message["text"].strip()})
-        dialog.append({"role": role,"content":content})
-    return tokenize_dialog(dialog,images, processor)
-
 
 def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
     # load_dataset will return DatasetDict that contains all the data in the train set
     dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
     dataset = dataset_dict[split]
     dataset = dataset.select(range(100))
-    tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
-    tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
-    return tokenized_datasets
+    return dataset
+
+class VQADataCollator:
+    def __init__(self, processor):
+        self.processor = processor
+        self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    def __call__(self, samples):
+        dialogs,images = [],[]
+        for sample in samples:
+            image,sample_text = sample["images"],sample["messages"]
+            dialog = []
+            for line in sample_text:
+                content = []
+                messages = line["content"]
+                role = line["role"]
+                for message in messages:
+                    if message["type"] == "image":
+                        content.append({"type": "image"})
+                    elif message["type"] == "text":
+                        content.append({"type": "text", "text": message["text"].strip()})
+                dialog.append({"role": role,"content":content})
+            dialogs.append(dialog)
+            images.append(image)
+        return tokenize_dialogs(dialogs,images, self.processor)
+def get_data_collator(processor):
+    return VQADataCollator(processor)

+ 94 - 0
recipes/quickstart/finetuning/datasets/vqa_dataset_old.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+import torch
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialog(dialog, images, processor):
+    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+    text_prompt = processor.apply_chat_template(dialog)
+    #print("text_prompt",text_prompt)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")    
+    labels = copy.copy(batch["input_ids"].tolist()[0])
+    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+    last_idx = 0
+    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+    for n, idx in enumerate(eot_indices):
+        current_seq = labels[last_idx:idx+1]
+        if check_header(prompt_header_seqs,current_seq):
+            # found prompt header, indicating that this seq should be masked
+            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+        else:
+            last_idx = idx+1
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+    assistant_header_seq = [128006, 78191, 128007]
+    labels = replace_target(assistant_header_seq,labels)
+    #print("labels",labels)
+    # print("pixel_values .shape",batch["pixel_values"].shape)
+    # print("batch_size, num_concurrent_media, num_tiles, num_channels, height, width = pixel_values.shape")
+
+    batch["labels"] = torch.tensor(labels)
+    #pixel_values .shape torch.Size([1, 1, 4, 3, 560, 560])
+    batch["pixel_values"] = torch.squeeze(batch["pixel_values"], 1)
+    # pixel_values .shape torch.Size([1, 4, 3, 560, 560])
+    print("pixel_values .shape",batch["pixel_values"].shape)
+    # exit()
+    # combined_tokens = {
+    #     # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+    #     # "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    #     "input_ids": dialog_tokens,
+    #     "labels": labels,
+    #     "attention_mask": [1]*len(dialog_tokens),
+    #     "pixel_values": batch["pixel_values"],
+    #     "aspect_ratio_ids": batch["aspect_ratio_ids"],
+    #     "aspect_ratio_mask": batch["aspect_ratio_mask"],
+    #     "cross_attention_mask": batch["cross_attention_mask"]
+    # }
+    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
+    # labels = list(itertools.chain(*(t for t in labels_tokens))),
+    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
+    # pixel_values =  batch["pixel_values"],
+    # image_sizes = batch["image_sizes"]
+#    print("combined_tokens",combined_tokens[image_sizes])
+    
+    return batch
+def image_tokenize(sample, processor):
+    processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    images,sample_text = sample["images"],sample["messages"]
+    dialog = []
+    for line in sample_text:
+        content = []
+        messages = line["content"]
+        role = line["role"]
+        for message in messages:
+            if message["type"] == "image":
+                content.append({"type": "image"})
+            elif message["type"] == "text":
+                content.append({"type": "text", "text": message["text"].strip()})
+        dialog.append({"role": role,"content":content})
+    return tokenize_dialog(dialog,images, processor)
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
+    dataset = dataset_dict[split]
+    dataset = dataset.select(range(100))
+    tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
+    tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
+    return tokenized_datasets

+ 5 - 3
src/llama_recipes/datasets/__init__.py

@@ -5,14 +5,16 @@ from functools import partial
 
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.custom_dataset import get_custom_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
 from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
-
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-}
+}
+DATALOADER_COLLATE_FUNC = {
+    "custom_dataset": get_data_collator
+}

+ 20 - 0
src/llama_recipes/datasets/custom_dataset.py

@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
 
+def get_data_collator(dataset_processer,dataset_config):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_data_collator"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_processer)
+    except AttributeError as e:
+        print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
+        print("Using the default data_collator instead.")
+        return None

+ 7 - 3
src/llama_recipes/finetuning.py

@@ -45,7 +45,7 @@ from llama_recipes.utils.config_utils import (
     get_dataloader_kwargs,
     check_fsdp_config,
 )
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
 
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
@@ -252,8 +252,12 @@ def main(**kwargs):
     if train_config.batching_strategy == "packing":
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
     print("length of dataset_train", len(dataset_train))
+    custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
+    if custom_data_collator:
+        print("custom_data_collator is used")
+        train_dl_kwargs["collate_fn"] = custom_data_collator
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -268,7 +272,7 @@ def main(**kwargs):
         if train_config.batching_strategy == "packing":
             dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,

+ 26 - 29
src/llama_recipes/utils/config_utils.py

@@ -14,11 +14,11 @@ from peft import (
 )
 from transformers import default_data_collator
 from transformers.data import DataCollatorForSeq2Seq
+from functools import partial
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
-from llama_recipes.utils.dataset_utils import DATASET_PREPROC
-
+from llama_recipes.datasets import DATASET_PREPROC
 
 def update_config(config, **kwargs):
     if isinstance(config, (tuple, list)):
@@ -76,39 +76,36 @@ def generate_dataset_config(train_config, kwargs):
     return  dataset_config
 
 
-def get_dataloader_kwargs(train_config, dataset, tokenizer, mode,collate_fn=None):
-        kwargs = {}
-        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.batching_strategy == "padding":
-            if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
-                    dataset,
-                    batch_size=batch_size,
-                    rank=dist.get_rank(),
-                    num_replicas=dist.get_world_size(),
-                    shuffle=mode=="train",
-                )
-            else:
-                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-            if not collate_fn:
-                kwargs["collate_fn"] = collate_fn
-            else:
-                kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-        elif train_config.batching_strategy == "packing":
-            if train_config.enable_fsdp:
-                kwargs["sampler"] = DistributedSampler(
+def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
+    kwargs = {}
+    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+    if train_config.batching_strategy == "padding":
+        if train_config.enable_fsdp:
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
-                drop_last=True,
             )
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
-            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-        return kwargs
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+    elif train_config.batching_strategy == "packing":
+        if train_config.enable_fsdp:
+            kwargs["sampler"] = DistributedSampler(
+            dataset,
+            rank=dist.get_rank(),
+            num_replicas=dist.get_world_size(),
+            shuffle=mode=="train",
+            drop_last=True,
+        )
+        kwargs["batch_size"] = batch_size
+        kwargs["drop_last"] = True
+        kwargs["collate_fn"] = default_data_collator
+    else:
+        raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+    return kwargs
 
 
 def check_fsdp_config(fsdp_config):

+ 11 - 1
src/llama_recipes/utils/dataset_utils.py

@@ -4,7 +4,7 @@
 import torch
 
 from llama_recipes.data.concatenator import ConcatDataset
-from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
+from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
 from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
@@ -27,6 +27,16 @@ def get_preprocessed_dataset(
         get_split(),
     )
 
+def get_custom_data_collator(
+    dataset_processer, dataset_config
+) -> torch.utils.data.Dataset:
+    if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
+        return None
+
+    return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
+        dataset_processer,
+        dataset_config
+    )
 
 def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
     dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)