# coding=utf-8 # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC from abc import abstractmethod import torch from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from megatron import get_args from megatron import mpu from .module import MegatronModule class MemoryBuffer: def __init__(self, numel, dtype): self.numel = numel self.dtype = dtype self.data = torch.zeros(self.numel, dtype=self.dtype, device=torch.cuda.current_device(), requires_grad=False) def zero(self): """Reset the buffer to zero.""" self.data.zero_() def get(self, shape, start_index): """Return a tensor with the input `shape` as a view into the 1-D data starting at `start_index`.""" end_index = start_index + shape.numel() assert end_index <= self.numel, \ 'requested tensor is out of the buffer range.' buffer_tensor = self.data[start_index:end_index] buffer_tensor = buffer_tensor.view(shape) return buffer_tensor class DistributedDataParallelBase(MegatronModule, ABC): """Abstract class for DDP.""" def __init__(self, module): super(DistributedDataParallelBase, self).__init__() # Keep a pointer to the model. self.module = module @abstractmethod def allreduce_gradients(self): pass def forward(self, *inputs, **kwargs): return self.module(*inputs, **kwargs) def state_dict(self, destination=None, prefix='', keep_vars=False): return self.module.state_dict(destination, prefix, keep_vars) def state_dict_for_save_checkpoint(self, destination=None, prefix='', keep_vars=False): return self.module.state_dict_for_save_checkpoint(destination, prefix, keep_vars) def load_state_dict(self, state_dict, strict=True): self.module.load_state_dict(state_dict, strict=strict) class DistributedDataParallel(DistributedDataParallelBase): """DDP with contiguous buffers options to storre and accumulate gradients. This class: - has the potential to reduce memory fragmentation. - provides the option to do the gradient accumulation in a type other than the params type (for example fp32) Arguments: module: input model. accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation and the gradient all-reduce all in in float32. If this option is true, we require `use_contiguous_buffers` to be true too. use_contiguous_buffers: if true, use a contiguous buffer to store the gradients. """ def __init__(self, module, accumulate_allreduce_grads_in_fp32, use_contiguous_buffers): super(DistributedDataParallel, self).__init__(module) self.accumulate_allreduce_grads_in_fp32 \ = accumulate_allreduce_grads_in_fp32 self.use_contiguous_buffers = use_contiguous_buffers # If we are using fp32-accumulate-allreduce explicitly # this means we need main grads in a continous buffer. if self.accumulate_allreduce_grads_in_fp32: assert self.use_contiguous_buffers # =================================== # Rest of this part applies only to # the case we use continuous buffers. # =================================== self._grad_buffers = None if self.use_contiguous_buffers: self._grad_buffers = {} # Simple function to define buffer type. def _get_buffer_type(param): return torch.float if \ self.accumulate_allreduce_grads_in_fp32 else param.dtype # First calculate total number of elements per type. type_num_elements = {} for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] = type_num_elements.get(dtype, 0) \ + param.data.nelement() # Allocate the buffer. for dtype, num_elements in type_num_elements.items(): self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype) # Assume the back prop order is reverse the params order, # store the start index for the gradients. for param in self.module.parameters(): if param.requires_grad: dtype = _get_buffer_type(param) type_num_elements[dtype] -= param.data.nelement() param.main_grad = self._grad_buffers[dtype].get( param.data.shape, type_num_elements[dtype]) # Backward hook. # Accumalation function for the gradients. We need # to store them so they don't go out of scope. self.grad_accs = [] # Loop over all the parameters in the model. for param in self.module.parameters(): if param.requires_grad: # Expand so we get access to grad_fn. param_tmp = param.expand_as(param) # Get the gradient accumulator functtion. grad_acc = param_tmp.grad_fn.next_functions[0][0] grad_acc.register_hook(self._make_param_hook(param)) self.grad_accs.append(grad_acc) def _make_param_hook(self, param): """Create the all-reduce hook for backprop.""" # Hook used for back-prop. def param_hook(*unused): # Add the gradient to the buffer. if param.grad.data is not None: param.main_grad.add_(param.grad.data) # Now we can deallocate grad memory. param.grad = None return param_hook def zero_grad_buffer(self): """Set the grad buffer data to zero. Needs to be called at the begining of each iteration.""" assert self._grad_buffers is not None, 'buffers are not initialized.' for _, buffer_ in self._grad_buffers.items(): buffer_.zero() def allreduce_gradients(self): """Reduce gradients across data parallel ranks.""" # If we have buffers, simply reduce the data in the buffer. if self._grad_buffers is not None: for _, buffer_ in self._grad_buffers.items(): buffer_.data /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( buffer_.data, group=mpu.get_data_parallel_group()) else: # Otherwise, bucketize and all-reduce buckets = {} # Pack the buckets. for param in self.module.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() if tp not in buckets: buckets[tp] = [] buckets[tp].append(param) param.main_grad = param.grad # For each bucket, all-reduce and copy all-reduced grads. for tp in buckets: bucket = buckets[tp] grads = [param.grad.data for param in bucket] coalesced = _flatten_dense_tensors(grads) coalesced /= mpu.get_data_parallel_world_size() torch.distributed.all_reduce( coalesced, group=mpu.get_data_parallel_group()) for buf, synced in zip(grads, _unflatten_dense_tensors( coalesced, grads)): buf.copy_(synced)