finetuning.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import dataclasses
  4. import os
  5. import random
  6. from collections import Counter
  7. from warnings import warn
  8. import fire
  9. import numpy as np
  10. import torch
  11. import torch.optim as optim
  12. from accelerate.utils import is_xpu_available
  13. from llama_recipes.configs import (
  14. fsdp_config as FSDP_CONFIG,
  15. quantization_config as QUANTIZATION_CONFIG,
  16. train_config as TRAIN_CONFIG,
  17. )
  18. from llama_recipes.data.concatenator import ConcatDataset
  19. from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  20. from llama_recipes.utils import fsdp_auto_wrap_policy
  21. from llama_recipes.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. from llama_recipes.utils.dataset_utils import (
  29. get_custom_data_collator,
  30. get_preprocessed_dataset,
  31. )
  32. from llama_recipes.utils.fsdp_utils import hsdp_device_mesh
  33. from llama_recipes.utils.train_utils import (
  34. clear_gpu_cache,
  35. freeze_transformer_layers,
  36. freeze_LLM_only,
  37. get_policies,
  38. print_model_size,
  39. print_frozen_model_status,
  40. setup,
  41. setup_environ_flags,
  42. train,
  43. )
  44. from peft import get_peft_model, PeftModel
  45. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
  46. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  47. from torch.optim.lr_scheduler import StepLR
  48. from transformers import (
  49. AutoConfig,
  50. AutoProcessor,
  51. AutoTokenizer,
  52. BitsAndBytesConfig,
  53. LlamaForCausalLM,
  54. MllamaForConditionalGeneration,
  55. )
  56. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  57. from transformers.models.mllama.modeling_mllama import (
  58. MllamaCrossAttentionDecoderLayer,
  59. MllamaSelfAttentionDecoderLayer,
  60. MllamaVisionEncoderLayer,
  61. )
  62. def setup_wandb(train_config, fsdp_config, **kwargs):
  63. try:
  64. import wandb
  65. except ImportError:
  66. raise ImportError(
  67. "You are trying to use wandb which is not currently installed. "
  68. "Please install it using pip install wandb"
  69. )
  70. from llama_recipes.configs import wandb_config as WANDB_CONFIG
  71. wandb_config = WANDB_CONFIG()
  72. update_config(wandb_config, **kwargs)
  73. init_dict = dataclasses.asdict(wandb_config)
  74. run = wandb.init(**init_dict)
  75. run.config.update(train_config)
  76. run.config.update(fsdp_config, allow_val_change=True)
  77. return run
  78. def main(**kwargs):
  79. # Update the configuration for the training and sharding process
  80. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  81. update_config((train_config, fsdp_config), **kwargs)
  82. # Set the seeds for reproducibility
  83. if is_xpu_available():
  84. torch.xpu.manual_seed(train_config.seed)
  85. torch.manual_seed(train_config.seed)
  86. random.seed(train_config.seed)
  87. np.random.seed(train_config.seed)
  88. if train_config.enable_fsdp:
  89. setup()
  90. # torchrun specific
  91. local_rank = int(os.environ["LOCAL_RANK"])
  92. rank = int(os.environ["RANK"])
  93. world_size = int(os.environ["WORLD_SIZE"])
  94. if torch.distributed.is_initialized():
  95. if is_xpu_available():
  96. torch.xpu.set_device(local_rank)
  97. elif torch.cuda.is_available():
  98. torch.cuda.set_device(local_rank)
  99. clear_gpu_cache(local_rank)
  100. setup_environ_flags(rank)
  101. wandb_run = None
  102. if train_config.use_wandb:
  103. if not train_config.enable_fsdp or rank == 0:
  104. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  105. # setting quantization configs
  106. bnb_config = None
  107. if train_config.quantization:
  108. if type(train_config.quantization) == type(True):
  109. warn(
  110. "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
  111. FutureWarning,
  112. )
  113. train_config.quantization = "8bit"
  114. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  115. raise ValueError(
  116. "8bit quantization is not supported with FSDP, please use 4bit quantization"
  117. )
  118. quant_config = QUANTIZATION_CONFIG()
  119. update_config(quant_config, **kwargs)
  120. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  121. # Load the pre-trained model and setup its configuration
  122. use_cache = False if train_config.enable_fsdp else None
  123. config = AutoConfig.from_pretrained(train_config.model_name)
  124. if config.model_type == "mllama":
  125. is_vision = True
  126. model = MllamaForConditionalGeneration.from_pretrained(
  127. train_config.model_name,
  128. quantization_config=bnb_config,
  129. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  130. device_map=(
  131. "auto"
  132. if train_config.quantization and not train_config.enable_fsdp
  133. else None
  134. ),
  135. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  136. )
  137. processor = AutoProcessor.from_pretrained(
  138. train_config.model_name
  139. if train_config.tokenizer_name is None
  140. else train_config.tokenizer_name
  141. )
  142. processor.tokenizer.padding_side = "right"
  143. model.supports_gradient_checkpointing = True
  144. model.language_model.supports_gradient_checkpointing = True
  145. elif config.model_type == "llama":
  146. is_vision = False
  147. model = LlamaForCausalLM.from_pretrained(
  148. train_config.model_name,
  149. quantization_config=bnb_config,
  150. use_cache=use_cache,
  151. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  152. device_map=(
  153. "auto"
  154. if train_config.quantization and not train_config.enable_fsdp
  155. else None
  156. ),
  157. torch_dtype=torch.float16 if train_config.use_fp16 else torch.bfloat16,
  158. )
  159. else:
  160. raise ValueError(
  161. f"Model type {config.model_type} is not supported. Please use llama or mllama model."
  162. )
  163. # Load the tokenizer and add special tokens
  164. tokenizer = AutoTokenizer.from_pretrained(
  165. train_config.model_name
  166. if train_config.tokenizer_name is None
  167. else train_config.tokenizer_name
  168. )
  169. if not tokenizer.pad_token_id:
  170. tokenizer.pad_token_id = tokenizer.eos_token_id
  171. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  172. # throw a warning and then expand the embedding matrix
  173. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  174. print(
  175. "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
  176. )
  177. model.resize_token_embeddings(len(tokenizer))
  178. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  179. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  180. if (
  181. train_config.enable_fsdp
  182. and fsdp_config.pure_bf16
  183. and not train_config.quantization
  184. ):
  185. model.to(torch.bfloat16)
  186. if train_config.use_peft:
  187. # Load the pre-trained peft model checkpoint and setup its configuration
  188. if train_config.from_peft_checkpoint:
  189. model = PeftModel.from_pretrained(
  190. model, train_config.from_peft_checkpoint, is_trainable=True
  191. )
  192. peft_config = model.peft_config
  193. # Generate the peft config and start fine-tuning from original model
  194. else:
  195. peft_config = generate_peft_config(train_config, kwargs)
  196. model = get_peft_model(model, peft_config)
  197. if wandb_run:
  198. wandb_run.config.update(peft_config)
  199. model.print_trainable_parameters()
  200. hsdp_device_mesh_plan = None
  201. if (
  202. fsdp_config.hsdp
  203. and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
  204. ):
  205. hsdp_device_mesh_plan = hsdp_device_mesh(
  206. replica_group_size=fsdp_config.replica_group_size,
  207. sharding_group_size=fsdp_config.sharding_group_size,
  208. )
  209. print("HSDP device mesh is ready")
  210. # setting up FSDP if enable_fsdp is enabled
  211. if train_config.enable_fsdp:
  212. check_fsdp_config(fsdp_config)
  213. if not train_config.use_peft and train_config.freeze_layers:
  214. freeze_transformer_layers(model, train_config.num_freeze_layers)
  215. # print model size and frozen layers after freezing layers
  216. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  217. if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
  218. freeze_LLM_only(model)
  219. # print model size and frozen layers after freezing layers
  220. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  221. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  222. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  223. if is_vision:
  224. my_auto_wrapping_policy = fsdp_auto_wrap_policy(
  225. model,
  226. [
  227. MllamaSelfAttentionDecoderLayer,
  228. MllamaCrossAttentionDecoderLayer,
  229. MllamaVisionEncoderLayer,
  230. ],
  231. )
  232. else:
  233. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  234. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  235. device_id = 0
  236. if is_xpu_available():
  237. device_id = torch.xpu.current_device()
  238. elif torch.cuda.is_available():
  239. device_id = torch.cuda.current_device()
  240. if train_config.freeze_LLM_only:
  241. use_orig_params = True
  242. else:
  243. use_orig_params = False
  244. model = FSDP(
  245. model,
  246. auto_wrap_policy=(
  247. my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
  248. ),
  249. cpu_offload=(
  250. CPUOffload(offload_params=True)
  251. if fsdp_config.fsdp_cpu_offload
  252. else None
  253. ),
  254. mixed_precision=(
  255. mixed_precision_policy if not fsdp_config.pure_bf16 else None
  256. ),
  257. sharding_strategy=fsdp_config.sharding_strategy,
  258. device_mesh=hsdp_device_mesh_plan,
  259. device_id=device_id,
  260. limit_all_gathers=True,
  261. sync_module_states=train_config.low_cpu_fsdp,
  262. param_init_fn=(
  263. (
  264. lambda module: module.to_empty(
  265. device=torch.device("cuda"), recurse=False
  266. )
  267. )
  268. if train_config.low_cpu_fsdp and rank != 0
  269. else None
  270. ),
  271. use_orig_params=use_orig_params,
  272. )
  273. if fsdp_config.fsdp_activation_checkpointing:
  274. model.enable_input_require_grads()
  275. model.gradient_checkpointing_enable()
  276. apply_fsdp_checkpointing(model)
  277. elif not train_config.quantization and not train_config.enable_fsdp:
  278. if is_xpu_available():
  279. model.to("xpu:0")
  280. elif torch.cuda.is_available():
  281. model.to("cuda")
  282. dataset_config = generate_dataset_config(train_config, kwargs)
  283. if is_vision:
  284. dataset_processer = processor
  285. else:
  286. dataset_processer = tokenizer
  287. # Load and preprocess the dataset for training and validation
  288. dataset_train = get_preprocessed_dataset(
  289. dataset_processer,
  290. dataset_config,
  291. split="train",
  292. )
  293. if not train_config.enable_fsdp or rank == 0:
  294. print(f"--> Training Set Length = {len(dataset_train)}")
  295. dataset_val = get_preprocessed_dataset(
  296. dataset_processer,
  297. dataset_config,
  298. split="test",
  299. )
  300. if not train_config.enable_fsdp or rank == 0:
  301. print(f"--> Validation Set Length = {len(dataset_val)}")
  302. if train_config.batching_strategy == "packing":
  303. if is_vision:
  304. raise ValueError("Packing is not supported for vision datasets")
  305. else:
  306. dataset_train = ConcatDataset(
  307. dataset_train, chunk_size=train_config.context_length
  308. )
  309. train_dl_kwargs = get_dataloader_kwargs(
  310. train_config, dataset_train, dataset_processer, "train"
  311. )
  312. print("length of dataset_train", len(dataset_train))
  313. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
  314. if custom_data_collator:
  315. print("custom_data_collator is used")
  316. train_dl_kwargs["collate_fn"] = custom_data_collator
  317. # Create DataLoaders for the training and validation dataset
  318. train_dataloader = torch.utils.data.DataLoader(
  319. dataset_train,
  320. num_workers=train_config.num_workers_dataloader,
  321. pin_memory=True,
  322. **train_dl_kwargs,
  323. )
  324. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  325. eval_dataloader = None
  326. if train_config.run_validation:
  327. if train_config.batching_strategy == "packing":
  328. if is_vision:
  329. raise ValueError("Packing is not supported for vision datasets")
  330. else:
  331. dataset_val = ConcatDataset(
  332. dataset_val, chunk_size=train_config.context_length
  333. )
  334. val_dl_kwargs = get_dataloader_kwargs(
  335. train_config, dataset_val, dataset_processer, "val"
  336. )
  337. if custom_data_collator:
  338. val_dl_kwargs["collate_fn"] = custom_data_collator
  339. eval_dataloader = torch.utils.data.DataLoader(
  340. dataset_val,
  341. num_workers=train_config.num_workers_dataloader,
  342. pin_memory=True,
  343. **val_dl_kwargs,
  344. )
  345. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  346. if len(eval_dataloader) == 0:
  347. raise ValueError(
  348. f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
  349. )
  350. else:
  351. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  352. # Initialize the optimizer and learning rate scheduler
  353. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  354. optimizer = AnyPrecisionAdamW(
  355. model.parameters(),
  356. lr=train_config.lr,
  357. momentum_dtype=torch.bfloat16,
  358. variance_dtype=torch.bfloat16,
  359. use_kahan_summation=False,
  360. weight_decay=train_config.weight_decay,
  361. )
  362. else:
  363. optimizer = optim.AdamW(
  364. model.parameters(),
  365. lr=train_config.lr,
  366. weight_decay=train_config.weight_decay,
  367. )
  368. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  369. results = train(
  370. model,
  371. train_dataloader,
  372. eval_dataloader,
  373. tokenizer,
  374. optimizer,
  375. scheduler,
  376. train_config.gradient_accumulation_steps,
  377. train_config,
  378. fsdp_config if train_config.enable_fsdp else None,
  379. local_rank if train_config.enable_fsdp else None,
  380. rank if train_config.enable_fsdp else None,
  381. wandb_run,
  382. )
  383. if not train_config.enable_fsdp or rank == 0:
  384. [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
  385. if train_config.use_wandb:
  386. for k, v in results.items():
  387. wandb_run.summary[k] = v
  388. if __name__ == "__main__":
  389. fire.Fire(main)