test_custom_dataset.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from contextlib import nullcontext
  5. from unittest.mock import patch
  6. from transformers import LlamaTokenizer
  7. EXPECTED_RESULTS={
  8. "meta-llama/Llama-2-7b-hf":{
  9. "example_1": "[INST] Who made Berlin [/INST] dunno",
  10. "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
  11. },
  12. "meta-llama/Meta-Llama-3.1-8B-Instruct":{
  13. "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
  14. "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
  15. },
  16. }
  17. def check_padded_entry(batch, tokenizer):
  18. seq_len = sum(batch["attention_mask"][0])
  19. assert seq_len < len(batch["attention_mask"][0])
  20. if tokenizer.vocab_size >= 128000:
  21. END_OF_TEXT_ID = 128009
  22. else:
  23. END_OF_TEXT_ID = tokenizer.eos_token_id
  24. assert batch["labels"][0][0] == -100
  25. assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
  26. assert batch["labels"][0][-1] == -100
  27. assert batch["input_ids"][0][0] == tokenizer.bos_token_id
  28. assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
  29. @pytest.mark.skip(reason="Flakey due to random dataset order @todo fix order")
  30. @pytest.mark.skip_missing_tokenizer
  31. @patch('llama_cookbook.finetuning.train')
  32. @patch('llama_cookbook.finetuning.AutoTokenizer')
  33. @patch('llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained')
  34. @patch('llama_cookbook.finetuning.optim.AdamW')
  35. @patch('llama_cookbook.finetuning.StepLR')
  36. def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
  37. from llama_cookbook.finetuning import main
  38. setup_tokenizer(tokenizer)
  39. skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
  40. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  41. kwargs = {
  42. "dataset": "custom_dataset",
  43. "model_name": llama_version,
  44. "custom_dataset.file": "getting-started/finetuning/datasets/custom_dataset.py",
  45. "custom_dataset.train_split": "validation",
  46. "batch_size_training": 2,
  47. "val_batch_size": 4,
  48. "use_peft": False,
  49. "batching_strategy": "padding"
  50. }
  51. main(**kwargs)
  52. assert train.call_count == 1
  53. args, kwargs = train.call_args
  54. train_dataloader = args[1]
  55. eval_dataloader = args[2]
  56. tokenizer = args[3]
  57. assert len(train_dataloader) == 1120
  58. assert len(eval_dataloader) == 1120 //2
  59. it = iter(eval_dataloader)
  60. batch = next(it)
  61. STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
  62. assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
  63. assert batch["input_ids"].size(0) == 4
  64. assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
  65. check_padded_entry(batch, tokenizer)
  66. it = iter(train_dataloader)
  67. next(it)
  68. batch = next(it)
  69. STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
  70. assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
  71. assert batch["input_ids"].size(0) == 2
  72. assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
  73. check_padded_entry(batch, tokenizer)
  74. @patch('llama_cookbook.finetuning.train')
  75. @patch('llama_cookbook.finetuning.AutoConfig.from_pretrained')
  76. @patch('llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained')
  77. @patch('llama_cookbook.finetuning.AutoTokenizer.from_pretrained')
  78. @patch('llama_cookbook.finetuning.optim.AdamW')
  79. @patch('llama_cookbook.finetuning.StepLR')
  80. def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, get_config, train, mocker, llama_version):
  81. from llama_cookbook.finetuning import main
  82. tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
  83. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  84. get_config.return_value.model_type = "llama"
  85. kwargs = {
  86. "dataset": "custom_dataset",
  87. "custom_dataset.file": "getting-started/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
  88. "batch_size_training": 1,
  89. "use_peft": False,
  90. }
  91. with pytest.raises(AttributeError):
  92. main(**kwargs)
  93. @pytest.mark.skip_missing_tokenizer
  94. @patch('llama_cookbook.finetuning.AutoTokenizer')
  95. def test_tokenize_dialog(tokenizer, monkeypatch, setup_tokenizer, llama_version):
  96. monkeypatch.syspath_prepend("getting-started/finetuning/datasets/")
  97. from custom_dataset import tokenize_dialog
  98. setup_tokenizer(tokenizer)
  99. tokenizer = tokenizer.from_pretrained()
  100. dialog = [
  101. {"role":"user", "content":"Who made Berlin?"},
  102. {"role":"assistant", "content":"dunno"},
  103. {"role":"user", "content":"And Rome?"},
  104. {"role":"assistant", "content":"Romans"},
  105. ]
  106. c = pytest.raises(AttributeError) if llama_version == "fake_llama" else nullcontext()
  107. with c:
  108. result = tokenize_dialog(dialog, tokenizer)
  109. if "Llama-2" in llama_version:
  110. assert result["labels"][:12] == [-100] * 12
  111. assert result["labels"][17:28] == [-100] * 11
  112. assert result["labels"].count(-100) == 11 + 12
  113. elif "Llama-3" in llama_version:
  114. assert result["labels"][:38] == [-100] * 38
  115. assert result["labels"][43:54] == [-100] * 11
  116. assert result["labels"].count(-100) == 38 + 11