p2p_communication.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from functools import reduce
  16. import operator
  17. import torch
  18. from megatron import get_args
  19. from megatron import mpu
  20. def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
  21. use_ring_exchange=False):
  22. """Communicate tensors between stages. Used as helper method in other
  23. communication methods that are used in megatron/schedules.py.
  24. Takes the following arguments:
  25. tensor_send_next: tensor to send to next rank (no tensor sent if
  26. set to None).
  27. tensor_send_prev: tensor to send to prev rank (no tensor sent if
  28. set to None).
  29. recv_prev: boolean for whether tensor should be received from
  30. previous rank.
  31. recv_next: boolean for whether tensor should be received from
  32. next rank.
  33. use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
  34. API should be used.
  35. Returns:
  36. (tensor_recv_prev, tensor_recv_next)
  37. """
  38. args = get_args()
  39. # Create placeholder tensors for receive in forward and backward directions
  40. # if needed.
  41. tensor_recv_prev = None
  42. tensor_recv_next = None
  43. tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
  44. if args.scatter_gather_tensors_in_pipeline:
  45. tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
  46. mpu.get_tensor_model_parallel_world_size()
  47. else:
  48. tensor_chunk_shape = tensor_shape
  49. dtype = args.params_dtype
  50. if args.fp32_residual_connection:
  51. dtype = torch.float
  52. if recv_prev:
  53. tensor_recv_prev = torch.empty(tensor_chunk_shape,
  54. requires_grad=True,
  55. device=torch.cuda.current_device(),
  56. dtype=dtype)
  57. if recv_next:
  58. tensor_recv_next = torch.empty(tensor_chunk_shape,
  59. requires_grad=True,
  60. device=torch.cuda.current_device(),
  61. dtype=dtype)
  62. # Split tensor into smaller chunks if using scatter-gather optimization.
  63. if args.scatter_gather_tensors_in_pipeline:
  64. if tensor_send_next is not None:
  65. tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
  66. if tensor_send_prev is not None:
  67. tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
  68. # Send tensors in both the forward and backward directions as appropriate.
  69. if use_ring_exchange:
  70. torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
  71. tensor_recv_prev=tensor_recv_prev,
  72. tensor_send_next=tensor_send_next,
  73. tensor_recv_next=tensor_recv_next,
  74. group=mpu.get_pipeline_model_parallel_group())
  75. else:
  76. ops = []
  77. if tensor_send_prev is not None:
  78. send_prev_op = torch.distributed.P2POp(
  79. torch.distributed.isend, tensor_send_prev,
  80. mpu.get_pipeline_model_parallel_prev_rank())
  81. ops.append(send_prev_op)
  82. if tensor_recv_prev is not None:
  83. recv_prev_op = torch.distributed.P2POp(
  84. torch.distributed.irecv, tensor_recv_prev,
  85. mpu.get_pipeline_model_parallel_prev_rank())
  86. ops.append(recv_prev_op)
  87. if tensor_send_next is not None:
  88. send_next_op = torch.distributed.P2POp(
  89. torch.distributed.isend, tensor_send_next,
  90. mpu.get_pipeline_model_parallel_next_rank())
  91. ops.append(send_next_op)
  92. if tensor_recv_next is not None:
  93. recv_next_op = torch.distributed.P2POp(
  94. torch.distributed.irecv, tensor_recv_next,
  95. mpu.get_pipeline_model_parallel_next_rank())
  96. ops.append(recv_next_op)
  97. if len(ops) > 0:
  98. reqs = torch.distributed.batch_isend_irecv(ops)
  99. for req in reqs:
  100. req.wait()
  101. # To protect against race condition when using batch_isend_irecv().
  102. torch.cuda.synchronize()
  103. # If using scatter-gather optimization, gather smaller chunks.
  104. if args.scatter_gather_tensors_in_pipeline:
  105. if recv_prev:
  106. tensor_recv_prev = mpu.gather_split_1d_tensor(
  107. tensor_recv_prev).view(tensor_shape).requires_grad_()
  108. if recv_next:
  109. tensor_recv_next = mpu.gather_split_1d_tensor(
  110. tensor_recv_next).view(tensor_shape).requires_grad_()
  111. return tensor_recv_prev, tensor_recv_next
  112. def recv_forward(timers=None):
  113. """Receive tensor from previous rank in pipeline (forward receive)."""
  114. if mpu.is_pipeline_first_stage():
  115. input_tensor = None
  116. else:
  117. if timers is not None:
  118. timers('forward-recv').start()
  119. input_tensor, _ = _communicate(
  120. tensor_send_next=None,
  121. tensor_send_prev=None,
  122. recv_prev=True,
  123. recv_next=False)
  124. if timers is not None:
  125. timers('forward-recv').stop()
  126. return input_tensor
  127. def recv_backward(timers=None):
  128. """Receive tensor from next rank in pipeline (backward receive)."""
  129. if mpu.is_pipeline_last_stage():
  130. output_tensor_grad = None
  131. else:
  132. if timers is not None:
  133. timers('backward-recv').start()
  134. _, output_tensor_grad = _communicate(
  135. tensor_send_next=None,
  136. tensor_send_prev=None,
  137. recv_prev=False,
  138. recv_next=True)
  139. if timers is not None:
  140. timers('backward-recv').stop()
  141. return output_tensor_grad
  142. def send_forward(output_tensor, timers=None):
  143. """Send tensor to next rank in pipeline (forward send)."""
  144. if not mpu.is_pipeline_last_stage():
  145. if timers is not None:
  146. timers('forward-send').start()
  147. _communicate(
  148. tensor_send_next=output_tensor,
  149. tensor_send_prev=None,
  150. recv_prev=False,
  151. recv_next=False)
  152. if timers is not None:
  153. timers('forward-send').stop()
  154. def send_backward(input_tensor_grad, timers=None):
  155. """Send tensor to previous rank in pipeline (backward send)."""
  156. if not mpu.is_pipeline_first_stage():
  157. if timers is not None:
  158. timers('backward-send').start()
  159. _communicate(
  160. tensor_send_next=None,
  161. tensor_send_prev=input_tensor_grad,
  162. recv_prev=False,
  163. recv_next=False)
  164. if timers is not None:
  165. timers('backward-send').stop()
  166. def send_forward_recv_backward(output_tensor, timers=None):
  167. """Batched send and recv with next rank in pipeline."""
  168. if mpu.is_pipeline_last_stage():
  169. output_tensor_grad = None
  170. else:
  171. if timers is not None:
  172. timers('forward-send-backward-recv').start()
  173. _, output_tensor_grad = _communicate(
  174. tensor_send_next=output_tensor,
  175. tensor_send_prev=None,
  176. recv_prev=False,
  177. recv_next=True)
  178. if timers is not None:
  179. timers('forward-send-backward-recv').stop()
  180. return output_tensor_grad
  181. def send_backward_recv_forward(input_tensor_grad, timers=None):
  182. """Batched send and recv with previous rank in pipeline."""
  183. if mpu.is_pipeline_first_stage():
  184. input_tensor = None
  185. else:
  186. if timers is not None:
  187. timers('backward-send-forward-recv').start()
  188. input_tensor, _ = _communicate(
  189. tensor_send_next=None,
  190. tensor_send_prev=input_tensor_grad,
  191. recv_prev=True,
  192. recv_next=False)
  193. if timers is not None:
  194. timers('backward-send-forward-recv').stop()
  195. return input_tensor
  196. def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
  197. """Batched recv from previous rank and send to next rank in pipeline."""
  198. if timers is not None:
  199. timers('forward-send-forward-recv').start()
  200. input_tensor, _ = _communicate(
  201. tensor_send_next=output_tensor,
  202. tensor_send_prev=None,
  203. recv_prev=recv_prev,
  204. recv_next=False)
  205. if timers is not None:
  206. timers('forward-send-forward-recv').stop()
  207. return input_tensor
  208. def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
  209. """Batched recv from next rank and send to previous rank in pipeline."""
  210. if timers is not None:
  211. timers('backward-send-backward-recv').start()
  212. _, output_tensor_grad = _communicate(
  213. tensor_send_next=None,
  214. tensor_send_prev=input_tensor_grad,
  215. recv_prev=False,
  216. recv_next=recv_next)
  217. if timers is not None:
  218. timers('backward-send-backward-recv').stop()
  219. return output_tensor_grad
  220. def send_forward_backward_recv_forward_backward(
  221. output_tensor, input_tensor_grad, recv_prev,
  222. recv_next, timers=None):
  223. """Batched send and recv with previous and next ranks in pipeline."""
  224. if timers is not None:
  225. timers('forward-backward-send-forward-backward-recv').start()
  226. input_tensor, output_tensor_grad = _communicate(
  227. tensor_send_next=output_tensor,
  228. tensor_send_prev=input_tensor_grad,
  229. recv_prev=recv_prev,
  230. recv_next=recv_next)
  231. if timers is not None:
  232. timers('forward-backward-send-forward-backward-recv').stop()
  233. return input_tensor, output_tensor_grad