瀏覽代碼

Disable prefix tuning and limit llama adapter (#482)

Hamid Shojanazeri 11 月之前
父節點
當前提交
5f11aeb88a

+ 1 - 1
recipes/finetuning/README.md

@@ -70,7 +70,7 @@ It lets us specify the training settings for everything from `model_name` to `da
 
 * [Datasets config file](../../src/llama_recipes/configs/datasets.py) provides the available options for datasets.
 
-* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified.
+* [peft config file](../../src/llama_recipes/configs/peft.py) provides the supported PEFT methods and respective settings that can be modified. We currently support LoRA and Llama-Adapter. Please note that LoRA is the only technique which is supported in combination with FSDP.
 
 * [FSDP config file](../../src/llama_recipes/configs/fsdp.py) provides FSDP settings such as:
 

+ 2 - 1
src/llama_recipes/configs/peft.py

@@ -20,7 +20,8 @@ class llama_adapter_config:
      adapter_layers: int= 30
      task_type: str= "CAUSAL_LM"
 
+#CAUTION prefix tuning is currently not supported
 @dataclass
 class prefix_config:
      num_virtual_tokens: int=30
-     task_type: str= "CAUSAL_LM"    
+     task_type: str= "CAUSAL_LM"

+ 1 - 1
src/llama_recipes/configs/training.py

@@ -29,7 +29,7 @@ class train_config:
     mixed_precision: bool=True
     val_batch_size: int=1
     dataset = "samsum_dataset"
-    peft_method: str = "lora" # None,llama_adapter, prefix
+    peft_method: str = "lora" # None, llama_adapter (Caution: llama_adapter is currently not supported with FSDP)
     use_peft: bool=False
     output_dir: str = "PATH/to/save/PEFT/model"
     freeze_layers: bool = False

+ 8 - 1
src/llama_recipes/utils/config_utils.py

@@ -45,7 +45,14 @@ def generate_peft_config(train_config, kwargs):
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
 
-    assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
+    if train_config.peft_method not in names:
+        raise RuntimeError(f"Peft config not found: {train_config.peft_method}")
+
+    if train_config.peft_method == "prefix":
+        raise RuntimeError("PrefixTuning is currently not supported (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089350811)")
+
+    if train_config.enable_fsdp and train_config.peft_method == "llama_adapter":
+        raise RuntimeError("Llama_adapter is currently not supported in combination with FSDP (see https://github.com/meta-llama/llama-recipes/issues/359#issuecomment-2089274425)")
 
     config = configs[names.index(train_config.peft_method)]()
 

+ 0 - 8
src/llama_recipes/utils/fsdp_utils.py

@@ -8,8 +8,6 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
 
     from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
 
-    from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder
-
     def lambda_policy_fn(module):
         if (
             len(list(module.named_children())) == 0
@@ -23,13 +21,7 @@ def fsdp_auto_wrap_policy(model, transformer_layer_name):
     transformer_wrap_policy = functools.partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls=(
-            PrefixEncoder,
-            PromptEncoder,
-            PromptEmbedding,
             transformer_layer_name,
-            # FullyShardedDataParallelPlugin.get_module_class_from_name(
-            #     model, transformer_layer_name
-            # ),
         ),
     )
 

+ 143 - 55
tests/test_finetuning.py

@@ -1,40 +1,56 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
-import pytest
-from pytest import approx
+import os
 from unittest.mock import patch
 
+import pytest
+
 import torch
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+from llama_recipes.finetuning import main
+from pytest import approx
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.sampler import BatchSampler
 
-from llama_recipes.finetuning import main
-from llama_recipes.data.sampler import LengthBasedBatchSampler
-
 
 def get_fake_dataset():
-    return [{
-        "input_ids":[1],
-        "attention_mask":[1],
-        "labels":[1],
-        }]
-
-@patch('llama_recipes.finetuning.torch.cuda.is_available')
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.optim.AdamW')
-@patch('llama_recipes.finetuning.StepLR')
+    return [
+        {
+            "input_ids": [1],
+            "attention_mask": [1],
+            "labels": [1],
+        }
+    ]
+
+
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.optim.AdamW")
+@patch("llama_recipes.finetuning.StepLR")
 @pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
+def test_finetuning_no_validation(
+    step_lr,
+    optimizer,
+    get_dataset,
+    tokenizer,
+    get_model,
+    train,
+    cuda,
+    cuda_is_available,
+):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
     cuda.return_value = cuda_is_available
 
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+
     main(**kwargs)
 
     assert train.call_count == 1
@@ -53,20 +69,31 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
         assert get_model.return_value.to.call_count == 0
 
 
-@patch('llama_recipes.finetuning.torch.cuda.is_available')
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.optim.AdamW')
-@patch('llama_recipes.finetuning.StepLR')
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.optim.AdamW")
+@patch("llama_recipes.finetuning.StepLR")
 @pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
+def test_finetuning_with_validation(
+    step_lr,
+    optimizer,
+    get_dataset,
+    tokenizer,
+    get_model,
+    train,
+    cuda,
+    cuda_is_available,
+):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
     cuda.return_value = cuda_is_available
 
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+
     main(**kwargs)
 
     assert train.call_count == 1
@@ -83,22 +110,36 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     else:
         assert get_model.return_value.to.call_count == 0
 
-@patch('llama_recipes.finetuning.torch.cuda.is_available')
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.generate_peft_config')
-@patch('llama_recipes.finetuning.get_peft_model')
-@patch('llama_recipes.finetuning.optim.AdamW')
-@patch('llama_recipes.finetuning.StepLR')
+
+@patch("llama_recipes.finetuning.torch.cuda.is_available")
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.generate_peft_config")
+@patch("llama_recipes.finetuning.get_peft_model")
+@patch("llama_recipes.finetuning.optim.AdamW")
+@patch("llama_recipes.finetuning.StepLR")
 @pytest.mark.parametrize("cuda_is_available", [True, False])
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
+def test_finetuning_peft_lora(
+    step_lr,
+    optimizer,
+    get_peft_model,
+    gen_peft_config,
+    get_dataset,
+    tokenizer,
+    get_model,
+    train,
+    cuda,
+    cuda_is_available,
+):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
     cuda.return_value = cuda_is_available
 
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+
     main(**kwargs)
 
     if cuda_is_available:
@@ -110,21 +151,64 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.get_peft_model')
-@patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
-    kwargs = {"weight_decay": 0.01}
+@patch("llama_recipes.finetuning.get_peft_model")
+@patch("llama_recipes.finetuning.setup")
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+def test_finetuning_peft_llama_adapter(
+    get_dataset, tokenizer, get_model, train, setup, get_peft_model
+):
+    kwargs = {
+        "use_peft": True,
+        "peft_method": "llama_adapter",
+        "enable_fsdp": True,
+    }
 
     get_dataset.return_value = get_fake_dataset()
 
-    model = mocker.MagicMock(name="Model")
-    model.parameters.return_value = [torch.ones(1,1)]
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+
+    os.environ["RANK"] = "0"
+    os.environ["LOCAL_RANK"] = "0"
+    os.environ["WORLD_SIZE"] = "1"
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "12345"
+
+    with pytest.raises(
+        RuntimeError,
+        match="Llama_adapter is currently not supported in combination with FSDP",
+    ):
+        main(**kwargs)
+
+    GET_ME_OUT = "Get me out of here"
+    get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
+
+    kwargs["enable_fsdp"] = False
+
+    with pytest.raises(
+        RuntimeError,
+        match=GET_ME_OUT,
+    ):
+        main(**kwargs)
+
 
-    get_model.return_value = model
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.get_peft_model")
+@patch("llama_recipes.finetuning.StepLR")
+def test_finetuning_weight_decay(
+    step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
+):
+    kwargs = {"weight_decay": 0.01}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
 
     main(**kwargs)
 
@@ -139,17 +223,21 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
 
 
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.optim.AdamW')
-@patch('llama_recipes.finetuning.StepLR')
-def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.optim.AdamW")
+@patch("llama_recipes.finetuning.StepLR")
+def test_batching_strategy(
+    step_lr, optimizer, get_dataset, tokenizer, get_model, train
+):
     kwargs = {"batching_strategy": "packing"}
 
     get_dataset.return_value = get_fake_dataset()
 
+    get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
+
     main(**kwargs)
 
     assert train.call_count == 1