train_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. from datetime import datetime
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. import json
  18. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. from accelerate.utils import is_xpu_available, is_ccl_available
  22. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  23. tokenizer.pad_token_id = 0
  24. tokenizer.padding_side = "left"
  25. # Converting Bytes to Megabytes
  26. def byte2mb(x):
  27. return int(x / 2**20)
  28. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
  29. """
  30. Trains the model on the given dataloader
  31. Args:
  32. model: The model to be trained
  33. train_dataloader: The dataloader containing the training data
  34. optimizer: The optimizer used for training
  35. lr_scheduler: The learning rate scheduler
  36. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  37. num_epochs: The number of epochs to train for
  38. local_rank: The rank of the current node in a distributed setting
  39. train_config: The training configuration
  40. eval_dataloader: The dataloader containing the eval data
  41. tokenizer: tokenizer used in the eval for decoding the predicitons
  42. Returns: results dictionary containing average training and validation perplexity and loss
  43. """
  44. # Create a gradient scaler for fp16
  45. if train_config.use_fp16 and train_config.enable_fsdp:
  46. scaler = ShardedGradScaler()
  47. elif train_config.use_fp16 and not train_config.enable_fsdp:
  48. scaler = torch.cuda.amp.GradScaler()
  49. if train_config.enable_fsdp:
  50. world_size = int(os.environ["WORLD_SIZE"])
  51. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  52. train_prep = []
  53. train_loss = []
  54. val_prep = []
  55. val_loss =[]
  56. if train_config.save_metrics:
  57. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  58. train_step_perplexity = []
  59. train_step_loss = []
  60. val_step_loss = []
  61. val_step_perplexity = []
  62. epoch_times = []
  63. checkpoint_times = []
  64. results = {}
  65. best_val_loss = float("inf")
  66. for epoch in range(train_config.num_epochs):
  67. epoch_start_time = time.perf_counter()
  68. with MemoryTrace() as memtrace: # track the memory usage
  69. model.train()
  70. total_loss = 0.0
  71. total_length = len(train_dataloader)//gradient_accumulation_steps
  72. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  73. for step, batch in enumerate(train_dataloader):
  74. for key in batch.keys():
  75. if train_config.enable_fsdp:
  76. if is_xpu_available():
  77. batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
  78. else:
  79. batch[key] = batch[key].to(local_rank)
  80. else:
  81. if is_xpu_available():
  82. batch[key] = batch[key].to('xpu:0')
  83. else:
  84. batch[key] = batch[key].to('cuda:0')
  85. with autocast():
  86. loss = model(**batch).loss
  87. loss = loss / gradient_accumulation_steps
  88. if train_config.save_metrics:
  89. train_step_loss.append(loss.detach().float().item())
  90. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  91. total_loss += loss.detach().float()
  92. if train_config.use_fp16:
  93. # if fp16 is enabled, use gradient scaler to handle gradient update
  94. scaler.scale(loss).backward()
  95. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  96. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  97. scaler.unscale_(optimizer)
  98. if train_config.enable_fsdp:
  99. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  100. else:
  101. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  102. scaler.step(optimizer)
  103. scaler.update()
  104. optimizer.zero_grad()
  105. pbar.update(1)
  106. else:
  107. # regular backpropagation when fp16 is not used
  108. loss.backward()
  109. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  110. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  111. if train_config.enable_fsdp:
  112. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  113. else:
  114. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  115. optimizer.step()
  116. optimizer.zero_grad()
  117. pbar.update(1)
  118. if wandb_run:
  119. if not train_config.enable_fsdp or rank==0:
  120. wandb_run.log({
  121. 'train/epoch': epoch + 1,
  122. 'train/step': epoch * len(train_dataloader) + step,
  123. 'train/loss': loss.detach().float(),
  124. })
  125. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  126. if train_config.save_metrics:
  127. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  128. pbar.close()
  129. epoch_end_time = time.perf_counter()-epoch_start_time
  130. epoch_times.append(epoch_end_time)
  131. # Reducing total_loss across all devices if there's more than one CUDA device
  132. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  133. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  134. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  135. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  136. train_epoch_loss = total_loss / len(train_dataloader)
  137. if train_config.enable_fsdp:
  138. train_epoch_loss = train_epoch_loss/world_size
  139. train_perplexity = torch.exp(train_epoch_loss)
  140. train_prep.append(float(train_perplexity))
  141. train_loss.append(float(train_epoch_loss))
  142. if not train_config.enable_fsdp or rank==0:
  143. memtrace.print_stats()
  144. # Update the learning rate as needed
  145. lr_scheduler.step()
  146. if train_config.run_validation:
  147. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
  148. if train_config.save_metrics:
  149. val_step_loss.extend(temp_val_loss)
  150. val_step_perplexity.extend(temp_step_perplexity)
  151. checkpoint_start_time = time.perf_counter()
  152. if train_config.save_model and eval_epoch_loss < best_val_loss:
  153. if train_config.enable_fsdp:
  154. dist.barrier()
  155. if train_config.use_peft:
  156. if train_config.enable_fsdp:
  157. if rank==0:
  158. print(f"we are about to save the PEFT modules")
  159. else:
  160. print(f"we are about to save the PEFT modules")
  161. model.save_pretrained(train_config.output_dir)
  162. if train_config.enable_fsdp:
  163. if rank==0:
  164. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  165. else:
  166. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  167. else:
  168. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  169. save_model_checkpoint(
  170. model, optimizer, rank, train_config, epoch=epoch
  171. )
  172. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  173. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  174. print("=====================================================")
  175. save_model_and_optimizer_sharded(model, rank, train_config)
  176. if train_config.save_optimizer:
  177. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  178. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  179. print("=====================================================")
  180. if not train_config.use_peft and train_config.save_optimizer:
  181. save_optimizer_checkpoint(
  182. model, optimizer, rank, train_config, epoch=epoch
  183. )
  184. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  185. print("=====================================================")
  186. if train_config.enable_fsdp:
  187. dist.barrier()
  188. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  189. checkpoint_times.append(checkpoint_end_time)
  190. if eval_epoch_loss < best_val_loss:
  191. best_val_loss = eval_epoch_loss
  192. if train_config.enable_fsdp:
  193. if rank==0:
  194. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  195. else:
  196. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  197. val_loss.append(float(best_val_loss))
  198. val_prep.append(float(eval_ppl))
  199. if train_config.enable_fsdp:
  200. if rank==0:
  201. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  202. else:
  203. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  204. # Saving the results every epoch to plot later
  205. if train_config.save_metrics:
  206. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  207. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  208. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  209. avg_train_prep = sum(train_prep)/len(train_prep)
  210. avg_train_loss = sum(train_loss)/len(train_loss)
  211. if train_config.run_validation:
  212. avg_eval_prep = sum(val_prep)/len(val_prep)
  213. avg_eval_loss = sum(val_loss)/len(val_loss)
  214. results['avg_train_prep'] = avg_train_prep
  215. results['avg_train_loss'] = avg_train_loss
  216. if train_config.run_validation:
  217. results['avg_eval_prep'] = avg_eval_prep
  218. results['avg_eval_loss'] = avg_eval_loss
  219. results["avg_epoch_time"] = avg_epoch_time
  220. results["avg_checkpoint_time"] = avg_checkpoint_time
  221. if train_config.save_metrics:
  222. results["metrics_filename"] = metrics_filename
  223. #saving the training params including fsdp setting for reference.
  224. if train_config.enable_fsdp and not train_config.use_peft:
  225. save_train_params(train_config, fsdp_config, rank)
  226. return results
  227. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
  228. """
  229. Evaluates the model on the given dataloader
  230. Args:
  231. model: The model to evaluate
  232. eval_dataloader: The dataloader containing the evaluation data
  233. local_rank: The rank of the current node in a distributed setting
  234. tokenizer: The tokenizer used to decode predictions
  235. Returns: eval_ppl, eval_epoch_loss
  236. """
  237. if train_config.enable_fsdp:
  238. world_size = int(os.environ["WORLD_SIZE"])
  239. model.eval()
  240. eval_preds = []
  241. val_step_loss = []
  242. val_step_perplexity = []
  243. eval_loss = 0.0 # Initialize evaluation loss
  244. with MemoryTrace() as memtrace:
  245. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  246. for key in batch.keys():
  247. if train_config.enable_fsdp:
  248. batch[key] = batch[key].to(local_rank)
  249. else:
  250. if is_xpu_available():
  251. batch[key] = batch[key].to('xpu:0')
  252. else:
  253. batch[key] = batch[key].to('cuda:0')
  254. # Ensure no gradients are computed for this scope to save memory
  255. with torch.no_grad():
  256. # Forward pass and compute loss
  257. outputs = model(**batch)
  258. loss = outputs.loss
  259. if train_config.save_metrics:
  260. val_step_loss.append(loss.detach().float().item())
  261. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  262. eval_loss += loss.detach().float()
  263. # Decode predictions and add to evaluation predictions list
  264. preds = torch.argmax(outputs.logits, -1)
  265. eval_preds.extend(
  266. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  267. )
  268. # If there's more than one CUDA device, reduce evaluation loss across all devices
  269. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  270. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  271. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  272. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  273. # Compute average loss and perplexity
  274. eval_epoch_loss = eval_loss / len(eval_dataloader)
  275. if train_config.enable_fsdp:
  276. eval_epoch_loss = eval_epoch_loss/world_size
  277. eval_ppl = torch.exp(eval_epoch_loss)
  278. # Print evaluation metrics
  279. if train_config.enable_fsdp:
  280. if local_rank==0:
  281. print(f" {eval_ppl=} {eval_epoch_loss=}")
  282. else:
  283. print(f" {eval_ppl=} {eval_epoch_loss=}")
  284. if wandb_run:
  285. wandb_run.log({
  286. 'eval/perplexity': eval_ppl,
  287. 'eval/loss': eval_epoch_loss,
  288. }, commit=False)
  289. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  290. def freeze_transformer_layers(model, num_layer):
  291. for i, layer in enumerate(model.model.layers):
  292. if i < num_layer:
  293. for param in layer.parameters():
  294. param.requires_grad = False
  295. def check_frozen_layers_peft_model(model):
  296. for i, layer in enumerate(model.base_model.model.model.layers):
  297. for name, param in layer.named_parameters():
  298. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  299. def setup():
  300. """Initialize the process group for distributed training"""
  301. if is_ccl_available():
  302. # distributed training on xpus
  303. dist.init_process_group("ccl")
  304. else:
  305. dist.init_process_group("nccl")
  306. def setup_environ_flags(rank):
  307. """Set environment flags for debugging purposes"""
  308. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  309. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  310. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  311. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  312. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  313. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  314. if rank == 0:
  315. print(f"--> Running with torch dist debug set to detail")
  316. def cleanup():
  317. """Clean up the process group after training"""
  318. dist.destroy_process_group()
  319. def clear_gpu_cache(rank=None):
  320. """Clear the GPU cache for all ranks"""
  321. if rank == 0:
  322. print(f"Clearing GPU cache for all ranks")
  323. if is_xpu_available():
  324. torch.xpu_empty_cache()
  325. else:
  326. torch.cuda.empty_cache()
  327. def get_parameter_dtypes(model):
  328. """Get the data types of model parameters"""
  329. parameter_dtypes = {}
  330. for name, parameter in model.named_parameters():
  331. parameter_dtypes[name] = parameter.dtype
  332. return parameter_dtypes
  333. def print_model_size(model, config, rank: int = 0) -> None:
  334. """
  335. Print model name, the number of trainable parameters and initialization time.
  336. Args:
  337. model: The PyTorch model.
  338. model_name (str): Name of the model.
  339. init_time_start (float): Initialization start time.
  340. init_time_end (float): Initialization end time.
  341. rank (int, optional): Current process's rank. Defaults to 0.
  342. """
  343. if rank == 0:
  344. print(f"--> Model {config.model_name}")
  345. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  346. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  347. def get_policies(cfg, rank):
  348. """Get the policies for mixed precision and fsdp wrapping"""
  349. verify_bfloat_support = ((
  350. torch.version.cuda
  351. and torch.cuda.is_bf16_supported()
  352. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  353. and dist.is_nccl_available()
  354. and nccl.version() >= (2, 10)
  355. ) or
  356. (is_xpu_available()))
  357. mixed_precision_policy = None
  358. wrapping_policy = None
  359. # Mixed precision
  360. if cfg.mixed_precision:
  361. bf16_ready = verify_bfloat_support
  362. if bf16_ready and not cfg.use_fp16:
  363. mixed_precision_policy = bfSixteen
  364. if rank == 0:
  365. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  366. elif cfg.use_fp16:
  367. mixed_precision_policy = fpSixteen
  368. if rank == 0:
  369. print(f"FP16 enabled")
  370. else:
  371. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  372. wrapping_policy = get_llama_wrapper()
  373. return mixed_precision_policy, wrapping_policy
  374. def save_train_params(train_config, fsdp_config, rank):
  375. """
  376. This function saves the train_config and FSDP config into a train_params.yaml.
  377. This will be used by converter script in the inference folder to fetch the HF model name or path.
  378. It also would be hepful as a log for future references.
  379. """
  380. # Convert the train_config and fsdp_config objects to dictionaries,
  381. # converting all values to strings to ensure they can be serialized into a YAML file
  382. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  383. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  384. # Merge the two dictionaries into one
  385. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  386. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  387. folder_name = (
  388. train_config.dist_checkpoint_root_folder
  389. + "/"
  390. + train_config.dist_checkpoint_folder
  391. + "-"
  392. + train_config.model_name
  393. )
  394. save_dir = Path.cwd() / folder_name
  395. # If the directory does not exist, create it
  396. if not os.path.exists(save_dir):
  397. os.makedirs(save_dir)
  398. # Convert the dictionary to a YAML string
  399. config_yaml = yaml.dump(train_params_dict, indent=4)
  400. file_name = os.path.join(save_dir,'train_params.yaml')
  401. # Check if there's a directory with the same name as the file
  402. if os.path.isdir(file_name):
  403. print(f"Error: {file_name} is a directory, not a file.")
  404. else:
  405. # Write the YAML string to the file
  406. with open(file_name, 'w') as f:
  407. f.write(config_yaml)
  408. if rank==0:
  409. print(f"training params are saved in {file_name}")
  410. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  411. metrics_data = {
  412. "train_step_loss": train_step_loss,
  413. "train_epoch_loss": train_epoch_loss,
  414. "train_step_perplexity": train_step_ppl,
  415. "train_epoch_perplexity": train_epoch_ppl,
  416. "val_step_loss": val_step_loss,
  417. "val_epoch_loss": val_epoch_loss,
  418. "val_step_perplexity": val_step_ppl,
  419. "val_epoch_perplexity": val_epoch_ppl
  420. }
  421. with open(output_filename, "w") as f:
  422. json.dump(metrics_data, f)