finetuning.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import dataclasses
  4. import os
  5. import random
  6. from collections import Counter
  7. from warnings import warn
  8. import fire
  9. import numpy as np
  10. import torch
  11. import torch.optim as optim
  12. from accelerate.utils import is_xpu_available
  13. from llama_recipes.configs import (
  14. fsdp_config as FSDP_CONFIG,
  15. quantization_config as QUANTIZATION_CONFIG,
  16. train_config as TRAIN_CONFIG,
  17. )
  18. from llama_recipes.data.concatenator import ConcatDataset
  19. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  20. from llama_recipes.utils import fsdp_auto_wrap_policy
  21. from llama_recipes.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. from llama_recipes.utils.dataset_utils import (
  29. get_custom_data_collator,
  30. get_preprocessed_dataset,
  31. )
  32. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  33. from llama_recipes.utils.train_utils import (
  34. clear_gpu_cache,
  35. freeze_transformer_layers,
  36. get_policies,
  37. print_model_size,
  38. setup,
  39. setup_environ_flags,
  40. train,
  41. )
  42. from peft import get_peft_model, PeftModel
  43. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
  44. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  45. from torch.optim.lr_scheduler import StepLR
  46. from transformers import (
  47. AutoConfig,
  48. AutoProcessor,
  49. AutoTokenizer,
  50. BitsAndBytesConfig,
  51. LlamaForCausalLM,
  52. MllamaForConditionalGeneration,
  53. )
  54. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  55. from transformers.models.mllama.modeling_mllama import (
  56. MllamaCrossAttentionDecoderLayer,
  57. MllamaSelfAttentionDecoderLayer,
  58. MllamaVisionEncoderLayer,
  59. )
  60. def setup_wandb(train_config, fsdp_config, **kwargs):
  61. try:
  62. import wandb
  63. except ImportError:
  64. raise ImportError(
  65. "You are trying to use wandb which is not currently installed. "
  66. "Please install it using pip install wandb"
  67. )
  68. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  69. wandb_config = WANDB_CONFIG()
  70. update_config(wandb_config, **kwargs)
  71. init_dict = dataclasses.asdict(wandb_config)
  72. run = wandb.init(**init_dict)
  73. run.config.update(train_config)
  74. run.config.update(fsdp_config, allow_val_change=True)
  75. return run
  76. def main(**kwargs):
  77. # Update the configuration for the training and sharding process
  78. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  79. update_config((train_config, fsdp_config), **kwargs)
  80. # Set the seeds for reproducibility
  81. if is_xpu_available():
  82. torch.xpu.manual_seed(train_config.seed)
  83. torch.manual_seed(train_config.seed)
  84. random.seed(train_config.seed)
  85. np.random.seed(train_config.seed)
  86. if train_config.enable_fsdp:
  87. setup()
  88. # torchrun specific
  89. local_rank = int(os.environ["LOCAL_RANK"])
  90. rank = int(os.environ["RANK"])
  91. world_size = int(os.environ["WORLD_SIZE"])
  92. if torch.distributed.is_initialized():
  93. if is_xpu_available():
  94. torch.xpu.set_device(local_rank)
  95. elif torch.cuda.is_available():
  96. torch.cuda.set_device(local_rank)
  97. clear_gpu_cache(local_rank)
  98. setup_environ_flags(rank)
  99. wandb_run = None
  100. if train_config.use_wandb:
  101. if not train_config.enable_fsdp or rank == 0:
  102. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  103. # setting quantization configs
  104. bnb_config = None
  105. if train_config.quantization:
  106. if type(train_config.quantization) == type(True):
  107. warn(
  108. "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
  109. FutureWarning,
  110. )
  111. train_config.quantization = "8bit"
  112. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  113. raise ValueError(
  114. "8bit quantization is not supported with FSDP, please use 4bit quantization"
  115. )
  116. quant_config = QUANTIZATION_CONFIG()
  117. update_config(quant_config, **kwargs)
  118. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  119. # Load the pre-trained model and setup its configuration
  120. use_cache = False if train_config.enable_fsdp else None
  121. config = AutoConfig.from_pretrained(train_config.model_name)
  122. if config.model_type == "mllama":
  123. is_vision = True
  124. model = MllamaForConditionalGeneration.from_pretrained(
  125. train_config.model_name,
  126. quantization_config=bnb_config,
  127. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  128. device_map=(
  129. "auto"
  130. if train_config.quantization and not train_config.enable_fsdp
  131. else None
  132. ),
  133. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  134. )
  135. processor = AutoProcessor.from_pretrained(
  136. train_config.model_name
  137. if train_config.tokenizer_name is None
  138. else train_config.tokenizer_name
  139. )
  140. processor.tokenizer.padding_side = "right"
  141. model.supports_gradient_checkpointing = True
  142. model.language_model.supports_gradient_checkpointing = True
  143. elif config.model_type == "llama":
  144. is_vision = False
  145. model = LlamaForCausalLM.from_pretrained(
  146. train_config.model_name,
  147. quantization_config=bnb_config,
  148. use_cache=use_cache,
  149. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  150. device_map=(
  151. "auto"
  152. if train_config.quantization and not train_config.enable_fsdp
  153. else None
  154. ),
  155. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  156. )
  157. else:
  158. raise ValueError(
  159. f"Model type {config.model_type} is not supported. Please use llama or mllama model."
  160. )
  161. # Load the tokenizer and add special tokens
  162. tokenizer = AutoTokenizer.from_pretrained(
  163. train_config.model_name
  164. if train_config.tokenizer_name is None
  165. else train_config.tokenizer_name
  166. )
  167. if not tokenizer.pad_token_id:
  168. tokenizer.pad_token_id = tokenizer.eos_token_id
  169. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  170. # throw a warning and then expand the embedding matrix
  171. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  172. print(
  173. "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
  174. )
  175. model.resize_token_embeddings(len(tokenizer))
  176. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  177. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  178. if (
  179. train_config.enable_fsdp
  180. and fsdp_config.pure_bf16
  181. and not train_config.quantization
  182. ):
  183. model.to(torch.bfloat16)
  184. if train_config.use_peft:
  185. # Load the pre-trained peft model checkpoint and setup its configuration
  186. if train_config.from_peft_checkpoint:
  187. model = PeftModel.from_pretrained(
  188. model, train_config.from_peft_checkpoint, is_trainable=True
  189. )
  190. peft_config = model.peft_config
  191. # Generate the peft config and start fine-tuning from original model
  192. else:
  193. peft_config = generate_peft_config(train_config, kwargs)
  194. model = get_peft_model(model, peft_config)
  195. if wandb_run:
  196. wandb_run.config.update(peft_config)
  197. model.print_trainable_parameters()
  198. hsdp_device_mesh_plan = None
  199. if (
  200. fsdp_config.hsdp
  201. and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
  202. ):
  203. hsdp_device_mesh_plan = hsdp_device_mesh(
  204. replica_group_size=fsdp_config.replica_group_size,
  205. sharding_group_size=fsdp_config.sharding_group_size,
  206. )
  207. print("HSDP device mesh is ready")
  208. # setting up FSDP if enable_fsdp is enabled
  209. if train_config.enable_fsdp:
  210. check_fsdp_config(fsdp_config)
  211. if not train_config.use_peft and train_config.freeze_layers:
  212. freeze_transformer_layers(model, train_config.num_freeze_layers)
  213. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  214. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaSelfAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  215. if is_vision:
  216. my_auto_wrapping_policy = fsdp_auto_wrap_policy(
  217. model,
  218. [
  219. MllamaSelfAttentionDecoderLayer,
  220. MllamaSelfAttentionDecoderLayer,
  221. MllamaVisionEncoderLayer,
  222. ],
  223. )
  224. else:
  225. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  226. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  227. device_id = 0
  228. if is_xpu_available():
  229. device_id = torch.xpu.current_device()
  230. elif torch.cuda.is_available():
  231. device_id = torch.cuda.current_device()
  232. model = FSDP(
  233. model,
  234. auto_wrap_policy=(
  235. my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
  236. ),
  237. cpu_offload=(
  238. CPUOffload(offload_params=True)
  239. if fsdp_config.fsdp_cpu_offload
  240. else None
  241. ),
  242. mixed_precision=(
  243. mixed_precision_policy if not fsdp_config.pure_bf16 else None
  244. ),
  245. sharding_strategy=fsdp_config.sharding_strategy,
  246. device_mesh=hsdp_device_mesh_plan,
  247. device_id=device_id,
  248. limit_all_gathers=True,
  249. sync_module_states=train_config.low_cpu_fsdp,
  250. param_init_fn=(
  251. (
  252. lambda module: module.to_empty(
  253. device=torch.device("cuda"), recurse=False
  254. )
  255. )
  256. if train_config.low_cpu_fsdp and rank != 0
  257. else None
  258. ),
  259. )
  260. if fsdp_config.fsdp_activation_checkpointing:
  261. model.enable_input_require_grads()
  262. model.gradient_checkpointing_enable()
  263. apply_fsdp_checkpointing(model)
  264. elif not train_config.quantization and not train_config.enable_fsdp:
  265. if is_xpu_available():
  266. model.to("xpu:0")
  267. elif torch.cuda.is_available():
  268. model.to("cuda")
  269. dataset_config = generate_dataset_config(train_config, kwargs)
  270. if is_vision:
  271. dataset_processer = processor
  272. else:
  273. dataset_processer = tokenizer
  274. # Load and preprocess the dataset for training and validation
  275. dataset_train = get_preprocessed_dataset(
  276. dataset_processer,
  277. dataset_config,
  278. split="train",
  279. )
  280. if not train_config.enable_fsdp or rank == 0:
  281. print(f"--> Training Set Length = {len(dataset_train)}")
  282. dataset_val = get_preprocessed_dataset(
  283. dataset_processer,
  284. dataset_config,
  285. split="test",
  286. )
  287. if not train_config.enable_fsdp or rank == 0:
  288. print(f"--> Validation Set Length = {len(dataset_val)}")
  289. if train_config.batching_strategy == "packing":
  290. if is_vision:
  291. raise ValueError("Packing is not supported for vision datasets")
  292. else:
  293. dataset_train = ConcatDataset(
  294. dataset_train, chunk_size=train_config.context_length
  295. )
  296. train_dl_kwargs = get_dataloader_kwargs(
  297. train_config, dataset_train, dataset_processer, "train"
  298. )
  299. print("length of dataset_train", len(dataset_train))
  300. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
  301. if custom_data_collator:
  302. print("custom_data_collator is used")
  303. train_dl_kwargs["collate_fn"] = custom_data_collator
  304. # Create DataLoaders for the training and validation dataset
  305. train_dataloader = torch.utils.data.DataLoader(
  306. dataset_train,
  307. num_workers=train_config.num_workers_dataloader,
  308. pin_memory=True,
  309. **train_dl_kwargs,
  310. )
  311. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  312. eval_dataloader = None
  313. if train_config.run_validation:
  314. if train_config.batching_strategy == "packing":
  315. if is_vision:
  316. raise ValueError("Packing is not supported for vision datasets")
  317. else:
  318. dataset_val = ConcatDataset(
  319. dataset_val, chunk_size=train_config.context_length
  320. )
  321. val_dl_kwargs = get_dataloader_kwargs(
  322. train_config, dataset_val, dataset_processer, "val"
  323. )
  324. if custom_data_collator:
  325. val_dl_kwargs["collate_fn"] = custom_data_collator
  326. eval_dataloader = torch.utils.data.DataLoader(
  327. dataset_val,
  328. num_workers=train_config.num_workers_dataloader,
  329. pin_memory=True,
  330. **val_dl_kwargs,
  331. )
  332. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  333. if len(eval_dataloader) == 0:
  334. raise ValueError(
  335. f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
  336. )
  337. else:
  338. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  339. # Initialize the optimizer and learning rate scheduler
  340. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  341. optimizer = AnyPrecisionAdamW(
  342. model.parameters(),
  343. lr=train_config.lr,
  344. momentum_dtype=torch.bfloat16,
  345. variance_dtype=torch.bfloat16,
  346. use_kahan_summation=False,
  347. weight_decay=train_config.weight_decay,
  348. )
  349. else:
  350. optimizer = optim.AdamW(
  351. model.parameters(),
  352. lr=train_config.lr,
  353. weight_decay=train_config.weight_decay,
  354. )
  355. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  356. results = train(
  357. model,
  358. train_dataloader,
  359. eval_dataloader,
  360. tokenizer,
  361. optimizer,
  362. scheduler,
  363. train_config.gradient_accumulation_steps,
  364. train_config,
  365. fsdp_config if train_config.enable_fsdp else None,
  366. local_rank if train_config.enable_fsdp else None,
  367. rank if train_config.enable_fsdp else None,
  368. wandb_run,
  369. )
  370. if not train_config.enable_fsdp or rank == 0:
  371. [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
  372. if train_config.use_wandb:
  373. for k, v in results.items():
  374. wandb_run.summary[k] = v
  375. if __name__ == "__main__":
  376. fire.Fire(main)