test_batching.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from contextlib import nullcontext
  5. from dataclasses import dataclass
  6. from datasets import Dataset
  7. from unittest.mock import patch
  8. @dataclass
  9. class Config:
  10. model_type: str = "llama"
  11. EXPECTED_SAMPLE_NUMBER ={
  12. "meta-llama/Llama-2-7b-hf": {
  13. "train": 4,
  14. "eval": 37,
  15. },
  16. "meta-llama/Meta-Llama-3.1-8B-Instruct": {
  17. "train": 3,
  18. "eval": 30,
  19. },
  20. "fake_llama": {
  21. "train": 2,
  22. "eval": 17,
  23. }
  24. }
  25. fake_samsum_dataset = 2048*[{'id': '420',
  26. 'dialogue': "Mario: It's a me, Mario!\nLuigi: It's a me, your brother!\nMario: I'm going to save the princess.\nLuigi: I'm going to help Mario.",
  27. 'summary': 'Mario and Luigi are going to save the princess.'}]
  28. @pytest.mark.skip_missing_tokenizer
  29. @patch('llama_recipes.finetuning.train')
  30. @patch('llama_recipes.finetuning.AutoTokenizer')
  31. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  32. @patch("llama_recipes.finetuning.AutoProcessor")
  33. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  34. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  35. @patch('llama_recipes.finetuning.optim.AdamW')
  36. @patch('llama_recipes.finetuning.StepLR')
  37. @patch('llama_recipes.datasets.samsum_dataset.datasets')
  38. def test_packing(
  39. datasets,
  40. step_lr,
  41. optimizer,
  42. get_model,
  43. get_mmodel,
  44. processor,
  45. get_config,
  46. tokenizer,
  47. train,
  48. setup_tokenizer,
  49. setup_processor,
  50. llama_version,
  51. model_type,
  52. ):
  53. from llama_recipes.finetuning import main
  54. setup_tokenizer(tokenizer)
  55. setup_processor(processor)
  56. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  57. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  58. get_config.return_value = Config(model_type=model_type)
  59. datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
  60. kwargs = {
  61. "model_name": llama_version,
  62. "batch_size_training": 8,
  63. "val_batch_size": 1,
  64. "use_peft": False,
  65. "dataset": "samsum_dataset",
  66. "batching_strategy": "packing",
  67. }
  68. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  69. with c:
  70. main(**kwargs)
  71. if model_type == "llama":
  72. assert train.call_count == 1
  73. args, kwargs = train.call_args
  74. train_dataloader = args[1]
  75. eval_dataloader = args[2]
  76. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
  77. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
  78. batch = next(iter(train_dataloader))
  79. assert "labels" in batch.keys()
  80. assert "input_ids" in batch.keys()
  81. assert "attention_mask" in batch.keys()
  82. assert batch["labels"][0].size(0) == 4096
  83. assert batch["input_ids"][0].size(0) == 4096
  84. assert batch["attention_mask"][0].size(0) == 4096
  85. @pytest.mark.skip_missing_tokenizer
  86. @patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
  87. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  88. @patch('llama_recipes.finetuning.train')
  89. @patch('llama_recipes.finetuning.AutoTokenizer')
  90. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  91. @patch("llama_recipes.finetuning.AutoProcessor")
  92. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  93. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  94. @patch('llama_recipes.finetuning.optim.AdamW')
  95. @patch('llama_recipes.finetuning.StepLR')
  96. @patch('llama_recipes.finetuning.setup')
  97. @patch('llama_recipes.finetuning.FSDP')
  98. @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
  99. @patch('llama_recipes.utils.config_utils.dist')
  100. @patch('llama_recipes.datasets.samsum_dataset.datasets')
  101. def test_distributed_packing(
  102. datasets,
  103. dist,
  104. is_initialized,
  105. fsdp,
  106. setup,
  107. step_lr,
  108. optimizer,
  109. get_model,
  110. get_mmodel,
  111. processor,
  112. get_config,
  113. tokenizer,
  114. train,
  115. cuda_is_available,
  116. cuda_is_bf16_supported,
  117. setup_tokenizer,
  118. setup_processor,
  119. llama_version,
  120. model_type,
  121. ):
  122. import os
  123. from llama_recipes.finetuning import main
  124. setup_tokenizer(tokenizer)
  125. setup_processor(processor)
  126. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  127. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  128. get_config.return_value = Config(model_type=model_type)
  129. cuda_is_available.return_value = False
  130. cuda_is_bf16_supported.return_value = False
  131. datasets.load_dataset.return_value = Dataset.from_list(fake_samsum_dataset)
  132. rank = 1
  133. os.environ['LOCAL_RANK'] = f'{rank}'
  134. os.environ['RANK'] = f'{rank}'
  135. os.environ['WORLD_SIZE'] = '2'
  136. os.environ['MASTER_ADDR'] = 'localhost'
  137. os.environ['MASTER_PORT'] = '12345'
  138. kwargs = {
  139. "model_name": llama_version,
  140. "batch_size_training": 8,
  141. "val_batch_size": 1,
  142. "use_peft": False,
  143. "dataset": "samsum_dataset",
  144. "batching_strategy": "packing",
  145. "enable_fsdp": True
  146. }
  147. is_initialized.return_value = True
  148. dist.get_rank.return_value = rank
  149. dist.get_world_size.return_value = 2
  150. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  151. with c:
  152. main(**kwargs)
  153. if model_type == "llama":
  154. assert train.call_count == 1
  155. args, kwargs = train.call_args
  156. train_dataloader = args[1]
  157. eval_dataloader = args[2]
  158. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
  159. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2