Parcourir la source

Merge pull request #13 from meta-llama/lmm_finetune

add vision model finetune recipe
Kai Wu il y a 6 mois
Parent
commit
e1bbffcbff

+ 3 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1451,3 +1451,6 @@ openhathi
 sarvam
 subtask
 acc
+OCRVQA
+OCRVQADataCollator
+ocrvqa

+ 90 - 0
recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py

@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+import torch
+
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialogs(dialogs, images, processor):
+    text_prompt = processor.apply_chat_template(dialogs)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
+    label_list = []
+    for i in range(len(batch["input_ids"])):
+        dialog_tokens = batch["input_ids"][i].tolist()
+        labels = copy.copy(dialog_tokens)
+        eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx+1
+            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        # Mask the padding token and image token 128256 
+        for i in range(len(labels)):
+            if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: #  128256 is image token index
+                labels[i] = -100
+        label_list.append(labels)
+    batch["labels"] = torch.tensor(label_list)
+    return batch
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
+    dataset = dataset_dict['train']
+    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
+    dataset = dataset.select(range(2000))
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
+    return dataset
+
+class OCRVQADataCollator:
+    def __init__(self, processor):
+        self.processor = processor
+        self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    def __call__(self, samples):
+        dialogs,images = [],[]
+        for sample in samples:
+            image_list,sample_list = sample["images"],sample["texts"]
+            if len(image_list) > 1:
+                raise ValueError("Only support one image per sample")
+            image = image_list[0].convert("RGB") # only use the first image
+            dialog = []
+            for sample_dict in sample_list:
+                if not dialog:
+                    # only append image to the first sentence
+                    dialog += [
+                    {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+                
+                else:
+                    dialog += [
+                    {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+            dialogs.append(dialog)
+            images.append([image])
+        return tokenize_dialogs(dialogs,images, self.processor)
+def get_data_collator(processor):
+    return OCRVQADataCollator(processor)

Fichier diff supprimé car celui-ci est trop grand
+ 33 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 5 - 3
src/llama_recipes/datasets/__init__.py

@@ -5,14 +5,16 @@ from functools import partial
 
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.custom_dataset import get_custom_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
 from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
-
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-}
+}
+DATALOADER_COLLATE_FUNC = {
+    "custom_dataset": get_data_collator
+}

+ 20 - 0
src/llama_recipes/datasets/custom_dataset.py

@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
 
+def get_data_collator(dataset_processer,dataset_config):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_data_collator"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_processer)
+    except AttributeError as e:
+        print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
+        print("Using the default data_collator instead.")
+        return None

+ 61 - 22
src/llama_recipes/finetuning.py

@@ -14,16 +14,18 @@ from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
 )
-
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoConfig,
     AutoTokenizer,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
-    LlamaConfig,
+    AutoProcessor, 
+    MllamaForConditionalGeneration,
+    AutoModel,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
@@ -39,7 +41,7 @@ from llama_recipes.utils.config_utils import (
     get_dataloader_kwargs,
     check_fsdp_config,
 )
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
 
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
@@ -118,19 +120,35 @@ def main(**kwargs):
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-        use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
     )
-
+        processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+        processor.tokenizer.padding_side='right'
+    elif config.model_type == "llama":
+        is_vision = False
+        model = AutoModel.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+    else:
+        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
-    tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    if not tokenizer.pad_token_id: 
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+        
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -169,8 +187,12 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        if is_vision:
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        else:
+        # Create the FSDP wrapper for LlamaDecoderLayer in text models
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
@@ -198,12 +220,16 @@ def main(**kwargs):
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
-
     dataset_config = generate_dataset_config(train_config, kwargs)
+    if is_vision:
+        dataset_processer = processor
+    else:
+        dataset_processer = tokenizer
+
+    # Load and preprocess the dataset for training and validation
 
-     # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="train",
     )
@@ -211,7 +237,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
 
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         split="test",
     )
@@ -219,10 +245,17 @@ def main(**kwargs):
         print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
-
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+        if is_vision:
+            raise ValueError("Packing is not supported for vision datasets")
+        else:
+            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
+    print("length of dataset_train", len(dataset_train))
+    custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
+    if custom_data_collator:
+        print("custom_data_collator is used")
+        train_dl_kwargs["collate_fn"] = custom_data_collator
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -230,13 +263,19 @@ def main(**kwargs):
         pin_memory=True,
         **train_dl_kwargs,
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
     eval_dataloader = None
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            if is_vision:
+                raise ValueError("Packing is not supported for vision datasets")
+            else:
+                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
+        if custom_data_collator:
+            val_dl_kwargs["collate_fn"] = custom_data_collator
 
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
@@ -244,6 +283,7 @@ def main(**kwargs):
             pin_memory=True,
             **val_dl_kwargs,
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:
@@ -266,7 +306,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-    # Start the training process
     results = train(
         model,
         train_dataloader,

+ 3 - 3
src/llama_recipes/policies/wrapping.py

@@ -4,6 +4,8 @@
 import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import   MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
+
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
@@ -25,9 +27,7 @@ def get_llama_wrapper():
 
     llama_auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls={
-            LlamaDecoderLayer,
-        },
+        transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
     )
 
     return llama_auto_wrap_policy

+ 25 - 27
src/llama_recipes/utils/config_utils.py

@@ -17,8 +17,7 @@ from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
-from llama_recipes.utils.dataset_utils import DATASET_PREPROC
-
+from llama_recipes.datasets import DATASET_PREPROC
 
 def update_config(config, **kwargs):
     if isinstance(config, (tuple, list)):
@@ -76,37 +75,36 @@ def generate_dataset_config(train_config, kwargs):
     return  dataset_config
 
 
-def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
-        kwargs = {}
-        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.batching_strategy == "padding":
-            if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
-                    dataset,
-                    batch_size=batch_size,
-                    rank=dist.get_rank(),
-                    num_replicas=dist.get_world_size(),
-                    shuffle=mode=="train",
-                )
-            else:
-                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-        elif train_config.batching_strategy == "packing":
-            if train_config.enable_fsdp:
-                kwargs["sampler"] = DistributedSampler(
+def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
+    kwargs = {}
+    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+    if train_config.batching_strategy == "padding":
+        if train_config.enable_fsdp:
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
-                drop_last=True,
             )
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
-            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
-        return kwargs
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+    elif train_config.batching_strategy == "packing":
+        if train_config.enable_fsdp:
+            kwargs["sampler"] = DistributedSampler(
+            dataset,
+            rank=dist.get_rank(),
+            num_replicas=dist.get_world_size(),
+            shuffle=mode=="train",
+            drop_last=True,
+        )
+        kwargs["batch_size"] = batch_size
+        kwargs["drop_last"] = True
+        kwargs["collate_fn"] = default_data_collator
+    else:
+        raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+    return kwargs
 
 
 def check_fsdp_config(fsdp_config):

+ 11 - 1
src/llama_recipes/utils/dataset_utils.py

@@ -4,7 +4,7 @@
 import torch
 
 from llama_recipes.data.concatenator import ConcatDataset
-from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
+from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
 from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
@@ -27,6 +27,16 @@ def get_preprocessed_dataset(
         get_split(),
     )
 
+def get_custom_data_collator(
+    dataset_processer, dataset_config
+) -> torch.utils.data.Dataset:
+    if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
+        return None
+
+    return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
+        dataset_processer,
+        dataset_config
+    )
 
 def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
     dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)

+ 2 - 4
src/llama_recipes/utils/fsdp_utils.py

@@ -3,7 +3,7 @@
 from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_name):
+def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -20,9 +20,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls=(
-            transformer_layer_name,
-        ),
+        transformer_layer_cls=set(transformer_layer_names)
     )
 
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])

+ 3 - 2
src/llama_recipes/utils/train_utils.py

@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
             break
@@ -143,10 +145,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                                 batch[key] = batch[key].to(local_rank)
                         else:
-
                             if is_xpu_available():
                                 batch[key] = batch[key].to('xpu:0')
-                            else:
+                            elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                         loss = model(**batch).loss