test_batching.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from dataclasses import dataclass
  5. from contextlib import nullcontext
  6. from unittest.mock import patch
  7. @dataclass
  8. class Config:
  9. model_type: str = "llama"
  10. EXPECTED_SAMPLE_NUMBER ={
  11. "meta-llama/Llama-2-7b-hf": {
  12. "train": 96,
  13. "eval": 42,
  14. },
  15. "meta-llama/Meta-Llama-3.1-8B-Instruct": {
  16. "train": 79,
  17. "eval": 34,
  18. },
  19. "fake_llama": {
  20. "train": 50,
  21. "eval": 21,
  22. }
  23. }
  24. @patch('llama_recipes.finetuning.train')
  25. @patch('llama_recipes.finetuning.AutoTokenizer')
  26. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  27. @patch("llama_recipes.finetuning.AutoProcessor")
  28. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  29. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  30. @patch('llama_recipes.finetuning.optim.AdamW')
  31. @patch('llama_recipes.finetuning.StepLR')
  32. def test_packing(
  33. step_lr,
  34. optimizer,
  35. get_model,
  36. get_mmodel,
  37. processor,
  38. get_config,
  39. tokenizer,
  40. train,
  41. setup_tokenizer,
  42. setup_processor,
  43. llama_version,
  44. model_type,
  45. ):
  46. from llama_recipes.finetuning import main
  47. setup_tokenizer(tokenizer)
  48. setup_processor(processor)
  49. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  50. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  51. get_config.return_value = Config(model_type=model_type)
  52. kwargs = {
  53. "model_name": llama_version,
  54. "batch_size_training": 8,
  55. "val_batch_size": 1,
  56. "use_peft": False,
  57. "dataset": "samsum_dataset",
  58. "batching_strategy": "packing",
  59. }
  60. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  61. with c:
  62. main(**kwargs)
  63. if model_type == "llama":
  64. assert train.call_count == 1
  65. args, kwargs = train.call_args
  66. train_dataloader = args[1]
  67. eval_dataloader = args[2]
  68. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
  69. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
  70. batch = next(iter(train_dataloader))
  71. assert "labels" in batch.keys()
  72. assert "input_ids" in batch.keys()
  73. assert "attention_mask" in batch.keys()
  74. assert batch["labels"][0].size(0) == 4096
  75. assert batch["input_ids"][0].size(0) == 4096
  76. assert batch["attention_mask"][0].size(0) == 4096
  77. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  78. @patch('llama_recipes.finetuning.train')
  79. @patch('llama_recipes.finetuning.AutoTokenizer')
  80. @patch("llama_recipes.finetuning.AutoConfig.from_pretrained")
  81. @patch("llama_recipes.finetuning.AutoProcessor")
  82. @patch("llama_recipes.finetuning.MllamaForConditionalGeneration.from_pretrained")
  83. @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
  84. @patch('llama_recipes.finetuning.optim.AdamW')
  85. @patch('llama_recipes.finetuning.StepLR')
  86. @patch('llama_recipes.finetuning.setup')
  87. @patch('llama_recipes.finetuning.FSDP')
  88. @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
  89. @patch('llama_recipes.utils.config_utils.dist')
  90. def test_distributed_packing(
  91. dist,
  92. is_initialized,
  93. fsdp,
  94. setup,
  95. step_lr,
  96. optimizer,
  97. get_model,
  98. get_mmodel,
  99. processor,
  100. get_config,
  101. tokenizer,
  102. train,
  103. cuda_is_available,
  104. setup_tokenizer,
  105. setup_processor,
  106. llama_version,
  107. model_type,
  108. ):
  109. import os
  110. from llama_recipes.finetuning import main
  111. setup_tokenizer(tokenizer)
  112. setup_processor(processor)
  113. get_model.return_value.get_input_embeddings.return_value.weight.shape = [32000 if "Llama-2" in llama_version else 128256]
  114. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  115. get_config.return_value = Config(model_type=model_type)
  116. cuda_is_available.return_value = False
  117. rank = 1
  118. os.environ['LOCAL_RANK'] = f'{rank}'
  119. os.environ['RANK'] = f'{rank}'
  120. os.environ['WORLD_SIZE'] = '2'
  121. os.environ['MASTER_ADDR'] = 'localhost'
  122. os.environ['MASTER_PORT'] = '12345'
  123. kwargs = {
  124. "model_name": llama_version,
  125. "batch_size_training": 8,
  126. "val_batch_size": 1,
  127. "use_peft": False,
  128. "dataset": "samsum_dataset",
  129. "batching_strategy": "packing",
  130. "enable_fsdp": True
  131. }
  132. is_initialized.return_value = True
  133. dist.get_rank.return_value = rank
  134. dist.get_world_size.return_value = 2
  135. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  136. with c:
  137. main(**kwargs)
  138. if model_type == "llama":
  139. assert train.call_count == 1
  140. args, kwargs = train.call_args
  141. train_dataloader = args[1]
  142. eval_dataloader = args[2]
  143. assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
  144. assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2