test_chat_completion.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. import sys
  2. from pathlib import Path
  3. from typing import List, TypedDict
  4. from unittest.mock import patch
  5. import pytest
  6. import torch
  7. from llama_cookbook.inference.chat_utils import read_dialogs_from_file
  8. ROOT_DIR = Path(__file__).parents[2]
  9. CHAT_COMPLETION_DIR = ROOT_DIR / "getting-started/inference/local_inference/chat_completion/"
  10. sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
  11. default_system_prompt = [{"role": "system", "content": "Cutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n"}]
  12. def _encode_header(message, tokenizer):
  13. tokens = []
  14. tokens.extend(tokenizer.encode("<|start_header_id|>", add_special_tokens=False))
  15. tokens.extend(tokenizer.encode(message["role"], add_special_tokens=False))
  16. tokens.extend(tokenizer.encode("<|end_header_id|>", add_special_tokens=False))
  17. tokens.extend(tokenizer.encode("\n\n", add_special_tokens=False))
  18. return tokens
  19. def _encode_message(message, tokenizer):
  20. tokens = _encode_header(message, tokenizer)
  21. tokens.extend(tokenizer.encode(message["content"], add_special_tokens=False))
  22. tokens.extend(tokenizer.encode("<|eot_id|>", add_special_tokens=False))
  23. return tokens
  24. def _format_dialog(dialog, tokenizer):
  25. tokens = []
  26. tokens.extend(tokenizer.encode("<|begin_of_text|>", add_special_tokens=False))
  27. if dialog[0]["role"] == "system":
  28. dialog[0]["content"] = default_system_prompt[0]["content"] + dialog[0]["content"]
  29. else:
  30. dialog = default_system_prompt + dialog
  31. for msg in dialog:
  32. tokens.extend(_encode_message(msg, tokenizer))
  33. return tokens
  34. def _format_tokens_llama3(dialogs, tokenizer):
  35. return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
  36. @pytest.mark.skip_missing_tokenizer
  37. @patch("chat_completion.AutoTokenizer")
  38. @patch("chat_completion.load_model")
  39. def test_chat_completion(
  40. load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
  41. ):
  42. if "Llama-2" in llama_version or llama_version == "fake_llama":
  43. pytest.skip(f"skipping test for {llama_version}")
  44. from chat_completion import main
  45. setup_tokenizer(tokenizer)
  46. load_model.return_value.get_input_embeddings.return_value.weight.shape = [128256]
  47. kwargs = {
  48. "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
  49. }
  50. main(llama_version, **kwargs)
  51. dialogs = read_dialogs_from_file(kwargs["prompt_file"])
  52. REF_RESULT = _format_tokens_llama3(dialogs, llama_tokenizer[llama_version])
  53. assert all(
  54. (
  55. load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
  56. == torch.tensor(REF_RESULT[0]).long()
  57. ).tolist()
  58. )
  59. assert all(
  60. (
  61. load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
  62. == torch.tensor(REF_RESULT[1]).long()
  63. ).tolist()
  64. )
  65. assert all(
  66. (
  67. load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
  68. == torch.tensor(REF_RESULT[2]).long()
  69. ).tolist()
  70. )
  71. assert all(
  72. (
  73. load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
  74. == torch.tensor(REF_RESULT[3]).long()
  75. ).tolist()
  76. )
  77. assert all(
  78. (
  79. load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
  80. == torch.tensor(REF_RESULT[4]).long()
  81. ).tolist()
  82. )