Bläddra i källkod

Add unit test for config check

Matthias Reso 11 månader sedan
förälder
incheckning
853f928987
1 ändrade filer med 56 tillägg och 7 borttagningar
  1. 56 7
      tests/test_finetuning.py

+ 56 - 7
tests/test_finetuning.py

@@ -110,13 +110,62 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
-@patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
-@patch('llama_recipes.finetuning.get_preprocessed_dataset')
-@patch('llama_recipes.finetuning.get_peft_model')
-@patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
+@patch("llama_recipes.finetuning.get_peft_model")
+@patch("llama_recipes.finetuning.setup")
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+def test_finetuning_peft_llama_adapter(
+    get_dataset, tokenizer, get_model, train, setup, get_peft_model, mocker
+):
+    kwargs = {
+        "use_peft": True,
+        "peft_method": "llama_adapter",
+        "enable_fsdp": True,
+    }
+
+    get_dataset.return_value = get_fake_dataset()
+
+    model = mocker.MagicMock(name="Model")
+    model.parameters.return_value = [torch.ones(1, 1)]
+    model.get_input_embeddings.return_value.weight.shape = [0]
+
+    get_model.return_value = model
+
+    os.environ["RANK"] = "0"
+    os.environ["LOCAL_RANK"] = "0"
+    os.environ["WORLD_SIZE"] = "1"
+    os.environ["MASTER_ADDR"] = "localhost"
+    os.environ["MASTER_PORT"] = "12345"
+
+    with pytest.raises(
+        RuntimeError,
+        match="Llama_adapter is currently not supported in combination with FSDP",
+    ):
+        main(**kwargs)
+
+    GET_ME_OUT = "Get me out of here"
+    get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
+
+    kwargs["enable_fsdp"] = False
+
+    with pytest.raises(
+        RuntimeError,
+        match=GET_ME_OUT,
+    ):
+        main(**kwargs)
+
+
+@patch("llama_recipes.finetuning.train")
+@patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
+@patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
+@patch("llama_recipes.finetuning.get_preprocessed_dataset")
+@patch("llama_recipes.finetuning.get_peft_model")
+@patch("llama_recipes.finetuning.StepLR")
+def test_finetuning_weight_decay(
+    step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker
+):
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()