|
@@ -5,17 +5,6 @@ import pytest
|
|
|
from functools import partial
|
|
|
from unittest.mock import patch
|
|
|
|
|
|
-EXPECTED_RESULTS = {
|
|
|
- "meta-llama/Llama-2-7b-hf":{
|
|
|
- "label": 8432,
|
|
|
- "pos": 242,
|
|
|
- },
|
|
|
- "meta-llama/Meta-Llama-3.1-8B":{
|
|
|
- "label": 2250,
|
|
|
- "pos": 211,
|
|
|
- },
|
|
|
-}
|
|
|
-
|
|
|
@pytest.mark.skip_missing_tokenizer
|
|
|
@patch('llama_recipes.finetuning.train')
|
|
|
@patch('llama_recipes.finetuning.AutoTokenizer')
|
|
@@ -59,9 +48,6 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
|
|
|
assert "input_ids" in batch.keys()
|
|
|
assert "attention_mask" in batch.keys()
|
|
|
|
|
|
- assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
|
|
|
- assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
|
|
|
-
|
|
|
assert batch["input_ids"][0][0] == token.bos_token_id
|
|
|
assert batch["labels"][0][-1] == token.eos_token_id
|
|
|
assert batch["input_ids"][0][-1] == token.eos_token_id
|