Browse Source

finetune not working with fsdp

Kai Wu 7 months ago
parent
commit
b566582a86

+ 89 - 0
recipes/quickstart/finetuning/datasets/vqa_dataset.py

@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialog(dialog, images, processor):
+    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+    text_prompt = processor.apply_chat_template(dialog)
+    #print("text_prompt",text_prompt)
+    batch = processor(images=images, text=text_prompt)
+    dialog_tokens = batch["input_ids"].tolist()[0]
+    #print("dialog_tokens",dialog_tokens)
+    #print("dialog_tokens",dialog_tokens)
+    attention_mask = batch["attention_mask"].tolist()[0]
+    #print("attention_mask",attention_mask)
+    labels = copy.copy(dialog_tokens)
+    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+    last_idx = 0
+    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+    for n, idx in enumerate(eot_indices):
+        current_seq = labels[last_idx:idx+1]
+        if check_header(prompt_header_seqs,current_seq):
+            # found prompt header, indicating that this seq should be masked
+            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+        else:
+            last_idx = idx
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+    assistant_header_seq = [128006, 78191, 128007]
+    labels = replace_target(assistant_header_seq,labels)
+    #print("labels",labels)
+
+
+    combined_tokens = {
+        # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        # "labels": list(itertools.chain(*(t for t in labels_tokens))),
+        "input_ids": dialog_tokens,
+        "labels": labels,
+        "attention_mask": [1]*len(dialog_tokens),
+        "pixel_values": batch["pixel_values"].tolist()[0],
+        "image_sizes": batch["image_sizes"].tolist()[0]
+    }
+    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
+    # labels = list(itertools.chain(*(t for t in labels_tokens))),
+    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
+    # pixel_values =  batch["pixel_values"],
+    # image_sizes = batch["image_sizes"]
+#    print("combined_tokens",combined_tokens[image_sizes])
+
+    return combined_tokens
+def image_tokenize(sample, processor):
+    processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    images,sample_text = sample["images"],sample["messages"]
+    dialog = []
+    for line in sample_text:
+        content = []
+        messages = line["content"]
+        role = line["role"]
+        for message in messages:
+            if message["type"] == "image":
+                content.append({"type": "image"})
+            elif message["type"] == "text":
+                content.append({"type": "text", "text": message["text"].strip()})
+        dialog.append({"role": role,"content":content})
+    return tokenize_dialog(dialog,images, processor)
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
+    dataset = dataset_dict[split]
+    dataset = dataset.select(range(100))
+    tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
+    tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
+    return tokenized_datasets

+ 1 - 1
src/llama_recipes/configs/datasets.py

@@ -37,4 +37,4 @@ class custom_dataset:
 class llamaguard_toxicchat_dataset:
     dataset: str = "llamaguard_toxicchat_dataset"
     train_split: str = "train"
-    test_split: str = "test"
+    test_split: str = "test"

+ 15 - 7
src/llama_recipes/finetuning.py

@@ -22,6 +22,11 @@ from transformers import (
     BitsAndBytesConfig,
     LlamaForCausalLM,
     LlamaConfig,
+    AutoConfig, 
+    AutoModel,
+    LlavaNextForConditionalGeneration,
+    LlavaNextProcessor
+
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
@@ -116,11 +121,11 @@ def main(**kwargs):
         bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
     # Load the pre-trained model and setup its configuration
-    use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    #use_cache = False if train_config.enable_fsdp else None
+    model = LlavaNextForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-        use_cache=use_cache,
+    #    use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -129,7 +134,8 @@ def main(**kwargs):
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    processor = LlavaNextProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+    processor.tokenizer.padding_side='right'
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -200,7 +206,7 @@ def main(**kwargs):
 
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        processor,
         dataset_config,
         split="train",
     )
@@ -208,7 +214,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
 
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        processor,
         dataset_config,
         split="test",
     )
@@ -219,7 +225,7 @@ def main(**kwargs):
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
     train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+    print("length of dataset_train", len(dataset_train))
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
@@ -227,6 +233,7 @@ def main(**kwargs):
         pin_memory=True,
         **train_dl_kwargs,
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
     eval_dataloader = None
     if train_config.run_validation:
@@ -241,6 +248,7 @@ def main(**kwargs):
             pin_memory=True,
             **val_dl_kwargs,
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:

+ 0 - 1
src/llama_recipes/utils/config_utils.py

@@ -104,5 +104,4 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
             kwargs["collate_fn"] = default_data_collator
         else:
             raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
         return kwargs

+ 9 - 2
src/llama_recipes/utils/train_utils.py

@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
             break
@@ -130,6 +132,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             with profile(train_config,local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
+                    #print("batch: ", batch)
                     # stop when the maximum number of training steps is reached
                     if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                         max_steps_reached = True
@@ -149,8 +152,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
+                        assert(next(model.parameters()).device == batch['input_ids'].device)
+                        #print("batch: ", batch)
                         loss = model(**batch).loss
                     loss = loss / gradient_accumulation_steps
+                    #print("loss",loss)
                     if train_config.save_metrics:
                         train_step_loss.append(loss.detach().float().item())
                         train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -171,6 +177,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             pbar.update(1)
                     else:
                         # regular backpropagation when fp16 is not used
+                        #print("loss123",loss)
                         loss.backward()
                         if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                             if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
@@ -243,12 +250,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
 
                 else:
-                    if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+                    if not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
 
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
-                    elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+                    elif not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")