Explorar o código

lora+fsdp working

Kai Wu hai 7 meses
pai
achega
1a76080807

+ 14 - 11
src/llama_recipes/finetuning.py

@@ -30,7 +30,8 @@ from transformers import (
     MllamaForConditionalGeneration
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
+from transformers.models.mllama.modeling_mllama import  MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
+
 from llama_recipes.configs import fsdp_config as FSDP_CONFIG
 from llama_recipes.configs import train_config as TRAIN_CONFIG
 from llama_recipes.configs import quantization_config  as QUANTIZATION_CONFIG
@@ -129,7 +130,6 @@ def main(**kwargs):
         model = MllamaForConditionalGeneration.from_pretrained(
         train_config.model_name,
         quantization_config=bnb_config,
-        #use_cache=use_cache,
         attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
         torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
@@ -146,7 +146,7 @@ def main(**kwargs):
             device_map="auto" if train_config.quantization and not train_config.enable_fsdp else None,
             torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
         )
-
+    print(model)
     # Load the tokenizer and add special tokens
     tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
@@ -189,11 +189,7 @@ def main(**kwargs):
             freeze_transformer_layers(model, train_config.num_freeze_layers)
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
-        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
-        # if is_vision:
-        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,CLIPEncoderLayer])
-        # else:
-        #     my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
+        my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer])
         print("FSDP is enabled",my_auto_wrapping_policy)
         device_id = 0
         if is_xpu_available():
@@ -222,7 +218,8 @@ def main(**kwargs):
             model.to("xpu:0")
         elif torch.cuda.is_available():
             model.to("cuda")
-
+    print("-------------------")
+    print("FSDP model", model)
     dataset_config = generate_dataset_config(train_config, kwargs)
     if is_vision:
         dataset_processer = processor
@@ -248,7 +245,10 @@ def main(**kwargs):
         print(f"--> Validation Set Length = {len(dataset_val)}")
 
     if train_config.batching_strategy == "packing":
-        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+        if is_vision:
+            raise ValueError("Packing is not supported for vision datasets")
+        else:
+            dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
     train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, dataset_processer, "train")
     print("length of dataset_train", len(dataset_train))
@@ -268,7 +268,10 @@ def main(**kwargs):
     eval_dataloader = None
     if train_config.run_validation:
         if train_config.batching_strategy == "packing":
-            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+            if is_vision:
+                raise ValueError("Packing is not supported for vision datasets")
+            else:
+                dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
 
         val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, dataset_processer, "val")
         if custom_data_collator:

+ 2 - 5
src/llama_recipes/policies/wrapping.py

@@ -4,7 +4,7 @@
 import functools
 
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
-from transformers.models.clip.modeling_clip import CLIPEncoder, CLIPEncoderLayer
+from transformers.models.mllama.modeling_mllama import   MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer
 
 from torch.distributed.fsdp.wrap import (
     transformer_auto_wrap_policy,
@@ -27,10 +27,7 @@ def get_llama_wrapper():
 
     llama_auto_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
-        transformer_layer_cls={
-            LlamaDecoderLayer,
-            CLIPEncoderLayer
-        },
+        transformer_layer_cls=set([LlamaDecoderLayer, MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer,MllamaCrossAttentionDecoderLayer])
     )
 
     return llama_auto_wrap_policy

+ 7 - 12
src/llama_recipes/utils/fsdp_utils.py

@@ -16,19 +16,14 @@ def fsdp_auto_wrap_policy(model, transformer_layer_names):
         ):
             return True
         return False
-    transformer_wrap_policies = []
+
     lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn)
-    for transformer_layer_name in transformer_layer_names:
-        
-        transformer_wrap_policy = functools.partial(
-            transformer_auto_wrap_policy,
-            transformer_layer_cls=(
-                transformer_layer_name,
-            ),
-        )
-        transformer_wrap_policies.append(transformer_wrap_policy)
-    policies = transformer_wrap_policies
-    auto_wrap_policy = functools.partial(_or_policy, policies=policies)
+    transformer_wrap_policy = functools.partial(
+        transformer_auto_wrap_policy,
+        transformer_layer_cls=set(transformer_layer_names)
+    )
+
+    auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy])
     return auto_wrap_policy