test_finetuning.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. from contextlib import nullcontext
  5. from dataclasses import dataclass
  6. from unittest.mock import patch
  7. import pytest
  8. import torch
  9. from llama_cookbook.data.sampler import LengthBasedBatchSampler
  10. from llama_cookbook.finetuning import main
  11. from pytest import approx
  12. from torch.optim import AdamW
  13. from torch.utils.data.dataloader import DataLoader
  14. from torch.utils.data.sampler import BatchSampler
  15. @dataclass
  16. class Config:
  17. model_type: str = "llama"
  18. def get_fake_dataset():
  19. return 8192*[
  20. {
  21. "input_ids": [1],
  22. "attention_mask": [1],
  23. "labels": [1],
  24. }
  25. ]
  26. @patch("llama_cookbook.finetuning.torch.cuda.is_available")
  27. @patch("llama_cookbook.finetuning.train")
  28. @patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
  29. @patch("llama_cookbook.finetuning.AutoProcessor.from_pretrained")
  30. @patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
  31. @patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
  32. @patch("llama_cookbook.finetuning.AutoTokenizer.from_pretrained")
  33. @patch("llama_cookbook.finetuning.get_preprocessed_dataset")
  34. @patch("llama_cookbook.finetuning.generate_peft_config")
  35. @patch("llama_cookbook.finetuning.get_peft_model")
  36. @patch("llama_cookbook.finetuning.optim.AdamW")
  37. @patch("llama_cookbook.finetuning.StepLR")
  38. @pytest.mark.parametrize("cuda_is_available", [True, False])
  39. @pytest.mark.parametrize("run_validation", [True, False])
  40. @pytest.mark.parametrize("use_peft", [True, False])
  41. def test_finetuning(
  42. step_lr,
  43. optimizer,
  44. get_peft_model,
  45. gen_peft_config,
  46. get_dataset,
  47. tokenizer,
  48. get_config,
  49. get_model,
  50. get_processor,
  51. get_mmodel,
  52. train,
  53. cuda,
  54. cuda_is_available,
  55. run_validation,
  56. use_peft,
  57. model_type,
  58. ):
  59. kwargs = {
  60. "run_validation": run_validation,
  61. "use_peft": use_peft,
  62. "batching_strategy": "packing" if model_type == "llama" else "padding",
  63. }
  64. get_dataset.return_value = get_fake_dataset()
  65. cuda.return_value = cuda_is_available
  66. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  67. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  68. get_config.return_value = Config(model_type=model_type)
  69. main(**kwargs)
  70. assert train.call_count == 1
  71. args, kwargs = train.call_args
  72. train_dataloader = args[1]
  73. eval_dataloader = args[2]
  74. assert isinstance(train_dataloader, DataLoader)
  75. if run_validation:
  76. assert isinstance(eval_dataloader, DataLoader)
  77. else:
  78. assert eval_dataloader is None
  79. if use_peft:
  80. assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
  81. model = get_peft_model
  82. elif model_type == "llama":
  83. model = get_model
  84. else:
  85. model = get_mmodel
  86. if cuda_is_available:
  87. assert model.return_value.to.call_count == 1
  88. assert model.return_value.to.call_args.args[0] == "cuda"
  89. else:
  90. assert model.return_value.to.call_count == 0
  91. @patch("llama_cookbook.finetuning.get_peft_model")
  92. @patch("llama_cookbook.finetuning.setup")
  93. @patch("llama_cookbook.finetuning.train")
  94. @patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
  95. @patch("llama_cookbook.finetuning.AutoProcessor.from_pretrained")
  96. @patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
  97. @patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
  98. @patch("llama_cookbook.finetuning.AutoTokenizer.from_pretrained")
  99. @patch("llama_cookbook.finetuning.get_preprocessed_dataset")
  100. def test_finetuning_peft_llama_adapter(
  101. get_dataset,
  102. tokenizer,
  103. get_config,
  104. get_model,
  105. get_processor,
  106. get_mmodel,
  107. train,
  108. setup,
  109. get_peft_model,
  110. model_type,
  111. ):
  112. kwargs = {
  113. "use_peft": True,
  114. "peft_method": "llama_adapter",
  115. "enable_fsdp": True,
  116. "batching_strategy": "packing" if model_type == "llama" else "padding",
  117. }
  118. get_dataset.return_value = get_fake_dataset()
  119. get_model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  120. get_mmodel.return_value.get_input_embeddings.return_value.weight.shape = [0]
  121. get_config.return_value = Config(model_type=model_type)
  122. os.environ["RANK"] = "0"
  123. os.environ["LOCAL_RANK"] = "0"
  124. os.environ["WORLD_SIZE"] = "1"
  125. os.environ["MASTER_ADDR"] = "localhost"
  126. os.environ["MASTER_PORT"] = "12345"
  127. with pytest.raises(
  128. RuntimeError,
  129. match="Llama_adapter is currently not supported in combination with FSDP",
  130. ):
  131. main(**kwargs)
  132. GET_ME_OUT = "Get me out of here"
  133. get_peft_model.side_effect = RuntimeError(GET_ME_OUT)
  134. kwargs["enable_fsdp"] = False
  135. with pytest.raises(
  136. RuntimeError,
  137. match=GET_ME_OUT,
  138. ):
  139. main(**kwargs)
  140. @patch("llama_cookbook.finetuning.train")
  141. @patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
  142. @patch("llama_cookbook.finetuning.AutoProcessor.from_pretrained")
  143. @patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
  144. @patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
  145. @patch("llama_cookbook.finetuning.AutoTokenizer.from_pretrained")
  146. @patch("llama_cookbook.finetuning.get_preprocessed_dataset")
  147. @patch("llama_cookbook.finetuning.get_peft_model")
  148. @patch("llama_cookbook.finetuning.StepLR")
  149. def test_finetuning_weight_decay(
  150. step_lr,
  151. get_peft_model,
  152. get_dataset,
  153. tokenizer,
  154. get_config,
  155. get_model,
  156. get_processor,
  157. get_mmodel,
  158. train,
  159. model_type,
  160. ):
  161. kwargs = {
  162. "weight_decay": 0.01,
  163. "batching_strategy": "packing" if model_type == "llama" else "padding",
  164. }
  165. get_dataset.return_value = get_fake_dataset()
  166. model = get_model if model_type == "llama" else get_mmodel
  167. model.return_value.parameters.return_value = [torch.ones(1, 1)]
  168. model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  169. get_config.return_value = Config(model_type=model_type)
  170. main(**kwargs)
  171. assert train.call_count == 1
  172. args, kwargs = train.call_args
  173. optimizer = args[4]
  174. assert isinstance(optimizer, AdamW)
  175. assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
  176. @patch("llama_cookbook.finetuning.train")
  177. @patch("llama_cookbook.finetuning.MllamaForConditionalGeneration.from_pretrained")
  178. @patch("llama_cookbook.finetuning.AutoProcessor.from_pretrained")
  179. @patch("llama_cookbook.finetuning.LlamaForCausalLM.from_pretrained")
  180. @patch("llama_cookbook.finetuning.AutoConfig.from_pretrained")
  181. @patch("llama_cookbook.finetuning.AutoTokenizer.from_pretrained")
  182. @patch("llama_cookbook.finetuning.get_preprocessed_dataset")
  183. @patch("llama_cookbook.finetuning.optim.AdamW")
  184. @patch("llama_cookbook.finetuning.StepLR")
  185. def test_batching_strategy(
  186. step_lr,
  187. optimizer,
  188. get_dataset,
  189. tokenizer,
  190. get_config,
  191. get_model,
  192. get_processor,
  193. get_mmodel,
  194. train,
  195. model_type,
  196. ):
  197. kwargs = {
  198. "batching_strategy": "packing",
  199. }
  200. get_dataset.return_value = get_fake_dataset()
  201. model = get_model if model_type == "llama" else get_mmodel
  202. model.return_value.get_input_embeddings.return_value.weight.shape = [0]
  203. get_config.return_value = Config(model_type=model_type)
  204. c = nullcontext() if model_type == "llama" else pytest.raises(ValueError)
  205. with c:
  206. main(**kwargs)
  207. assert train.call_count == (1 if model_type == "llama" else 0)
  208. if model_type == "llama":
  209. args, kwargs = train.call_args
  210. train_dataloader, eval_dataloader = args[1:3]
  211. assert isinstance(train_dataloader.batch_sampler, BatchSampler)
  212. assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
  213. kwargs["batching_strategy"] = "padding"
  214. train.reset_mock()
  215. main(**kwargs)
  216. assert train.call_count == 1
  217. args, kwargs = train.call_args
  218. train_dataloader, eval_dataloader = args[1:3]
  219. assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
  220. assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
  221. kwargs["batching_strategy"] = "none"
  222. with pytest.raises(ValueError):
  223. main(**kwargs)