Browse Source

Add unit test for tokenizer_dialog for custom dataset

Matthias Reso 8 months ago
parent
commit
d45222ec1a
2 changed files with 34 additions and 2 deletions
  1. 1 1
      src/tests/conftest.py
  2. 33 1
      src/tests/datasets/test_custom_dataset.py

+ 1 - 1
src/tests/conftest.py

@@ -6,7 +6,7 @@ import pytest
 from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
-LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B"]
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Meta-Llama-3.1-8B-Instruct"]
 
 @pytest.fixture(params=LLAMA_VERSIONS)
 def llama_version(request):

+ 33 - 1
src/tests/datasets/test_custom_dataset.py

@@ -11,7 +11,7 @@ EXPECTED_RESULTS={
         "example_1": "[INST] Who made Berlin [/INST] dunno",
         "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
     },
-    "meta-llama/Meta-Llama-3.1-8B":{
+    "meta-llama/Meta-Llama-3.1-8B-Instruct":{
         "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
         "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
     },
@@ -114,3 +114,35 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
         }
     with pytest.raises(AttributeError):
         main(**kwargs)
+
+@pytest.mark.skip_missing_tokenizer
+@patch('llama_recipes.finetuning.AutoTokenizer')
+def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version):
+    monkeypatch.syspath_prepend("recipes/quickstart/finetuning/datasets/")
+    from custom_dataset import tokenize_dialog
+
+    setup_tokenizer(tokenizer)
+    tokenizer = tokenizer.from_pretrained()
+
+    dialog = [
+        {"role":"user", "content":"Who made Berlin?"},
+        {"role":"assistant", "content":"dunno"},
+        {"role":"user", "content":"And Rome?"},
+        {"role":"assistant", "content":"Romans"},
+    ]
+
+    result = tokenize_dialog(dialog, tokenizer)
+    print(f"{tokenizer.encode('system')=}")
+    print(f"{tokenizer.encode('user')=}")
+    print(f"{tokenizer.encode('assistant')=}")
+    print(f"{tokenizer.decode(result['input_ids'])=}")
+    print(f"{result['labels']=}")
+
+    if "Llama-2" in llama_version:
+        assert result["labels"][:12] == [-100] * 12
+        assert result["labels"][17:28] == [-100] * 11
+        assert result["labels"].count(-100) == 11 + 12
+    else:
+        assert result["labels"][:35] == [-100] * 35
+        assert result["labels"][42:51] == [-100] * 9
+        assert result["labels"].count(-100) == 35 + 9