فهرست منبع

finetune not working with fsdp

Kai Wu 7 ماه پیش
والد
کامیت
b566582a86

+ 89 - 0
recipes/quickstart/finetuning/datasets/vqa_dataset.py

@@ -0,0 +1,89 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement.
+
+
+import copy
+from datasets import load_dataset
+import itertools
+# check system prompt token seq or user prompt token seq is in the current token list
+def check_header(targets,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] in targets:
+            return True
+    return False
+def replace_target(target,seq):
+    for i in range(len(seq)-3):
+        if seq[i:i+3] == target:
+            seq[i],seq[i+1],seq[i+2] = -100,-100,-100
+    return seq
+def tokenize_dialog(dialog, images, processor):
+    # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models
+    text_prompt = processor.apply_chat_template(dialog)
+    #print("text_prompt",text_prompt)
+    batch = processor(images=images, text=text_prompt)
+    dialog_tokens = batch["input_ids"].tolist()[0]
+    #print("dialog_tokens",dialog_tokens)
+    #print("dialog_tokens",dialog_tokens)
+    attention_mask = batch["attention_mask"].tolist()[0]
+    #print("attention_mask",attention_mask)
+    labels = copy.copy(dialog_tokens)
+    eot_indices = [i for i,n in enumerate(labels) if n == 128009]
+    last_idx = 0
+    # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007]
+    # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007]
+    prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]]
+    for n, idx in enumerate(eot_indices):
+        current_seq = labels[last_idx:idx+1]
+        if check_header(prompt_header_seqs,current_seq):
+            # found prompt header, indicating that this seq should be masked
+            labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+        else:
+            last_idx = idx
+        # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+    assistant_header_seq = [128006, 78191, 128007]
+    labels = replace_target(assistant_header_seq,labels)
+    #print("labels",labels)
+
+
+    combined_tokens = {
+        # "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        # "labels": list(itertools.chain(*(t for t in labels_tokens))),
+        "input_ids": dialog_tokens,
+        "labels": labels,
+        "attention_mask": [1]*len(dialog_tokens),
+        "pixel_values": batch["pixel_values"].tolist()[0],
+        "image_sizes": batch["image_sizes"].tolist()[0]
+    }
+    # input_ids =  list(itertools.chain(*(t for t in dialog_tokens))),
+    # labels = list(itertools.chain(*(t for t in labels_tokens))),
+    # attention_mask = [1]*len(list(itertools.chain(*(t for t in dialog_tokens)))),
+    # pixel_values =  batch["pixel_values"],
+    # image_sizes = batch["image_sizes"]
+#    print("combined_tokens",combined_tokens[image_sizes])
+
+    return combined_tokens
+def image_tokenize(sample, processor):
+    processor.tokenizer.padding_side = "right" # during training, one always uses padding on the right
+    images,sample_text = sample["images"],sample["messages"]
+    dialog = []
+    for line in sample_text:
+        content = []
+        messages = line["content"]
+        role = line["role"]
+        for message in messages:
+            if message["type"] == "image":
+                content.append({"type": "image"})
+            elif message["type"] == "text":
+                content.append({"type": "text", "text": message["text"].strip()})
+        dialog.append({"role": role,"content":content})
+    return tokenize_dialog(dialog,images, processor)
+
+
+def get_custom_dataset(dataset_config, processor, split, split_ratio=0.9):
+    # load_dataset will return DatasetDict that contains all the data in the train set
+    dataset_dict = load_dataset("remyxai/vqasynth_spacellava")
+    dataset = dataset_dict[split]
+    dataset = dataset.select(range(100))
+    tokenized_datasets = dataset.map(lambda x: image_tokenize(x, processor))
+    tokenized_datasets = tokenized_datasets.remove_columns(dataset.column_names)
+    return tokenized_datasets

+ 1 - 1
src/llama_recipes/configs/datasets.py

@@ -37,4 +37,4 @@ class custom_dataset:
 class llamaguard_toxicchat_dataset:
 class llamaguard_toxicchat_dataset:
     dataset: str = "llamaguard_toxicchat_dataset"
     dataset: str = "llamaguard_toxicchat_dataset"
     train_split: str = "train"
     train_split: str = "train"
-    test_split: str = "test"
+    test_split: str = "test"

+ 15 - 7
src/llama_recipes/finetuning.py

@@ -22,6 +22,11 @@ from transformers import (
     BitsAndBytesConfig,
     BitsAndBytesConfig,
     LlamaForCausalLM,
     LlamaForCausalLM,
     LlamaConfig,
     LlamaConfig,
+    AutoConfig, 
+    AutoModel,
+    LlavaNextForConditionalGeneration,
+    LlavaNextProcessor
+
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
@@ -116,11 +121,11 @@ def main(**kwargs):
         bnb_config = quant_config.create_bnb_config(train_config.quantization)
         bnb_config = quant_config.create_bnb_config(train_config.quantization)
 
 
     # Load the pre-trained model and setup its configuration
     # Load the pre-trained model and setup its configuration
-    use_cache = False if train_config.enable_fsdp else None
-    model = LlamaForCausalLM.from_pretrained(
+    #use_cache = False if train_config.enable_fsdp else None
+    model = LlavaNextForConditionalGeneration.from_pretrained(
         train_config.model_name,
         train_config.model_name,
         quantization_config=bnb_config,
         quantization_config=bnb_config,
-        use_cache=use_cache,
+    #    use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -129,7 +134,8 @@ def main(**kwargs):
     # Load the tokenizer and add special tokens
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
     tokenizer.pad_token_id = tokenizer.eos_token_id
-
+    processor = LlavaNextProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+    processor.tokenizer.padding_side='right'
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # If there is a mismatch between tokenizer vocab size and embedding matrix,
     # throw a warning and then expand the embedding matrix
     # throw a warning and then expand the embedding matrix
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
     if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
@@ -200,7 +206,7 @@ def main(**kwargs):
 
 
      # Load and preprocess the dataset for training and validation
      # Load and preprocess the dataset for training and validation
     dataset_train = get_preprocessed_dataset(
     dataset_train = get_preprocessed_dataset(
-        tokenizer,
+        processor,
         dataset_config,
         dataset_config,
         split="train",
         split="train",
     )
     )
@@ -208,7 +214,7 @@ def main(**kwargs):
         print(f"--> Training Set Length = {len(dataset_train)}")
         print(f"--> Training Set Length = {len(dataset_train)}")
 
 
     dataset_val = get_preprocessed_dataset(
     dataset_val = get_preprocessed_dataset(
-        tokenizer,
+        processor,
         dataset_config,
         dataset_config,
         split="test",
         split="test",
     )
     )
@@ -219,7 +225,7 @@ def main(**kwargs):
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
         dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
 
     train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
     train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
-
+    print("length of dataset_train", len(dataset_train))
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
         dataset_train,
@@ -227,6 +233,7 @@ def main(**kwargs):
         pin_memory=True,
         pin_memory=True,
         **train_dl_kwargs,
         **train_dl_kwargs,
     )
     )
+    print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
 
 
     eval_dataloader = None
     eval_dataloader = None
     if train_config.run_validation:
     if train_config.run_validation:
@@ -241,6 +248,7 @@ def main(**kwargs):
             pin_memory=True,
             pin_memory=True,
             **val_dl_kwargs,
             **val_dl_kwargs,
         )
         )
+        print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
         if len(eval_dataloader) == 0:
         if len(eval_dataloader) == 0:
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
             raise ValueError("The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set.")
         else:
         else:

+ 0 - 1
src/llama_recipes/utils/config_utils.py

@@ -104,5 +104,4 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
             kwargs["collate_fn"] = default_data_collator
             kwargs["collate_fn"] = default_data_collator
         else:
         else:
             raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
             raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
-
         return kwargs
         return kwargs

+ 9 - 2
src/llama_recipes/utils/train_utils.py

@@ -118,6 +118,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     max_steps_reached = False  # Flag to indicate max training steps reached
     max_steps_reached = False  # Flag to indicate max training steps reached
     # Start the training loop
     # Start the training loop
     for epoch in range(train_config.num_epochs):
     for epoch in range(train_config.num_epochs):
+        print(f"Starting epoch {epoch}/{train_config.num_epochs}")
+        print(f"train_config.max_train_step: {train_config.max_train_step}")
         # stop when the maximum number of training steps is reached
         # stop when the maximum number of training steps is reached
         if max_steps_reached:
         if max_steps_reached:
             break
             break
@@ -130,6 +132,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             with profile(train_config,local_rank) as profile_context:
             with profile(train_config,local_rank) as profile_context:
                 for step, batch in enumerate(train_dataloader):
                 for step, batch in enumerate(train_dataloader):
                     total_train_steps += 1
                     total_train_steps += 1
+                    #print("batch: ", batch)
                     # stop when the maximum number of training steps is reached
                     # stop when the maximum number of training steps is reached
                     if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                     if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
                         max_steps_reached = True
                         max_steps_reached = True
@@ -149,8 +152,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             else:
                             else:
                                 batch[key] = batch[key].to('cuda:0')
                                 batch[key] = batch[key].to('cuda:0')
                     with autocast():
                     with autocast():
+                        assert(next(model.parameters()).device == batch['input_ids'].device)
+                        #print("batch: ", batch)
                         loss = model(**batch).loss
                         loss = model(**batch).loss
                     loss = loss / gradient_accumulation_steps
                     loss = loss / gradient_accumulation_steps
+                    #print("loss",loss)
                     if train_config.save_metrics:
                     if train_config.save_metrics:
                         train_step_loss.append(loss.detach().float().item())
                         train_step_loss.append(loss.detach().float().item())
                         train_step_perplexity.append(float(torch.exp(loss.detach().float())))
                         train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -171,6 +177,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             pbar.update(1)
                             pbar.update(1)
                     else:
                     else:
                         # regular backpropagation when fp16 is not used
                         # regular backpropagation when fp16 is not used
+                        #print("loss123",loss)
                         loss.backward()
                         loss.backward()
                         if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                             if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
                             if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
@@ -243,12 +250,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
 
 
                 else:
                 else:
-                    if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
+                    if not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
 
 
                         save_model_checkpoint(
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
-                    elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
+                    elif not train_config.use_peft and fsdp_config and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         print("=====================================================")