Forráskód Böngészése

Fix tests for custom dataset, grammar, batching, chat_completion

Matthias Reso 7 hónapja
szülő
commit
9f5200670d

+ 2 - 2
src/tests/datasets/test_custom_dataset.py

@@ -37,7 +37,7 @@ def check_padded_entry(batch, tokenizer):
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
-@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
@@ -97,7 +97,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
-@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')

+ 1 - 1
src/tests/datasets/test_grammar_datasets.py

@@ -13,7 +13,7 @@ DATA_DIR = Path(__file__).parents[2] / "llama_recipes/datasets/grammar_dataset/"
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.AutoConfig.from_pretrained')
-@patch('llama_recipes.finetuning.AutoModel.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_grammar_dataset(step_lr, optimizer, get_model, get_config, tokenizer, train, setup_tokenizer, llama_version):

+ 1 - 1
src/tests/test_batching.py

@@ -9,7 +9,7 @@ EXPECTED_SAMPLE_NUMBER ={
         "train": 96,
         "eval": 42,
     },
-    "meta-llama/Meta-Llama-3.1-8B": {
+    "meta-llama/Meta-Llama-3.1-8B-Instruct": {
         "train": 79,
         "eval": 34,
     }

+ 19 - 73
src/tests/test_chat_completion.py

@@ -1,6 +1,6 @@
 import sys
 from pathlib import Path
-from typing import List, Literal, TypedDict
+from typing import List, TypedDict
 from unittest.mock import patch
 
 import pytest
@@ -8,46 +8,37 @@ import torch
 from llama_recipes.inference.chat_utils import read_dialogs_from_file
 
 ROOT_DIR = Path(__file__).parents[2]
-CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/quickstart/inference/local_inference/chat_completion/"
 
 sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
 
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-
+default_system_prompt = [{"role": "system", "content": "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"}]
 
 def _encode_header(message, tokenizer):
     tokens = []
-    tokens.extend(tokenizer.encode("<|start_header_id|>"))
-    tokens.extend(tokenizer.encode(message["role"]))
-    tokens.extend(tokenizer.encode("<|end_header_id|>"))
-    tokens.extend(tokenizer.encode("\n\n"))
+    tokens.extend(tokenizer.encode("<|start_header_id|>", add_special_tokens=False))
+    tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
+    tokens.extend(tokenizer.encode("<|end_header_id|>", add_special_tokens=False))
+    tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
     return tokens
 
 
 def _encode_message(message, tokenizer):
     tokens = _encode_header(message, tokenizer)
-    tokens.extend(tokenizer.encode(message["content"].strip()))
-    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    tokens.extend(tokenizer.encode(message["content"], add_special_tokens=False))
+    tokens.extend(tokenizer.encode("<|eot_id|>", add_special_tokens=False))
     return tokens
 
 
 def _format_dialog(dialog, tokenizer):
     tokens = []
-    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    tokens.extend(tokenizer.encode("<|begin_of_text|>", add_special_tokens=False))
+    if dialog[0]["role"] == "system":
+        dialog[0]["content"] = default_system_prompt[0]["content"] + dialog[0]["content"]
+    else:
+        dialog = default_system_prompt + dialog
     for msg in dialog:
         tokens.extend(_encode_message(msg, tokenizer))
-    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
     return tokens
 
 
@@ -55,59 +46,19 @@ def _format_tokens_llama3(dialogs, tokenizer):
     return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
 
 
-def _format_tokens_llama2(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-                {
-                    "role": dialog[1]["role"],
-                    "content": B_SYS
-                    + dialog[0]["content"]
-                    + E_SYS
-                    + dialog[1]["content"],
-                }
-            ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                )
-                + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-
-
 @pytest.mark.skip_missing_tokenizer
 @patch("chat_completion.AutoTokenizer")
 @patch("chat_completion.load_model")
 def test_chat_completion(
     load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
 ):
+    if "Llama-2" in llama_version:
+        pytest.skip("skipping test for Llama-2")
+
     from chat_completion import main
 
     setup_tokenizer(tokenizer)
-    load_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
+    load_model.return_value.get_input_embeddings.return_value.weight.shape = [128256]
 
     kwargs = {
         "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
@@ -116,13 +67,8 @@ def test_chat_completion(
     main(llama_version, **kwargs)
 
     dialogs = read_dialogs_from_file(kwargs["prompt_file"])
-    format_tokens = (
-        _format_tokens_llama2
-        if llama_version == "meta-llama/Llama-2-7b-hf"
-        else _format_tokens_llama3
-    )
 
-    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+    REF_RESULT = _format_tokens_llama3(dialogs, llama_tokenizer[llama_version])
 
     assert all(
         (