train_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import yaml
  7. import fire
  8. import torch
  9. import transformers
  10. from datasets import load_dataset
  11. from tqdm import tqdm
  12. import time
  13. """
  14. Unused imports:
  15. import torch.nn as nn
  16. import bitsandbytes as bnb
  17. """
  18. from torch.nn import functional as F
  19. from peft import (
  20. LoraConfig,
  21. get_peft_model,
  22. get_peft_model_state_dict,
  23. prepare_model_for_int8_training,
  24. set_peft_model_state_dict,
  25. )
  26. from transformers import LlamaForCausalLM, LlamaTokenizer
  27. from torch.distributed.fsdp import StateDictType
  28. import torch.distributed as dist
  29. from pkg_resources import packaging
  30. from .memory_utils import MemoryTrace
  31. import model_checkpointing
  32. import torch.cuda.nccl as nccl
  33. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  34. from pathlib import Path
  35. sys.path.append(str(Path(__file__).resolve().parent.parent))
  36. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  37. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  38. tokenizer.pad_token_id = 0
  39. tokenizer.padding_side = "left"
  40. # Converting Bytes to Megabytes
  41. def byte2mb(x):
  42. return int(x / 2**20)
  43. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  44. """
  45. Trains the model on the given dataloader
  46. Args:
  47. model: The model to be trained
  48. train_dataloader: The dataloader containing the training data
  49. optimizer: The optimizer used for training
  50. lr_scheduler: The learning rate scheduler
  51. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  52. num_epochs: The number of epochs to train for
  53. local_rank: The rank of the current node in a distributed setting
  54. train_config: The training configuration
  55. eval_dataloader: The dataloader containing the eval data
  56. tokenizer: tokenizer used in the eval for decoding the predicitons
  57. Returns: results dictionary containing average training and validation perplexity and loss
  58. """
  59. # Create a gradient scaler for fp16
  60. if train_config.use_fp16 and train_config.enable_fsdp:
  61. scaler = ShardedGradScaler()
  62. elif train_config.use_fp16 and not train_config.enable_fsdp:
  63. scaler = torch.cuda.amp.GradScaler()
  64. if train_config.enable_fsdp:
  65. world_size = int(os.environ["WORLD_SIZE"])
  66. train_prep = []
  67. train_loss = []
  68. val_prep = []
  69. val_loss =[]
  70. results = {}
  71. best_val_loss = float("inf")
  72. epoch_times=[]
  73. for epoch in range(train_config.num_epochs):
  74. start_epoch = time.perf_counter()
  75. with MemoryTrace() as memtrace: # track the memory usage
  76. model.train()
  77. total_loss = 0.0
  78. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  79. for key in batch.keys():
  80. if train_config.enable_fsdp:
  81. batch[key] = batch[key].to(local_rank)
  82. else:
  83. batch[key] = batch[key].to('cuda:0')
  84. loss = model(**batch).loss
  85. loss = loss / gradient_accumulation_steps
  86. total_loss += loss.detach().float()
  87. if train_config.use_fp16:
  88. # if fp16 is enabled, use gradient scaler to handle gradient update
  89. scaler.scale(loss).backward()
  90. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  91. scaler.step(optimizer)
  92. scaler.update()
  93. optimizer.zero_grad()
  94. else:
  95. # regular backpropagation when fp16 is not used
  96. loss.backward()
  97. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  98. optimizer.step()
  99. optimizer.zero_grad()
  100. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  101. end_epoch = time.perf_counter()
  102. epoch_time = end_epoch- start_epoch
  103. print(f"epoch time is {epoch_time}")
  104. print("==================================================")
  105. epoch_times.append(epoch_time)
  106. # Reducing total_loss across all devices if there's more than one CUDA device
  107. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  108. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  109. world_size = int(os.environ["WORLD_SIZE"])
  110. train_epoch_loss = total_loss / len(train_dataloader)
  111. if train_config.enable_fsdp:
  112. train_epoch_loss = train_epoch_loss/world_size
  113. train_perplexity = torch.exp(train_epoch_loss)
  114. train_prep.append(train_perplexity)
  115. train_loss.append(train_epoch_loss)
  116. if train_config.enable_fsdp:
  117. if rank==0:
  118. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  119. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  120. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  121. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  122. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  123. else:
  124. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  125. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  126. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  127. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  128. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  129. # Update the learning rate as needed
  130. lr_scheduler.step()
  131. if train_config.run_validation:
  132. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  133. if train_config.save_model and eval_epoch_loss < best_val_loss:
  134. if train_config.enable_fsdp:
  135. dist.barrier()
  136. if train_config.use_peft:
  137. if train_config.enable_fsdp:
  138. if rank==0:
  139. print(f"we are about to save the PEFT modules")
  140. else:
  141. print(f"we are about to save the PEFT modules")
  142. model.save_pretrained(train_config.output_dir)
  143. if train_config.enable_fsdp:
  144. if rank==0:
  145. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  146. else:
  147. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  148. else:
  149. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  150. model_checkpointing.save_model_checkpoint(
  151. model, optimizer, rank, train_config, epoch=epoch
  152. )
  153. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  154. print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
  155. print("=====================================================")
  156. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  157. if train_config.save_optimizer:
  158. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  159. print(" Saving the FSDP model checkpoints qnd optimizer using SHARDED_STATE_DICT")
  160. print("=====================================================")
  161. if not train_config.use_peft and train_config.save_optimizer:
  162. model_checkpointing.save_optimizer_checkpoint(
  163. model, optimizer, rank, train_config, epoch=epoch
  164. )
  165. print(" Saving the FSDP model checkpoints qnd optimizer using FULL_STATE_DICT")
  166. print("=====================================================")
  167. if train_config.enable_fsdp:
  168. dist.barrier()
  169. if eval_epoch_loss < best_val_loss:
  170. best_val_loss = eval_epoch_loss
  171. if train_config.enable_fsdp:
  172. if rank==0:
  173. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  174. else:
  175. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  176. val_loss.append(best_val_loss)
  177. val_prep.append(eval_ppl)
  178. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  179. lr_scheduler.step()
  180. avg_epoch_time = sum(epoch_times)/len(epoch_times)
  181. print(f"avg epoch time is {avg_epoch_time}")
  182. print("==========================================")
  183. avg_train_prep = sum(train_prep)/len(train_prep)
  184. avg_train_loss = sum(train_loss)/len(train_loss)
  185. if train_config.run_validation:
  186. avg_eval_prep = sum(val_prep)/len(val_prep)
  187. avg_eval_loss = sum(val_loss)/len(val_loss)
  188. results['avg_train_prep'] = avg_train_prep
  189. results['avg_train_loss'] = avg_train_loss
  190. if train_config.run_validation:
  191. results['avg_eval_prep'] = avg_eval_prep
  192. results['avg_eval_loss'] = avg_eval_loss
  193. #saving the training params including fsdp setting for reference.
  194. if train_config.enable_fsdp and not train_config.use_peft:
  195. save_train_params(train_config, fsdp_config, rank)
  196. return results
  197. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  198. """
  199. Evaluates the model on the given dataloader
  200. Args:
  201. model: The model to evaluate
  202. eval_dataloader: The dataloader containing the evaluation data
  203. local_rank: The rank of the current node in a distributed setting
  204. tokenizer: The tokenizer used to decode predictions
  205. Returns: eval_ppl, eval_epoch_loss
  206. """
  207. if train_config.enable_fsdp:
  208. world_size = int(os.environ["WORLD_SIZE"])
  209. model.eval()
  210. eval_preds = []
  211. eval_loss = 0.0 # Initialize evaluation loss
  212. with MemoryTrace() as memtrace:
  213. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  214. for key in batch.keys():
  215. if train_config.enable_fsdp:
  216. batch[key] = batch[key].to(local_rank)
  217. else:
  218. batch[key] = batch[key].to('cuda:0')
  219. # Ensure no gradients are computed for this scope to save memory
  220. with torch.no_grad():
  221. # Forward pass and compute loss
  222. outputs = model(**batch)
  223. loss = outputs.loss
  224. eval_loss += loss.detach().float()
  225. # Decode predictions and add to evaluation predictions list
  226. preds = torch.argmax(outputs.logits, -1)
  227. eval_preds.extend(
  228. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  229. )
  230. # If there's more than one CUDA device, reduce evaluation loss across all devices
  231. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  232. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  233. world_size = int(os.environ["WORLD_SIZE"])
  234. # Compute average loss and perplexity
  235. eval_epoch_loss = eval_loss / len(eval_dataloader)
  236. if train_config.enable_fsdp:
  237. eval_epoch_loss = eval_epoch_loss/world_size
  238. eval_ppl = torch.exp(eval_epoch_loss)
  239. # Print evaluation metrics
  240. if train_config.enable_fsdp:
  241. if local_rank==0:
  242. print(f" {eval_ppl=} {eval_epoch_loss=}")
  243. else:
  244. print(f" {eval_ppl=} {eval_epoch_loss=}")
  245. return eval_ppl, eval_epoch_loss
  246. def freeze_transformer_layers(model, num_layer):
  247. for i, layer in enumerate(model.model.layers):
  248. if i < num_layer:
  249. for param in layer.parameters():
  250. param.requires_grad = False
  251. def check_frozen_layers_peft_model(model):
  252. for i, layer in enumerate(model.base_model.model.model.layers):
  253. for name, param in layer.named_parameters():
  254. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  255. def setup():
  256. """Initialize the process group for distributed training"""
  257. dist.init_process_group("nccl")
  258. def setup_environ_flags(rank):
  259. """Set environment flags for debugging purposes"""
  260. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  261. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  262. # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  263. # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
  264. # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
  265. # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
  266. if rank == 0:
  267. print(f"--> Running with torch dist debug set to detail")
  268. def cleanup():
  269. """Clean up the process group after training"""
  270. dist.destroy_process_group()
  271. def clear_gpu_cache(rank=None):
  272. """Clear the GPU cache for all ranks"""
  273. if rank == 0:
  274. print(f"Clearing GPU cache for all ranks")
  275. torch.cuda.empty_cache()
  276. def get_parameter_dtypes(model):
  277. """Get the data types of model parameters"""
  278. parameter_dtypes = {}
  279. for name, parameter in model.named_parameters():
  280. parameter_dtypes[name] = parameter.dtype
  281. return parameter_dtypes
  282. def print_model_size(model, config, rank: int = 0) -> None:
  283. """
  284. Print model name, the number of trainable parameters and initialization time.
  285. Args:
  286. model: The PyTorch model.
  287. model_name (str): Name of the model.
  288. init_time_start (float): Initialization start time.
  289. init_time_end (float): Initialization end time.
  290. rank (int, optional): Current process's rank. Defaults to 0.
  291. """
  292. if rank == 0:
  293. print(f"--> Model {config.model_name}")
  294. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  295. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  296. def get_policies(cfg, rank):
  297. """Get the policies for mixed precision and fsdp wrapping"""
  298. verify_bfloat_support = (
  299. torch.version.cuda
  300. and torch.cuda.is_bf16_supported()
  301. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  302. and dist.is_nccl_available()
  303. and nccl.version() >= (2, 10)
  304. )
  305. mixed_precision_policy = None
  306. wrapping_policy = None
  307. # Mixed precision
  308. if cfg.mixed_precision:
  309. bf16_ready = verify_bfloat_support
  310. if bf16_ready and not cfg.use_fp16:
  311. mixed_precision_policy = bfSixteen_mixed
  312. if rank == 0:
  313. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  314. elif cfg.use_fp16:
  315. mixed_precision_policy = fpSixteen
  316. if rank == 0:
  317. print(f"FP16 enabled")
  318. else:
  319. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  320. wrapping_policy = get_llama_wrapper()
  321. return mixed_precision_policy, wrapping_policy
  322. def save_train_params(train_config, fsdp_config, rank):
  323. """
  324. This function saves the train_config and FSDP config into a train_params.yaml.
  325. This will be used by converter script in the inference folder to fetch the HF model name or path.
  326. It also would be hepful as a log for future references.
  327. """
  328. # Convert the train_config and fsdp_config objects to dictionaries,
  329. # converting all values to strings to ensure they can be serialized into a YAML file
  330. train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
  331. fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
  332. # Merge the two dictionaries into one
  333. train_params_dict = {**train_config_dict, **fsdp_config_dict}
  334. # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object
  335. folder_name = (
  336. train_config.dist_checkpoint_root_folder
  337. + "/"
  338. + train_config.dist_checkpoint_folder
  339. + "-"
  340. + train_config.model_name
  341. )
  342. save_dir = Path.cwd() / folder_name
  343. # If the directory does not exist, create it
  344. if not os.path.exists(save_dir):
  345. os.makedirs(save_dir)
  346. # Convert the dictionary to a YAML string
  347. config_yaml = yaml.dump(train_params_dict, indent=4)
  348. file_name = os.path.join(save_dir,'train_params.yaml')
  349. # Check if there's a directory with the same name as the file
  350. if os.path.isdir(file_name):
  351. print(f"Error: {file_name} is a directory, not a file.")
  352. else:
  353. # Write the YAML string to the file
  354. with open(file_name, 'w') as f:
  355. f.write(config_yaml)
  356. if rank==0:
  357. print(f"training params are saved in {file_name}")