finetuning.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import dataclasses
  4. import os
  5. import random
  6. from collections import Counter
  7. from warnings import warn
  8. import fire
  9. import numpy as np
  10. import torch
  11. import torch.optim as optim
  12. from accelerate.utils import is_xpu_available
  13. from llama_cookbook.configs import (
  14. fsdp_config as FSDP_CONFIG,
  15. quantization_config as QUANTIZATION_CONFIG,
  16. train_config as TRAIN_CONFIG,
  17. )
  18. from llama_cookbook.data.concatenator import ConcatDataset
  19. from llama_cookbook.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
  20. from llama_cookbook.utils import fsdp_auto_wrap_policy
  21. from llama_cookbook.utils.config_utils import (
  22. check_fsdp_config,
  23. generate_dataset_config,
  24. generate_peft_config,
  25. get_dataloader_kwargs,
  26. update_config,
  27. )
  28. from llama_cookbook.utils.dataset_utils import (
  29. get_custom_data_collator,
  30. get_preprocessed_dataset,
  31. )
  32. from llama_cookbook.utils.fsdp_utils import hsdp_device_mesh, get_policies
  33. from llama_cookbook.utils.train_utils import (
  34. clear_gpu_cache,
  35. freeze_transformer_layers,
  36. freeze_LLM_only,
  37. print_model_size,
  38. print_frozen_model_status,
  39. setup,
  40. setup_environ_flags,
  41. train,
  42. )
  43. from peft import get_peft_model, PeftModel
  44. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
  45. from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
  46. from torch.optim.lr_scheduler import StepLR
  47. from transformers import (
  48. AutoConfig,
  49. AutoProcessor,
  50. AutoTokenizer,
  51. BitsAndBytesConfig,
  52. LlamaForCausalLM,
  53. MllamaForConditionalGeneration,
  54. )
  55. from transformers.models.llama.modeling_llama import LlamaDecoderLayer
  56. from transformers.models.mllama.modeling_mllama import (
  57. MllamaCrossAttentionDecoderLayer,
  58. MllamaSelfAttentionDecoderLayer,
  59. MllamaVisionEncoderLayer,
  60. )
  61. def setup_wandb(train_config, fsdp_config, **kwargs):
  62. try:
  63. import wandb
  64. except ImportError:
  65. raise ImportError(
  66. "You are trying to use wandb which is not currently installed. "
  67. "Please install it using pip install wandb"
  68. )
  69. from llama_cookbook.configs import wandb_config as WANDB_CONFIG
  70. wandb_config = WANDB_CONFIG()
  71. update_config(wandb_config, **kwargs)
  72. init_dict = dataclasses.asdict(wandb_config)
  73. run = wandb.init(**init_dict)
  74. run.config.update(train_config)
  75. run.config.update(fsdp_config, allow_val_change=True)
  76. return run
  77. def main(**kwargs):
  78. # Update the configuration for the training and sharding process
  79. train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
  80. update_config((train_config, fsdp_config), **kwargs)
  81. # Set the seeds for reproducibility
  82. if is_xpu_available():
  83. torch.xpu.manual_seed(train_config.seed)
  84. torch.manual_seed(train_config.seed)
  85. random.seed(train_config.seed)
  86. np.random.seed(train_config.seed)
  87. if train_config.enable_fsdp:
  88. setup()
  89. # torchrun specific
  90. local_rank = int(os.environ["LOCAL_RANK"])
  91. rank = int(os.environ["RANK"])
  92. world_size = int(os.environ["WORLD_SIZE"])
  93. if torch.distributed.is_initialized():
  94. if is_xpu_available():
  95. torch.xpu.set_device(local_rank)
  96. elif torch.cuda.is_available():
  97. torch.cuda.set_device(local_rank)
  98. clear_gpu_cache(local_rank)
  99. setup_environ_flags(rank)
  100. wandb_run = None
  101. if train_config.use_wandb:
  102. if not train_config.enable_fsdp or rank == 0:
  103. wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
  104. # setting quantization configs
  105. bnb_config = None
  106. if train_config.quantization:
  107. if type(train_config.quantization) == type(True):
  108. warn(
  109. "Quantization (--quantization) is a boolean, please specify quantization as '4bit' or '8bit'. Defaulting to '8bit' but this might change in the future.",
  110. FutureWarning,
  111. )
  112. train_config.quantization = "8bit"
  113. if train_config.quantization == "8bit" and train_config.enable_fsdp:
  114. raise ValueError(
  115. "8bit quantization is not supported with FSDP, please use 4bit quantization"
  116. )
  117. quant_config = QUANTIZATION_CONFIG()
  118. update_config(quant_config, **kwargs)
  119. bnb_config = quant_config.create_bnb_config(train_config.quantization)
  120. # Load the pre-trained model and setup its configuration
  121. use_cache = False if train_config.enable_fsdp else None
  122. config = AutoConfig.from_pretrained(train_config.model_name)
  123. if config.model_type == "mllama":
  124. is_vision = True
  125. model = MllamaForConditionalGeneration.from_pretrained(
  126. train_config.model_name,
  127. quantization_config=bnb_config,
  128. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  129. device_map=(
  130. "auto"
  131. if train_config.quantization and not train_config.enable_fsdp
  132. else None
  133. ),
  134. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  135. )
  136. processor = AutoProcessor.from_pretrained(
  137. train_config.model_name
  138. if train_config.tokenizer_name is None
  139. else train_config.tokenizer_name
  140. )
  141. processor.tokenizer.padding_side = "right"
  142. model.supports_gradient_checkpointing = True
  143. model.language_model.supports_gradient_checkpointing = True
  144. elif config.model_type == "llama":
  145. is_vision = False
  146. model = LlamaForCausalLM.from_pretrained(
  147. train_config.model_name,
  148. quantization_config=bnb_config,
  149. use_cache=use_cache,
  150. attn_implementation="sdpa" if train_config.use_fast_kernels else None,
  151. device_map=(
  152. "auto"
  153. if train_config.quantization and not train_config.enable_fsdp
  154. else None
  155. ),
  156. torch_dtype=torch.float16 if train_config.use_fp16 else "auto",
  157. )
  158. else:
  159. raise ValueError(
  160. f"Model type {config.model_type} is not supported. Please use llama or mllama model."
  161. )
  162. # Load the tokenizer and add special tokens
  163. tokenizer = AutoTokenizer.from_pretrained(
  164. train_config.model_name
  165. if train_config.tokenizer_name is None
  166. else train_config.tokenizer_name
  167. )
  168. if not tokenizer.pad_token_id:
  169. tokenizer.pad_token_id = tokenizer.eos_token_id
  170. # If there is a mismatch between tokenizer vocab size and embedding matrix,
  171. # throw a warning and then expand the embedding matrix
  172. if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
  173. print(
  174. "WARNING: Resizing the embedding matrix to match the tokenizer vocab size."
  175. )
  176. model.resize_token_embeddings(len(tokenizer))
  177. print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
  178. # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled
  179. if (
  180. train_config.enable_fsdp
  181. and fsdp_config.pure_bf16
  182. and not train_config.quantization
  183. ):
  184. model.to(torch.bfloat16)
  185. if train_config.use_peft:
  186. # Load the pre-trained peft model checkpoint and setup its configuration
  187. if train_config.from_peft_checkpoint:
  188. model = PeftModel.from_pretrained(
  189. model, train_config.from_peft_checkpoint, is_trainable=True
  190. )
  191. peft_config = model.peft_config
  192. # Generate the peft config and start fine-tuning from original model
  193. else:
  194. peft_config = generate_peft_config(train_config, kwargs)
  195. model = get_peft_model(model, peft_config)
  196. if wandb_run:
  197. wandb_run.config.update(peft_config)
  198. model.print_trainable_parameters()
  199. hsdp_device_mesh_plan = None
  200. if (
  201. fsdp_config.hsdp
  202. and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD
  203. ):
  204. hsdp_device_mesh_plan = hsdp_device_mesh(
  205. replica_group_size=fsdp_config.replica_group_size,
  206. sharding_group_size=fsdp_config.sharding_group_size,
  207. )
  208. print("HSDP device mesh is ready")
  209. # setting up FSDP if enable_fsdp is enabled
  210. if train_config.enable_fsdp:
  211. check_fsdp_config(fsdp_config)
  212. if not train_config.use_peft and train_config.freeze_layers:
  213. freeze_transformer_layers(model, train_config.num_freeze_layers)
  214. # print model size and frozen layers after freezing layers
  215. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  216. if not train_config.use_peft and train_config.freeze_LLM_only and config.model_type == "mllama":
  217. freeze_LLM_only(model)
  218. # print model size and frozen layers after freezing layers
  219. print_frozen_model_status(model, train_config, rank if train_config.enable_fsdp else 0)
  220. mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
  221. # Create the FSDP wrapper for MllamaSelfAttentionDecoderLayer,MllamaCrossAttentionDecoderLayer,MllamaVisionEncoderLayer in vision models
  222. if is_vision:
  223. my_auto_wrapping_policy = fsdp_auto_wrap_policy(
  224. model,
  225. [
  226. MllamaSelfAttentionDecoderLayer,
  227. MllamaCrossAttentionDecoderLayer,
  228. MllamaVisionEncoderLayer,
  229. ],
  230. )
  231. else:
  232. # Create the FSDP wrapper for LlamaDecoderLayer in text models
  233. my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, [LlamaDecoderLayer])
  234. device_id = 0
  235. if is_xpu_available():
  236. device_id = torch.xpu.current_device()
  237. elif torch.cuda.is_available():
  238. device_id = torch.cuda.current_device()
  239. if train_config.freeze_LLM_only:
  240. use_orig_params = True
  241. else:
  242. use_orig_params = False
  243. model = FSDP(
  244. model,
  245. auto_wrap_policy=(
  246. my_auto_wrapping_policy if train_config.use_peft else wrapping_policy
  247. ),
  248. cpu_offload=(
  249. CPUOffload(offload_params=True)
  250. if fsdp_config.fsdp_cpu_offload
  251. else None
  252. ),
  253. mixed_precision=(
  254. mixed_precision_policy if not fsdp_config.pure_bf16 else None
  255. ),
  256. sharding_strategy=fsdp_config.sharding_strategy,
  257. device_mesh=hsdp_device_mesh_plan,
  258. device_id=device_id,
  259. limit_all_gathers=True,
  260. sync_module_states=train_config.low_cpu_fsdp,
  261. param_init_fn=(
  262. (
  263. lambda module: module.to_empty(
  264. device=torch.device("cuda"), recurse=False
  265. )
  266. )
  267. if train_config.low_cpu_fsdp and rank != 0
  268. else None
  269. ),
  270. use_orig_params=use_orig_params,
  271. )
  272. if fsdp_config.fsdp_activation_checkpointing:
  273. model.enable_input_require_grads()
  274. model.gradient_checkpointing_enable()
  275. apply_fsdp_checkpointing(model)
  276. elif not train_config.quantization and not train_config.enable_fsdp:
  277. if is_xpu_available():
  278. model.to("xpu:0")
  279. elif torch.cuda.is_available():
  280. model.to("cuda")
  281. dataset_config = generate_dataset_config(train_config, kwargs)
  282. if is_vision:
  283. dataset_processer = processor
  284. else:
  285. dataset_processer = tokenizer
  286. # Load and preprocess the dataset for training and validation
  287. dataset_train = get_preprocessed_dataset(
  288. dataset_processer,
  289. dataset_config,
  290. split="train",
  291. )
  292. if not train_config.enable_fsdp or rank == 0:
  293. print(f"--> Training Set Length = {len(dataset_train)}")
  294. dataset_val = get_preprocessed_dataset(
  295. dataset_processer,
  296. dataset_config,
  297. split="test",
  298. )
  299. if not train_config.enable_fsdp or rank == 0:
  300. print(f"--> Validation Set Length = {len(dataset_val)}")
  301. if train_config.batching_strategy == "packing":
  302. if is_vision:
  303. raise ValueError("Packing is not supported for vision datasets")
  304. else:
  305. dataset_train = ConcatDataset(
  306. dataset_train, chunk_size=train_config.context_length
  307. )
  308. train_dl_kwargs = get_dataloader_kwargs(
  309. train_config, dataset_train, dataset_processer, "train"
  310. )
  311. print("length of dataset_train", len(dataset_train))
  312. custom_data_collator = get_custom_data_collator(dataset_processer, dataset_config)
  313. if custom_data_collator:
  314. print("custom_data_collator is used")
  315. train_dl_kwargs["collate_fn"] = custom_data_collator
  316. # Create DataLoaders for the training and validation dataset
  317. train_dataloader = torch.utils.data.DataLoader(
  318. dataset_train,
  319. num_workers=train_config.num_workers_dataloader,
  320. pin_memory=True,
  321. **train_dl_kwargs,
  322. )
  323. print(f"--> Num of Training Set Batches loaded = {len(train_dataloader)}")
  324. eval_dataloader = None
  325. if train_config.run_validation:
  326. if train_config.batching_strategy == "packing":
  327. if is_vision:
  328. raise ValueError("Packing is not supported for vision datasets")
  329. else:
  330. dataset_val = ConcatDataset(
  331. dataset_val, chunk_size=train_config.context_length
  332. )
  333. val_dl_kwargs = get_dataloader_kwargs(
  334. train_config, dataset_val, dataset_processer, "val"
  335. )
  336. if custom_data_collator:
  337. val_dl_kwargs["collate_fn"] = custom_data_collator
  338. eval_dataloader = torch.utils.data.DataLoader(
  339. dataset_val,
  340. num_workers=train_config.num_workers_dataloader,
  341. pin_memory=True,
  342. **val_dl_kwargs,
  343. )
  344. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  345. if len(eval_dataloader) == 0:
  346. raise ValueError(
  347. f"The eval set size is too small for dataloader to load even one batch. Please increase the size of eval set. ({len(eval_dataloader)=})"
  348. )
  349. else:
  350. print(f"--> Num of Validation Set Batches loaded = {len(eval_dataloader)}")
  351. # Initialize the optimizer and learning rate scheduler
  352. if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision":
  353. optimizer = AnyPrecisionAdamW(
  354. model.parameters(),
  355. lr=train_config.lr,
  356. momentum_dtype=torch.bfloat16,
  357. variance_dtype=torch.bfloat16,
  358. use_kahan_summation=False,
  359. weight_decay=train_config.weight_decay,
  360. )
  361. else:
  362. optimizer = optim.AdamW(
  363. model.parameters(),
  364. lr=train_config.lr,
  365. weight_decay=train_config.weight_decay,
  366. )
  367. scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
  368. results = train(
  369. model,
  370. train_dataloader,
  371. eval_dataloader,
  372. tokenizer,
  373. optimizer,
  374. scheduler,
  375. train_config.gradient_accumulation_steps,
  376. train_config,
  377. fsdp_config if train_config.enable_fsdp else None,
  378. local_rank if train_config.enable_fsdp else None,
  379. rank if train_config.enable_fsdp else None,
  380. wandb_run,
  381. )
  382. if not train_config.enable_fsdp or rank == 0:
  383. [print(f"Key: {k}, Value: {v}") for k, v in results.items()]
  384. if train_config.use_wandb:
  385. for k, v in results.items():
  386. wandb_run.summary[k] = v
  387. if __name__ == "__main__":
  388. fire.Fire(main)