Selaa lähdekoodia

Fix test_custom_dataset.py

Matthias Reso 7 kuukautta sitten
vanhempi
commit
dd8ca3c211
1 muutettua tiedostoa jossa 5 lisäystä ja 3 poistoa
  1. 5 3
      src/tests/datasets/test_custom_dataset.py

+ 5 - 3
src/tests/datasets/test_custom_dataset.py

@@ -37,7 +37,7 @@ def check_padded_entry(batch, tokenizer):
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
@@ -96,15 +96,17 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
 
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
+@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
 @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker, llama_version):
+def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, get_config, train, mocker, llama_version):
     from llama_recipes.finetuning import main
 
     tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_config.return_value.model_type = "llama"
 
     kwargs = {
         "dataset": "custom_dataset",