123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from abc import ABC
- from abc import abstractmethod
- import torch
- from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
- from megatron import get_args
- from megatron import mpu
- from .module import MegatronModule
- class MemoryBuffer:
- def __init__(self, numel, dtype):
- self.numel = numel
- self.dtype = dtype
- self.data = torch.zeros(self.numel,
- dtype=self.dtype,
- device=torch.cuda.current_device(),
- requires_grad=False)
- def zero(self):
- """Reset the buffer to zero."""
- self.data.zero_()
- def get(self, shape, start_index):
- """Return a tensor with the input `shape` as a view into the
- 1-D data starting at `start_index`."""
- end_index = start_index + shape.numel()
- assert end_index <= self.numel, \
- 'requested tensor is out of the buffer range.'
- buffer_tensor = self.data[start_index:end_index]
- buffer_tensor = buffer_tensor.view(shape)
- return buffer_tensor
- class DistributedDataParallelBase(MegatronModule, ABC):
- """Abstract class for DDP."""
- def __init__(self, module):
- super(DistributedDataParallelBase, self).__init__()
- # Keep a pointer to the model.
- self.module = module
- @abstractmethod
- def allreduce_gradients(self):
- pass
- def forward(self, *inputs, **kwargs):
- return self.module(*inputs, **kwargs)
- def state_dict(self, destination=None, prefix='', keep_vars=False):
- return self.module.state_dict(destination, prefix, keep_vars)
- def state_dict_for_save_checkpoint(self, destination=None, prefix='',
- keep_vars=False):
- return self.module.state_dict_for_save_checkpoint(destination, prefix,
- keep_vars)
- def load_state_dict(self, state_dict, strict=True):
- self.module.load_state_dict(state_dict, strict=strict)
- class DistributedDataParallel(DistributedDataParallelBase):
- """DDP with contiguous buffers options to storre and accumulate gradients.
- This class:
- - has the potential to reduce memory fragmentation.
- - provides the option to do the gradient accumulation
- in a type other than the params type (for example fp32)
- Arguments:
- module: input model.
- accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
- and the gradient all-reduce all in in float32. If this option is
- true, we require `use_contiguous_buffers` to be true too.
- use_contiguous_buffers: if true, use a contiguous buffer to store the
- gradients.
- """
- def __init__(self, module,
- accumulate_allreduce_grads_in_fp32,
- use_contiguous_buffers):
- super(DistributedDataParallel, self).__init__(module)
- self.accumulate_allreduce_grads_in_fp32 \
- = accumulate_allreduce_grads_in_fp32
- self.use_contiguous_buffers = use_contiguous_buffers
- # If we are using fp32-accumulate-allreduce explicitly
- # this means we need main grads in a continous buffer.
- if self.accumulate_allreduce_grads_in_fp32:
- assert self.use_contiguous_buffers
- # ===================================
- # Rest of this part applies only to
- # the case we use continuous buffers.
- # ===================================
- self._grad_buffers = None
- if self.use_contiguous_buffers:
- self._grad_buffers = {}
- # Simple function to define buffer type.
- def _get_buffer_type(param):
- return torch.float if \
- self.accumulate_allreduce_grads_in_fp32 else param.dtype
- # First calculate total number of elements per type.
- type_num_elements = {}
- for param in self.module.parameters():
- if param.requires_grad:
- dtype = _get_buffer_type(param)
- type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
- + param.data.nelement()
- # Allocate the buffer.
- for dtype, num_elements in type_num_elements.items():
- self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
- # Assume the back prop order is reverse the params order,
- # store the start index for the gradients.
- for param in self.module.parameters():
- if param.requires_grad:
- dtype = _get_buffer_type(param)
- type_num_elements[dtype] -= param.data.nelement()
- param.main_grad = self._grad_buffers[dtype].get(
- param.data.shape, type_num_elements[dtype])
- # Backward hook.
- # Accumalation function for the gradients. We need
- # to store them so they don't go out of scope.
- self.grad_accs = []
- # Loop over all the parameters in the model.
- for param in self.module.parameters():
- if param.requires_grad:
- # Expand so we get access to grad_fn.
- param_tmp = param.expand_as(param)
- # Get the gradient accumulator functtion.
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
- grad_acc.register_hook(self._make_param_hook(param))
- self.grad_accs.append(grad_acc)
- def _make_param_hook(self, param):
- """Create the all-reduce hook for backprop."""
- # Hook used for back-prop.
- def param_hook(*unused):
- # Add the gradient to the buffer.
- if param.grad.data is not None:
- param.main_grad.add_(param.grad.data)
- # Now we can deallocate grad memory.
- param.grad = None
- return param_hook
- def zero_grad_buffer(self):
- """Set the grad buffer data to zero. Needs to be called at the
- begining of each iteration."""
- assert self._grad_buffers is not None, 'buffers are not initialized.'
- for _, buffer_ in self._grad_buffers.items():
- buffer_.zero()
- def allreduce_gradients(self):
- """Reduce gradients across data parallel ranks."""
- # If we have buffers, simply reduce the data in the buffer.
- if self._grad_buffers is not None:
- for _, buffer_ in self._grad_buffers.items():
- buffer_.data /= mpu.get_data_parallel_world_size()
- torch.distributed.all_reduce(
- buffer_.data, group=mpu.get_data_parallel_group())
- else:
- # Otherwise, bucketize and all-reduce
- buckets = {}
- # Pack the buckets.
- for param in self.module.parameters():
- if param.requires_grad and param.grad is not None:
- tp = param.data.type()
- if tp not in buckets:
- buckets[tp] = []
- buckets[tp].append(param)
- param.main_grad = param.grad
- # For each bucket, all-reduce and copy all-reduced grads.
- for tp in buckets:
- bucket = buckets[tp]
- grads = [param.grad.data for param in bucket]
- coalesced = _flatten_dense_tensors(grads)
- coalesced /= mpu.get_data_parallel_world_size()
- torch.distributed.all_reduce(
- coalesced, group=mpu.get_data_parallel_group())
- for buf, synced in zip(grads, _unflatten_dense_tensors(
- coalesced, grads)):
- buf.copy_(synced)
|