train_utils.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import os
  4. import sys
  5. from typing import List
  6. import fire
  7. import torch
  8. import transformers
  9. from datasets import load_dataset
  10. from tqdm import tqdm
  11. import time
  12. """
  13. Unused imports:
  14. import torch.nn as nn
  15. import bitsandbytes as bnb
  16. """
  17. from torch.nn import functional as F
  18. from peft import (
  19. LoraConfig,
  20. get_peft_model,
  21. get_peft_model_state_dict,
  22. prepare_model_for_int8_training,
  23. set_peft_model_state_dict,
  24. )
  25. from transformers import LlamaForCausalLM, LlamaTokenizer
  26. from torch.distributed.fsdp import StateDictType
  27. import torch.distributed as dist
  28. from pkg_resources import packaging
  29. from .memory_utils import MemoryTrace
  30. import model_checkpointing
  31. import torch.cuda.nccl as nccl
  32. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  33. from pathlib import Path
  34. sys.path.append(str(Path(__file__).resolve().parent.parent))
  35. from policies import bfSixteen, fpSixteen,bfSixteen_mixed, get_llama_wrapper
  36. def set_tokenizer_params(tokenizer: LlamaTokenizer):
  37. tokenizer.pad_token_id = 0
  38. tokenizer.padding_side = "left"
  39. # Converting Bytes to Megabytes
  40. def byte2mb(x):
  41. return int(x / 2**20)
  42. def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
  43. """
  44. Trains the model on the given dataloader
  45. Args:
  46. model: The model to be trained
  47. train_dataloader: The dataloader containing the training data
  48. optimizer: The optimizer used for training
  49. lr_scheduler: The learning rate scheduler
  50. gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation
  51. num_epochs: The number of epochs to train for
  52. local_rank: The rank of the current node in a distributed setting
  53. train_config: The training configuration
  54. eval_dataloader: The dataloader containing the eval data
  55. tokenizer: tokenizer used in the eval for decoding the predicitons
  56. Returns: results dictionary containing average training and validation perplexity and loss
  57. """
  58. # Create a gradient scaler for fp16
  59. if train_config.use_fp16 and train_config.enable_fsdp:
  60. scaler = ShardedGradScaler()
  61. elif train_config.use_fp16 and not train_config.enable_fsdp:
  62. scaler = torch.cuda.amp.GradScaler()
  63. train_prep = []
  64. train_loss = []
  65. val_prep = []
  66. val_loss =[]
  67. results = {}
  68. best_val_loss = float("inf")
  69. epoch_times=[]
  70. for epoch in range(train_config.num_epochs):
  71. start_epoch = time.perf_counter()
  72. with MemoryTrace() as memtrace: # track the memory usage
  73. model.train()
  74. total_loss = 0.0
  75. data_set_len = 0
  76. for step, batch in enumerate(tqdm(train_dataloader,colour="blue", desc=f"Training Epoch{epoch}")):
  77. for key in batch.keys():
  78. if train_config.enable_fsdp:
  79. batch[key] = batch[key].to(local_rank)
  80. else:
  81. batch[key] = batch[key].to('cuda:0')
  82. loss = model(**batch).loss
  83. loss = loss / gradient_accumulation_steps
  84. total_loss += loss.detach().float()
  85. first_key = next(iter(batch))
  86. data_set_len += len(batch[first_key])
  87. if train_config.use_fp16:
  88. # if fp16 is enabled, use gradient scaler to handle gradient update
  89. scaler.scale(loss).backward()
  90. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  91. scaler.step(optimizer)
  92. scaler.update()
  93. optimizer.zero_grad()
  94. else:
  95. # regular backpropagation when fp16 is not used
  96. loss.backward()
  97. if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
  98. optimizer.step()
  99. optimizer.zero_grad()
  100. print(f"\n step {step} is completed and loss is {loss.detach().float()}")
  101. end_epoch = time.perf_counter()
  102. epoch_time = end_epoch- start_epoch
  103. print(f"epoch time is {epoch_time}")
  104. print("==================================================")
  105. epoch_times.append(epoch_time)
  106. # Reducing total_loss across all devices if there's more than one CUDA device
  107. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  108. dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
  109. world_size = int(os.environ["WORLD_SIZE"])
  110. train_epoch_loss = total_loss / len(train_dataloader)
  111. train_epoch_loss = train_epoch_loss/world_size
  112. train_perplexity = torch.exp(train_epoch_loss)
  113. train_prep.append(train_perplexity)
  114. train_loss.append(train_epoch_loss)
  115. print(f"Max CUDA memory allocated was {memtrace.peak} GB")
  116. print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
  117. print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
  118. print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
  119. print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
  120. # Update the learning rate as needed
  121. lr_scheduler.step()
  122. if train_config.run_validation:
  123. eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, rank, tokenizer)
  124. if train_config.save_model and eval_epoch_loss < best_val_loss:
  125. if train_config.use_peft:
  126. print(f"we are in the saving the PEFT modules")
  127. model.save_pretrained(train_config.output_dir)
  128. print(f"PEFT modules are saved in {train_config.output_dir} directory")
  129. else:
  130. if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
  131. model_checkpointing.save_model_checkpoint(
  132. model, optimizer, rank, train_config, epoch=1
  133. )
  134. elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
  135. print(" we are about to save the models *******")
  136. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config)
  137. if train_config.save_optimizer:
  138. model_checkpointing.save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
  139. if not train_config.use_peft and train_config.save_optimizer:
  140. model_checkpointing.save_optimizer_checkpoint(
  141. model, optimizer, rank, train_config, epoch=1
  142. )
  143. if local_rank == 0 and eval_epoch_loss < best_val_loss:
  144. best_val_loss = eval_epoch_loss
  145. print(f"best eval loss on epoch {epoch} is {best_val_loss}")
  146. val_loss.append(best_val_loss)
  147. val_prep.append(eval_ppl)
  148. print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}")
  149. lr_scheduler.step()
  150. avg_epoch_time = sum(epoch_times)/len(epoch_times)
  151. print("avg epoch time is {avg_epoch_time}")
  152. print("==========================================")
  153. avg_train_prep = sum(train_prep)/len(train_prep)
  154. avg_train_loss = sum(train_loss)/len(train_loss)
  155. if train_config.run_validation:
  156. avg_eval_prep = sum(val_prep)/len(val_prep)
  157. avg_eval_loss = sum(val_loss)/len(val_loss)
  158. results['avg_train_prep'] = avg_train_prep
  159. results['avg_train_loss'] = avg_train_loss
  160. if train_config.run_validation:
  161. results['avg_eval_prep'] = avg_eval_prep
  162. results['avg_eval_loss'] = avg_eval_loss
  163. return results
  164. def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
  165. """
  166. Evaluates the model on the given dataloader
  167. Args:
  168. model: The model to evaluate
  169. eval_dataloader: The dataloader containing the evaluation data
  170. local_rank: The rank of the current node in a distributed setting
  171. tokenizer: The tokenizer used to decode predictions
  172. Returns: eval_ppl, eval_epoch_loss
  173. """
  174. model.eval()
  175. eval_preds = []
  176. eval_loss = 0.0 # Initialize evaluation loss
  177. eval_dataset_len = 0
  178. with MemoryTrace() as memtrace:
  179. for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
  180. for key in batch.keys():
  181. if train_config.enable_fsdp:
  182. batch[key] = batch[key].to(local_rank)
  183. else:
  184. batch[key] = batch[key].to('cuda:0')
  185. # Ensure no gradients are computed for this scope to save memory
  186. with torch.no_grad():
  187. # Forward pass and compute loss
  188. outputs = model(**batch)
  189. loss = outputs.loss
  190. eval_loss += loss.detach().float()
  191. first_key = next(iter(batch))
  192. eval_dataset_len+= len(batch[first_key])
  193. # Decode predictions and add to evaluation predictions list
  194. preds = torch.argmax(outputs.logits, -1)
  195. eval_preds.extend(
  196. tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
  197. )
  198. # If there's more than one CUDA device, reduce evaluation loss across all devices
  199. if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
  200. dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
  201. world_size = int(os.environ["WORLD_SIZE"])
  202. # Compute average loss and perplexity
  203. eval_epoch_loss = eval_loss / len(eval_dataloader)
  204. eval_epoch_loss = eval_epoch_loss/world_size
  205. eval_ppl = torch.exp(eval_epoch_loss)
  206. # Print evaluation metrics
  207. print(f" {eval_ppl=} {eval_epoch_loss=}")
  208. return eval_ppl, eval_epoch_loss
  209. def freeze_transformer_layers(model, num_layer):
  210. for i, layer in enumerate(model.model.layers):
  211. if i < num_layer:
  212. for param in layer.parameters():
  213. param.requires_grad = False
  214. def check_frozen_layers_peft_model(model):
  215. for i, layer in enumerate(model.base_model.model.model.layers):
  216. for name, param in layer.named_parameters():
  217. print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
  218. def setup():
  219. """Initialize the process group for distributed training"""
  220. dist.init_process_group("nccl")
  221. def setup_environ_flags(rank):
  222. """Set environment flags for debugging purposes"""
  223. os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1)
  224. os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1)
  225. os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
  226. if rank == 0:
  227. print(f"--> Running with torch dist debug set to detail")
  228. def cleanup():
  229. """Clean up the process group after training"""
  230. dist.destroy_process_group()
  231. def clear_gpu_cache(rank=None):
  232. """Clear the GPU cache for all ranks"""
  233. if rank == 0:
  234. print(f"Clearing GPU cache for all ranks")
  235. torch.cuda.empty_cache()
  236. def get_parameter_dtypes(model):
  237. """Get the data types of model parameters"""
  238. parameter_dtypes = {}
  239. for name, parameter in model.named_parameters():
  240. parameter_dtypes[name] = parameter.dtype
  241. return parameter_dtypes
  242. def print_model_size(model, config, rank: int = 0) -> None:
  243. """
  244. Print model name, the number of trainable parameters and initialization time.
  245. Args:
  246. model: The PyTorch model.
  247. model_name (str): Name of the model.
  248. init_time_start (float): Initialization start time.
  249. init_time_end (float): Initialization end time.
  250. rank (int, optional): Current process's rank. Defaults to 0.
  251. """
  252. if rank == 0:
  253. print(f"--> Model {config.model_name}")
  254. total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  255. print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n")
  256. def get_policies(cfg, rank):
  257. """Get the policies for mixed precision and fsdp wrapping"""
  258. verify_bfloat_support = (
  259. torch.version.cuda
  260. and torch.cuda.is_bf16_supported()
  261. and packaging.version.parse(torch.version.cuda).release >= (11, 0)
  262. and dist.is_nccl_available()
  263. and nccl.version() >= (2, 10)
  264. )
  265. mixed_precision_policy = None
  266. wrapping_policy = None
  267. # Mixed precision
  268. if cfg.mixed_precision:
  269. bf16_ready = verify_bfloat_support
  270. if bf16_ready and not cfg.use_fp16:
  271. mixed_precision_policy = bfSixteen_mixed
  272. if rank == 0:
  273. print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
  274. elif cfg.use_fp16:
  275. mixed_precision_policy = fpSixteen
  276. if rank == 0:
  277. print(f"FP16 enabled")
  278. else:
  279. print(f"bFloat16 support not present. Using FP32, and not mixed precision")
  280. wrapping_policy = get_llama_wrapper()
  281. return mixed_precision_policy, wrapping_policy