|
@@ -1,32 +1,27 @@
|
|
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
|
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
|
|
|
|
+from pathlib import Path
|
|
|
import pytest
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
-
|
|
|
-EXPECTED_RESULTS = {
|
|
|
- "meta-llama/Llama-2-7b-hf":{
|
|
|
- "label": 1152,
|
|
|
- "pos": 31,
|
|
|
- },
|
|
|
- "meta-llama/Meta-Llama-3.1-8B":{
|
|
|
- "label": 40,
|
|
|
- "pos": 26,
|
|
|
- },
|
|
|
-}
|
|
|
+DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
|
|
|
|
|
|
@pytest.mark.skip_missing_tokenizer
|
|
|
+@pytest.mark.skipif(not Path(DATA_DIR / "grammar_validation.csv").exists(), reason="grammar_validation.csv not found")
|
|
|
+@pytest.mark.skipif(not Path(DATA_DIR / "gtrain_10k.csv").exists(), reason="gtrain_10k.csv not found")
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.AutoTokenizer')
|
|
|
-@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
|
|
|
+@patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
|
|
|
+@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
|
|
|
@patch('llama_recipes.finetuning.optim.AdamW')
|
|
|
@patch('llama_recipes.finetuning.StepLR')
|
|
|
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
|
|
|
+def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):
|
|
|
from llama_recipes.finetuning import main
|
|
|
|
|
|
setup_tokenizer(tokenizer)
|
|
|
get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
|
|
|
+ get_config.return_value.model_type = "llama"
|
|
|
|
|
|
BATCH_SIZE = 8
|
|
|
kwargs = {
|
|
@@ -58,9 +53,6 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_
|
|
|
assert "input_ids" in batch.keys()
|
|
|
assert "attention_mask" in batch.keys()
|
|
|
|
|
|
- assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
|
|
|
- assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
|
|
|
-
|
|
|
token = args[3]
|
|
|
assert batch["input_ids"][0][0] == token.bos_token_id
|
|
|
assert batch["labels"][0][-1] == token.eos_token_id
|