Browse Source

Merge pull request #13 from meta-llama/lmm_finetune

add vision model finetune recipe
Kai Wu 10 months ago
parent
commit
e1bbffcbff

+ 3 - 0
.github/scripts/spellcheck_conf/wordlist.txt

@@ -1451,3 +1451,6 @@ openhathi
 sarvam
 sarvam
 subtask
 subtask
 acc
 acc
+OCRVQA
+OCRVQADataCollator
+ocrvqa

+ 90 - 0
recipes/quickstart/finetuning/datasets/ocrvqa_dataset.py

@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+import torch
+
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialogs(dialogs, images, processor):
+    text_prompt = processor.apply_chat_template(dialogs)
+    batch = processor(images=images, text=text_prompt,padding = True, return_tensors="pt")
+    label_list = []
+    for i in range(len(batch["input_ids"])):
+        dialog_tokens = batch["input_ids"][i].tolist()
+        labels = copy.copy(dialog_tokens)
+        eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+        last_idx = 0
+        # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+        # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+        prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+        for n, idx in enumerate(eot_indices):
+            current_seq = labels[last_idx:idx+1]
+            if check_header(prompt_header_seqs,current_seq):
+                # found prompt header, indicating that this seq should be masked
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+            else:
+                last_idx = idx+1
+            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+        assistant_header_seq = [128006, 78191, 128007]
+        labels = replace_target(assistant_header_seq,labels)
+        # Mask the padding token and image token 128256 
+        for i in range(len(labels)):
+            if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: #  128256 is image token index
+                labels[i] = -100
+        label_list.append(labels)
+    batch["labels"] = torch.tensor(label_list)
+    return batch
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("HuggingFaceM4/the_cauldron", name="ocrvqa")
+    dataset = dataset_dict['train']
+    # Comment out the following line to use the full dataset, for quick testing only use 2000 samples
+    dataset = dataset.select(range(2000))
+    dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42)[split]
+    return dataset
+
+class OCRVQADataCollator:
+    def __init__(self, processor):
+        self.processor = processor
+        self.processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    def __call__(self, samples):
+        dialogs,images = [],[]
+        for sample in samples:
+            image_list,sample_list = sample["images"],sample["texts"]
+            if len(image_list) > 1:
+                raise ValueError("Only support one image per sample")
+            image = image_list[0].convert("RGB") # only use the first image
+            dialog = []
+            for sample_dict in sample_list:
+                if not dialog:
+                    # only append image to the first sentence
+                    dialog += [
+                    {"role":"user","content":[{"type": "image"},{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+                
+                else:
+                    dialog += [
+                    {"role":"user","content":[{"type": "text", "text": sample_dict["user"].strip()}]},
+                    {"role":"assistant","content":[{"type": "text", "text": sample_dict["assistant"].strip()}]}
+                ]
+            dialogs.append(dialog)
+            images.append([image])
+        return tokenize_dialogs(dialogs,images, self.processor)
+def get_data_collator(processor):
+    return OCRVQADataCollator(processor)

File diff suppressed because it is too large
+ 33 - 0
recipes/quickstart/finetuning/finetune_vision_model.md


+ 5 - 3
src/llama_recipes/datasets/__init__.py

@@ -5,14 +5,16 @@ from functools import partial
 
 
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
 from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.custom_dataset import get_custom_dataset
+from llama_recipes.datasets.custom_dataset import get_custom_dataset,get_data_collator
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
 from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
 from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
 from llama_recipes.datasets.toxicchat_dataset import get_llamaguard_toxicchat_dataset as get_llamaguard_toxicchat_dataset
-
 DATASET_PREPROC = {
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset),
     "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "custom_dataset": get_custom_dataset,
     "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
     "llamaguard_toxicchat_dataset": get_llamaguard_toxicchat_dataset,
-}
+}
+DATALOADER_COLLATE_FUNC = {
+    "custom_dataset": get_data_collator
+}

+ 20 - 0
src/llama_recipes/datasets/custom_dataset.py

@@ -35,3 +35,23 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
         raise e
 
 
+def get_data_collator(dataset_processer,dataset_config):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_data_collator"
+
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_processer)
+    except AttributeError as e:
+        print(f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()}).")
+        print("Using the default data_collator instead.")
+        return None

+ 61 - 22
src/llama_recipes/finetuning.py

@@ -14,16 +14,18 @@ from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     FullyShardedDataParallel as FSDP,
     ShardingStrategy
     ShardingStrategy
 )
 )
-
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
 from transformers import (
+    AutoConfig,
     AutoTokenizer,
     AutoTokenizer,
     BitsAndBytesConfig,
     BitsAndBytesConfig,
-    LlamaForCausalLM,
-    LlamaConfig,
+    AutoProcessor, 
+    MllamaForConditionalGeneration,
+    AutoModel,
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
 
 
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
@@ -39,7 +41,7 @@ from llama_recipes.utils.config_utils import (
     get_dataloader_kwargs,
     get_dataloader_kwargs,
     check_fsdp_config,
     check_fsdp_config,
 )
 )
-from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
+from llama_recipes.utils.dataset_utils import get_preprocessed_dataset,get_custom_data_collator
 
 
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
 from llama_recipes.utils.train_utils import (
 from llama_recipes.utils.train_utils import (
@@ -118,19 +120,35 @@ def main(**kwargs):
 
 
     # Load the pre-trained model and setup its configuration
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
     use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    config = AutoConfig.from_pretrained(train_config.model_name)
+    if config.model_type == "mllama":
+        is_vision = True
+        model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
         train_config.model_name,
         quantization_config=bnb_config,
         quantization_config=bnb_config,
-        use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
     )
     )
-
+        processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+        processor.tokenizer.padding_side='right'
+    elif config.model_type == "llama":
+        is_vision = False
+        model = AutoModel.from_pretrained(
+            train_config.model_name,
+            quantization_config=bnb_config,
+            use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
+            device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
+            torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
+        )
+    else:
+        raise ValueError(f"Model type {config.model_type} is not supported. Please use llama or mllama model.")
     # Load the tokenizer and add special tokens
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
-    tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    if not tokenizer.pad_token_id: 
+        tokenizer.pad_token_id = tokenizer.eos_token_id
+        
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -169,8 +187,12 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-
+        # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
+        if is_vision:
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
+        else:
+        # Create the FSDP wrapper for LlamaDecoderLayer in text models
+            my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         device_id = 0
         device_id = 0
         if is_xpu_available():
         if is_xpu_available():
             device_id = torch.xpu.current_device()
             device_id = torch.xpu.current_device()
@@ -198,12 +220,16 @@ def main(**kwargs):
             model.to("xpu:0")
             model.to("xpu:0")
         elif torch.cuda.is_available():
         elif torch.cuda.is_available():
             model.to("cuda")
             model.to("cuda")
-
     dataset_config = generate_dataset_config(train_config, kwargs)
     dataset_config = generate_dataset_config(train_config, kwargs)
+    if is_vision:
+        dataset_processer = processor
+    else:
+        dataset_processer = tokenizer
+
+    # Load and preprocess the dataset for training and validation
 
 
-     # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         dataset_config,
         split="train",
         split="train",
     )
     )
@@ -211,7 +237,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
         print(f"--> Training Set Length = {len(dataset_train)}")
 
 
     dataset_val = get_preprocessed_dataset(
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        dataset_processer,
         dataset_config,
         dataset_config,
         split="test",
         split="test",
     )
     )
@@ -219,10 +245,17 @@ def main(**kwargs):
         print(f"--> Validation Set Length = {len(dataset_val)}")
         print(f"--> Validation Set Length = {len(dataset_val)}")
 
 
     if train_config.batching_strategy == "packing":
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
-
-    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+        if is_vision:
+            raise ValueError("Packing is not supported for vision datasets")
+        else:
+            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
+    print("length of dataset_train", len(dataset_train))
+    custom_data_collator = get_custom_data_collator(dataset_processer,dataset_config)
+    if custom_data_collator:
+        print("custom_data_collator is used")
+        train_dl_kwargs["collate_fn"] = custom_data_collator
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
         dataset_train,
@@ -230,13 +263,19 @@ def main(**kwargs):
         pin_memory=True,
         pin_memory=True,
         **train_dl_kwargs,
         **train_dl_kwargs,
     )
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
 
     eval_dataloader = None
     eval_dataloader = None
     if train_config.run_validation:
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            if is_vision:
+                raise ValueError("Packing is not supported for vision datasets")
+            else:
+                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
 
-        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
+        if custom_data_collator:
+            val_dl_kwargs["collate_fn"] = custom_data_collator
 
 
         eval_dataloader = torch.utils.data.DataLoader(
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             dataset_val,
@@ -244,6 +283,7 @@ def main(**kwargs):
             pin_memory=True,
             pin_memory=True,
             **val_dl_kwargs,
             **val_dl_kwargs,
         )
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:
         else:
@@ -266,7 +306,6 @@ def main(**kwargs):
             weight_decay=train_config.weight_decay,
             weight_decay=train_config.weight_decay,
         )
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
-    # Start the training process
     results = train(
     results = train(
         model,
         model,
         train_dataloader,
         train_dataloader,

+ 3 - 3
src/llama_recipes/policies/wrapping.py

@@ -4,6 +4,8 @@
 import functools
 import functools
 
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.mllama.modeling_mllama import   MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
+
 from torch.distributed.fsdp.wrap import (
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
     size_based_auto_wrap_policy,
@@ -25,9 +27,7 @@ def get_llama_wrapper():
 
 
     llama_auto_wrap_policy = functools.partial(
     llama_auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_auto_wrap_policy,
-        transformer_layer_cls={
-            LlamaDecoderLayer,
-        },
+        transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
     )
     )
 
 
     return llama_auto_wrap_policy
     return llama_auto_wrap_policy

+ 25 - 27
src/llama_recipes/utils/config_utils.py

@@ -17,8 +17,7 @@ from transformers.data import DataCollatorForSeq2Seq
 
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
-from llama_recipes.utils.dataset_utils import DATASET_PREPROC
-
+from llama_recipes.datasets import DATASET_PREPROC
 
 
 def update_config(config, **kwargs):
 def update_config(config, **kwargs):
     if isinstance(config, (tuple, list)):
     if isinstance(config, (tuple, list)):
@@ -76,37 +75,36 @@ def generate_dataset_config(train_config, kwargs):
     return  dataset_config
     return  dataset_config
 
 
 
 
-def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
-        kwargs = {}
-        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.batching_strategy == "padding":
-            if train_config.enable_fsdp:
-                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
-                    dataset,
-                    batch_size=batch_size,
-                    rank=dist.get_rank(),
-                    num_replicas=dist.get_world_size(),
-                    shuffle=mode=="train",
-                )
-            else:
-                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-        elif train_config.batching_strategy == "packing":
-            if train_config.enable_fsdp:
-                kwargs["sampler"] = DistributedSampler(
+def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
+    kwargs = {}
+    batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+    if train_config.batching_strategy == "padding":
+        if train_config.enable_fsdp:
+            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
                 dataset,
                 dataset,
+                batch_size=batch_size,
                 rank=dist.get_rank(),
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
                 shuffle=mode=="train",
-                drop_last=True,
             )
             )
-            kwargs["batch_size"] = batch_size
-            kwargs["drop_last"] = True
-            kwargs["collate_fn"] = default_data_collator
         else:
         else:
-            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
-        return kwargs
+            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
+    elif train_config.batching_strategy == "packing":
+        if train_config.enable_fsdp:
+            kwargs["sampler"] = DistributedSampler(
+            dataset,
+            rank=dist.get_rank(),
+            num_replicas=dist.get_world_size(),
+            shuffle=mode=="train",
+            drop_last=True,
+        )
+        kwargs["batch_size"] = batch_size
+        kwargs["drop_last"] = True
+        kwargs["collate_fn"] = default_data_collator
+    else:
+        raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+    return kwargs
 
 
 
 
 def check_fsdp_config(fsdp_config):
 def check_fsdp_config(fsdp_config):

+ 11 - 1
src/llama_recipes/utils/dataset_utils.py

@@ -4,7 +4,7 @@
 import torch
 import torch
 
 
 from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.data.concatenator import ConcatDataset
-from llama_recipes.datasets import DATASET_PREPROC, get_custom_dataset
+from llama_recipes.datasets import DATASET_PREPROC, DATALOADER_COLLATE_FUNC
 from llama_recipes.utils.config_utils import get_dataloader_kwargs
 from llama_recipes.utils.config_utils import get_dataloader_kwargs
 
 
 
 
@@ -27,6 +27,16 @@ def get_preprocessed_dataset(
         get_split(),
         get_split(),
     )
     )
 
 
+def get_custom_data_collator(
+    dataset_processer, dataset_config
+) -> torch.utils.data.Dataset:
+    if not dataset_config.dataset in DATALOADER_COLLATE_FUNC:
+        return None
+
+    return DATALOADER_COLLATE_FUNC[dataset_config.dataset](
+        dataset_processer,
+        dataset_config
+    )
 
 
 def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
 def get_dataloader(tokenizer, dataset_config, train_config, split: str = "train"):
     dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)
     dataset = get_preprocessed_dataset(tokenizer, dataset_config, split)

+ 2 - 4
src/llama_recipes/utils/fsdp_utils.py

@@ -3,7 +3,7 @@
 from torch.distributed._tensor.device_mesh import init_device_mesh
 from torch.distributed._tensor.device_mesh import init_device_mesh
 import os 
 import os 
 
 
-def fsdp_auto_wrap_policy(model, transformer_layer_name):
+def fsdp_auto_wrap_policy(model, transformer_layer_names):
     import functools
     import functools
 
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
@@ -20,9 +20,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
     transformer_wrap_policy = functools.partial(
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_auto_wrap_policy,
-        transformer_layer_cls=(
-            transformer_layer_name,
-        ),
+        transformer_layer_cls=set(transformer_layer_names)
     )
     )
 
 
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
     auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])

+ 3 - 2
src/llama_recipes/utils/train_utils.py

@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     # Start the training loop
     for epoch in range(train_config.num_epochs):
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
         if max_steps_reached:
             break
             break
@@ -143,10 +145,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                             else:
                                 batch[key] = batch[key].to(local_rank)
                                 batch[key] = batch[key].to(local_rank)
                         else:
                         else:
-
                             if is_xpu_available():
                             if is_xpu_available():
                                 batch[key] = batch[key].to('xpu:0')
                                 batch[key] = batch[key].to('xpu:0')
-                            else:
+                            elif torch.cuda.is_available():
                                 batch[key] = batch[key].to('cuda:0')
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                     with autocast():
                         loss = model(**batch).loss
                         loss = model(**batch).loss