Kai Wu преди 7 месеца
родител
ревизия
8a11b48022
променени са 3 файла, в които са добавени 15 реда и са изтрити 10 реда
  1. 5 1
      recipes/quickstart/finetuning/datasets/vqa_dataset.py
  2. 7 9
      src/llama_recipes/finetuning.py
  3. 3 0
      src/llama_recipes/policies/wrapping.py

+ 5 - 1
recipes/quickstart/finetuning/datasets/vqa_dataset.py

@@ -39,9 +39,13 @@ def tokenize_dialogs(dialogs, images, processor):
                 labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
             else:
                 last_idx = idx+1
-            # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
+            #  Mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007]
         assistant_header_seq = [128006, 78191, 128007]
         labels = replace_target(assistant_header_seq,labels)
+        # Mask the padding token and image token 128256 
+        for i in range(len(labels)):
+            if labels[i] == processor.tokenizer.pad_token_id or labels[i] == 128256: #  128256 is image token index
+                labels[i] = -100
         label_list.append(labels)
     batch["labels"] = torch.tensor(label_list)
     tokenizer_length = len(processor.tokenizer)

+ 7 - 9
src/llama_recipes/finetuning.py

@@ -137,6 +137,7 @@ def main(**kwargs):
         processor = AutoProcessor.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
         processor.tokenizer.padding_side='right'
     else:
+        is_vision = False
         model = LlamaForCausalLM.from_pretrained(
             train_config.model_name,
             quantization_config=bnb_config,
@@ -188,23 +189,20 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [CLIPEncoderLayer])
+        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
+        # if is_vision:
+        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
+        # else:
+        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
         print("FSDP is enabled",my_auto_wrapping_policy)
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()
         elif torch.cuda.is_available():
             device_id = torch.cuda.current_device()
-        if train_config.use_peft:
-            wrapping_policy = my_auto_wrapping_policy
-        else:
-            if is_vision:
-                wrapping_policy = ModuleWrapPolicy([CLIPEncoderLayer, LlamaDecoderLayer])
-            else:
-                wrapping_policy = ModuleWrapPolicy([LlamaDecoderLayer])
         model = FSDP(
             model,
-            auto_wrap_policy= wrapping_policy,
+            auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,

+ 3 - 0
src/llama_recipes/policies/wrapping.py

@@ -4,6 +4,8 @@
 import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
+from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
+
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
     size_based_auto_wrap_policy,
@@ -27,6 +29,7 @@ def get_llama_wrapper():
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             LlamaDecoderLayer,
+            CLIPEncoderLayer
         },
     )