finetuning.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import dataclasses
  4. import os
  5. import random
  6. from collections import Counter
  7. from warnings import warn
  8. import fire
  9. import numpy as np
  10. import torch
  11. import torch.optim as optim
  12. from accelerate.utils import is_xpu_available
  13. from llama_cookbook.configs import (
  14. fsdp_config as FSDP_CONFIG,
  15. quantization_config as QUANTIZATION_CONFIG,
  16. train_config as TRAIN_CONFIG,
  17. )
  18. from llama_cookbook.data.concatenator import ConcatDataset
  19. from llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  20. from llama_cookbook.utils import fsdp_auto_wrap_policy
  21. from llama_cookbook.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. from llama_cookbook.utils.dataset_utils import (
  29. get_custom_data_collator,
  30. get_preprocessed_dataset,
  31. )
  32. from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh
  33. from llama_cookbook.utils.train_utils import (
  34. clear_gpu_cache,
  35. freeze_transformer_layers,
  36. freeze_LLM_only,
  37. get_policies,
  38. hsdp_device_mesh,
  39. print_model_size,
  40. print_frozen_model_status,
  41. setup,
  42. setup_environ_flags,
  43. train,
  44. )
  45. from peft import get_peft_model, PeftModel
  46. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
  47. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  48. from torch.optim.lr_scheduler import StepLR
  49. from transformers import (
  50. AutoConfig,
  51. AutoProcessor,
  52. AutoTokenizer,
  53. BitsAndBytesConfig,
  54. LlamaForCausalLM,
  55. MllamaForConditionalGeneration,
  56. )
  57. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  58. from transformers.models.mllama.modeling_mllama import (
  59. MllamaCrossAttentionDecoderLayer,
  60. MllamaSelfAttentionDecoderLayer,
  61. MllamaVisionEncoderLayer,
  62. )
  63. def setup_wandb(train_config, fsdp_config, **kwargs):
  64. try:
  65. import wandb
  66. except ImportError:
  67. raise ImportError(
  68. "You are trying to use wandb which is not currently installed. "
  69. "Please install it using pip install wandb"
  70. )
  71. from llama_cookbook.configs import wandb_config as WANDB_CONFIG
  72. wandb_config = WANDB_CONFIG()
  73. update_config(wandb_config, **kwargs)
  74. init_dict = dataclasses.asdict(wandb_config)
  75. run = wandb.init(**init_dict)
  76. run.config.update(train_config)
  77. run.config.update(fsdp_config, allow_val_change=True)
  78. return run
  79. def main(**kwargs):
  80. # Update the configuration for the training and sharding process
  81. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  82. update_config((train_config, fsdp_config), **kwargs)
  83. # Set the seeds for reproducibility
  84. if is_xpu_available():
  85. torch.xpu.manual_seed(train_config.seed)
  86. torch.manual_seed(train_config.seed)
  87. random.seed(train_config.seed)
  88. np.random.seed(train_config.seed)
  89. if train_config.enable_fsdp:
  90. setup()
  91. # torchrun specific
  92. local_rank = int(os.environ["LOCAL_RANK"])
  93. rank = int(os.environ["RANK"])
  94. world_size = int(os.environ["WORLD_SIZE"])
  95. if torch.distributed.is_initialized():
  96. if is_xpu_available():
  97. torch.xpu.set_device(local_rank)
  98. elif torch.cuda.is_available():
  99. torch.cuda.set_device(local_rank)
  100. clear_gpu_cache(local_rank)
  101. setup_environ_flags(rank)
  102. wandb_run = None
  103. if train_config.use_wandb:
  104. if not train_config.enable_fsdp or rank == 0:
  105. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  106. # setting quantization configs
  107. bnb_config = None
  108. if train_config.quantization:
  109. if type(train_config.quantization) == type(True):
  110. warn(
  111. "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
  112. FutureWarning,
  113. )
  114. train_config.quantization = "8bit"
  115. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  116. raise ValueError(
  117. "8bit quantization is not supported with FSDP, please use 4bit quantization"
  118. )
  119. quant_config = QUANTIZATION_CONFIG()
  120. update_config(quant_config, **kwargs)
  121. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  122. # Load the pre-trained model and setup its configuration
  123. use_cache = False if train_config.enable_fsdp else None
  124. config = AutoConfig.from_pretrained(train_config.model_name)
  125. if config.model_type == "mllama":
  126. is_vision = True
  127. model = MllamaForConditionalGeneration.from_pretrained(
  128. train_config.model_name,
  129. quantization_config=bnb_config,
  130. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  131. device_map=(
  132. "auto"
  133. if train_config.quantization and not train_config.enable_fsdp
  134. else None
  135. ),
  136. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  137. )
  138. processor = AutoProcessor.from_pretrained(
  139. train_config.model_name
  140. if train_config.tokenizer_name is None
  141. else train_config.tokenizer_name
  142. )
  143. processor.tokenizer.padding_side = "right"
  144. model.supports_gradient_checkpointing = True
  145. model.language_model.supports_gradient_checkpointing = True
  146. elif config.model_type == "llama":
  147. is_vision = False
  148. model = LlamaForCausalLM.from_pretrained(
  149. train_config.model_name,
  150. quantization_config=bnb_config,
  151. use_cache=use_cache,
  152. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  153. device_map=(
  154. "auto"
  155. if train_config.quantization and not train_config.enable_fsdp
  156. else None
  157. ),
  158. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  159. )
  160. else:
  161. raise ValueError(
  162. f"Model type {config.model_type} is not supported. Please use llama or mllama model."
  163. )
  164. # Load the tokenizer and add special tokens
  165. tokenizer = AutoTokenizer.from_pretrained(
  166. train_config.model_name
  167. if train_config.tokenizer_name is None
  168. else train_config.tokenizer_name
  169. )
  170. if not tokenizer.pad_token_id:
  171. tokenizer.pad_token_id = tokenizer.eos_token_id
  172. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  173. # throw a warning and then expand the embedding matrix
  174. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  175. print(
  176. "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
  177. )
  178. model.resize_token_embeddings(len(tokenizer))
  179. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  180. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  181. if (
  182. train_config.enable_fsdp
  183. and fsdp_config.pure_bf16
  184. and not train_config.quantization
  185. ):
  186. model.to(torch.bfloat16)
  187. if train_config.use_peft:
  188. # Load the pre-trained peft model checkpoint and setup its configuration
  189. if train_config.from_peft_checkpoint:
  190. model = PeftModel.from_pretrained(
  191. model, train_config.from_peft_checkpoint, is_trainable=True
  192. )
  193. peft_config = model.peft_config
  194. # Generate the peft config and start fine-tuning from original model
  195. else:
  196. peft_config = generate_peft_config(train_config, kwargs)
  197. model = get_peft_model(model, peft_config)
  198. if wandb_run:
  199. wandb_run.config.update(peft_config)
  200. model.print_trainable_parameters()
  201. hsdp_device_mesh_plan = None
  202. if (
  203. fsdp_config.hsdp
  204. and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
  205. ):
  206. hsdp_device_mesh_plan = hsdp_device_mesh(
  207. replica_group_size=fsdp_config.replica_group_size,
  208. sharding_group_size=fsdp_config.sharding_group_size,
  209. )
  210. print("HSDP device mesh is ready")
  211. # setting up FSDP if enable_fsdp is enabled
  212. if train_config.enable_fsdp:
  213. check_fsdp_config(fsdp_config)
  214. if not train_config.use_peft and train_config.freeze_layers:
  215. freeze_transformer_layers(model, train_config.num_freeze_layers)
  216. # print model size and frozen layers after freezing layers
  217. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  218. if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
  219. freeze_LLM_only(model)
  220. # print model size and frozen layers after freezing layers
  221. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  222. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  223. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  224. if is_vision:
  225. my_auto_wrapping_policy = fsdp_auto_wrap_policy(
  226. model,
  227. [
  228. MllamaSelfAttentionDecoderLayer,
  229. MllamaCrossAttentionDecoderLayer,
  230. MllamaVisionEncoderLayer,
  231. ],
  232. )
  233. else:
  234. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  235. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  236. device_id = 0
  237. if is_xpu_available():
  238. device_id = torch.xpu.current_device()
  239. elif torch.cuda.is_available():
  240. device_id = torch.cuda.current_device()
  241. if train_config.freeze_LLM_only:
  242. use_orig_params = True
  243. else:
  244. use_orig_params = False
  245. model = FSDP(
  246. model,
  247. auto_wrap_policy=(
  248. my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
  249. ),
  250. cpu_offload=(
  251. CPUOffload(offload_params=True)
  252. if fsdp_config.fsdp_cpu_offload
  253. else None
  254. ),
  255. mixed_precision=(
  256. mixed_precision_policy if not fsdp_config.pure_bf16 else None
  257. ),
  258. sharding_strategy=fsdp_config.sharding_strategy,
  259. device_mesh=hsdp_device_mesh_plan,
  260. device_id=device_id,
  261. limit_all_gathers=True,
  262. sync_module_states=train_config.low_cpu_fsdp,
  263. param_init_fn=(
  264. (
  265. lambda module: module.to_empty(
  266. device=torch.device("cuda"), recurse=False
  267. )
  268. )
  269. if train_config.low_cpu_fsdp and rank != 0
  270. else None
  271. ),
  272. use_orig_params=use_orig_params,
  273. )
  274. if fsdp_config.fsdp_activation_checkpointing:
  275. model.enable_input_require_grads()
  276. model.gradient_checkpointing_enable()
  277. apply_fsdp_checkpointing(model)
  278. elif not train_config.quantization and not train_config.enable_fsdp:
  279. if is_xpu_available():
  280. model.to("xpu:0")
  281. elif torch.cuda.is_available():
  282. model.to("cuda")
  283. dataset_config = generate_dataset_config(train_config, kwargs)
  284. if is_vision:
  285. dataset_processer = processor
  286. else:
  287. dataset_processer = tokenizer
  288. # Load and preprocess the dataset for training and validation
  289. dataset_train = get_preprocessed_dataset(
  290. dataset_processer,
  291. dataset_config,
  292. split="train",
  293. )
  294. if not train_config.enable_fsdp or rank == 0:
  295. print(f"--> Training Set Length = {len(dataset_train)}")
  296. dataset_val = get_preprocessed_dataset(
  297. dataset_processer,
  298. dataset_config,
  299. split="test",
  300. )
  301. if not train_config.enable_fsdp or rank == 0:
  302. print(f"--> Validation Set Length = {len(dataset_val)}")
  303. if train_config.batching_strategy == "packing":
  304. if is_vision:
  305. raise ValueError("Packing is not supported for vision datasets")
  306. else:
  307. dataset_train = ConcatDataset(
  308. dataset_train, chunk_size=train_config.context_length
  309. )
  310. train_dl_kwargs = get_dataloader_kwargs(
  311. train_config, dataset_train, dataset_processer, "train"
  312. )
  313. print("length of dataset_train", len(dataset_train))
  314. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
  315. if custom_data_collator:
  316. print("custom_data_collator is used")
  317. train_dl_kwargs["collate_fn"] = custom_data_collator
  318. # Create DataLoaders for the training and validation dataset
  319. train_dataloader = torch.utils.data.DataLoader(
  320. dataset_train,
  321. num_workers=train_config.num_workers_dataloader,
  322. pin_memory=True,
  323. **train_dl_kwargs,
  324. )
  325. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  326. eval_dataloader = None
  327. if train_config.run_validation:
  328. if train_config.batching_strategy == "packing":
  329. if is_vision:
  330. raise ValueError("Packing is not supported for vision datasets")
  331. else:
  332. dataset_val = ConcatDataset(
  333. dataset_val, chunk_size=train_config.context_length
  334. )
  335. val_dl_kwargs = get_dataloader_kwargs(
  336. train_config, dataset_val, dataset_processer, "val"
  337. )
  338. if custom_data_collator:
  339. val_dl_kwargs["collate_fn"] = custom_data_collator
  340. eval_dataloader = torch.utils.data.DataLoader(
  341. dataset_val,
  342. num_workers=train_config.num_workers_dataloader,
  343. pin_memory=True,
  344. **val_dl_kwargs,
  345. )
  346. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  347. if len(eval_dataloader) == 0:
  348. raise ValueError(
  349. f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
  350. )
  351. else:
  352. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  353. # Initialize the optimizer and learning rate scheduler
  354. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  355. optimizer = AnyPrecisionAdamW(
  356. model.parameters(),
  357. lr=train_config.lr,
  358. momentum_dtype=torch.bfloat16,
  359. variance_dtype=torch.bfloat16,
  360. use_kahan_summation=False,
  361. weight_decay=train_config.weight_decay,
  362. )
  363. else:
  364. optimizer = optim.AdamW(
  365. model.parameters(),
  366. lr=train_config.lr,
  367. weight_decay=train_config.weight_decay,
  368. )
  369. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  370. results = train(
  371. model,
  372. train_dataloader,
  373. eval_dataloader,
  374. tokenizer,
  375. optimizer,
  376. scheduler,
  377. train_config.gradient_accumulation_steps,
  378. train_config,
  379. fsdp_config if train_config.enable_fsdp else None,
  380. local_rank if train_config.enable_fsdp else None,
  381. rank if train_config.enable_fsdp else None,
  382. wandb_run,
  383. )
  384. if not train_config.enable_fsdp or rank == 0:
  385. [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
  386. if train_config.use_wandb:
  387. for k, v in results.items():
  388. wandb_run.summary[k] = v
  389. if __name__ == "__main__":
  390. fire.Fire(main)