utils.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """General utilities."""
  16. import sys
  17. import torch
  18. from torch.nn.parallel import DistributedDataParallel as torchDDP
  19. from apex.multi_tensor_apply import multi_tensor_applier
  20. import amp_C
  21. from megatron import get_args
  22. from megatron import print_rank_0
  23. from megatron import get_adlr_autoresume
  24. from megatron import mpu
  25. from megatron.model.module import param_is_not_shared
  26. from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
  27. def unwrap_model(model, module_instances=(torchDDP)):
  28. return_list = True
  29. if not isinstance(model, list):
  30. model = [model]
  31. return_list = False
  32. unwrapped_model = []
  33. for model_module in model:
  34. while isinstance(model_module, module_instances):
  35. model_module = model_module.module
  36. unwrapped_model.append(model_module)
  37. if not return_list:
  38. return unwrapped_model[0]
  39. return unwrapped_model
  40. def calc_params_l2_norm(model):
  41. """Calculate l2 norm of parameters """
  42. args = get_args()
  43. if not isinstance(model, list):
  44. model = [model]
  45. # Remove duplicate params.
  46. params_data = []
  47. for model_ in model:
  48. for param in model_.parameters():
  49. is_not_shared = param_is_not_shared(param)
  50. is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
  51. if is_not_shared and is_not_tp_duplicate:
  52. if args.bf16:
  53. params_data.append(param.data.float())
  54. else:
  55. params_data.append(param.data)
  56. # Calculate norm
  57. dummy_overflow_buf = torch.cuda.IntTensor([0])
  58. norm, _ = multi_tensor_applier(
  59. amp_C.multi_tensor_l2norm,
  60. dummy_overflow_buf,
  61. [params_data],
  62. False # no per-parameter norm
  63. )
  64. norm_2 = norm * norm
  65. # Sum across all model-parallel GPUs.
  66. torch.distributed.all_reduce(norm_2,
  67. op=torch.distributed.ReduceOp.SUM,
  68. group=mpu.get_model_parallel_group())
  69. return norm_2.item() ** 0.5
  70. def average_losses_across_data_parallel_group(losses):
  71. """Reduce a tensor of losses across all GPUs."""
  72. averaged_losses = torch.cat(
  73. [loss.clone().detach().view(1) for loss in losses])
  74. torch.distributed.all_reduce(averaged_losses,
  75. group=mpu.get_data_parallel_group())
  76. averaged_losses = averaged_losses / \
  77. torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
  78. return averaged_losses
  79. def report_memory(name):
  80. """Simple GPU memory report."""
  81. mega_bytes = 1024.0 * 1024.0
  82. string = name + ' memory (MB)'
  83. string += ' | allocated: {}'.format(
  84. torch.cuda.memory_allocated() / mega_bytes)
  85. string += ' | max allocated: {}'.format(
  86. torch.cuda.max_memory_allocated() / mega_bytes)
  87. string += ' | reserved: {}'.format(
  88. torch.cuda.memory_reserved() / mega_bytes)
  89. string += ' | max reserved: {}'.format(
  90. torch.cuda.max_memory_reserved() / mega_bytes)
  91. if mpu.get_data_parallel_rank() == 0:
  92. print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
  93. flush=True)
  94. def print_params_min_max_norm(optimizer, iteration):
  95. """Print min, max, and norm of all parameters."""
  96. index = 0
  97. rank = torch.distributed.get_rank()
  98. string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
  99. optimizer_ = optimizer.optimizer
  100. for param_group in optimizer_.param_groups:
  101. for param in param_group['params']:
  102. index += 1
  103. min_ = param.data.min()
  104. max_ = param.data.max()
  105. norm = torch.linalg.norm(param.data)
  106. string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
  107. iteration, rank, index, int(param.tensor_model_parallel))
  108. string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
  109. print(string, flush=True)
  110. def check_adlr_autoresume_termination(iteration, model,
  111. optimizer, lr_scheduler):
  112. """Check for autoresume signal and exit if it is received."""
  113. from megatron.checkpointing import save_checkpoint
  114. args = get_args()
  115. autoresume = get_adlr_autoresume()
  116. # Add barrier to ensure consistnecy.
  117. torch.distributed.barrier()
  118. if autoresume.termination_requested():
  119. if args.save:
  120. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  121. print_rank_0(">>> autoresume termination request found!")
  122. if torch.distributed.get_rank() == 0:
  123. autoresume.request_resume()
  124. print_rank_0(">>> training terminated. Returning")
  125. sys.exit(0)
  126. def get_ltor_masks_and_position_ids(data,
  127. eod_token,
  128. reset_position_ids,
  129. reset_attention_mask,
  130. eod_mask_loss):
  131. """Build masks and position id for left to right model."""
  132. # Extract batch size and sequence length.
  133. micro_batch_size, seq_length = data.size()
  134. # Attention mask (lower triangular).
  135. if reset_attention_mask:
  136. att_mask_batch = micro_batch_size
  137. else:
  138. att_mask_batch = 1
  139. attention_mask = torch.tril(torch.ones(
  140. (att_mask_batch, seq_length, seq_length), device=data.device)).view(
  141. att_mask_batch, 1, seq_length, seq_length)
  142. # Loss mask.
  143. loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
  144. if eod_mask_loss:
  145. loss_mask[data == eod_token] = 0.0
  146. # Position ids.
  147. position_ids = torch.arange(seq_length, dtype=torch.long,
  148. device=data.device)
  149. position_ids = position_ids.unsqueeze(0).expand_as(data)
  150. # We need to clone as the ids will be modifed based on batch index.
  151. if reset_position_ids:
  152. position_ids = position_ids.clone()
  153. if reset_position_ids or reset_attention_mask:
  154. # Loop through the batches:
  155. for b in range(micro_batch_size):
  156. # Find indecies where EOD token is.
  157. eod_index = position_ids[b, data[b] == eod_token]
  158. # Detach indecies from positions if going to modify positions.
  159. if reset_position_ids:
  160. eod_index = eod_index.clone()
  161. # Loop through EOD indecies:
  162. prev_index = 0
  163. for j in range(eod_index.size()[0]):
  164. i = eod_index[j]
  165. # Mask attention loss.
  166. if reset_attention_mask:
  167. attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
  168. # Reset positions.
  169. if reset_position_ids:
  170. position_ids[b, (i + 1):] -= (i + 1 - prev_index)
  171. prev_index = i + 1
  172. # Convert attention mask to binary:
  173. attention_mask = (attention_mask < 0.5)
  174. return attention_mask, loss_mask, position_ids