train_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from contextlib import nullcontext
  7. from pathlib import Path
  8. from pkg_resources import packaging
  9. from datetime import datetime
  10. import torch
  11. import torch.cuda.nccl as nccl
  12. import torch.distributed as dist
  13. from torch.distributed.fsdp import StateDictType
  14. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. import json
  18. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  19. from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
  20. from llama_recipes.utils.memory_utils import MemoryTrace
  21. from accelerate.utils import is_xpu_available, is_ccl_available
  22. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  23. tokenizer.pad_token_id = 0
  24. tokenizer.padding_side = "left"
  25. # Converting Bytes to Megabytes
  26. def byte2mb(x):
  27. return int(x / 2**20)
  28. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  29. """
  30. Trains the model on the given dataloader
  31. Args:
  32. model: The model to be trained
  33. train_dataloader: The dataloader containing the training data
  34. optimizer: The optimizer used for training
  35. lr_scheduler: The learning rate scheduler
  36. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  37. num_epochs: The number of epochs to train for
  38. local_rank: The rank of the current node in a distributed setting
  39. train_config: The training configuration
  40. eval_dataloader: The dataloader containing the eval data
  41. tokenizer: tokenizer used in the eval for decoding the predicitons
  42. Returns: results dictionary containing average training and validation perplexity and loss
  43. """
  44. # Create a gradient scaler for fp16
  45. if train_config.use_fp16 and train_config.enable_fsdp:
  46. scaler = ShardedGradScaler()
  47. elif train_config.use_fp16 and not train_config.enable_fsdp:
  48. scaler = torch.cuda.amp.GradScaler()
  49. if train_config.enable_fsdp:
  50. world_size = int(os.environ["WORLD_SIZE"])
  51. autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
  52. train_prep = []
  53. train_loss = []
  54. val_prep = []
  55. val_loss =[]
  56. if train_config.save_metrics:
  57. metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
  58. train_step_perplexity = []
  59. train_step_loss = []
  60. val_step_loss = []
  61. val_step_perplexity = []
  62. epoch_times = []
  63. checkpoint_times = []
  64. results = {}
  65. best_val_loss = float("inf")
  66. for epoch in range(train_config.num_epochs):
  67. epoch_start_time = time.perf_counter()
  68. with MemoryTrace() as memtrace: # track the memory usage
  69. model.train()
  70. total_loss = 0.0
  71. total_length = len(train_dataloader)//gradient_accumulation_steps
  72. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  73. for step, batch in enumerate(train_dataloader):
  74. for key in batch.keys():
  75. if train_config.enable_fsdp:
  76. if is_xpu_available():
  77. batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
  78. else:
  79. batch[key] = batch[key].to(local_rank)
  80. else:
  81. if is_xpu_available():
  82. batch[key] = batch[key].to('xpu:0')
  83. else:
  84. batch[key] = batch[key].to('cuda:0')
  85. with autocast():
  86. loss = model(**batch).loss
  87. loss = loss / gradient_accumulation_steps
  88. if train_config.save_metrics:
  89. train_step_loss.append(loss.detach().float().item())
  90. train_step_perplexity.append(float(torch.exp(loss.detach().float())))
  91. total_loss += loss.detach().float()
  92. if train_config.use_fp16:
  93. # if fp16 is enabled, use gradient scaler to handle gradient update
  94. scaler.scale(loss).backward()
  95. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  96. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  97. scaler.unscale_(optimizer)
  98. if train_config.enable_fsdp:
  99. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  100. else:
  101. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  102. scaler.step(optimizer)
  103. scaler.update()
  104. optimizer.zero_grad()
  105. pbar.update(1)
  106. else:
  107. # regular backpropagation when fp16 is not used
  108. loss.backward()
  109. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  110. if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
  111. if train_config.enable_fsdp:
  112. model.clip_grad_norm_(train_config.gradient_clipping_threshold)
  113. else:
  114. torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
  115. optimizer.step()
  116. optimizer.zero_grad()
  117. pbar.update(1)
  118. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  119. if train_config.save_metrics:
  120. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  121. pbar.close()
  122. epoch_end_time = time.perf_counter()-epoch_start_time
  123. epoch_times.append(epoch_end_time)
  124. # Reducing total_loss across all devices if there's more than one CUDA device
  125. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  126. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  127. elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  128. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  129. train_epoch_loss = total_loss / len(train_dataloader)
  130. if train_config.enable_fsdp:
  131. train_epoch_loss = train_epoch_loss/world_size
  132. train_perplexity = torch.exp(train_epoch_loss)
  133. train_prep.append(float(train_perplexity))
  134. train_loss.append(float(train_epoch_loss))
  135. if not train_config.enable_fsdp or rank==0:
  136. memtrace.print_stats()
  137. # Update the learning rate as needed
  138. lr_scheduler.step()
  139. if train_config.run_validation:
  140. eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  141. if train_config.save_metrics:
  142. val_step_loss.extend(temp_val_loss)
  143. val_step_perplexity.extend(temp_step_perplexity)
  144. checkpoint_start_time = time.perf_counter()
  145. if train_config.save_model and eval_epoch_loss < best_val_loss:
  146. if train_config.enable_fsdp:
  147. dist.barrier()
  148. if train_config.use_peft:
  149. if train_config.enable_fsdp:
  150. if rank==0:
  151. print(f"we are about to save the PEFT modules")
  152. else:
  153. print(f"we are about to save the PEFT modules")
  154. model.save_pretrained(train_config.output_dir)
  155. if train_config.enable_fsdp:
  156. if rank==0:
  157. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  158. else:
  159. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  160. else:
  161. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  162. save_model_checkpoint(
  163. model, optimizer, rank, train_config, epoch=epoch
  164. )
  165. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  166. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  167. print("=====================================================")
  168. save_model_and_optimizer_sharded(model, rank, train_config)
  169. if train_config.save_optimizer:
  170. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  171. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  172. print("=====================================================")
  173. if not train_config.use_peft and train_config.save_optimizer:
  174. save_optimizer_checkpoint(
  175. model, optimizer, rank, train_config, epoch=epoch
  176. )
  177. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  178. print("=====================================================")
  179. if train_config.enable_fsdp:
  180. dist.barrier()
  181. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  182. checkpoint_times.append(checkpoint_end_time)
  183. if eval_epoch_loss < best_val_loss:
  184. best_val_loss = eval_epoch_loss
  185. if train_config.enable_fsdp:
  186. if rank==0:
  187. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  188. else:
  189. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  190. val_loss.append(float(best_val_loss))
  191. val_prep.append(float(eval_ppl))
  192. if train_config.enable_fsdp:
  193. if rank==0:
  194. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  195. else:
  196. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  197. # Saving the results every epoch to plot later
  198. if train_config.save_metrics:
  199. save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
  200. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  201. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  202. avg_train_prep = sum(train_prep)/len(train_prep)
  203. avg_train_loss = sum(train_loss)/len(train_loss)
  204. if train_config.run_validation:
  205. avg_eval_prep = sum(val_prep)/len(val_prep)
  206. avg_eval_loss = sum(val_loss)/len(val_loss)
  207. results['avg_train_prep'] = avg_train_prep
  208. results['avg_train_loss'] = avg_train_loss
  209. if train_config.run_validation:
  210. results['avg_eval_prep'] = avg_eval_prep
  211. results['avg_eval_loss'] = avg_eval_loss
  212. results["avg_epoch_time"] = avg_epoch_time
  213. results["avg_checkpoint_time"] = avg_checkpoint_time
  214. if train_config.save_metrics:
  215. results["metrics_filename"] = metrics_filename
  216. #saving the training params including fsdp setting for reference.
  217. if train_config.enable_fsdp and not train_config.use_peft:
  218. save_train_params(train_config, fsdp_config, rank)
  219. return results
  220. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  221. """
  222. Evaluates the model on the given dataloader
  223. Args:
  224. model: The model to evaluate
  225. eval_dataloader: The dataloader containing the evaluation data
  226. local_rank: The rank of the current node in a distributed setting
  227. tokenizer: The tokenizer used to decode predictions
  228. Returns: eval_ppl, eval_epoch_loss
  229. """
  230. if train_config.enable_fsdp:
  231. world_size = int(os.environ["WORLD_SIZE"])
  232. model.eval()
  233. eval_preds = []
  234. val_step_loss = []
  235. val_step_perplexity = []
  236. eval_loss = 0.0 # Initialize evaluation loss
  237. with MemoryTrace() as memtrace:
  238. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  239. for key in batch.keys():
  240. if train_config.enable_fsdp:
  241. batch[key] = batch[key].to(local_rank)
  242. else:
  243. if is_xpu_available():
  244. batch[key] = batch[key].to('xpu:0')
  245. else:
  246. batch[key] = batch[key].to('cuda:0')
  247. # Ensure no gradients are computed for this scope to save memory
  248. with torch.no_grad():
  249. # Forward pass and compute loss
  250. outputs = model(**batch)
  251. loss = outputs.loss
  252. if train_config.save_metrics:
  253. val_step_loss.append(loss.detach().float().item())
  254. val_step_perplexity.append(float(torch.exp(loss.detach().float())))
  255. eval_loss += loss.detach().float()
  256. # Decode predictions and add to evaluation predictions list
  257. preds = torch.argmax(outputs.logits, -1)
  258. eval_preds.extend(
  259. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  260. )
  261. # If there's more than one CUDA device, reduce evaluation loss across all devices
  262. if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
  263. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  264. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  265. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  266. # Compute average loss and perplexity
  267. eval_epoch_loss = eval_loss / len(eval_dataloader)
  268. if train_config.enable_fsdp:
  269. eval_epoch_loss = eval_epoch_loss/world_size
  270. eval_ppl = torch.exp(eval_epoch_loss)
  271. # Print evaluation metrics
  272. if train_config.enable_fsdp:
  273. if local_rank==0:
  274. print(f" {eval_ppl=} {eval_epoch_loss=}")
  275. else:
  276. print(f" {eval_ppl=} {eval_epoch_loss=}")
  277. return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
  278. def freeze_transformer_layers(model, num_layer):
  279. for i, layer in enumerate(model.model.layers):
  280. if i < num_layer:
  281. for param in layer.parameters():
  282. param.requires_grad = False
  283. def check_frozen_layers_peft_model(model):
  284. for i, layer in enumerate(model.base_model.model.model.layers):
  285. for name, param in layer.named_parameters():
  286. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  287. def setup():
  288. """Initialize the process group for distributed training"""
  289. if is_ccl_available():
  290. # distributed training on xpus
  291. dist.init_process_group("ccl")
  292. else:
  293. dist.init_process_group("nccl")
  294. def setup_environ_flags(rank):
  295. """Set environment flags for debugging purposes"""
  296. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  297. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  298. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  299. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  300. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  301. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  302. if rank == 0:
  303. print(f"--> Running with torch dist debug set to detail")
  304. def cleanup():
  305. """Clean up the process group after training"""
  306. dist.destroy_process_group()
  307. def clear_gpu_cache(rank=None):
  308. """Clear the GPU cache for all ranks"""
  309. if rank == 0:
  310. print(f"Clearing GPU cache for all ranks")
  311. if is_xpu_available():
  312. torch.xpu_empty_cache()
  313. else:
  314. torch.cuda.empty_cache()
  315. def get_parameter_dtypes(model):
  316. """Get the data types of model parameters"""
  317. parameter_dtypes = {}
  318. for name, parameter in model.named_parameters():
  319. parameter_dtypes[name] = parameter.dtype
  320. return parameter_dtypes
  321. def print_model_size(model, config, rank: int = 0) -> None:
  322. """
  323. Print model name, the number of trainable parameters and initialization time.
  324. Args:
  325. model: The PyTorch model.
  326. model_name (str): Name of the model.
  327. init_time_start (float): Initialization start time.
  328. init_time_end (float): Initialization end time.
  329. rank (int, optional): Current process's rank. Defaults to 0.
  330. """
  331. if rank == 0:
  332. print(f"--> Model {config.model_name}")
  333. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  334. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  335. def get_policies(cfg, rank):
  336. """Get the policies for mixed precision and fsdp wrapping"""
  337. verify_bfloat_support = ((
  338. torch.version.cuda
  339. and torch.cuda.is_bf16_supported()
  340. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  341. and dist.is_nccl_available()
  342. and nccl.version() >= (2, 10)
  343. ) or
  344. (is_xpu_available()))
  345. mixed_precision_policy = None
  346. wrapping_policy = None
  347. # Mixed precision
  348. if cfg.mixed_precision:
  349. bf16_ready = verify_bfloat_support
  350. if bf16_ready and not cfg.use_fp16:
  351. mixed_precision_policy = bfSixteen
  352. if rank == 0:
  353. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  354. elif cfg.use_fp16:
  355. mixed_precision_policy = fpSixteen
  356. if rank == 0:
  357. print(f"FP16 enabled")
  358. else:
  359. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  360. wrapping_policy = get_llama_wrapper()
  361. return mixed_precision_policy, wrapping_policy
  362. def save_train_params(train_config, fsdp_config, rank):
  363. """
  364. This function saves the train_config and FSDP config into a train_params.yaml.
  365. This will be used by converter script in the inference folder to fetch the HF model name or path.
  366. It also would be hepful as a log for future references.
  367. """
  368. # Convert the train_config and fsdp_config objects to dictionaries,
  369. # converting all values to strings to ensure they can be serialized into a YAML file
  370. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  371. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  372. # Merge the two dictionaries into one
  373. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  374. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  375. folder_name = (
  376. train_config.dist_checkpoint_root_folder
  377. + "/"
  378. + train_config.dist_checkpoint_folder
  379. + "-"
  380. + train_config.model_name
  381. )
  382. save_dir = Path.cwd() / folder_name
  383. # If the directory does not exist, create it
  384. if not os.path.exists(save_dir):
  385. os.makedirs(save_dir)
  386. # Convert the dictionary to a YAML string
  387. config_yaml = yaml.dump(train_params_dict, indent=4)
  388. file_name = os.path.join(save_dir,'train_params.yaml')
  389. # Check if there's a directory with the same name as the file
  390. if os.path.isdir(file_name):
  391. print(f"Error: {file_name} is a directory, not a file.")
  392. else:
  393. # Write the YAML string to the file
  394. with open(file_name, 'w') as f:
  395. f.write(config_yaml)
  396. if rank==0:
  397. print(f"training params are saved in {file_name}")
  398. def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
  399. metrics_data = {
  400. "train_step_loss": train_step_loss,
  401. "train_epoch_loss": train_epoch_loss,
  402. "train_step_perplexity": train_step_ppl,
  403. "train_epoch_perplexity": train_epoch_ppl,
  404. "val_step_loss": val_step_loss,
  405. "val_epoch_loss": val_epoch_loss,
  406. "val_step_perplexity": val_step_ppl,
  407. "val_epoch_perplexity": val_epoch_ppl
  408. }
  409. with open(output_filename, "w") as f:
  410. json.dump(metrics_data, f)