123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """General utilities."""
- import sys
- import torch
- from torch.nn.parallel import DistributedDataParallel as torchDDP
- from apex.multi_tensor_apply import multi_tensor_applier
- import amp_C
- from megatron import get_args
- from megatron import print_rank_0
- from megatron import get_adlr_autoresume
- from megatron import mpu
- from megatron.model.module import param_is_not_shared
- from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
- def unwrap_model(model, module_instances=(torchDDP)):
- return_list = True
- if not isinstance(model, list):
- model = [model]
- return_list = False
- unwrapped_model = []
- for model_module in model:
- while isinstance(model_module, module_instances):
- model_module = model_module.module
- unwrapped_model.append(model_module)
- if not return_list:
- return unwrapped_model[0]
- return unwrapped_model
- def calc_params_l2_norm(model):
- """Calculate l2 norm of parameters """
- args = get_args()
- if not isinstance(model, list):
- model = [model]
- # Remove duplicate params.
- params_data = []
- for model_ in model:
- for param in model_.parameters():
- is_not_shared = param_is_not_shared(param)
- is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
- if is_not_shared and is_not_tp_duplicate:
- if args.bf16:
- params_data.append(param.data.float())
- else:
- params_data.append(param.data)
- # Calculate norm
- dummy_overflow_buf = torch.cuda.IntTensor([0])
- norm, _ = multi_tensor_applier(
- amp_C.multi_tensor_l2norm,
- dummy_overflow_buf,
- [params_data],
- False # no per-parameter norm
- )
- norm_2 = norm * norm
- # Sum across all model-parallel GPUs.
- torch.distributed.all_reduce(norm_2,
- op=torch.distributed.ReduceOp.SUM,
- group=mpu.get_model_parallel_group())
- return norm_2.item() ** 0.5
- def average_losses_across_data_parallel_group(losses):
- """Reduce a tensor of losses across all GPUs."""
- averaged_losses = torch.cat(
- [loss.clone().detach().view(1) for loss in losses])
- torch.distributed.all_reduce(averaged_losses,
- group=mpu.get_data_parallel_group())
- averaged_losses = averaged_losses / \
- torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
- return averaged_losses
- def report_memory(name):
- """Simple GPU memory report."""
- mega_bytes = 1024.0 * 1024.0
- string = name + ' memory (MB)'
- string += ' | allocated: {}'.format(
- torch.cuda.memory_allocated() / mega_bytes)
- string += ' | max allocated: {}'.format(
- torch.cuda.max_memory_allocated() / mega_bytes)
- string += ' | reserved: {}'.format(
- torch.cuda.memory_reserved() / mega_bytes)
- string += ' | max reserved: {}'.format(
- torch.cuda.max_memory_reserved() / mega_bytes)
- if mpu.get_data_parallel_rank() == 0:
- print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
- flush=True)
- def print_params_min_max_norm(optimizer, iteration):
- """Print min, max, and norm of all parameters."""
- index = 0
- rank = torch.distributed.get_rank()
- string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
- optimizer_ = optimizer.optimizer
- for param_group in optimizer_.param_groups:
- for param in param_group['params']:
- index += 1
- min_ = param.data.min()
- max_ = param.data.max()
- norm = torch.linalg.norm(param.data)
- string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
- iteration, rank, index, int(param.tensor_model_parallel))
- string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
- print(string, flush=True)
- def check_adlr_autoresume_termination(iteration, model,
- optimizer, lr_scheduler):
- """Check for autoresume signal and exit if it is received."""
- from megatron.checkpointing import save_checkpoint
- args = get_args()
- autoresume = get_adlr_autoresume()
- # Add barrier to ensure consistnecy.
- torch.distributed.barrier()
- if autoresume.termination_requested():
- if args.save:
- save_checkpoint(iteration, model, optimizer, lr_scheduler)
- print_rank_0(">>> autoresume termination request found!")
- if torch.distributed.get_rank() == 0:
- autoresume.request_resume()
- print_rank_0(">>> training terminated. Returning")
- sys.exit(0)
- def get_ltor_masks_and_position_ids(data,
- eod_token,
- reset_position_ids,
- reset_attention_mask,
- eod_mask_loss):
- """Build masks and position id for left to right model."""
- # Extract batch size and sequence length.
- micro_batch_size, seq_length = data.size()
- # Attention mask (lower triangular).
- if reset_attention_mask:
- att_mask_batch = micro_batch_size
- else:
- att_mask_batch = 1
- attention_mask = torch.tril(torch.ones(
- (att_mask_batch, seq_length, seq_length), device=data.device)).view(
- att_mask_batch, 1, seq_length, seq_length)
- # Loss mask.
- loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
- if eod_mask_loss:
- loss_mask[data == eod_token] = 0.0
- # Position ids.
- position_ids = torch.arange(seq_length, dtype=torch.long,
- device=data.device)
- position_ids = position_ids.unsqueeze(0).expand_as(data)
- # We need to clone as the ids will be modifed based on batch index.
- if reset_position_ids:
- position_ids = position_ids.clone()
- if reset_position_ids or reset_attention_mask:
- # Loop through the batches:
- for b in range(micro_batch_size):
- # Find indecies where EOD token is.
- eod_index = position_ids[b, data[b] == eod_token]
- # Detach indecies from positions if going to modify positions.
- if reset_position_ids:
- eod_index = eod_index.clone()
- # Loop through EOD indecies:
- prev_index = 0
- for j in range(eod_index.size()[0]):
- i = eod_index[j]
- # Mask attention loss.
- if reset_attention_mask:
- attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
- # Reset positions.
- if reset_position_ids:
- position_ids[b, (i + 1):] -= (i + 1 - prev_index)
- prev_index = i + 1
- # Convert attention mask to binary:
- attention_mask = (attention_mask < 0.5)
- return attention_mask, loss_mask, position_ids
|