123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Pretrain utilities."""
- from datetime import datetime
- import nvtx
- import math
- import sys
- import time
- # The earliest we can measure the start time.
- _TRAIN_START_TIME = time.time()
- import torch
- from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
- from megatron import get_args
- from megatron import get_timers
- from megatron import get_tensorboard_writer
- from megatron import get_current_global_batch_size
- from megatron import get_num_microbatches
- from megatron import is_last_rank
- from megatron import update_num_microbatches
- from megatron import mpu
- from megatron import print_rank_0
- from megatron import print_rank_last
- from megatron.checkpointing import load_checkpoint
- from megatron.checkpointing import save_checkpoint
- from megatron.model import Float16Module
- from megatron.optimizer import get_megatron_optimizer
- from megatron.initialize import initialize_megatron
- from megatron.initialize import write_args_to_tensorboard
- from megatron.learning_rates import AnnealingLR
- from megatron.model import DistributedDataParallel as LocalDDP
- from megatron.utils import check_adlr_autoresume_termination
- from megatron.utils import unwrap_model
- from megatron.data.data_samplers import build_pretraining_data_loader
- from megatron.utils import calc_params_l2_norm
- from megatron.schedules import forward_backward_no_pipelining
- from megatron.schedules import forward_backward_pipelining_without_interleaving
- from megatron.schedules import forward_backward_pipelining_with_interleaving
- from megatron.utils import report_memory
- def print_datetime(string):
- """Note that this call will sync across all ranks."""
- torch.distributed.barrier()
- time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
- print_rank_0('[' + string + '] datetime: {} '.format(time_str))
- def pretrain(train_valid_test_dataset_provider,
- model_provider,
- forward_step_func,
- extra_args_provider=None,
- args_defaults={}):
- """Main training program.
- This function will run the followings in the order provided:
- 1) initialize Megatron.
- 2) setup model, optimizer and lr schedule using the model_provider.
- 3) call train_val_test_data_provider to get train/val/test datasets.
- 4) train the modle using the forward_step_func.
- Arguments:
- train_valid_test_dataset_provider: a function that takes the size of
- train/valid/test dataset and returns `train, valid, test` datasets.
- model_provider: a function that returns a vanilla version of the
- model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
- forward_step_func: a function that takes a `data iterator` and `model`,
- and returns a `loss` scalar with a dictionary with key:values being
- the info we would like to monitor during training, for example
- `lm-loss: value`. We also require that this function add
- `batch generator` to the timers class.
- extra_args_provider: a function that takes a parser and adds arguments
- to it. It is used for programs to add their own arguments.
- args_defaults: a dictionary from argument-name to argument-value. It
- to set already parse arguments.
- """
- with nvtx.annotate("initialize", color="cyan"):
- # Initalize and get arguments, timers, and Tensorboard writer.
- initialize_megatron(extra_args_provider=extra_args_provider,
- args_defaults=args_defaults)
- # Adjust the startup time so it reflects the largest value.
- # This will be closer to what scheduler will see (outside of
- # image ... launches.
- global _TRAIN_START_TIME
- start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
- torch.distributed.all_reduce(start_time_tensor,
- op=torch.distributed.ReduceOp.MIN)
- _TRAIN_START_TIME = start_time_tensor.item()
- print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
- time.time() - _TRAIN_START_TIME))
- print_datetime('after megatron is initialized')
- args = get_args()
- timers = get_timers()
- # Model, optimizer, and learning rate.
- timers('model-and-optimizer-setup').start()
- model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
- timers('model-and-optimizer-setup').stop()
- print_datetime('after model, optimizer, and learning rate '
- 'scheduler are built')
- # Data stuff.
- timers('train/valid/test-data-iterators-setup').start()
- with nvtx.annotate("data_loading", color="orange"):
- if args.virtual_pipeline_model_parallel_size is not None:
- all_data_iterators = [
- build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
- for _ in range(len(model))
- ]
- train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
- valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
- test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
- else:
- train_data_iterator, valid_data_iterator, test_data_iterator \
- = build_train_valid_test_data_iterators(
- train_valid_test_dataset_provider)
- timers('train/valid/test-data-iterators-setup').stop()
- print_datetime('after dataloaders are built')
- # Print setup timing.
- print_rank_0('done with setup ...')
- timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
- print_rank_0('training ...')
- iteration = 0
- with nvtx.annotate("training", color="blue"):
- if args.do_train and args.train_iters > 0:
- iteration = train(forward_step_func,
- model, optimizer, lr_scheduler,
- train_data_iterator, valid_data_iterator)
- print_datetime('after training is done')
- if args.do_valid:
- prefix = 'the end of training for val data'
- evaluate_and_print_results(prefix, forward_step_func,
- valid_data_iterator, model,
- iteration, False)
- with nvtx.annotate("checkpointing", color="yellow"):
- if args.save and iteration != 0:
- save_checkpoint(iteration, model, optimizer, lr_scheduler)
- with nvtx.annotate("do_test", color="darkgreen"):
- if args.do_test:
- # Run on test data.
- prefix = 'the end of training for test data'
- evaluate_and_print_results(prefix, forward_step_func,
- test_data_iterator, model,
- 0, True)
- def update_train_iters(args):
- # For iteration-based training, we don't need to do anything
- if args.train_iters:
- return
- # Constant batch size with sample-based training.
- if args.rampup_batch_size is None:
- args.train_iters = args.train_samples // args.global_batch_size
- else:
- # Sample based training with rampup batch size.
- iterations = 0
- consumed_samples = 0
- # Rampup phase.
- while consumed_samples <= int(args.rampup_batch_size[2]):
- update_num_microbatches(consumed_samples, consistency_check=False)
- consumed_samples += get_current_global_batch_size()
- iterations += 1
- # Reset
- update_num_microbatches(0, consistency_check=False)
- # Constant phase
- # Note that we throw away any partial last batch.
- iterations += (args.train_samples - consumed_samples) // \
- args.global_batch_size
- args.train_iters = iterations
- print_rank_0('setting training iterations to {}'.format(args.train_iters))
- def get_model(model_provider_func):
- """Build the model."""
- args = get_args()
- # Build model.
- if mpu.get_pipeline_model_parallel_world_size() > 1 and \
- args.virtual_pipeline_model_parallel_size is not None:
- model = []
- for i in range(args.virtual_pipeline_model_parallel_size):
- mpu.set_virtual_pipeline_model_parallel_rank(i)
- # Set pre_process and post_process only after virtual rank is set.
- pre_process = mpu.is_pipeline_first_stage()
- post_process = mpu.is_pipeline_last_stage()
- this_model = model_provider_func(
- pre_process=pre_process,
- post_process=post_process
- )
- model.append(this_model)
- else:
- pre_process = mpu.is_pipeline_first_stage()
- post_process = mpu.is_pipeline_last_stage()
- model = model_provider_func(
- pre_process=pre_process,
- post_process=post_process
- )
- if not isinstance(model, list):
- model = [model]
- # Set tensor model parallel attributes if not set.
- # Only parameters that are already tensor model parallel have these
- # attributes set for them. We should make sure the default attributes
- # are set for all params so the optimizer can use them.
- for model_module in model:
- for param in model_module.parameters():
- mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
- # Print number of parameters.
- if mpu.get_data_parallel_rank() == 0:
- print(' > number of parameters on (tensor, pipeline) '
- 'model parallel rank ({}, {}): {}'.format(
- mpu.get_tensor_model_parallel_rank(),
- mpu.get_pipeline_model_parallel_rank(),
- sum([sum([p.nelement() for p in model_module.parameters()])
- for model_module in model])), flush=True)
- # GPU allocation.
- for model_module in model:
- model_module.cuda(torch.cuda.current_device())
- # Fp16 conversion.
- if args.fp16 or args.bf16:
- model = [Float16Module(model_module, args) for model_module in model]
- if args.DDP_impl == 'torch':
- i = torch.cuda.current_device()
- model = [torchDDP(model_module, device_ids=[i], output_device=i,
- process_group=mpu.get_data_parallel_group())
- for model_module in model]
- return model
- if args.DDP_impl == 'local':
- model = [LocalDDP(model_module,
- args.accumulate_allreduce_grads_in_fp32,
- args.use_contiguous_buffers_in_ddp)
- for model_module in model]
- return model
- raise NotImplementedError('Unknown DDP implementation specified: {}. '
- 'Exiting.'.format(args.DDP_impl))
- def get_learning_rate_scheduler(optimizer):
- """Build the learning rate scheduler."""
- args = get_args()
- # Iteration-based training.
- if args.train_iters:
- if args.lr_decay_iters is None:
- args.lr_decay_iters = args.train_iters
- decay_steps = args.lr_decay_iters * args.global_batch_size
- if args.lr_warmup_fraction is not None:
- warmup_steps = args.lr_warmup_fraction * decay_steps
- else:
- warmup_steps = args.lr_warmup_iters * args.global_batch_size
- # Sample-based training.
- elif args.train_samples:
- # We need to set training iters for later use. Technically
- # we need to adjust the training samples too (due to last
- # batch being incomplete) but we leave it as is for now.
- update_train_iters(args)
- if args.lr_decay_samples is None:
- args.lr_decay_samples = args.train_samples
- decay_steps = args.lr_decay_samples
- if args.lr_warmup_fraction is not None:
- warmup_steps = args.lr_warmup_fraction * decay_steps
- else:
- warmup_steps = args.lr_warmup_samples
- else:
- raise Exception(
- 'either train-iters or train-samples should be provided.')
- lr_scheduler = AnnealingLR(
- optimizer,
- max_lr=args.lr,
- min_lr=args.min_lr,
- warmup_steps=warmup_steps,
- decay_steps=decay_steps,
- decay_style=args.lr_decay_style,
- use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
- override_lr_scheduler=args.override_lr_scheduler)
- return lr_scheduler
- def setup_model_and_optimizer(model_provider_func):
- """Setup model and optimizer."""
- args = get_args()
- model = get_model(model_provider_func)
- unwrapped_model = unwrap_model(model,
- (torchDDP, LocalDDP, Float16Module))
- optimizer = get_megatron_optimizer(unwrapped_model)
- lr_scheduler = get_learning_rate_scheduler(optimizer)
- if args.load is not None:
- timers = get_timers()
- # Extra barrier is added to make sure all ranks report the
- # max time.
- torch.distributed.barrier()
- timers('load-checkpoint').start()
- args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
- torch.distributed.barrier()
- timers('load-checkpoint').stop()
- timers.log(['load-checkpoint'])
- else:
- args.iteration = 0
- # We only support local DDP with multiple micro-batches.
- if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
- assert args.DDP_impl == 'local'
- # get model without FP16 and/or TorchDDP wrappers
- if args.iteration == 0 and len(unwrapped_model) == 1 \
- and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
- print_rank_0("Initializing ICT from pretrained BERT model")
- unwrapped_model[0].init_state_dict_from_bert()
- if args.fp16:
- optimizer.reload_model_params()
- return model, optimizer, lr_scheduler
- def train_step(forward_step_func, data_iterator,
- model, optimizer, lr_scheduler):
- """Single training step."""
- args = get_args()
- timers = get_timers()
- # Set grad to zero.
- if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
- for partition in model:
- partition.zero_grad_buffer()
- else:
- optimizer.zero_grad()
- if mpu.get_pipeline_model_parallel_world_size() > 1:
- if args.virtual_pipeline_model_parallel_size is not None:
- forward_backward_func = forward_backward_pipelining_with_interleaving
- assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
- 'number of microbatches is not divisible by pipeline-parallel ' \
- 'size when using interleaved schedule'
- else:
- forward_backward_func = forward_backward_pipelining_without_interleaving
- else:
- forward_backward_func = forward_backward_no_pipelining
- losses_reduced = forward_backward_func(
- forward_step_func, data_iterator, model,
- optimizer, timers, forward_only=False)
- # All-reduce if needed.
- if args.DDP_impl == 'local':
- timers('backward-params-all-reduce').start()
- for model_module in model:
- model_module.allreduce_gradients()
- timers('backward-params-all-reduce').stop()
- # All-reduce word_embeddings' grad across first and last stages to ensure
- # that word_embeddings parameters stay in sync.
- # This should only run for models that support pipelined model parallelism
- # (BERT and GPT-2).
- timers('backward-embedding-all-reduce').start()
- if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
- mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
- mpu.get_pipeline_model_parallel_world_size() > 1:
- if mpu.is_pipeline_first_stage(ignore_virtual=True):
- unwrapped_model = model[0]
- elif mpu.is_pipeline_last_stage(ignore_virtual=True):
- unwrapped_model = model[-1]
- unwrapped_model = unwrap_model(
- unwrapped_model, (torchDDP, LocalDDP, Float16Module))
- if unwrapped_model.share_word_embeddings:
- word_embeddings_weight = unwrapped_model.word_embeddings_weight()
- if args.DDP_impl == 'local':
- grad = word_embeddings_weight.main_grad
- else:
- grad = word_embeddings_weight.grad
- torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
- timers('backward-embedding-all-reduce').stop()
- # Update parameters.
- timers('optimizer').start()
- update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
- timers('optimizer').stop()
- # Update learning rate.
- if update_successful:
- increment = get_num_microbatches() * \
- args.micro_batch_size * \
- args.data_parallel_size
- lr_scheduler.step(increment=increment)
- skipped_iter = 0
- else:
- skipped_iter = 1
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # Average loss across microbatches.
- loss_reduced = {}
- for key in losses_reduced[0]:
- losses_reduced_for_key = [x[key] for x in losses_reduced]
- loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
- return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
- return {}, skipped_iter, grad_norm, num_zeros_in_grad
- def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
- loss_scale, report_memory_flag, skipped_iter,
- grad_norm, params_norm, num_zeros_in_grad):
- """Log training information such as losses, timing, ...."""
- args = get_args()
- timers = get_timers()
- writer = get_tensorboard_writer()
- # Advanced, skipped, and Nan iterations.
- advanced_iters_key = 'advanced iterations'
- skipped_iters_key = 'skipped iterations'
- nan_iters_key = 'nan iterations'
- # Advanced iterations.
- if not skipped_iter:
- total_loss_dict[advanced_iters_key] = total_loss_dict.get(
- advanced_iters_key, 0) + 1
- else:
- if advanced_iters_key not in total_loss_dict:
- total_loss_dict[advanced_iters_key] = 0
- # Skipped iterations.
- total_loss_dict[skipped_iters_key] = total_loss_dict.get(
- skipped_iters_key, 0) + skipped_iter
- # Update losses and set nan iterations
- got_nan = False
- for key in loss_dict:
- if not skipped_iter:
- total_loss_dict[key] = total_loss_dict.get(
- key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
- else:
- value = loss_dict[key].float().sum().item()
- is_nan = value == float('inf') or \
- value == -float('inf') or \
- value != value
- got_nan = got_nan or is_nan
- total_loss_dict[nan_iters_key] = total_loss_dict.get(
- nan_iters_key, 0) + int(got_nan)
- # Logging.
- timers_to_log = []
- def add_to_logging(name):
- if name in timers.timers:
- timers_to_log.append(name)
- add_to_logging('forward-compute')
- add_to_logging('forward-recv')
- add_to_logging('forward-send')
- add_to_logging('forward-backward-send-forward-backward-recv')
- add_to_logging('backward-compute')
- add_to_logging('backward-recv')
- add_to_logging('backward-send')
- add_to_logging('backward-send-forward-recv')
- add_to_logging('backward-send-backward-recv')
- add_to_logging('backward-params-all-reduce')
- add_to_logging('backward-embedding-all-reduce')
- add_to_logging('optimizer-copy-to-main-grad')
- add_to_logging('optimizer-unscale-and-check-inf')
- add_to_logging('optimizer-clip-main-grad')
- add_to_logging('optimizer-copy-main-to-model-params')
- add_to_logging('optimizer')
- add_to_logging('batch-generator')
- # Calculate batch size.
- batch_size = args.micro_batch_size * args.data_parallel_size * \
- get_num_microbatches()
- total_iterations = total_loss_dict[advanced_iters_key] + \
- total_loss_dict[skipped_iters_key]
- # Tensorboard values.
- if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
- is_last_rank():
- if args.log_learning_rate_to_tensorboard:
- writer.add_scalar('learning-rate', learning_rate, iteration)
- writer.add_scalar('learning-rate vs samples', learning_rate,
- args.consumed_train_samples)
- if args.log_batch_size_to_tensorboard:
- writer.add_scalar('batch-size', batch_size, iteration)
- writer.add_scalar('batch-size vs samples', batch_size,
- args.consumed_train_samples)
- for key in loss_dict:
- writer.add_scalar(key , loss_dict[key], iteration)
- writer.add_scalar(key + ' vs samples', loss_dict[key],
- args.consumed_train_samples)
- if args.log_loss_scale_to_tensorboard:
- writer.add_scalar('loss-scale', loss_scale, iteration)
- writer.add_scalar('loss-scale vs samples', loss_scale,
- args.consumed_train_samples)
- if grad_norm is not None:
- writer.add_scalar('grad-norm', grad_norm, iteration)
- writer.add_scalar('grad-norm vs samples', grad_norm,
- args.consumed_train_samples)
- if num_zeros_in_grad is not None:
- writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
- writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
- args.consumed_train_samples)
- if params_norm is not None:
- writer.add_scalar('params-norm', params_norm, iteration)
- writer.add_scalar('params-norm vs samples', params_norm,
- args.consumed_train_samples)
- if args.log_timers_to_tensorboard:
- timers.write(timers_to_log, writer, iteration,
- normalizer=total_iterations)
- if iteration % args.log_interval == 0:
- elapsed_time = timers('interval-time').elapsed()
- elapsed_time_per_iteration = elapsed_time / total_iterations
- if writer and torch.distributed.get_rank() == 0:
- if args.log_timers_to_tensorboard:
- writer.add_scalar('iteration-time',
- elapsed_time_per_iteration, iteration)
- log_string = ' iteration {:8d}/{:8d} |'.format(
- iteration, args.train_iters)
- log_string += ' consumed samples: {:12d} |'.format(
- args.consumed_train_samples)
- log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
- elapsed_time_per_iteration * 1000.0)
- log_string += ' learning rate: {:.3E} |'.format(learning_rate)
- log_string += ' global batch size: {:5d} |'.format(batch_size)
- for key in total_loss_dict:
- if key not in [advanced_iters_key, skipped_iters_key,
- nan_iters_key]:
- avg = total_loss_dict[key].item() / \
- float(max(1, total_loss_dict[advanced_iters_key]))
- if avg > 0.0:
- log_string += ' {}: {:.6E} |'.format(key, avg)
- total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
- log_string += ' loss scale: {:.1f} |'.format(loss_scale)
- if grad_norm is not None:
- log_string += ' grad norm: {:.3f} |'.format(grad_norm)
- if num_zeros_in_grad is not None:
- log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
- if params_norm is not None:
- log_string += ' params norm: {:.3f} |'.format(params_norm)
- log_string += ' number of skipped iterations: {:3d} |'.format(
- total_loss_dict[skipped_iters_key])
- log_string += ' number of nan iterations: {:3d} |'.format(
- total_loss_dict[nan_iters_key])
- total_loss_dict[advanced_iters_key] = 0
- total_loss_dict[skipped_iters_key] = 0
- total_loss_dict[nan_iters_key] = 0
- print_rank_last(log_string)
- if report_memory_flag and learning_rate > 0.:
- # Report memory after optimizer state has been initialized.
- report_memory('(after {} iterations)'.format(iteration))
- report_memory_flag = False
- timers.log(timers_to_log, normalizer=args.log_interval)
- return report_memory_flag
- def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
- timers = get_timers()
- # Extra barrier is added to make sure
- # all ranks report the max time.
- torch.distributed.barrier()
- timers('save-checkpoint').start()
- save_checkpoint(iteration, model, optimizer, lr_scheduler)
- torch.distributed.barrier()
- timers('save-checkpoint').stop()
- timers.log(['save-checkpoint'])
- def train(forward_step_func, model, optimizer, lr_scheduler,
- train_data_iterator, valid_data_iterator):
- """Train the model function."""
- args = get_args()
- timers = get_timers()
- # Write args to tensorboard
- write_args_to_tensorboard()
- # Turn on training mode which enables dropout.
- for model_module in model:
- model_module.train()
- # Tracking loss.
- total_loss_dict = {}
- # Iterations.
- iteration = args.iteration
- timers('interval-time').start()
- print_datetime('before the start of training step')
- report_memory_flag = True
- while iteration < args.train_iters:
- update_num_microbatches(args.consumed_train_samples)
- loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
- train_step(forward_step_func,
- train_data_iterator,
- model,
- optimizer,
- lr_scheduler)
- iteration += 1
- args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
- args.micro_batch_size * \
- get_num_microbatches()
- # Logging.
- loss_scale = optimizer.get_loss_scale().item()
- params_norm = None
- if args.log_params_norm:
- params_norm = calc_params_l2_norm(model)
- report_memory_flag = training_log(loss_dict, total_loss_dict,
- optimizer.param_groups[0]['lr'],
- iteration, loss_scale,
- report_memory_flag, skipped_iter,
- grad_norm, params_norm, num_zeros_in_grad)
- # Autoresume
- if args.adlr_autoresume and \
- (iteration % args.adlr_autoresume_interval == 0):
- check_adlr_autoresume_termination(iteration, model, optimizer,
- lr_scheduler)
- # Evaluation
- if args.eval_interval and iteration % args.eval_interval == 0 and \
- args.do_valid:
- prefix = 'iteration {}'.format(iteration)
- evaluate_and_print_results(prefix, forward_step_func,
- valid_data_iterator, model,
- iteration, False)
- # Checkpointing
- saved_checkpoint = False
- if args.save and args.save_interval and \
- iteration % args.save_interval == 0:
- save_checkpoint_and_time(iteration, model, optimizer,
- lr_scheduler)
- saved_checkpoint = True
- # Exiting based on duration
- if args.exit_duration_in_mins:
- train_time = (time.time() - _TRAIN_START_TIME) / 60.0
- done_cuda = torch.cuda.IntTensor(
- [train_time > args.exit_duration_in_mins])
- torch.distributed.all_reduce(
- done_cuda, op=torch.distributed.ReduceOp.MAX)
- done = done_cuda.item()
- if done:
- if not saved_checkpoint:
- save_checkpoint_and_time(iteration, model, optimizer,
- lr_scheduler)
- print_datetime('exiting program after {} minutes'.format(train_time))
- sys.exit()
- # Exiting based on iterations
- if args.exit_interval and iteration % args.exit_interval == 0:
- if not saved_checkpoint:
- save_checkpoint_and_time(iteration, model, optimizer,
- lr_scheduler)
- torch.distributed.barrier()
- print_datetime('exiting program at iteration {}'.format(iteration))
- sys.exit()
- return iteration
- def evaluate(forward_step_func, data_iterator, model, verbose=False):
- """Evaluation."""
- args = get_args()
- # Turn on evaluation mode which disables dropout.
- for model_module in model:
- model_module.eval()
- total_loss_dict = {}
- with torch.no_grad():
- iteration = 0
- while iteration < args.eval_iters:
- iteration += 1
- if verbose and iteration % args.log_interval == 0:
- print_rank_0('Evaluating iter {}/{}'.format(iteration,
- args.eval_iters))
- if mpu.get_pipeline_model_parallel_world_size() > 1:
- if args.virtual_pipeline_model_parallel_size is not None:
- forward_backward_func = forward_backward_pipelining_with_interleaving
- else:
- forward_backward_func = forward_backward_pipelining_without_interleaving
- else:
- forward_backward_func = forward_backward_no_pipelining
- loss_dicts = forward_backward_func(
- forward_step_func, data_iterator, model, optimizer=None,
- timers=None, forward_only=True)
- if mpu.is_pipeline_last_stage(ignore_virtual=True):
- # Reduce across processes.
- for loss_dict in loss_dicts:
- for key in loss_dict:
- total_loss_dict[key] = total_loss_dict.get(
- key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
- args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
- * args.micro_batch_size \
- * get_num_microbatches()
- # Move model back to the train mode.
- for model_module in model:
- model_module.train()
- for key in total_loss_dict:
- total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
- return total_loss_dict
- def evaluate_and_print_results(prefix, forward_step_func,
- data_iterator, model,
- iteration, verbose=False):
- """Helper function to evaluate and dump results on screen."""
- args = get_args()
- writer = get_tensorboard_writer()
- total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
- string = ' validation loss at {} | '.format(prefix)
- for key in total_loss_dict:
- string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
- ppl = math.exp(min(20, total_loss_dict[key].item()))
- string += '{} PPL: {:.6E} | '.format(key, ppl)
- if writer and is_last_rank():
- writer.add_scalar('{} validation'.format(key),
- total_loss_dict[key].item(),
- iteration)
- writer.add_scalar('{} validation vs samples'.format(key),
- total_loss_dict[key].item(),
- args.consumed_train_samples)
- if args.log_validation_ppl_to_tensorboard:
- writer.add_scalar('{} validation ppl'.format(key), ppl,
- iteration)
- writer.add_scalar('{} validation ppl vs samples'.format(key),
- ppl, args.consumed_train_samples)
- length = len(string) + 1
- print_rank_last('-' * length)
- print_rank_last(string)
- print_rank_last('-' * length)
- def cyclic_iter(iter):
- while True:
- for x in iter:
- yield x
- def build_train_valid_test_data_iterators(
- build_train_valid_test_datasets_provider):
- """XXX"""
- args = get_args()
- (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
- print_rank_0('> building train, validation, and test datasets ...')
- # Backward compatibility, assume fixed batch size.
- if args.iteration > 0 and args.consumed_train_samples == 0:
- assert args.train_samples is None, \
- 'only backward compatiblity support for iteration-based training'
- args.consumed_train_samples = args.iteration * args.global_batch_size
- if args.iteration > 0 and args.consumed_valid_samples == 0:
- assert args.train_samples is None, \
- 'only backward compatiblity support for iteration-based training'
- args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
- args.eval_iters * args.global_batch_size
- # Data loader only on rank 0 of each model parallel group.
- if mpu.get_tensor_model_parallel_rank() == 0:
- # Number of train/valid/test samples.
- if args.train_samples:
- train_samples = args.train_samples
- else:
- train_samples = args.train_iters * args.global_batch_size
- eval_iters = (args.train_iters // args.eval_interval + 1) * \
- args.eval_iters
- test_iters = args.eval_iters
- train_val_test_num_samples = [train_samples,
- eval_iters * args.global_batch_size,
- test_iters * args.global_batch_size]
- print_rank_0(' > datasets target sizes (minimum size):')
- print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
- print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
- print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
- # Build the datasets.
- train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
- train_val_test_num_samples)
- # Build dataloders.
- train_dataloader = build_pretraining_data_loader(
- train_ds, args.consumed_train_samples)
- valid_dataloader = build_pretraining_data_loader(
- valid_ds, args.consumed_valid_samples)
- test_dataloader = build_pretraining_data_loader(test_ds, 0)
- # Flags to know if we need to do training/validation/testing.
- do_train = train_dataloader is not None and args.train_iters > 0
- do_valid = valid_dataloader is not None and args.eval_iters > 0
- do_test = test_dataloader is not None and args.eval_iters > 0
- # Need to broadcast num_tokens and num_type_tokens.
- flags = torch.cuda.LongTensor(
- [int(do_train), int(do_valid), int(do_test)])
- else:
- flags = torch.cuda.LongTensor([0, 0, 0])
- # Broadcast num tokens.
- torch.distributed.broadcast(flags,
- mpu.get_tensor_model_parallel_src_rank(),
- group=mpu.get_tensor_model_parallel_group())
- args.do_train = flags[0].item()
- args.do_valid = flags[1].item()
- args.do_test = flags[2].item()
- # Build iterators.
- dl_type = args.dataloader_type
- assert dl_type in ['single', 'cyclic']
- if train_dataloader is not None:
- train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(train_dataloader))
- else:
- train_data_iterator = None
- if valid_dataloader is not None:
- valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(valid_dataloader))
- else:
- valid_data_iterator = None
- if test_dataloader is not None:
- test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
- else iter(cyclic_iter(test_dataloader))
- else:
- test_data_iterator = None
- return train_data_iterator, valid_data_iterator, test_data_iterator
|