train_utils.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import time
  5. import yaml
  6. from pathlib import Path
  7. from pkg_resources import packaging
  8. import contextlib
  9. import torch
  10. import torch.cuda.nccl as nccl
  11. import torch.distributed as dist
  12. from torch.distributed.fsdp import StateDictType
  13. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  14. # from torch.utils.flop_counter import FlopCounterMode
  15. from tqdm import tqdm
  16. from transformers import LlamaTokenizer
  17. from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
  18. from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
  19. from llama_recipes.utils.memory_utils import MemoryTrace
  20. from llama_recipes.utils.tflop_counter import FlopCounterMode
  21. @contextlib.contextmanager
  22. def maybe_run_profiler(cfg, *args, **kwargs):
  23. use_profiler: bool = cfg.profiler
  24. if use_profiler:
  25. with torch.profiler.profile(
  26. activities=[
  27. torch.profiler.ProfilerActivity.CPU,
  28. torch.profiler.ProfilerActivity.CUDA,
  29. ],
  30. schedule=torch.profiler.schedule(wait=1, warmup=2, active=3, repeat=1),
  31. on_trace_ready=torch.profiler.tensorboard_trace_handler(
  32. cfg.profile_output_dir
  33. ),
  34. profile_memory=True,
  35. with_stack=False,
  36. record_shapes=True,
  37. ) as torch_profiler:
  38. yield torch_profiler
  39. else:
  40. torch_profiler = contextlib.nullcontext()
  41. yield None
  42. def get_total_flops(model):
  43. return (sum([v for _, v in model.flop_counts["Global"].items()]))
  44. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  45. tokenizer.pad_token_id = 0
  46. tokenizer.padding_side = "left"
  47. # Converting Bytes to Megabytes
  48. def byte2mb(x):
  49. return int(x / 2**20)
  50. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  51. """
  52. Trains the model on the given dataloader
  53. Args:
  54. model: The model to be trained
  55. train_dataloader: The dataloader containing the training data
  56. optimizer: The optimizer used for training
  57. lr_scheduler: The learning rate scheduler
  58. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  59. num_epochs: The number of epochs to train for
  60. local_rank: The rank of the current node in a distributed setting
  61. train_config: The training configuration
  62. eval_dataloader: The dataloader containing the eval data
  63. tokenizer: tokenizer used in the eval for decoding the predicitons
  64. Returns: results dictionary containing average training and validation perplexity and loss
  65. """
  66. # Create a gradient scaler for fp16
  67. if train_config.use_fp16 and train_config.enable_fsdp:
  68. scaler = ShardedGradScaler()
  69. elif train_config.use_fp16 and not train_config.enable_fsdp:
  70. scaler = torch.cuda.amp.GradScaler()
  71. if train_config.enable_fsdp:
  72. world_size = int(os.environ["WORLD_SIZE"])
  73. train_prep = []
  74. train_loss = []
  75. val_prep = []
  76. val_loss =[]
  77. epoch_times = []
  78. checkpoint_times = []
  79. results = {}
  80. best_val_loss = float("inf")
  81. for epoch in range(train_config.num_epochs):
  82. epoch_start_time = time.perf_counter()
  83. with MemoryTrace() as memtrace: # track the memory usage
  84. model.train()
  85. total_loss = 0.0
  86. total_length = len(train_dataloader)//gradient_accumulation_steps
  87. pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
  88. with maybe_run_profiler(train_config) as torch_profiler:
  89. for step, batch in enumerate(train_dataloader):
  90. for key in batch.keys():
  91. if train_config.enable_fsdp:
  92. batch[key] = batch[key].to(local_rank)
  93. else:
  94. batch[key] = batch[key].to('cuda:0')
  95. flop_check_done = False
  96. if train_config.flop_counter and step == 3 and not flop_check_done:
  97. flop_counter = FlopCounterMode(rank=local_rank)
  98. with flop_counter:
  99. loss = model(**batch).loss
  100. loss = loss / gradient_accumulation_steps
  101. total_loss += loss.detach().float()
  102. if train_config.use_fp16:
  103. # if fp16 is enabled, use gradient scaler to handle gradient update
  104. scaler.scale(loss).backward()
  105. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  106. scaler.step(optimizer)
  107. scaler.update()
  108. optimizer.zero_grad()
  109. pbar.update(1)
  110. else:
  111. # regular backpropagation when fp16 is not used
  112. loss.backward()
  113. TFlops = get_total_flops(flop_counter) / 1e12
  114. flop_check_done = True
  115. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  116. optimizer.step()
  117. optimizer.zero_grad()
  118. pbar.update(1)
  119. else:
  120. loss = model(**batch).loss
  121. loss = loss / gradient_accumulation_steps
  122. total_loss += loss.detach().float()
  123. if train_config.use_fp16:
  124. # if fp16 is enabled, use gradient scaler to handle gradient update
  125. scaler.scale(loss).backward()
  126. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  127. scaler.step(optimizer)
  128. scaler.update()
  129. optimizer.zero_grad()
  130. pbar.update(1)
  131. else:
  132. # regular backpropagation when fp16 is not used
  133. loss.backward()
  134. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  135. optimizer.step()
  136. optimizer.zero_grad()
  137. pbar.update(1)
  138. pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
  139. pbar.close()
  140. epoch_end_time = time.perf_counter()-epoch_start_time
  141. epoch_times.append(epoch_end_time)
  142. # Reducing total_loss across all devices if there's more than one CUDA device
  143. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  144. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  145. train_epoch_loss = total_loss / len(train_dataloader)
  146. if train_config.enable_fsdp:
  147. train_epoch_loss = train_epoch_loss/world_size
  148. train_perplexity = torch.exp(train_epoch_loss)
  149. train_prep.append(train_perplexity)
  150. train_loss.append(train_epoch_loss)
  151. if train_config.enable_fsdp:
  152. if rank==0:
  153. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  154. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  155. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  156. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  157. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  158. else:
  159. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  160. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  161. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  162. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  163. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  164. # Update the learning rate as needed
  165. lr_scheduler.step()
  166. if train_config.run_validation:
  167. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
  168. checkpoint_start_time = time.perf_counter()
  169. if train_config.save_model and eval_epoch_loss < best_val_loss:
  170. if train_config.enable_fsdp:
  171. dist.barrier()
  172. if train_config.use_peft:
  173. if train_config.enable_fsdp:
  174. if rank==0:
  175. print(f"we are about to save the PEFT modules")
  176. else:
  177. print(f"we are about to save the PEFT modules")
  178. model.save_pretrained(train_config.output_dir)
  179. if train_config.enable_fsdp:
  180. if rank==0:
  181. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  182. else:
  183. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  184. else:
  185. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  186. save_model_checkpoint(
  187. model, optimizer, rank, train_config, epoch=epoch
  188. )
  189. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  190. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  191. print("=====================================================")
  192. save_model_and_optimizer_sharded(model, rank, train_config)
  193. if train_config.save_optimizer:
  194. save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  195. print(" Saving the FSDP model checkpoints and optimizer using SHARDED_STATE_DICT")
  196. print("=====================================================")
  197. if not train_config.use_peft and train_config.save_optimizer:
  198. save_optimizer_checkpoint(
  199. model, optimizer, rank, train_config, epoch=epoch
  200. )
  201. print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
  202. print("=====================================================")
  203. if train_config.enable_fsdp:
  204. dist.barrier()
  205. checkpoint_end_time = time.perf_counter() - checkpoint_start_time
  206. checkpoint_times.append(checkpoint_end_time)
  207. if eval_epoch_loss < best_val_loss:
  208. best_val_loss = eval_epoch_loss
  209. if train_config.enable_fsdp:
  210. if rank==0:
  211. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  212. else:
  213. print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
  214. val_loss.append(best_val_loss)
  215. val_prep.append(eval_ppl)
  216. if train_config.enable_fsdp:
  217. if rank==0:
  218. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  219. else:
  220. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
  221. avg_epoch_time = sum(epoch_times)/ len(epoch_times)
  222. avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
  223. avg_train_prep = sum(train_prep)/len(train_prep)
  224. avg_train_loss = sum(train_loss)/len(train_loss)
  225. if train_config.run_validation:
  226. avg_eval_prep = sum(val_prep)/len(val_prep)
  227. avg_eval_loss = sum(val_loss)/len(val_loss)
  228. results['avg_train_prep'] = avg_train_prep
  229. results['avg_train_loss'] = avg_train_loss
  230. if train_config.run_validation:
  231. results['avg_eval_prep'] = avg_eval_prep
  232. results['avg_eval_loss'] = avg_eval_loss
  233. results["avg_epoch_time"] = avg_epoch_time
  234. results["avg_checkpoint_time"] = avg_checkpoint_time
  235. if train_config.flop_counter:
  236. results["model_flops"]= TFlops
  237. #saving the training params including fsdp setting for reference.
  238. if train_config.enable_fsdp and not train_config.use_peft:
  239. save_train_params(train_config, fsdp_config, rank)
  240. return results
  241. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  242. """
  243. Evaluates the model on the given dataloader
  244. Args:
  245. model: The model to evaluate
  246. eval_dataloader: The dataloader containing the evaluation data
  247. local_rank: The rank of the current node in a distributed setting
  248. tokenizer: The tokenizer used to decode predictions
  249. Returns: eval_ppl, eval_epoch_loss
  250. """
  251. if train_config.enable_fsdp:
  252. world_size = int(os.environ["WORLD_SIZE"])
  253. model.eval()
  254. eval_preds = []
  255. eval_loss = 0.0 # Initialize evaluation loss
  256. with MemoryTrace() as memtrace:
  257. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
  258. for key in batch.keys():
  259. if train_config.enable_fsdp:
  260. batch[key] = batch[key].to(local_rank)
  261. else:
  262. batch[key] = batch[key].to('cuda:0')
  263. # Ensure no gradients are computed for this scope to save memory
  264. with torch.no_grad():
  265. # Forward pass and compute loss
  266. outputs = model(**batch)
  267. loss = outputs.loss
  268. eval_loss += loss.detach().float()
  269. # Decode predictions and add to evaluation predictions list
  270. preds = torch.argmax(outputs.logits, -1)
  271. eval_preds.extend(
  272. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  273. )
  274. # If there's more than one CUDA device, reduce evaluation loss across all devices
  275. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  276. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  277. # Compute average loss and perplexity
  278. eval_epoch_loss = eval_loss / len(eval_dataloader)
  279. if train_config.enable_fsdp:
  280. eval_epoch_loss = eval_epoch_loss/world_size
  281. eval_ppl = torch.exp(eval_epoch_loss)
  282. # Print evaluation metrics
  283. if train_config.enable_fsdp:
  284. if local_rank==0:
  285. print(f" {eval_ppl=} {eval_epoch_loss=}")
  286. else:
  287. print(f" {eval_ppl=} {eval_epoch_loss=}")
  288. return eval_ppl, eval_epoch_loss
  289. def freeze_transformer_layers(model, num_layer):
  290. for i, layer in enumerate(model.model.layers):
  291. if i < num_layer:
  292. for param in layer.parameters():
  293. param.requires_grad = False
  294. def check_frozen_layers_peft_model(model):
  295. for i, layer in enumerate(model.base_model.model.model.layers):
  296. for name, param in layer.named_parameters():
  297. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  298. def setup():
  299. """Initialize the process group for distributed training"""
  300. dist.init_process_group("nccl")
  301. def setup_environ_flags(rank):
  302. """Set environment flags for debugging purposes"""
  303. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  304. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  305. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  306. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  307. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  308. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  309. if rank == 0:
  310. print(f"--> Running with torch dist debug set to detail")
  311. def cleanup():
  312. """Clean up the process group after training"""
  313. dist.destroy_process_group()
  314. def clear_gpu_cache(rank=None):
  315. """Clear the GPU cache for all ranks"""
  316. if rank == 0:
  317. print(f"Clearing GPU cache for all ranks")
  318. torch.cuda.empty_cache()
  319. def get_parameter_dtypes(model):
  320. """Get the data types of model parameters"""
  321. parameter_dtypes = {}
  322. for name, parameter in model.named_parameters():
  323. parameter_dtypes[name] = parameter.dtype
  324. return parameter_dtypes
  325. def print_model_size(model, config, rank: int = 0) -> None:
  326. """
  327. Print model name, the number of trainable parameters and initialization time.
  328. Args:
  329. model: The PyTorch model.
  330. model_name (str): Name of the model.
  331. init_time_start (float): Initialization start time.
  332. init_time_end (float): Initialization end time.
  333. rank (int, optional): Current process's rank. Defaults to 0.
  334. """
  335. if rank == 0:
  336. print(f"--> Model {config.model_name}")
  337. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  338. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  339. def get_policies(cfg, rank):
  340. """Get the policies for mixed precision and fsdp wrapping"""
  341. verify_bfloat_support = (
  342. torch.version.cuda
  343. and torch.cuda.is_bf16_supported()
  344. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  345. and dist.is_nccl_available()
  346. and nccl.version() >= (2, 10)
  347. )
  348. mixed_precision_policy = None
  349. wrapping_policy = None
  350. # Mixed precision
  351. if cfg.mixed_precision:
  352. bf16_ready = verify_bfloat_support
  353. if bf16_ready and not cfg.use_fp16:
  354. mixed_precision_policy = bfSixteen_mixed
  355. if rank == 0:
  356. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  357. elif cfg.use_fp16:
  358. mixed_precision_policy = fpSixteen
  359. if rank == 0:
  360. print(f"FP16 enabled")
  361. else:
  362. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  363. wrapping_policy = get_llama_wrapper()
  364. return mixed_precision_policy, wrapping_policy
  365. def save_train_params(train_config, fsdp_config, rank):
  366. """
  367. This function saves the train_config and FSDP config into a train_params.yaml.
  368. This will be used by converter script in the inference folder to fetch the HF model name or path.
  369. It also would be hepful as a log for future references.
  370. """
  371. # Convert the train_config and fsdp_config objects to dictionaries,
  372. # converting all values to strings to ensure they can be serialized into a YAML file
  373. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  374. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  375. # Merge the two dictionaries into one
  376. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  377. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  378. folder_name = (
  379. train_config.dist_checkpoint_root_folder
  380. + "/"
  381. + train_config.dist_checkpoint_folder
  382. + "-"
  383. + train_config.model_name
  384. )
  385. save_dir = Path.cwd() / folder_name
  386. # If the directory does not exist, create it
  387. if not os.path.exists(save_dir):
  388. os.makedirs(save_dir)
  389. # Convert the dictionary to a YAML string
  390. config_yaml = yaml.dump(train_params_dict, indent=4)
  391. file_name = os.path.join(save_dir,'train_params.yaml')
  392. # Check if there's a directory with the same name as the file
  393. if os.path.isdir(file_name):
  394. print(f"Error: {file_name} is a directory, not a file.")
  395. else:
  396. # Write the YAML string to the file
  397. with open(file_name, 'w') as f:
  398. f.write(config_yaml)
  399. if rank==0:
  400. print(f"training params are saved in {file_name}")