test_finetuning.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from unittest.mock import patch
  5. import pytest
  6. import torch
  7. from llama_recipes.data.sampler import LengthBasedBatchSampler
  8. from llama_recipes.finetuning import main
  9. from pytest import approx
  10. from torch.optim import AdamW
  11. from torch.utils.data.dataloader import DataLoader
  12. from torch.utils.data.sampler import BatchSampler
  13. def get_fake_dataset():
  14. return [
  15. {
  16. "input_ids": [1],
  17. "attention_mask": [1],
  18. "labels": [1],
  19. }
  20. ]
  21. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  22. @patch("llama_recipes.finetuning.train")
  23. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  24. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  25. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  26. @patch("llama_recipes.finetuning.optim.AdamW")
  27. @patch("llama_recipes.finetuning.StepLR")
  28. @pytest.mark.parametrize("cuda_is_available", [True, False])
  29. def test_finetuning_no_validation(
  30. step_lr,
  31. optimizer,
  32. get_dataset,
  33. tokenizer,
  34. get_model,
  35. train,
  36. cuda,
  37. cuda_is_available,
  38. ):
  39. kwargs = {"run_validation": False}
  40. get_dataset.return_value = get_fake_dataset()
  41. cuda.return_value = cuda_is_available
  42. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  43. main(**kwargs)
  44. assert train.call_count == 1
  45. args, kwargs = train.call_args
  46. train_dataloader = args[1]
  47. eval_dataloader = args[2]
  48. assert isinstance(train_dataloader, DataLoader)
  49. assert eval_dataloader is None
  50. if cuda_is_available:
  51. assert get_model.return_value.to.call_count == 1
  52. assert get_model.return_value.to.call_args.args[0] == "cuda"
  53. else:
  54. assert get_model.return_value.to.call_count == 0
  55. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  56. @patch("llama_recipes.finetuning.train")
  57. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  58. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  59. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  60. @patch("llama_recipes.finetuning.optim.AdamW")
  61. @patch("llama_recipes.finetuning.StepLR")
  62. @pytest.mark.parametrize("cuda_is_available", [True, False])
  63. def test_finetuning_with_validation(
  64. step_lr,
  65. optimizer,
  66. get_dataset,
  67. tokenizer,
  68. get_model,
  69. train,
  70. cuda,
  71. cuda_is_available,
  72. ):
  73. kwargs = {"run_validation": True}
  74. get_dataset.return_value = get_fake_dataset()
  75. cuda.return_value = cuda_is_available
  76. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  77. main(**kwargs)
  78. assert train.call_count == 1
  79. args, kwargs = train.call_args
  80. train_dataloader = args[1]
  81. eval_dataloader = args[2]
  82. assert isinstance(train_dataloader, DataLoader)
  83. assert isinstance(eval_dataloader, DataLoader)
  84. if cuda_is_available:
  85. assert get_model.return_value.to.call_count == 1
  86. assert get_model.return_value.to.call_args.args[0] == "cuda"
  87. else:
  88. assert get_model.return_value.to.call_count == 0
  89. @patch("llama_recipes.finetuning.torch.cuda.is_available")
  90. @patch("llama_recipes.finetuning.train")
  91. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  92. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  93. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  94. @patch("llama_recipes.finetuning.generate_peft_config")
  95. @patch("llama_recipes.finetuning.get_peft_model")
  96. @patch("llama_recipes.finetuning.optim.AdamW")
  97. @patch("llama_recipes.finetuning.StepLR")
  98. @pytest.mark.parametrize("cuda_is_available", [True, False])
  99. def test_finetuning_peft_lora(
  100. step_lr,
  101. optimizer,
  102. get_peft_model,
  103. gen_peft_config,
  104. get_dataset,
  105. tokenizer,
  106. get_model,
  107. train,
  108. cuda,
  109. cuda_is_available,
  110. ):
  111. kwargs = {"use_peft": True}
  112. get_dataset.return_value = get_fake_dataset()
  113. cuda.return_value = cuda_is_available
  114. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  115. main(**kwargs)
  116. if cuda_is_available:
  117. assert get_peft_model.return_value.to.call_count == 1
  118. assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
  119. else:
  120. assert get_peft_model.return_value.to.call_count == 0
  121. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
  122. @patch("llama_recipes.finetuning.get_peft_model")
  123. @patch("llama_recipes.finetuning.setup")
  124. @patch("llama_recipes.finetuning.train")
  125. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  126. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  127. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  128. def test_finetuning_peft_llama_adapter(
  129. get_dataset, tokenizer, get_model, train, setup, get_peft_model
  130. ):
  131. kwargs = {
  132. "use_peft": True,
  133. "peft_method": "llama_adapter",
  134. "enable_fsdp": True,
  135. }
  136. get_dataset.return_value = get_fake_dataset()
  137. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  138. os.environ["RANK"] = "0"
  139. os.environ["LOCAL_RANK"] = "0"
  140. os.environ["WORLD_SIZE"] = "1"
  141. os.environ["MASTER_ADDR"] = "localhost"
  142. os.environ["MASTER_PORT"] = "12345"
  143. with pytest.raises(
  144. RuntimeError,
  145. match="Llama_adapter is currently not supported in combination with FSDP",
  146. ):
  147. main(**kwargs)
  148. GET_ME_OUT = "Get me out of here"
  149. get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
  150. kwargs["enable_fsdp"] = False
  151. with pytest.raises(
  152. RuntimeError,
  153. match=GET_ME_OUT,
  154. ):
  155. main(**kwargs)
  156. @patch("llama_recipes.finetuning.train")
  157. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  158. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  159. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  160. @patch("llama_recipes.finetuning.get_peft_model")
  161. @patch("llama_recipes.finetuning.StepLR")
  162. def test_finetuning_weight_decay(
  163. step_lr, get_peft_model, get_dataset, tokenizer, get_model, train
  164. ):
  165. kwargs = {"weight_decay": 0.01}
  166. get_dataset.return_value = get_fake_dataset()
  167. get_model.return_value.parameters.return_value = [torch.ones(1, 1)]
  168. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  169. main(**kwargs)
  170. assert train.call_count == 1
  171. args, kwargs = train.call_args
  172. optimizer = args[4]
  173. print(optimizer.state_dict())
  174. assert isinstance(optimizer, AdamW)
  175. assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
  176. @patch("llama_recipes.finetuning.train")
  177. @patch("llama_recipes.finetuning.LlamaForCausalLM.from_pretrained")
  178. @patch("llama_recipes.finetuning.AutoTokenizer.from_pretrained")
  179. @patch("llama_recipes.finetuning.get_preprocessed_dataset")
  180. @patch("llama_recipes.finetuning.optim.AdamW")
  181. @patch("llama_recipes.finetuning.StepLR")
  182. def test_batching_strategy(
  183. step_lr, optimizer, get_dataset, tokenizer, get_model, train
  184. ):
  185. kwargs = {"batching_strategy": "packing"}
  186. get_dataset.return_value = get_fake_dataset()
  187. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  188. main(**kwargs)
  189. assert train.call_count == 1
  190. args, kwargs = train.call_args
  191. train_dataloader, eval_dataloader = args[1:3]
  192. assert isinstance(train_dataloader.batch_sampler, BatchSampler)
  193. assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
  194. kwargs["batching_strategy"] = "padding"
  195. train.reset_mock()
  196. main(**kwargs)
  197. assert train.call_count == 1
  198. args, kwargs = train.call_args
  199. train_dataloader, eval_dataloader = args[1:3]
  200. assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
  201. assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
  202. kwargs["batching_strategy"] = "none"
  203. with pytest.raises(ValueError):
  204. main(**kwargs)