test_batching.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from dataclasses import dataclass
  5. from contextlib import nullcontext
  6. from unittest.mock import patch
  7. @dataclass
  8. class Config:
  9. model_type: str = "llama"
  10. EXPECTED_SAMPLE_NUMBER ={
  11. "meta-llama/Llama-2-7b-hf": {
  12. "train": 96,
  13. "eval": 42,
  14. },
  15. "meta-llama/Meta-Llama-3.1-8B-Instruct": {
  16. "train": 79,
  17. "eval": 34,
  18. },
  19. "fake_llama": {
  20. "train": 50,
  21. "eval": 21,
  22. }
  23. }
  24. @pytest.mark.skip_missing_tokenizer
  25. @patch('llama_recipes.finetuning.train')
  26. @patch('llama_recipes.finetuning.AutoTokenizer')
  27. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  28. @patch("llama_recipes.finetuning.AutoProcessor")
  29. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  30. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  31. @patch('llama_recipes.finetuning.optim.AdamW')
  32. @patch('llama_recipes.finetuning.StepLR')
  33. def test_packing(
  34. step_lr,
  35. optimizer,
  36. get_model,
  37. get_mmodel,
  38. processor,
  39. get_config,
  40. tokenizer,
  41. train,
  42. setup_tokenizer,
  43. setup_processor,
  44. llama_version,
  45. model_type,
  46. ):
  47. from llama_recipes.finetuning import main
  48. setup_tokenizer(tokenizer)
  49. setup_processor(processor)
  50. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  51. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  52. get_config.return_value = Config(model_type=model_type)
  53. kwargs = {
  54. "model_name": llama_version,
  55. "batch_size_training": 8,
  56. "val_batch_size": 1,
  57. "use_peft": False,
  58. "dataset": "samsum_dataset",
  59. "batching_strategy": "packing",
  60. }
  61. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  62. with c:
  63. main(**kwargs)
  64. if model_type == "llama":
  65. assert train.call_count == 1
  66. args, kwargs = train.call_args
  67. train_dataloader = args[1]
  68. eval_dataloader = args[2]
  69. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
  70. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
  71. batch = next(iter(train_dataloader))
  72. assert "labels" in batch.keys()
  73. assert "input_ids" in batch.keys()
  74. assert "attention_mask" in batch.keys()
  75. assert batch["labels"][0].size(0) == 4096
  76. assert batch["input_ids"][0].size(0) == 4096
  77. assert batch["attention_mask"][0].size(0) == 4096
  78. @pytest.mark.skip_missing_tokenizer
  79. @patch("llama_recipes.utils.train_utils.torch.cuda.is_bf16_supported")
  80. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  81. @patch('llama_recipes.finetuning.train')
  82. @patch('llama_recipes.finetuning.AutoTokenizer')
  83. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  84. @patch("llama_recipes.finetuning.AutoProcessor")
  85. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  86. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  87. @patch('llama_recipes.finetuning.optim.AdamW')
  88. @patch('llama_recipes.finetuning.StepLR')
  89. @patch('llama_recipes.finetuning.setup')
  90. @patch('llama_recipes.finetuning.FSDP')
  91. @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
  92. @patch('llama_recipes.utils.config_utils.dist')
  93. def test_distributed_packing(
  94. dist,
  95. is_initialized,
  96. fsdp,
  97. setup,
  98. step_lr,
  99. optimizer,
  100. get_model,
  101. get_mmodel,
  102. processor,
  103. get_config,
  104. tokenizer,
  105. train,
  106. cuda_is_available,
  107. cuda_is_bf16_supported,
  108. setup_tokenizer,
  109. setup_processor,
  110. llama_version,
  111. model_type,
  112. ):
  113. import os
  114. from llama_recipes.finetuning import main
  115. setup_tokenizer(tokenizer)
  116. setup_processor(processor)
  117. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  118. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  119. get_config.return_value = Config(model_type=model_type)
  120. cuda_is_available.return_value = False
  121. cuda_is_bf16_supported.return_value = False
  122. rank = 1
  123. os.environ['LOCAL_RANK'] = f'{rank}'
  124. os.environ['RANK'] = f'{rank}'
  125. os.environ['WORLD_SIZE'] = '2'
  126. os.environ['MASTER_ADDR'] = 'localhost'
  127. os.environ['MASTER_PORT'] = '12345'
  128. kwargs = {
  129. "model_name": llama_version,
  130. "batch_size_training": 8,
  131. "val_batch_size": 1,
  132. "use_peft": False,
  133. "dataset": "samsum_dataset",
  134. "batching_strategy": "packing",
  135. "enable_fsdp": True
  136. }
  137. is_initialized.return_value = True
  138. dist.get_rank.return_value = rank
  139. dist.get_world_size.return_value = 2
  140. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  141. with c:
  142. main(**kwargs)
  143. if model_type == "llama":
  144. assert train.call_count == 1
  145. args, kwargs = train.call_args
  146. train_dataloader = args[1]
  147. eval_dataloader = args[2]
  148. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
  149. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2