finetuning.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import dataclasses
  4. import os
  5. import random
  6. from collections import Counter
  7. from warnings import warn
  8. import fire
  9. import numpy as np
  10. import torch
  11. import torch.optim as optim
  12. from accelerate.utils import is_xpu_available
  13. from llama_cookbook.configs import (
  14. fsdp_config as FSDP_CONFIG,
  15. quantization_config as QUANTIZATION_CONFIG,
  16. train_config as TRAIN_CONFIG,
  17. )
  18. from llama_cookbook.data.concatenator import ConcatDataset
  19. from llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  20. from llama_cookbook.utils import fsdp_auto_wrap_policy
  21. from llama_cookbook.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. from llama_cookbook.utils.dataset_utils import (
  29. get_custom_data_collator,
  30. get_preprocessed_dataset,
  31. )
  32. from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh, get_policies
  33. from llama_cookbook.utils.train_utils import (
  34. clear_gpu_cache,
  35. freeze_transformer_layers,
  36. freeze_LLM_only,
  37. print_model_size,
  38. print_frozen_model_status,
  39. setup,
  40. setup_environ_flags,
  41. train,
  42. )
  43. from peft import get_peft_model, PeftModel
  44. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
  45. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  46. from torch.optim.lr_scheduler import StepLR
  47. from transformers import (
  48. AutoConfig,
  49. AutoProcessor,
  50. AutoTokenizer,
  51. BitsAndBytesConfig,
  52. LlamaForCausalLM,
  53. MllamaForConditionalGeneration,
  54. )
  55. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  56. from transformers.models.mllama.modeling_mllama import (
  57. MllamaCrossAttentionDecoderLayer,
  58. MllamaSelfAttentionDecoderLayer,
  59. MllamaVisionEncoderLayer,
  60. )
  61. def setup_wandb(train_config, fsdp_config, **kwargs):
  62. try:
  63. import wandb
  64. except ImportError:
  65. raise ImportError(
  66. "You are trying to use wandb which is not currently installed. "
  67. "Please install it using pip install wandb"
  68. )
  69. from llama_cookbook.configs import wandb_config as WANDB_CONFIG
  70. wandb_config = WANDB_CONFIG()
  71. update_config(wandb_config, **kwargs)
  72. init_dict = dataclasses.asdict(wandb_config)
  73. run = wandb.init(**init_dict)
  74. run.config.update(train_config)
  75. run.config.update(fsdp_config, allow_val_change=True)
  76. return run
  77. def main(**kwargs):
  78. # Update the configuration for the training and sharding process
  79. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  80. update_config((train_config, fsdp_config), **kwargs)
  81. # Set the seeds for reproducibility
  82. if is_xpu_available():
  83. torch.xpu.manual_seed(train_config.seed)
  84. torch.manual_seed(train_config.seed)
  85. random.seed(train_config.seed)
  86. np.random.seed(train_config.seed)
  87. if train_config.enable_fsdp:
  88. setup()
  89. # torchrun specific
  90. local_rank = int(os.environ["LOCAL_RANK"])
  91. rank = int(os.environ["RANK"])
  92. world_size = int(os.environ["WORLD_SIZE"])
  93. if torch.distributed.is_initialized():
  94. if is_xpu_available():
  95. torch.xpu.set_device(local_rank)
  96. elif torch.cuda.is_available():
  97. torch.cuda.set_device(local_rank)
  98. clear_gpu_cache(local_rank)
  99. setup_environ_flags(rank)
  100. wandb_run = None
  101. if train_config.use_wandb:
  102. if not train_config.enable_fsdp or rank == 0:
  103. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  104. # setting quantization configs
  105. bnb_config = None
  106. if train_config.quantization:
  107. if type(train_config.quantization) == type(True):
  108. warn(
  109. "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
  110. FutureWarning,
  111. )
  112. train_config.quantization = "8bit"
  113. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  114. raise ValueError(
  115. "8bit quantization is not supported with FSDP, please use 4bit quantization"
  116. )
  117. quant_config = QUANTIZATION_CONFIG()
  118. update_config(quant_config, **kwargs)
  119. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  120. if train_config.enable_fsdp:
  121. if train_config.quantization == "4bit":
  122. bnb_config.bnb_4bit_quant_storage = bnb_config.bnb_4bit_compute_dtype
  123. from logging import getLogger
  124. logger = getLogger()
  125. logger.warning(
  126. "FSDP and 4-bit QLoRA enabled. Setting `bnb_4bit_quant_storage` "
  127. f"to {bnb_config.bnb_4bit_compute_dtype} for compatibility."
  128. )
  129. # Load the pre-trained model and setup its configuration
  130. use_cache = False if train_config.enable_fsdp else None
  131. config = AutoConfig.from_pretrained(train_config.model_name)
  132. if config.model_type == "mllama":
  133. is_vision = True
  134. model = MllamaForConditionalGeneration.from_pretrained(
  135. train_config.model_name,
  136. quantization_config=bnb_config,
  137. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  138. device_map=(
  139. "auto"
  140. if train_config.quantization and not train_config.enable_fsdp
  141. else None
  142. ),
  143. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  144. )
  145. processor = AutoProcessor.from_pretrained(
  146. train_config.model_name
  147. if train_config.tokenizer_name is None
  148. else train_config.tokenizer_name
  149. )
  150. processor.tokenizer.padding_side = "right"
  151. model.supports_gradient_checkpointing = True
  152. model.language_model.supports_gradient_checkpointing = True
  153. elif config.model_type == "llama":
  154. is_vision = False
  155. model = LlamaForCausalLM.from_pretrained(
  156. train_config.model_name,
  157. quantization_config=bnb_config,
  158. use_cache=use_cache,
  159. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  160. device_map=(
  161. "auto"
  162. if train_config.quantization and not train_config.enable_fsdp
  163. else None
  164. ),
  165. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  166. )
  167. else:
  168. raise ValueError(
  169. f"Model type {config.model_type} is not supported. Please use llama or mllama model."
  170. )
  171. # Load the tokenizer and add special tokens
  172. tokenizer = AutoTokenizer.from_pretrained(
  173. train_config.model_name
  174. if train_config.tokenizer_name is None
  175. else train_config.tokenizer_name
  176. )
  177. if not tokenizer.pad_token_id:
  178. tokenizer.pad_token_id = tokenizer.eos_token_id
  179. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  180. # throw a warning and then expand the embedding matrix
  181. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  182. print(
  183. "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
  184. )
  185. model.resize_token_embeddings(len(tokenizer))
  186. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  187. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  188. if (
  189. train_config.enable_fsdp
  190. and fsdp_config.pure_bf16
  191. and not train_config.quantization
  192. ):
  193. model.to(torch.bfloat16)
  194. if train_config.use_peft:
  195. # Load the pre-trained peft model checkpoint and setup its configuration
  196. if train_config.from_peft_checkpoint:
  197. model = PeftModel.from_pretrained(
  198. model, train_config.from_peft_checkpoint, is_trainable=True
  199. )
  200. peft_config = model.peft_config
  201. # Generate the peft config and start fine-tuning from original model
  202. else:
  203. peft_config = generate_peft_config(train_config, kwargs)
  204. model = get_peft_model(model, peft_config)
  205. if wandb_run:
  206. wandb_run.config.update(peft_config)
  207. model.print_trainable_parameters()
  208. hsdp_device_mesh_plan = None
  209. if (
  210. fsdp_config.hsdp
  211. and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
  212. ):
  213. hsdp_device_mesh_plan = hsdp_device_mesh(
  214. replica_group_size=fsdp_config.replica_group_size,
  215. sharding_group_size=fsdp_config.sharding_group_size,
  216. )
  217. print("HSDP device mesh is ready")
  218. # setting up FSDP if enable_fsdp is enabled
  219. if train_config.enable_fsdp:
  220. check_fsdp_config(fsdp_config)
  221. if not train_config.use_peft and train_config.freeze_layers:
  222. freeze_transformer_layers(model, train_config.num_freeze_layers)
  223. # print model size and frozen layers after freezing layers
  224. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  225. if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
  226. freeze_LLM_only(model)
  227. # print model size and frozen layers after freezing layers
  228. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  229. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  230. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  231. if is_vision:
  232. my_auto_wrapping_policy = fsdp_auto_wrap_policy(
  233. model,
  234. [
  235. MllamaSelfAttentionDecoderLayer,
  236. MllamaCrossAttentionDecoderLayer,
  237. MllamaVisionEncoderLayer,
  238. ],
  239. )
  240. else:
  241. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  242. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  243. device_id = 0
  244. if is_xpu_available():
  245. device_id = torch.xpu.current_device()
  246. elif torch.cuda.is_available():
  247. device_id = torch.cuda.current_device()
  248. use_orig_params = train_config.freeze_LLM_only or train_config.use_peft
  249. model = FSDP(
  250. model,
  251. auto_wrap_policy=(
  252. my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
  253. ),
  254. cpu_offload=(
  255. CPUOffload(offload_params=True)
  256. if fsdp_config.fsdp_cpu_offload
  257. else None
  258. ),
  259. mixed_precision=(
  260. mixed_precision_policy if not fsdp_config.pure_bf16 else None
  261. ),
  262. sharding_strategy=fsdp_config.sharding_strategy,
  263. device_mesh=hsdp_device_mesh_plan,
  264. device_id=device_id,
  265. limit_all_gathers=True,
  266. sync_module_states=train_config.low_cpu_fsdp,
  267. param_init_fn=(
  268. (
  269. lambda module: module.to_empty(
  270. device=torch.device("cuda"), recurse=False
  271. )
  272. )
  273. if train_config.low_cpu_fsdp and rank != 0
  274. else None
  275. ),
  276. use_orig_params=use_orig_params,
  277. )
  278. if fsdp_config.fsdp_activation_checkpointing:
  279. model.enable_input_require_grads()
  280. model.gradient_checkpointing_enable()
  281. apply_fsdp_checkpointing(model)
  282. elif not train_config.quantization and not train_config.enable_fsdp:
  283. if is_xpu_available():
  284. model.to("xpu:0")
  285. elif torch.cuda.is_available():
  286. model.to("cuda")
  287. dataset_config = generate_dataset_config(train_config, kwargs)
  288. if is_vision:
  289. dataset_processer = processor
  290. else:
  291. dataset_processer = tokenizer
  292. # Load and preprocess the dataset for training and validation
  293. dataset_train = get_preprocessed_dataset(
  294. dataset_processer,
  295. dataset_config,
  296. split="train",
  297. )
  298. if not train_config.enable_fsdp or rank == 0:
  299. print(f"--> Training Set Length = {len(dataset_train)}")
  300. dataset_val = get_preprocessed_dataset(
  301. dataset_processer,
  302. dataset_config,
  303. split="test",
  304. )
  305. if not train_config.enable_fsdp or rank == 0:
  306. print(f"--> Validation Set Length = {len(dataset_val)}")
  307. if train_config.batching_strategy == "packing":
  308. if is_vision:
  309. raise ValueError("Packing is not supported for vision datasets")
  310. else:
  311. dataset_train = ConcatDataset(
  312. dataset_train, chunk_size=train_config.context_length
  313. )
  314. train_dl_kwargs = get_dataloader_kwargs(
  315. train_config, dataset_train, dataset_processer, "train"
  316. )
  317. print("length of dataset_train", len(dataset_train))
  318. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
  319. if custom_data_collator:
  320. print("custom_data_collator is used")
  321. train_dl_kwargs["collate_fn"] = custom_data_collator
  322. # Create DataLoaders for the training and validation dataset
  323. train_dataloader = torch.utils.data.DataLoader(
  324. dataset_train,
  325. num_workers=train_config.num_workers_dataloader,
  326. pin_memory=True,
  327. **train_dl_kwargs,
  328. )
  329. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  330. eval_dataloader = None
  331. if train_config.run_validation:
  332. if train_config.batching_strategy == "packing":
  333. if is_vision:
  334. raise ValueError("Packing is not supported for vision datasets")
  335. else:
  336. dataset_val = ConcatDataset(
  337. dataset_val, chunk_size=train_config.context_length
  338. )
  339. val_dl_kwargs = get_dataloader_kwargs(
  340. train_config, dataset_val, dataset_processer, "val"
  341. )
  342. if custom_data_collator:
  343. val_dl_kwargs["collate_fn"] = custom_data_collator
  344. eval_dataloader = torch.utils.data.DataLoader(
  345. dataset_val,
  346. num_workers=train_config.num_workers_dataloader,
  347. pin_memory=True,
  348. **val_dl_kwargs,
  349. )
  350. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  351. if len(eval_dataloader) == 0:
  352. raise ValueError(
  353. f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
  354. )
  355. else:
  356. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  357. # Initialize the optimizer and learning rate scheduler
  358. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  359. optimizer = AnyPrecisionAdamW(
  360. model.parameters(),
  361. lr=train_config.lr,
  362. momentum_dtype=torch.bfloat16,
  363. variance_dtype=torch.bfloat16,
  364. use_kahan_summation=False,
  365. weight_decay=train_config.weight_decay,
  366. )
  367. else:
  368. optimizer = optim.AdamW(
  369. model.parameters(),
  370. lr=train_config.lr,
  371. weight_decay=train_config.weight_decay,
  372. )
  373. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  374. results = train(
  375. model,
  376. train_dataloader,
  377. eval_dataloader,
  378. tokenizer,
  379. optimizer,
  380. scheduler,
  381. train_config.gradient_accumulation_steps,
  382. train_config,
  383. fsdp_config if train_config.enable_fsdp else None,
  384. local_rank if train_config.enable_fsdp else None,
  385. rank if train_config.enable_fsdp else None,
  386. wandb_run,
  387. )
  388. if not train_config.enable_fsdp or rank == 0:
  389. [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
  390. if train_config.use_wandb:
  391. for k, v in results.items():
  392. wandb_run.summary[k] = v
  393. if __name__ == "__main__":
  394. fire.Fire(main)