|
@@ -110,13 +110,62 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
|
|
|
assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
|
|
|
|
|
|
|
|
|
-@patch('llama_recipes.finetuning.train')
|
|
|
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
|
|
|
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
|
|
|
-@patch('llama_recipes.finetuning.get_peft_model')
|
|
|
-@patch('llama_recipes.finetuning.StepLR')
|
|
|
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
|
|
|
+@patch("llama_recipes.finetuning.get_peft_model")
|
|
|
+@patch("llama_recipes.finetuning.setup")
|
|
|
+@patch("llama_recipes.finetuning.train")
|
|
|
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
|
|
|
+def test_finetuning_peft_llama_adapter(
|
|
|
+ get_dataset, tokenizer, get_model, train, setup, get_peft_model, mocker
|
|
|
+):
|
|
|
+ kwargs = {
|
|
|
+ "use_peft": True,
|
|
|
+ "peft_method": "llama_adapter",
|
|
|
+ "enable_fsdp": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ get_dataset.return_value = get_fake_dataset()
|
|
|
+
|
|
|
+ model = mocker.MagicMock(name="Model")
|
|
|
+ model.parameters.return_value = [torch.ones(1, 1)]
|
|
|
+ model.get_input_embeddings.return_value.weight.shape = [0]
|
|
|
+
|
|
|
+ get_model.return_value = model
|
|
|
+
|
|
|
+ os.environ["RANK"] = "0"
|
|
|
+ os.environ["LOCAL_RANK"] = "0"
|
|
|
+ os.environ["WORLD_SIZE"] = "1"
|
|
|
+ os.environ["MASTER_ADDR"] = "localhost"
|
|
|
+ os.environ["MASTER_PORT"] = "12345"
|
|
|
+
|
|
|
+ with pytest.raises(
|
|
|
+ RuntimeError,
|
|
|
+ match="Llama_adapter is currently not supported in combination with FSDP",
|
|
|
+ ):
|
|
|
+ main(**kwargs)
|
|
|
+
|
|
|
+ GET_ME_OUT = "Get me out of here"
|
|
|
+ get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
|
|
|
+
|
|
|
+ kwargs["enable_fsdp"] = False
|
|
|
+
|
|
|
+ with pytest.raises(
|
|
|
+ RuntimeError,
|
|
|
+ match=GET_ME_OUT,
|
|
|
+ ):
|
|
|
+ main(**kwargs)
|
|
|
+
|
|
|
+
|
|
|
+@patch("llama_recipes.finetuning.train")
|
|
|
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
|
|
|
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
|
|
|
+@patch("llama_recipes.finetuning.get_peft_model")
|
|
|
+@patch("llama_recipes.finetuning.StepLR")
|
|
|
+def test_finetuning_weight_decay(
|
|
|
+ step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker
|
|
|
+):
|
|
|
kwargs = {"weight_decay": 0.01}
|
|
|
|
|
|
get_dataset.return_value = get_fake_dataset()
|