train_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. from datetime import datetime
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. import json
  18. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. from accelerate.utils import is_xpu_available, is_ccl_available
  22. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  23. tokenizer.pad_token_id = 0
  24. tokenizer.padding_side = "left"
  25. # Converting Bytes to Megabytes
  26. def byte2mb(x):
  27. return int(x / 2**20)
  28. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None, wandb_run=None):
  29. """
  30. Trains the model on the given dataloader
  31. Args:
  32. model: The model to be trained
  33. train_dataloader: The dataloader containing the training data
  34. optimizer: The optimizer used for training
  35. lr_scheduler: The learning rate scheduler
  36. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  37. num_epochs: The number of epochs to train for
  38. local_rank: The rank of the current node in a distributed setting
  39. train_config: The training configuration
  40. eval_dataloader: The dataloader containing the eval data
  41. tokenizer: tokenizer used in the eval for decoding the predicitons
  42. Returns: results dictionary containing average training and validation perplexity and loss
  43. """
  44. # Create a gradient scaler for fp16
  45. if train_config.use_fp16 and train_config.enable_fsdp:
  46. scaler = ShardedGradScaler()
  47. elif train_config.use_fp16 and not train_config.enable_fsdp:
  48. scaler = torch.cuda.amp.GradScaler()
  49. if train_config.enable_fsdp:
  50. world_size = int(os.environ["WORLD_SIZE"])
  51. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  52. train_prep = []
  53. train_loss = []
  54. val_prep = []
  55. val_loss =[]
  56. if train_config.save_metrics:
  57. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  58. train_step_perplexity = []
  59. train_step_loss = []
  60. val_step_loss = []
  61. val_step_perplexity = []
  62. epoch_times = []
  63. checkpoint_times = []
  64. results = {}
  65. best_val_loss = float("inf")
  66. total_train_steps = 0
  67. for epoch in range(train_config.num_epochs):
  68. # stop when the maximum number of training steps is reached
  69. if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
  70. if not train_config.enable_fsdp or local_rank==0:
  71. print("max training steps reached, stopping training, total_train_steps: ", total_train_steps-1)
  72. break
  73. epoch_start_time = time.perf_counter()
  74. with MemoryTrace() as memtrace: # track the memory usage
  75. model.train()
  76. total_loss = 0.0
  77. total_length = len(train_dataloader)//gradient_accumulation_steps
  78. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  79. for step, batch in enumerate(train_dataloader):
  80. total_train_steps += 1
  81. # stop when the maximum number of training steps is reached
  82. if train_config.max_train_step > 0 and total_train_steps > train_config.max_train_step:
  83. break
  84. for key in batch.keys():
  85. if train_config.enable_fsdp:
  86. if is_xpu_available():
  87. batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
  88. else:
  89. batch[key] = batch[key].to(local_rank)
  90. else:
  91. if is_xpu_available():
  92. batch[key] = batch[key].to('xpu:0')
  93. else:
  94. batch[key] = batch[key].to('cuda:0')
  95. with autocast():
  96. loss = model(**batch).loss
  97. loss = loss / gradient_accumulation_steps
  98. if train_config.save_metrics:
  99. train_step_loss.append(loss.detach().float().item())
  100. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  101. total_loss += loss.detach().float()
  102. if train_config.use_fp16:
  103. # if fp16 is enabled, use gradient scaler to handle gradient update
  104. scaler.scale(loss).backward()
  105. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  106. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  107. scaler.unscale_(optimizer)
  108. if train_config.enable_fsdp:
  109. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  110. else:
  111. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  112. scaler.step(optimizer)
  113. scaler.update()
  114. optimizer.zero_grad()
  115. pbar.update(1)
  116. else:
  117. # regular backpropagation when fp16 is not used
  118. loss.backward()
  119. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  120. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  121. if train_config.enable_fsdp:
  122. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  123. else:
  124. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  125. optimizer.step()
  126. optimizer.zero_grad()
  127. pbar.update(1)
  128. if wandb_run:
  129. if not train_config.enable_fsdp or rank==0:
  130. wandb_run.log({
  131. 'train/epoch': epoch + 1,
  132. 'train/step': epoch * len(train_dataloader) + step,
  133. 'train/loss': loss.detach().float(),
  134. })
  135. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  136. if train_config.save_metrics:
  137. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  138. pbar.close()
  139. epoch_end_time = time.perf_counter()-epoch_start_time
  140. epoch_times.append(epoch_end_time)
  141. # Reducing total_loss across all devices if there's more than one CUDA device
  142. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  143. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  144. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  145. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  146. train_epoch_loss = total_loss / len(train_dataloader)
  147. if train_config.enable_fsdp:
  148. train_epoch_loss = train_epoch_loss/world_size
  149. train_perplexity = torch.exp(train_epoch_loss)
  150. train_prep.append(float(train_perplexity))
  151. train_loss.append(float(train_epoch_loss))
  152. if not train_config.enable_fsdp or rank==0:
  153. memtrace.print_stats()
  154. # Update the learning rate as needed
  155. lr_scheduler.step()
  156. if train_config.run_validation:
  157. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer, wandb_run)
  158. if train_config.save_metrics:
  159. val_step_loss.extend(temp_val_loss)
  160. val_step_perplexity.extend(temp_step_perplexity)
  161. checkpoint_start_time = time.perf_counter()
  162. if train_config.save_model and eval_epoch_loss < best_val_loss:
  163. if train_config.enable_fsdp:
  164. dist.barrier()
  165. if train_config.use_peft:
  166. if train_config.enable_fsdp:
  167. if rank==0:
  168. print(f"we are about to save the PEFT modules")
  169. else:
  170. print(f"we are about to save the PEFT modules")
  171. model.save_pretrained(train_config.output_dir)
  172. if train_config.enable_fsdp:
  173. if rank==0:
  174. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  175. else:
  176. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  177. else:
  178. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  179. save_model_checkpoint(
  180. model, optimizer, rank, train_config, epoch=epoch
  181. )
  182. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  183. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  184. print("=====================================================")
  185. save_model_and_optimizer_sharded(model, rank, train_config)
  186. if train_config.save_optimizer:
  187. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  188. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  189. print("=====================================================")
  190. if not train_config.use_peft and train_config.save_optimizer:
  191. save_optimizer_checkpoint(
  192. model, optimizer, rank, train_config, epoch=epoch
  193. )
  194. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  195. print("=====================================================")
  196. if train_config.enable_fsdp:
  197. dist.barrier()
  198. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  199. checkpoint_times.append(checkpoint_end_time)
  200. if eval_epoch_loss < best_val_loss:
  201. best_val_loss = eval_epoch_loss
  202. if train_config.enable_fsdp:
  203. if rank==0:
  204. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  205. else:
  206. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  207. val_loss.append(float(best_val_loss))
  208. val_prep.append(float(eval_ppl))
  209. if train_config.enable_fsdp:
  210. if rank==0:
  211. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  212. else:
  213. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  214. # Saving the results every epoch to plot later
  215. if train_config.save_metrics:
  216. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  217. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  218. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  219. avg_train_prep = sum(train_prep)/len(train_prep)
  220. avg_train_loss = sum(train_loss)/len(train_loss)
  221. if train_config.run_validation:
  222. avg_eval_prep = sum(val_prep)/len(val_prep)
  223. avg_eval_loss = sum(val_loss)/len(val_loss)
  224. results['avg_train_prep'] = avg_train_prep
  225. results['avg_train_loss'] = avg_train_loss
  226. if train_config.run_validation:
  227. results['avg_eval_prep'] = avg_eval_prep
  228. results['avg_eval_loss'] = avg_eval_loss
  229. results["avg_epoch_time"] = avg_epoch_time
  230. results["avg_checkpoint_time"] = avg_checkpoint_time
  231. if train_config.save_metrics:
  232. results["metrics_filename"] = metrics_filename
  233. #saving the training params including fsdp setting for reference.
  234. if train_config.enable_fsdp and not train_config.use_peft and rank==0:
  235. save_train_params(train_config, fsdp_config, rank)
  236. return results
  237. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer, wandb_run):
  238. """
  239. Evaluates the model on the given dataloader
  240. Args:
  241. model: The model to evaluate
  242. eval_dataloader: The dataloader containing the evaluation data
  243. local_rank: The rank of the current node in a distributed setting
  244. tokenizer: The tokenizer used to decode predictions
  245. Returns: eval_ppl, eval_epoch_loss
  246. """
  247. if train_config.enable_fsdp:
  248. world_size = int(os.environ["WORLD_SIZE"])
  249. model.eval()
  250. eval_preds = []
  251. val_step_loss = []
  252. val_step_perplexity = []
  253. eval_loss = 0.0 # Initialize evaluation loss
  254. total_eval_steps = 0
  255. with MemoryTrace() as memtrace:
  256. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  257. total_eval_steps += 1
  258. # stop when the maximum number of eval steps is reached
  259. if train_config.max_eval_step > 0 and total_eval_steps > train_config.max_eval_step:
  260. if not train_config.enable_fsdp or local_rank==0:
  261. print("max eval steps reached, stopping evaluation, total_eval_steps: ", total_eval_steps - 1)
  262. break
  263. for key in batch.keys():
  264. if train_config.enable_fsdp:
  265. batch[key] = batch[key].to(local_rank)
  266. else:
  267. if is_xpu_available():
  268. batch[key] = batch[key].to('xpu:0')
  269. else:
  270. batch[key] = batch[key].to('cuda:0')
  271. # Ensure no gradients are computed for this scope to save memory
  272. with torch.no_grad():
  273. # Forward pass and compute loss
  274. outputs = model(**batch)
  275. loss = outputs.loss
  276. if train_config.save_metrics:
  277. val_step_loss.append(loss.detach().float().item())
  278. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  279. eval_loss += loss.detach().float()
  280. # Decode predictions and add to evaluation predictions list
  281. preds = torch.argmax(outputs.logits, -1)
  282. eval_preds.extend(
  283. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  284. )
  285. # If there's more than one CUDA device, reduce evaluation loss across all devices
  286. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  287. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  288. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  289. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  290. # Compute average loss and perplexity
  291. eval_epoch_loss = eval_loss / len(eval_dataloader)
  292. if train_config.enable_fsdp:
  293. eval_epoch_loss = eval_epoch_loss/world_size
  294. eval_ppl = torch.exp(eval_epoch_loss)
  295. # Print evaluation metrics
  296. if train_config.enable_fsdp:
  297. if local_rank==0:
  298. print(f" {eval_ppl=} {eval_epoch_loss=}")
  299. else:
  300. print(f" {eval_ppl=} {eval_epoch_loss=}")
  301. if wandb_run:
  302. wandb_run.log({
  303. 'eval/perplexity': eval_ppl,
  304. 'eval/loss': eval_epoch_loss,
  305. }, commit=False)
  306. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  307. def freeze_transformer_layers(model, num_layer):
  308. for i, layer in enumerate(model.model.layers):
  309. if i < num_layer:
  310. for param in layer.parameters():
  311. param.requires_grad = False
  312. def check_frozen_layers_peft_model(model):
  313. for i, layer in enumerate(model.base_model.model.model.layers):
  314. for name, param in layer.named_parameters():
  315. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  316. def setup():
  317. """Initialize the process group for distributed training"""
  318. if is_ccl_available():
  319. # distributed training on xpus
  320. dist.init_process_group("ccl")
  321. else:
  322. dist.init_process_group("nccl")
  323. def setup_environ_flags(rank):
  324. """Set environment flags for debugging purposes"""
  325. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  326. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  327. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  328. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  329. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  330. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  331. if rank == 0:
  332. print(f"--> Running with torch dist debug set to detail")
  333. def cleanup():
  334. """Clean up the process group after training"""
  335. dist.destroy_process_group()
  336. def clear_gpu_cache(rank=None):
  337. """Clear the GPU cache for all ranks"""
  338. if rank == 0:
  339. print(f"Clearing GPU cache for all ranks")
  340. if is_xpu_available():
  341. torch.xpu_empty_cache()
  342. else:
  343. torch.cuda.empty_cache()
  344. def get_parameter_dtypes(model):
  345. """Get the data types of model parameters"""
  346. parameter_dtypes = {}
  347. for name, parameter in model.named_parameters():
  348. parameter_dtypes[name] = parameter.dtype
  349. return parameter_dtypes
  350. def print_model_size(model, config, rank: int = 0) -> None:
  351. """
  352. Print model name, the number of trainable parameters and initialization time.
  353. Args:
  354. model: The PyTorch model.
  355. model_name (str): Name of the model.
  356. init_time_start (float): Initialization start time.
  357. init_time_end (float): Initialization end time.
  358. rank (int, optional): Current process's rank. Defaults to 0.
  359. """
  360. if rank == 0:
  361. print(f"--> Model {config.model_name}")
  362. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  363. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  364. def get_policies(cfg, rank):
  365. """Get the policies for mixed precision and fsdp wrapping"""
  366. verify_bfloat_support = ((
  367. torch.version.cuda
  368. and torch.cuda.is_bf16_supported()
  369. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  370. and dist.is_nccl_available()
  371. and nccl.version() >= (2, 10)
  372. ) or
  373. (is_xpu_available()))
  374. mixed_precision_policy = None
  375. wrapping_policy = None
  376. # Mixed precision
  377. if cfg.mixed_precision:
  378. bf16_ready = verify_bfloat_support
  379. if bf16_ready and not cfg.use_fp16:
  380. mixed_precision_policy = bfSixteen
  381. if rank == 0:
  382. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  383. elif cfg.use_fp16:
  384. mixed_precision_policy = fpSixteen
  385. if rank == 0:
  386. print(f"FP16 enabled")
  387. else:
  388. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  389. wrapping_policy = get_llama_wrapper()
  390. return mixed_precision_policy, wrapping_policy
  391. def save_train_params(train_config, fsdp_config, rank):
  392. """
  393. This function saves the train_config and FSDP config into a train_params.yaml.
  394. This will be used by converter script in the inference folder to fetch the HF model name or path.
  395. It also would be hepful as a log for future references.
  396. """
  397. # Convert the train_config and fsdp_config objects to dictionaries,
  398. # converting all values to strings to ensure they can be serialized into a YAML file
  399. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  400. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  401. # Merge the two dictionaries into one
  402. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  403. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  404. folder_name = (
  405. train_config.dist_checkpoint_root_folder
  406. + "/"
  407. + train_config.dist_checkpoint_folder
  408. + "-"
  409. + train_config.model_name
  410. )
  411. save_dir = Path.cwd() / folder_name
  412. # If the directory does not exist, create it
  413. if not os.path.exists(save_dir):
  414. os.makedirs(save_dir)
  415. # Convert the dictionary to a YAML string
  416. config_yaml = yaml.dump(train_params_dict, indent=4)
  417. file_name = os.path.join(save_dir,'train_params.yaml')
  418. # Check if there's a directory with the same name as the file
  419. if os.path.isdir(file_name):
  420. print(f"Error: {file_name} is a directory, not a file.")
  421. else:
  422. # Write the YAML string to the file
  423. with open(file_name, 'w') as f:
  424. f.write(config_yaml)
  425. if rank==0:
  426. print(f"training params are saved in {file_name}")
  427. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  428. metrics_data = {
  429. "train_step_loss": train_step_loss,
  430. "train_epoch_loss": train_epoch_loss,
  431. "train_step_perplexity": train_step_ppl,
  432. "train_epoch_perplexity": train_epoch_ppl,
  433. "val_step_loss": val_step_loss,
  434. "val_epoch_loss": val_epoch_loss,
  435. "val_step_perplexity": val_step_ppl,
  436. "val_epoch_perplexity": val_epoch_ppl
  437. }
  438. with open(output_filename, "w") as f:
  439. json.dump(metrics_data, f)