# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Input/output checkpointing.""" import os import random import sys import numpy as np import torch from megatron import (get_args, mpu, print_rank_0, update_num_microbatches, utils) _CHECKPOINT_VERSION = None def set_checkpoint_version(value): global _CHECKPOINT_VERSION if _CHECKPOINT_VERSION is not None: assert _CHECKPOINT_VERSION == value, \ "checkpoint versions do not match" _CHECKPOINT_VERSION = value def get_checkpoint_version(): global _CHECKPOINT_VERSION return _CHECKPOINT_VERSION def check_checkpoint_args(checkpoint_args): """Ensure fixed arguments for a model are the same for the input arguments and the one retrieved from checkpoint.""" args = get_args() def _compare(arg_name, old_arg_name=None): if old_arg_name is not None: checkpoint_value = getattr(checkpoint_args, old_arg_name) else: checkpoint_value = getattr(checkpoint_args, arg_name) args_value = getattr(args, arg_name) error_message = '{} value from checkpoint ({}) is not equal to the ' \ 'input argument value ({}).'.format( arg_name, checkpoint_value, args_value) assert checkpoint_value == args_value, error_message _compare('num_layers') _compare('hidden_size') _compare('num_attention_heads') if args.vocab_file: _compare('max_position_embeddings') _compare('make_vocab_size_divisible_by') _compare('padded_vocab_size') _compare('tokenizer_type') if get_checkpoint_version() < 3.0: _compare('tensor_model_parallel_size', old_arg_name='model_parallel_size') if get_checkpoint_version() >= 3.0: _compare('tensor_model_parallel_size') _compare('pipeline_model_parallel_size') def ensure_directory_exists(filename): """Build filename's path if it does not already exists.""" dirname = os.path.dirname(filename) if not os.path.exists(dirname): os.makedirs(dirname) def get_checkpoint_name(checkpoints_path, iteration, release=False): """A unified checkpoint name.""" if release: directory = 'release' else: directory = 'iter_{:07d}'.format(iteration) # Use both the tensor and pipeline MP rank. if mpu.get_pipeline_model_parallel_world_size() == 1: return os.path.join(checkpoints_path, directory, 'mp_rank_{:02d}'.format( mpu.get_tensor_model_parallel_rank()), 'model_optim_rng.pt') return os.path.join(checkpoints_path, directory, 'mp_rank_{:02d}_{:03d}'.format( mpu.get_tensor_model_parallel_rank(), mpu.get_pipeline_model_parallel_rank()), 'model_optim_rng.pt') def get_checkpoint_tracker_filename(checkpoints_path): """Tracker file rescords the latest chckpoint during training to restart from.""" return os.path.join(checkpoints_path, 'latest_checkpointed_iteration.txt') def save_checkpoint(iteration, model, optimizer, lr_scheduler): """Save a model checkpoint.""" args = get_args() # Only rank zero of the data parallel writes to the disk. model = utils.unwrap_model(model) print_rank_0('saving checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) if not torch.distributed.is_initialized() or mpu.get_data_parallel_rank() == 0: # Arguments, iteration, and model. state_dict = {} state_dict['args'] = args state_dict['checkpoint_version'] = 3.0 state_dict['iteration'] = iteration if len(model) == 1: state_dict['model'] = model[0].state_dict_for_save_checkpoint() else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) state_dict['model%d' % i] = model[i].state_dict_for_save_checkpoint() # Optimizer stuff. if not args.no_save_optim: if optimizer is not None: state_dict['optimizer'] = optimizer.state_dict() if lr_scheduler is not None: state_dict['lr_scheduler'] = lr_scheduler.state_dict() # RNG states. if not args.no_save_rng: state_dict['random_rng_state'] = random.getstate() state_dict['np_rng_state'] = np.random.get_state() state_dict['torch_rng_state'] = torch.get_rng_state() state_dict['cuda_rng_state'] = torch.cuda.get_rng_state() state_dict['rng_tracker_states'] \ = mpu.get_cuda_rng_tracker().get_states() # Save. checkpoint_name = get_checkpoint_name(args.save, iteration) ensure_directory_exists(checkpoint_name) torch.save(state_dict, checkpoint_name) # Wait so everyone is done (necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0(' successfully saved checkpoint at iteration {:7d} to {}'.format( iteration, args.save)) # And update the latest iteration if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: tracker_filename = get_checkpoint_tracker_filename(args.save) with open(tracker_filename, 'w') as f: f.write(str(iteration)) # Wait so everyone is done (not necessary) if torch.distributed.is_initialized(): torch.distributed.barrier() def _transpose_first_dim(t, num_splits, num_splits_first, model): input_shape = t.size() # We use a self_attention module but the values extracted aren't # specific to self attention so should work for cross attention as well while hasattr(model, 'module'): model = model.module attention_module = model.language_model.encoder.layers[0].self_attention hidden_size_per_attention_head = attention_module.hidden_size_per_attention_head num_attention_heads_per_partition = attention_module.num_attention_heads_per_partition if num_splits_first: """[num_splits * np * hn, h] -->(view) [num_splits, np, hn, h] -->(tranpose) [np, num_splits, hn, h] -->(view) [np * num_splits * hn, h] """ intermediate_shape = \ (num_splits, num_attention_heads_per_partition, hidden_size_per_attention_head) + input_shape[1:] t = t.view(*intermediate_shape) t = t.transpose(0, 1).contiguous() else: """[np * hn * num_splits, h] -->(view) [np, hn, num_splits, h] -->(tranpose) [np, num_splits, hn, h] -->(view) [np * num_splits * hn, h] """ intermediate_shape = \ (num_attention_heads_per_partition, hidden_size_per_attention_head, num_splits) +\ input_shape[1:] t = t.view(*intermediate_shape) t = t.transpose(1, 2).contiguous() t = t.view(*input_shape) return t def fix_query_key_value_ordering(model, checkpoint_version): """Fix up query/key/value matrix ordering if checkpoint version is smaller than 2.0 """ if checkpoint_version < 2.0: if isinstance(model, list): assert len(model)==1 model = model[0] for name, param in model.named_parameters(): if name.endswith(('.query_key_value.weight', '.query_key_value.bias')): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 3, True, model) elif checkpoint_version == 1.0: fixed_param = _transpose_first_dim(param.data, 3, False, model) else: print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) if name.endswith(('.key_value.weight', '.key_value.bias')): if checkpoint_version == 0: fixed_param = _transpose_first_dim(param.data, 2, True, model) elif checkpoint_version == 1.0: fixed_param = _transpose_first_dim(param.data, 2, False, model) else: print_rank_0(f"Invalid checkpoint version {checkpoint_version}.") sys.exit() param.data.copy_(fixed_param) print_rank_0(" succesfully fixed query-key-values ordering for" " checkpoint version {}".format(checkpoint_version)) def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model. """ args = get_args() load_dir = getattr(args, load_arg) model = utils.unwrap_model(model) # Read the tracker file and set the iteration. tracker_filename = get_checkpoint_tracker_filename(load_dir) # If no tracker file, return iretation zero. if not os.path.isfile(tracker_filename): print_rank_0('WARNING: could not find the metadata file {} '.format( tracker_filename)) print_rank_0(' will not load any checkpoints and will start from ' 'random') return 0 # Otherwise, read the tracker file and either set the iteration or # mark it as a release checkpoint. iteration = 0 release = False with open(tracker_filename, 'r') as f: metastring = f.read().strip() try: iteration = int(metastring) except ValueError: release = metastring == 'release' if not release: print_rank_0('ERROR: Invalid metadata file {}. Exiting'.format( tracker_filename)) sys.exit() assert iteration > 0 or release, 'error parsing metadata file {}'.format( tracker_filename) # Checkpoint. checkpoint_name = get_checkpoint_name(load_dir, iteration, release) print_rank_0(f' loading checkpoint from {args.load} at iteration {iteration}') # Load the checkpoint. try: state_dict = torch.load(checkpoint_name, map_location='cpu') except ModuleNotFoundError: from megatron.fp16_deprecated import loss_scaler # For backward compatibility. print_rank_0(' > deserializing using the old code structure ...') sys.modules['fp16.loss_scaler'] = sys.modules[ 'megatron.fp16_deprecated.loss_scaler'] sys.modules['megatron.fp16.loss_scaler'] = sys.modules[ 'megatron.fp16_deprecated.loss_scaler'] state_dict = torch.load(checkpoint_name, map_location='cpu') sys.modules.pop('fp16.loss_scaler', None) sys.modules.pop('megatron.fp16.loss_scaler', None) except BaseException as e: print_rank_0('could not load the checkpoint') print_rank_0(e) sys.exit() # set checkpoint version set_checkpoint_version(state_dict.get('checkpoint_version', 0)) # Set iteration. if args.finetune or release: iteration = 0 else: try: iteration = state_dict['iteration'] except KeyError: try: # Backward compatible with older checkpoints iteration = state_dict['total_iters'] except KeyError: print_rank_0('A metadata file exists but unable to load ' 'iteration from checkpoint {}, exiting'.format( checkpoint_name)) sys.exit() # Check arguments. assert args.consumed_train_samples == 0 assert args.consumed_valid_samples == 0 if 'args' in state_dict: checkpoint_args = state_dict['args'] check_checkpoint_args(checkpoint_args) args.consumed_train_samples = getattr(checkpoint_args, 'consumed_train_samples', 0) update_num_microbatches(consumed_samples=args.consumed_train_samples) args.consumed_valid_samples = getattr(checkpoint_args, 'consumed_valid_samples', 0) else: print_rank_0('could not find arguments in the checkpoint ...') # Model. if len(model) == 1: model[0].load_state_dict(state_dict['model'], strict=strict) else: for i in range(len(model)): mpu.set_virtual_pipeline_model_parallel_rank(i) model[i].load_state_dict(state_dict['model%d' % i], strict=strict) # Fix up query/key/value matrix ordering if needed checkpoint_version = get_checkpoint_version() print_rank_0(f' checkpoint version {checkpoint_version}') fix_query_key_value_ordering(model, checkpoint_version) # Optimizer. if not release and not args.finetune and not args.no_load_optim: try: if optimizer is not None: optimizer.load_state_dict(state_dict['optimizer']) if lr_scheduler is not None: lr_scheduler.load_state_dict(state_dict['lr_scheduler']) except KeyError: print_rank_0('Unable to load optimizer from checkpoint {}. ' 'Specify --no-load-optim or --finetune to prevent ' 'attempting to load the optimizer state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # rng states. if not release and not args.finetune and not args.no_load_rng: try: random.setstate(state_dict['random_rng_state']) np.random.set_state(state_dict['np_rng_state']) torch.set_rng_state(state_dict['torch_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state']) # Check for empty states array if not state_dict['rng_tracker_states']: raise KeyError mpu.get_cuda_rng_tracker().set_states( state_dict['rng_tracker_states']) except KeyError: print_rank_0('Unable to load rng state from checkpoint {}. ' 'Specify --no-load-rng or --finetune to prevent ' 'attempting to load the rng state, ' 'exiting ...'.format(checkpoint_name)) sys.exit() # Some utilities want to load a checkpoint without distributed being initialized if torch.distributed.is_initialized(): torch.distributed.barrier() print_rank_0(f' successfully loaded checkpoint from {args.load} ' f'at iteration {iteration}') return iteration def load_biencoder_checkpoint(model, only_query_model=False, only_context_model=False, custom_load_path=None): """ selectively load retrieval models for indexing/retrieving from saved checkpoints """ args = get_args() model = utils.unwrap_model(model) load_path = custom_load_path if custom_load_path is not None else args.load tracker_filename = get_checkpoint_tracker_filename(load_path) with open(tracker_filename, 'r') as f: iteration = int(f.read().strip()) checkpoint_name = get_checkpoint_name(load_path, iteration, False) if mpu.get_data_parallel_rank() == 0: print('global rank {} is loading checkpoint {}'.format( torch.distributed.get_rank(), checkpoint_name)) state_dict = torch.load(checkpoint_name, map_location='cpu') ret_state_dict = state_dict['model'] if only_query_model: ret_state_dict.pop('context_model') if only_context_model: ret_state_dict.pop('query_model') assert len(model) == 1 model[0].load_state_dict(ret_state_dict) torch.distributed.barrier() if mpu.get_data_parallel_rank() == 0: print(' successfully loaded {}'.format(checkpoint_name)) return model