| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# Parts of the code here are adapted from PyTorch# repo: https://github.com/pytorch/pytorchimport contextlibimport torchfrom torch import _Cfrom torch.cuda import _lazy_call, device as device_ctx_managerfrom torch.utils.checkpoint import detach_variablefrom megatron import get_argsfrom megatron.memory import allocate_mem_bufffrom .initialize import get_data_parallel_rankfrom .initialize import get_tensor_model_parallel_groupfrom .initialize import get_tensor_model_parallel_rankfrom .initialize import get_tensor_model_parallel_world_size# Default name for the model parallel rng tracker._MODEL_PARALLEL_RNG_TRACKER_NAME = 'model-parallel-rng'# Whether apply model parallelsim to checkpointed hidden states._CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = Nonedef init_checkpointed_activations_memory_buffer():    """Initializ the memory buffer for the checkpointed activations."""    args = get_args()    per_layer = args.micro_batch_size * args.max_position_embeddings * \                args.hidden_size // args.tensor_model_parallel_size    assert args.num_layers % args.checkpoint_num_layers == 0, \        'number of layers is not divisible by checkpoint-num-layers'    num_checkpointer_layers = args.num_layers // args.checkpoint_num_layers    numel = per_layer * num_checkpointer_layers    dtype = torch.half    if not args.fp16:        dtype = torch.float    global _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER    assert _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is None, \        'checkpointed activations memory buffer is already allocated.'    _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER = allocate_mem_buff(        'checkpointed activations', numel, dtype, track_usage=False)def reset_checkpointed_activations_memory_buffer():    """Reset the memory used for checkpointing."""    if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:        _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.reset()def _set_cuda_rng_state(new_state, device=-1):    """Sets the random number generator state of the current GPU.    Argumentss:        new_state (torch.ByteTensor): The desired state    This function is adapted from PyTorch repo (torch.cuda.set_rng_state)    with a single change: the input state is not cloned. Cloning caused    major performance issues for +4 GPU cases.    """    if hasattr(_C, '_cuda_setRNGState') and callable(_C._cuda_setRNGState):        # older PyTorch        def cb():            with device_ctx_manager(device):                _C._cuda_setRNGState(new_state)    else:        # newer PyTorch        if device == -1:            device = torch.device('cuda')        elif isinstance(device, str):            device = torch.device(device)        elif isinstance(device, int):            device = torch.device('cuda', device)        def cb():            idx = device.index            if idx is None:                idx = torch.cuda.current_device()            default_generator = torch.cuda.default_generators[idx]            default_generator.set_state(new_state)    _lazy_call(cb)def split_tensor_into_1d_equal_chunks(tensor):    """Break a tensor into equal 1D chunks."""    data = tensor.view(-1)    partition_size = torch.numel(data) // get_tensor_model_parallel_world_size()    start_index = partition_size * get_tensor_model_parallel_rank()    end_index = start_index + partition_size    return data[start_index:end_index]def gather_split_1d_tensor(tensor):    """Opposite of above function, gather values from model parallel ranks."""    world_size = get_tensor_model_parallel_world_size()    numel = torch.numel(tensor)    numel_gathered = world_size * numel    gathered = torch.empty(numel_gathered, dtype=tensor.dtype,                           device=torch.cuda.current_device(),                           requires_grad=False)    chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]    torch.distributed.all_gather(chunks, tensor,                                 group=get_tensor_model_parallel_group())    return gatheredclass CudaRNGStatesTracker:    """Tracker for the cuda RNG states.    Using the `add` method, a cuda rng state is initialized based on    the input `seed` and is assigned to `name`. Later, by forking the    rng state, we can perform operations and return to our starting    cuda state.    """    def __init__(self):        # Map from a string name to the cuda rng state.        self.states_ = {}        # Seeds are just for book keeping and ensure no seed is set twice.        self.seeds_ = set()    def reset(self):        """Set to the initial state (no tracker)."""        self.states_ = {}        self.seeds_ = set()    def get_states(self):        """Get rng states. Copy the dictionary so we have direct        pointers to the states, not just a pointer to the dictionary."""        states = {}        for name in self.states_:            states[name] = self.states_[name]        return states    def set_states(self, states):        """Set the rng states. For efficiency purposes, we do not check        the size of seed for compatibility."""        self.states_ = states    def add(self, name, seed):        """Track the rng state."""        # Check seed is not already used.        if seed in self.seeds_:            raise Exception('seed {} already exists'.format(seed))        self.seeds_.add(seed)        # Check that state is not already defined.        if name in self.states_:            raise Exception('cuda rng state {} already exists'.format(name))        # Get the current rng state.        orig_rng_state = torch.cuda.get_rng_state()        # Set the new state and store it.        torch.cuda.manual_seed(seed)        self.states_[name] = torch.cuda.get_rng_state()        # Reset rng state to what it was.        _set_cuda_rng_state(orig_rng_state)    @contextlib.contextmanager    def fork(self, name=_MODEL_PARALLEL_RNG_TRACKER_NAME):        """Fork the cuda rng state, perform operations, and exit with        the original state."""        # Check if we have added the state        if name not in self.states_:            raise Exception('cuda rng state {} is not added'.format(name))        # Store current rng state.        orig_cuda_rng_state = torch.cuda.get_rng_state()        # Set rng state to the desired one        _set_cuda_rng_state(self.states_[name])        # Do the stuff we wanted to do.        try:            yield        finally:            # Update the current rng state for later use.            self.states_[name] = torch.cuda.get_rng_state()            # And set the state to the original state we started with.            _set_cuda_rng_state(orig_cuda_rng_state)# RNG tracker object._CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker()def get_cuda_rng_tracker():    """Get cuda rng tracker."""    return _CUDA_RNG_STATE_TRACKERdef model_parallel_cuda_manual_seed(seed):    """Initialize model parallel cuda seed.    This function should be called after the model parallel is    initialized. Also, no torch.cuda.manual_seed should be called    after this function. Basically, this is replacement for that    function.    Two set of RNG states are tracked:        default state: This is for data parallelism and is the same among a                       set of model parallel GPUs but different across                       different model paralle groups. This is used for                       example for dropout in the non-tensor-model-parallel regions.        tensor-model-parallel state: This state is different among a set of model                              parallel GPUs, but the same across data parallel                              groups. This is used for example for dropout in                              model parallel regions.    """    # 2718 is just for fun and any POSITIVE value will work.    offset = seed + 2718    tensor_model_parallel_seed = offset + get_tensor_model_parallel_rank()    # Data parallel gets the original seed.    data_parallel_seed = seed    if torch.distributed.get_rank() == 0:        print('> initializing model parallel cuda seeds on global rank {}, '              'model parallel rank {}, and data parallel rank {} with '              'model parallel seed: {} and data parallel seed: {}'.format(                  torch.distributed.get_rank(), get_tensor_model_parallel_rank(),                  get_data_parallel_rank(), tensor_model_parallel_seed,                  data_parallel_seed), flush=True)    _CUDA_RNG_STATE_TRACKER.reset()    # Set the default state.    torch.cuda.manual_seed(data_parallel_seed)    # and model parallel state.    _CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME,                                tensor_model_parallel_seed)class CheckpointFunction(torch.autograd.Function):    """This function is adapted from torch.utils.checkpoint with       two main changes:           1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state`           2) the states in the model parallel tracker are also properly              tracked/set/reset.    """    @staticmethod    def forward(ctx, run_function, *args):        ctx.run_function = run_function        # Copy the rng states.        ctx.fwd_cpu_rng_state = torch.get_rng_state()        ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state()        ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()        with torch.no_grad():            outputs = run_function(*args)        # Divide hidden states across model parallel group and only keep        # the chunk corresponding to the current rank.        if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:            ctx.input_0_shape = args[0].data.shape            args[0].data = split_tensor_into_1d_equal_chunks(args[0].data)            args[0].data = _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER.add(                args[0].data)        # Store everything.        ctx.save_for_backward(*args)        return outputs    @staticmethod    def backward(ctx, *args):        if not torch.autograd._is_checkpoint_valid():            raise RuntimeError("Checkpointing is not compatible with .grad(), "                               "please use .backward() if possible")        inputs = ctx.saved_tensors        if _CHECKPOINTED_ACTIVATIONS_MEMORY_BUFFER is not None:            inputs[0].data = gather_split_1d_tensor(inputs[0].data)            inputs[0].data = inputs[0].data.view(ctx.input_0_shape)        # Store the current states.        bwd_cpu_rng_state = torch.get_rng_state()        bwd_cuda_rng_state = torch.cuda.get_rng_state()        bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states()        # Set the states to what it used to be before the forward pass.        torch.set_rng_state(ctx.fwd_cpu_rng_state)        _set_cuda_rng_state(ctx.fwd_cuda_rng_state)        get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker)        # Compute the forward pass.        detached_inputs = detach_variable(inputs)        with torch.enable_grad():            outputs = ctx.run_function(*detached_inputs)        # Set the states back to what it was at the start of this function.        torch.set_rng_state(bwd_cpu_rng_state)        _set_cuda_rng_state(bwd_cuda_rng_state)        get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker)        if isinstance(outputs, torch.Tensor):            outputs = (outputs,)        torch.autograd.backward(outputs, args)        grads = tuple(inp.grad if isinstance(inp, torch.Tensor) else inp                      for inp in detached_inputs)        return (None,) + gradsdef checkpoint(function, *args):    """Checkpoint a model or part of the model.    This has been directly copied from torch.utils.checkpoint."""    return CheckpointFunction.apply(function, *args)
 |