distributed.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from abc import ABC
  16. from abc import abstractmethod
  17. import torch
  18. from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
  19. from megatron import get_args
  20. from megatron import mpu
  21. from .module import MegatronModule
  22. class MemoryBuffer:
  23. def __init__(self, numel, dtype):
  24. self.numel = numel
  25. self.dtype = dtype
  26. self.data = torch.zeros(self.numel,
  27. dtype=self.dtype,
  28. device=torch.cuda.current_device(),
  29. requires_grad=False)
  30. def zero(self):
  31. """Reset the buffer to zero."""
  32. self.data.zero_()
  33. def get(self, shape, start_index):
  34. """Return a tensor with the input `shape` as a view into the
  35. 1-D data starting at `start_index`."""
  36. end_index = start_index + shape.numel()
  37. assert end_index <= self.numel, \
  38. 'requested tensor is out of the buffer range.'
  39. buffer_tensor = self.data[start_index:end_index]
  40. buffer_tensor = buffer_tensor.view(shape)
  41. return buffer_tensor
  42. class DistributedDataParallelBase(MegatronModule, ABC):
  43. """Abstract class for DDP."""
  44. def __init__(self, module):
  45. super(DistributedDataParallelBase, self).__init__()
  46. # Keep a pointer to the model.
  47. self.module = module
  48. @abstractmethod
  49. def allreduce_gradients(self):
  50. pass
  51. def forward(self, *inputs, **kwargs):
  52. return self.module(*inputs, **kwargs)
  53. def state_dict(self, destination=None, prefix='', keep_vars=False):
  54. return self.module.state_dict(destination, prefix, keep_vars)
  55. def state_dict_for_save_checkpoint(self, destination=None, prefix='',
  56. keep_vars=False):
  57. return self.module.state_dict_for_save_checkpoint(destination, prefix,
  58. keep_vars)
  59. def load_state_dict(self, state_dict, strict=True):
  60. self.module.load_state_dict(state_dict, strict=strict)
  61. class DistributedDataParallel(DistributedDataParallelBase):
  62. """DDP with contiguous buffers options to storre and accumulate gradients.
  63. This class:
  64. - has the potential to reduce memory fragmentation.
  65. - provides the option to do the gradient accumulation
  66. in a type other than the params type (for example fp32)
  67. Arguments:
  68. module: input model.
  69. accumulate_allreduce_grads_in_fp32: if true do the gradient accumulation
  70. and the gradient all-reduce all in in float32. If this option is
  71. true, we require `use_contiguous_buffers` to be true too.
  72. use_contiguous_buffers: if true, use a contiguous buffer to store the
  73. gradients.
  74. """
  75. def __init__(self, module,
  76. accumulate_allreduce_grads_in_fp32,
  77. use_contiguous_buffers):
  78. super(DistributedDataParallel, self).__init__(module)
  79. self.accumulate_allreduce_grads_in_fp32 \
  80. = accumulate_allreduce_grads_in_fp32
  81. self.use_contiguous_buffers = use_contiguous_buffers
  82. # If we are using fp32-accumulate-allreduce explicitly
  83. # this means we need main grads in a continous buffer.
  84. if self.accumulate_allreduce_grads_in_fp32:
  85. assert self.use_contiguous_buffers
  86. # ===================================
  87. # Rest of this part applies only to
  88. # the case we use continuous buffers.
  89. # ===================================
  90. self._grad_buffers = None
  91. if self.use_contiguous_buffers:
  92. self._grad_buffers = {}
  93. # Simple function to define buffer type.
  94. def _get_buffer_type(param):
  95. return torch.float if \
  96. self.accumulate_allreduce_grads_in_fp32 else param.dtype
  97. # First calculate total number of elements per type.
  98. type_num_elements = {}
  99. for param in self.module.parameters():
  100. if param.requires_grad:
  101. dtype = _get_buffer_type(param)
  102. type_num_elements[dtype] = type_num_elements.get(dtype, 0) \
  103. + param.data.nelement()
  104. # Allocate the buffer.
  105. for dtype, num_elements in type_num_elements.items():
  106. self._grad_buffers[dtype] = MemoryBuffer(num_elements, dtype)
  107. # Assume the back prop order is reverse the params order,
  108. # store the start index for the gradients.
  109. for param in self.module.parameters():
  110. if param.requires_grad:
  111. dtype = _get_buffer_type(param)
  112. type_num_elements[dtype] -= param.data.nelement()
  113. param.main_grad = self._grad_buffers[dtype].get(
  114. param.data.shape, type_num_elements[dtype])
  115. # Backward hook.
  116. # Accumalation function for the gradients. We need
  117. # to store them so they don't go out of scope.
  118. self.grad_accs = []
  119. # Loop over all the parameters in the model.
  120. for param in self.module.parameters():
  121. if param.requires_grad:
  122. # Expand so we get access to grad_fn.
  123. param_tmp = param.expand_as(param)
  124. # Get the gradient accumulator functtion.
  125. grad_acc = param_tmp.grad_fn.next_functions[0][0]
  126. grad_acc.register_hook(self._make_param_hook(param))
  127. self.grad_accs.append(grad_acc)
  128. def _make_param_hook(self, param):
  129. """Create the all-reduce hook for backprop."""
  130. # Hook used for back-prop.
  131. def param_hook(*unused):
  132. # Add the gradient to the buffer.
  133. if param.grad.data is not None:
  134. param.main_grad.add_(param.grad.data)
  135. # Now we can deallocate grad memory.
  136. param.grad = None
  137. return param_hook
  138. def zero_grad_buffer(self):
  139. """Set the grad buffer data to zero. Needs to be called at the
  140. begining of each iteration."""
  141. assert self._grad_buffers is not None, 'buffers are not initialized.'
  142. for _, buffer_ in self._grad_buffers.items():
  143. buffer_.zero()
  144. def allreduce_gradients(self):
  145. """Reduce gradients across data parallel ranks."""
  146. # If we have buffers, simply reduce the data in the buffer.
  147. if self._grad_buffers is not None:
  148. for _, buffer_ in self._grad_buffers.items():
  149. buffer_.data /= mpu.get_data_parallel_world_size()
  150. torch.distributed.all_reduce(
  151. buffer_.data, group=mpu.get_data_parallel_group())
  152. else:
  153. # Otherwise, bucketize and all-reduce
  154. buckets = {}
  155. # Pack the buckets.
  156. for param in self.module.parameters():
  157. if param.requires_grad and param.grad is not None:
  158. tp = param.data.type()
  159. if tp not in buckets:
  160. buckets[tp] = []
  161. buckets[tp].append(param)
  162. param.main_grad = param.grad
  163. # For each bucket, all-reduce and copy all-reduced grads.
  164. for tp in buckets:
  165. bucket = buckets[tp]
  166. grads = [param.grad.data for param in bucket]
  167. coalesced = _flatten_dense_tensors(grads)
  168. coalesced /= mpu.get_data_parallel_world_size()
  169. torch.distributed.all_reduce(
  170. coalesced, group=mpu.get_data_parallel_group())
  171. for buf, synced in zip(grads, _unflatten_dense_tensors(
  172. coalesced, grads)):
  173. buf.copy_(synced)