瀏覽代碼

Fix test_grammar_dataset.py

Matthias Reso 7 月之前
父節點
當前提交
6a0f956831
共有 1 個文件被更改,包括 8 次插入16 次删除
  1. 8 16
      src/tests/datasets/test_grammar_datasets.py

+ 8 - 16
src/tests/datasets/test_grammar_datasets.py

@@ -1,32 +1,27 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from pathlib import Path
 import pytest
 from unittest.mock import patch
 
-
-EXPECTED_RESULTS = {
-    "meta-llama/Llama-2-7b-hf":{
-        "label": 1152,
-        "pos": 31,
-    },
-    "meta-llama/Meta-Llama-3.1-8B":{
-        "label": 40,
-        "pos": 26,
-    },
-}
+DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
 
 @pytest.mark.skip_missing_tokenizer
+@pytest.mark.skipif(not Path(DATA_DIR / "grammar_validation.csv").exists(), reason="grammar_validation.csv not found")
+@pytest.mark.skipif(not Path(DATA_DIR / "gtrain_10k.csv").exists(), reason="gtrain_10k.csv not found")
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
+@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
+def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
     get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    get_config.return_value.model_type = "llama"
 
     BATCH_SIZE = 8
     kwargs = {
@@ -58,9 +53,6 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
-    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
-
     token = args[3]
     assert batch["input_ids"][0][0] == token.bos_token_id
     assert batch["labels"][0][-1] == token.eos_token_id