train_utils.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List, Optional
  6. import yaml
  7. import fire
  8. import torch
  9. import transformers
  10. from datasets import load_dataset
  11. from tqdm import tqdm
  12. import time
  13. """
  14. Unused imports:
  15. import torch.nn as nn
  16. import bitsandbytes as bnb
  17. """
  18. from torch.nn import functional as F
  19. from peft import (
  20. LoraConfig,
  21. get_peft_model,
  22. get_peft_model_state_dict,
  23. prepare_model_for_int8_training,
  24. set_peft_model_state_dict,
  25. )
  26. from transformers import LlamaForCausalLM, LlamaTokenizer
  27. from torch.distributed.fsdp import StateDictType
  28. import torch.distributed as dist
  29. from pkg_resources import packaging
  30. from .memory_utils import MemoryTrace
  31. import model_checkpointing
  32. import torch.cuda.nccl as nccl
  33. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  34. from pathlib import Path
  35. sys.path.append(str(Path(__file__).resolve().parent.parent))
  36. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  37. import torch.autograd.profiler as profiler
  38. from torch.cuda._memory_viz import profile_plot
  39. from pickle import dump
  40. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  41. tokenizer.pad_token_id = 0
  42. tokenizer.padding_side = "left"
  43. # Converting Bytes to Megabytes
  44. def byte2mb(x):
  45. return int(x / 2**20)
  46. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  47. """
  48. Trains the model on the given dataloader
  49. Args:
  50. model: The model to be trained
  51. train_dataloader: The dataloader containing the training data
  52. optimizer: The optimizer used for training
  53. lr_scheduler: The learning rate scheduler
  54. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  55. num_epochs: The number of epochs to train for
  56. local_rank: The rank of the current node in a distributed setting
  57. train_config: The training configuration
  58. eval_dataloader: The dataloader containing the eval data
  59. tokenizer: tokenizer used in the eval for decoding the predicitons
  60. Returns: results dictionary containing average training and validation perplexity and loss
  61. """
  62. # Create a gradient scaler for fp16
  63. torch.cuda.memory._record_memory_history(True,
  64. # keep 100,000 alloc/free events from before the snapshot
  65. trace_alloc_max_entries=100000,
  66. # record stack information for the trace events
  67. trace_alloc_record_context=True)
  68. if train_config.use_fp16 and train_config.enable_fsdp:
  69. scaler = ShardedGradScaler()
  70. elif train_config.use_fp16 and not train_config.enable_fsdp:
  71. scaler = torch.cuda.amp.GradScaler()
  72. if train_config.enable_fsdp:
  73. world_size = int(os.environ["WORLD_SIZE"])
  74. train_prep = []
  75. train_loss = []
  76. val_prep = []
  77. val_loss =[]
  78. results = {}
  79. best_val_loss = float("inf")
  80. epoch_times=[]
  81. for epoch in range(train_config.num_epochs):
  82. start_epoch = time.perf_counter()
  83. with MemoryTrace() as memtrace: # track the memory usage
  84. model.train()
  85. total_loss = 0.0
  86. # if fsdp_config.profile_mem:
  87. # with torch.profiler.profile(
  88. # schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1),
  89. # activities=[profiler.ProfilerActivity.CPU, profiler.ProfilerActivity.CUDA],
  90. # on_trace_ready=torch.profiler.tensorboard_trace_handler('./log/llama2-7b'),
  91. # record_shapes=True,
  92. # profile_memory=True,
  93. # with_stack=True,
  94. # ) as prof:
  95. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  96. if step >10:
  97. break
  98. for key in batch.keys():
  99. if train_config.enable_fsdp:
  100. batch[key] = batch[key].to(local_rank)
  101. else:
  102. batch[key] = batch[key].to('cuda:0')
  103. loss = model(**batch).loss
  104. loss = loss / gradient_accumulation_steps
  105. total_loss += loss.detach().float()
  106. if train_config.use_fp16:
  107. # if fp16 is enabled, use gradient scaler to handle gradient update
  108. scaler.scale(loss).backward()
  109. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  110. scaler.step(optimizer)
  111. scaler.update()
  112. optimizer.zero_grad()
  113. else:
  114. # regular backpropagation when fp16 is not used
  115. loss.backward()
  116. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  117. optimizer.step()
  118. optimizer.zero_grad()
  119. if step == 4:
  120. if rank==0:
  121. snapshot = torch.cuda.memory._snapshot()
  122. with open('snapshot.pickle', 'wb') as f:
  123. dump(snapshot, f)
  124. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  125. end_epoch = time.perf_counter()
  126. epoch_time = end_epoch- start_epoch
  127. print(f"epoch time is {epoch_time}")
  128. print("==================================================")
  129. epoch_times.append(epoch_time)
  130. # Reducing total_loss across all devices if there's more than one CUDA device
  131. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  132. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  133. world_size = int(os.environ["WORLD_SIZE"])
  134. train_epoch_loss = total_loss / len(train_dataloader)
  135. if train_config.enable_fsdp:
  136. train_epoch_loss = train_epoch_loss/world_size
  137. train_perplexity = torch.exp(train_epoch_loss)
  138. train_prep.append(train_perplexity)
  139. train_loss.append(train_epoch_loss)
  140. if train_config.enable_fsdp:
  141. if rank==0:
  142. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  143. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  144. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  145. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  146. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  147. else:
  148. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  149. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  150. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  151. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  152. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  153. # Update the learning rate as needed
  154. lr_scheduler.step()
  155. if train_config.run_validation:
  156. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  157. if train_config.save_model and eval_epoch_loss < best_val_loss:
  158. if train_config.enable_fsdp:
  159. dist.barrier()
  160. if train_config.use_peft:
  161. if train_config.enable_fsdp:
  162. if rank==0:
  163. print(f"we are about to save the PEFT modules")
  164. else:
  165. print(f"we are about to save the PEFT modules")
  166. model.save_pretrained(train_config.output_dir)
  167. if train_config.enable_fsdp:
  168. if rank==0:
  169. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  170. else:
  171. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  172. else:
  173. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  174. model_checkpointing.save_model_checkpoint(
  175. model, optimizer, rank, train_config, epoch=epoch
  176. )
  177. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  178. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  179. print("=====================================================")
  180. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  181. if train_config.save_optimizer:
  182. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  183. print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
  184. print("=====================================================")
  185. if not train_config.use_peft and train_config.save_optimizer:
  186. model_checkpointing.save_optimizer_checkpoint(
  187. model, optimizer, rank, train_config, epoch=epoch
  188. )
  189. print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
  190. print("=====================================================")
  191. if train_config.enable_fsdp:
  192. dist.barrier()
  193. if eval_epoch_loss < best_val_loss:
  194. best_val_loss = eval_epoch_loss
  195. if train_config.enable_fsdp:
  196. if rank==0:
  197. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  198. else:
  199. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  200. val_loss.append(best_val_loss)
  201. val_prep.append(eval_ppl)
  202. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  203. lr_scheduler.step()
  204. avg_epoch_time = sum(epoch_times)/len(epoch_times)
  205. print(f"avg epoch time is {avg_epoch_time}")
  206. print("==========================================")
  207. avg_train_prep = sum(train_prep)/len(train_prep)
  208. avg_train_loss = sum(train_loss)/len(train_loss)
  209. if train_config.run_validation:
  210. avg_eval_prep = sum(val_prep)/len(val_prep)
  211. avg_eval_loss = sum(val_loss)/len(val_loss)
  212. results['avg_train_prep'] = avg_train_prep
  213. results['avg_train_loss'] = avg_train_loss
  214. if train_config.run_validation:
  215. results['avg_eval_prep'] = avg_eval_prep
  216. results['avg_eval_loss'] = avg_eval_loss
  217. #saving the training params including fsdp setting for reference.
  218. if train_config.enable_fsdp and not train_config.use_peft:
  219. save_train_params(train_config, fsdp_config, rank)
  220. return results
  221. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  222. """
  223. Evaluates the model on the given dataloader
  224. Args:
  225. model: The model to evaluate
  226. eval_dataloader: The dataloader containing the evaluation data
  227. local_rank: The rank of the current node in a distributed setting
  228. tokenizer: The tokenizer used to decode predictions
  229. Returns: eval_ppl, eval_epoch_loss
  230. """
  231. if train_config.enable_fsdp:
  232. world_size = int(os.environ["WORLD_SIZE"])
  233. model.eval()
  234. eval_preds = []
  235. eval_loss = 0.0 # Initialize evaluation loss
  236. with MemoryTrace() as memtrace:
  237. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  238. if step>6:
  239. break
  240. for key in batch.keys():
  241. if train_config.enable_fsdp:
  242. batch[key] = batch[key].to(local_rank)
  243. else:
  244. batch[key] = batch[key].to('cuda:0')
  245. # Ensure no gradients are computed for this scope to save memory
  246. with torch.no_grad():
  247. # Forward pass and compute loss
  248. outputs = model(**batch)
  249. loss = outputs.loss
  250. eval_loss += loss.detach().float()
  251. # Decode predictions and add to evaluation predictions list
  252. preds = torch.argmax(outputs.logits, -1)
  253. eval_preds.extend(
  254. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  255. )
  256. # If there's more than one CUDA device, reduce evaluation loss across all devices
  257. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  258. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  259. world_size = int(os.environ["WORLD_SIZE"])
  260. # Compute average loss and perplexity
  261. eval_epoch_loss = eval_loss / len(eval_dataloader)
  262. if train_config.enable_fsdp:
  263. eval_epoch_loss = eval_epoch_loss/world_size
  264. eval_ppl = torch.exp(eval_epoch_loss)
  265. # Print evaluation metrics
  266. if train_config.enable_fsdp:
  267. if local_rank==0:
  268. print(f" {eval_ppl=} {eval_epoch_loss=}")
  269. else:
  270. print(f" {eval_ppl=} {eval_epoch_loss=}")
  271. return eval_ppl, eval_epoch_loss
  272. def freeze_transformer_layers(model, num_layer):
  273. for i, layer in enumerate(model.model.layers):
  274. if i < num_layer:
  275. for param in layer.parameters():
  276. param.requires_grad = False
  277. def check_frozen_layers_peft_model(model):
  278. for i, layer in enumerate(model.base_model.model.model.layers):
  279. for name, param in layer.named_parameters():
  280. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  281. def setup():
  282. """Initialize the process group for distributed training"""
  283. dist.init_process_group("nccl")
  284. def setup_environ_flags(rank):
  285. """Set environment flags for debugging purposes"""
  286. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  287. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  288. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  289. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  290. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  291. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  292. if rank == 0:
  293. print(f"--> Running with torch dist debug set to detail")
  294. def cleanup():
  295. """Clean up the process group after training"""
  296. dist.destroy_process_group()
  297. def clear_gpu_cache(rank=None):
  298. """Clear the GPU cache for all ranks"""
  299. if rank == 0:
  300. print(f"Clearing GPU cache for all ranks")
  301. torch.cuda.empty_cache()
  302. def get_parameter_dtypes(model):
  303. """Get the data types of model parameters"""
  304. parameter_dtypes = {}
  305. for name, parameter in model.named_parameters():
  306. parameter_dtypes[name] = parameter.dtype
  307. return parameter_dtypes
  308. def print_model_size(model, config, rank: int = 0) -> None:
  309. """
  310. Print model name, the number of trainable parameters and initialization time.
  311. Args:
  312. model: The PyTorch model.
  313. model_name (str): Name of the model.
  314. init_time_start (float): Initialization start time.
  315. init_time_end (float): Initialization end time.
  316. rank (int, optional): Current process's rank. Defaults to 0.
  317. """
  318. if rank == 0:
  319. print(f"--> Model {config.model_name}")
  320. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  321. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  322. def get_policies(cfg, rank):
  323. """Get the policies for mixed precision and fsdp wrapping"""
  324. verify_bfloat_support = (
  325. torch.version.cuda
  326. and torch.cuda.is_bf16_supported()
  327. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  328. and dist.is_nccl_available()
  329. and nccl.version() >= (2, 10)
  330. )
  331. mixed_precision_policy = None
  332. wrapping_policy = None
  333. # Mixed precision
  334. if cfg.mixed_precision:
  335. bf16_ready = verify_bfloat_support
  336. if bf16_ready and not cfg.use_fp16:
  337. mixed_precision_policy = bfSixteen_mixed
  338. if rank == 0:
  339. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  340. elif cfg.use_fp16:
  341. mixed_precision_policy = fpSixteen
  342. if rank == 0:
  343. print(f"FP16 enabled")
  344. else:
  345. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  346. wrapping_policy = get_llama_wrapper()
  347. return mixed_precision_policy, wrapping_policy
  348. def save_train_params(train_config, fsdp_config, rank):
  349. """
  350. This function saves the train_config and FSDP config into a train_params.yaml.
  351. This will be used by converter script in the inference folder to fetch the HF model name or path.
  352. It also would be hepful as a log for future references.
  353. """
  354. # Convert the train_config and fsdp_config objects to dictionaries,
  355. # converting all values to strings to ensure they can be serialized into a YAML file
  356. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  357. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  358. # Merge the two dictionaries into one
  359. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  360. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  361. folder_name = (
  362. train_config.dist_checkpoint_root_folder
  363. + "/"
  364. + train_config.dist_checkpoint_folder
  365. + "-"
  366. + train_config.model_name
  367. )
  368. save_dir = Path.cwd() / folder_name
  369. # If the directory does not exist, create it
  370. if not os.path.exists(save_dir):
  371. os.makedirs(save_dir)
  372. # Convert the dictionary to a YAML string
  373. config_yaml = yaml.dump(train_params_dict, indent=4)
  374. file_name = os.path.join(save_dir,'train_params.yaml')
  375. # Check if there's a directory with the same name as the file
  376. if os.path.isdir(file_name):
  377. print(f"Error: {file_name} is a directory, not a file.")
  378. else:
  379. # Write the YAML string to the file
  380. with open(file_name, 'w') as f:
  381. f.write(config_yaml)
  382. if rank==0:
  383. print(f"training params are saved in {file_name}")
  384. def export_memory_timeline(path: str, device: Optional[str] = None) -> None:
  385. try:
  386. from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
  387. except ImportError:
  388. # Handle the ImportError here, such as providing an alternative implementation or an error message.
  389. print("The required module 'MemoryProfileTimeline' is not available.")
  390. def _memory_profile():
  391. try:
  392. from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
  393. except ImportError:
  394. # Handle the ImportError here, such as providing an alternative implementation or an error message.
  395. print("The required module 'MemoryProfileTimeline' is not available.")
  396. required = ("record_shapes", "profile_memory", "with_stack")
  397. missing = [f"{i}=True" for i in required if not getattr(self, i)]
  398. if missing:
  399. raise ValueError(f"{', '.join(missing)} required for memory profiling.")
  400. assert self.profiler is not None and self.profiler.kineto_results is not None
  401. return MemoryProfile(self.profiler.kineto_results)