123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from functools import reduce
- import operator
- import torch
- from megatron import get_args
- from megatron import mpu
- def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
- use_ring_exchange=False):
- """Communicate tensors between stages. Used as helper method in other
- communication methods that are used in megatron/schedules.py.
- Takes the following arguments:
- tensor_send_next: tensor to send to next rank (no tensor sent if
- set to None).
- tensor_send_prev: tensor to send to prev rank (no tensor sent if
- set to None).
- recv_prev: boolean for whether tensor should be received from
- previous rank.
- recv_next: boolean for whether tensor should be received from
- next rank.
- use_ring_exchange: boolean for whether torch.distributed.ring_exchange()
- API should be used.
- Returns:
- (tensor_recv_prev, tensor_recv_next)
- """
- args = get_args()
- # Create placeholder tensors for receive in forward and backward directions
- # if needed.
- tensor_recv_prev = None
- tensor_recv_next = None
- tensor_shape = (args.seq_length, args.micro_batch_size, args.hidden_size)
- if args.scatter_gather_tensors_in_pipeline:
- tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) // \
- mpu.get_tensor_model_parallel_world_size()
- else:
- tensor_chunk_shape = tensor_shape
- dtype = args.params_dtype
- if args.fp32_residual_connection:
- dtype = torch.float
- if recv_prev:
- tensor_recv_prev = torch.empty(tensor_chunk_shape,
- requires_grad=True,
- device=torch.cuda.current_device(),
- dtype=dtype)
- if recv_next:
- tensor_recv_next = torch.empty(tensor_chunk_shape,
- requires_grad=True,
- device=torch.cuda.current_device(),
- dtype=dtype)
- # Split tensor into smaller chunks if using scatter-gather optimization.
- if args.scatter_gather_tensors_in_pipeline:
- if tensor_send_next is not None:
- tensor_send_next = mpu.split_tensor_into_1d_equal_chunks(tensor_send_next)
- if tensor_send_prev is not None:
- tensor_send_prev = mpu.split_tensor_into_1d_equal_chunks(tensor_send_prev)
- # Send tensors in both the forward and backward directions as appropriate.
- if use_ring_exchange:
- torch.distributed.ring_exchange(tensor_send_prev=tensor_send_prev,
- tensor_recv_prev=tensor_recv_prev,
- tensor_send_next=tensor_send_next,
- tensor_recv_next=tensor_recv_next,
- group=mpu.get_pipeline_model_parallel_group())
- else:
- ops = []
- if tensor_send_prev is not None:
- send_prev_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_prev,
- mpu.get_pipeline_model_parallel_prev_rank())
- ops.append(send_prev_op)
- if tensor_recv_prev is not None:
- recv_prev_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_prev,
- mpu.get_pipeline_model_parallel_prev_rank())
- ops.append(recv_prev_op)
- if tensor_send_next is not None:
- send_next_op = torch.distributed.P2POp(
- torch.distributed.isend, tensor_send_next,
- mpu.get_pipeline_model_parallel_next_rank())
- ops.append(send_next_op)
- if tensor_recv_next is not None:
- recv_next_op = torch.distributed.P2POp(
- torch.distributed.irecv, tensor_recv_next,
- mpu.get_pipeline_model_parallel_next_rank())
- ops.append(recv_next_op)
- if len(ops) > 0:
- reqs = torch.distributed.batch_isend_irecv(ops)
- for req in reqs:
- req.wait()
- # To protect against race condition when using batch_isend_irecv().
- torch.cuda.synchronize()
- # If using scatter-gather optimization, gather smaller chunks.
- if args.scatter_gather_tensors_in_pipeline:
- if recv_prev:
- tensor_recv_prev = mpu.gather_split_1d_tensor(
- tensor_recv_prev).view(tensor_shape).requires_grad_()
- if recv_next:
- tensor_recv_next = mpu.gather_split_1d_tensor(
- tensor_recv_next).view(tensor_shape).requires_grad_()
- return tensor_recv_prev, tensor_recv_next
- def recv_forward(timers=None):
- """Receive tensor from previous rank in pipeline (forward receive)."""
- if mpu.is_pipeline_first_stage():
- input_tensor = None
- else:
- if timers is not None:
- timers('forward-recv').start()
- input_tensor, _ = _communicate(
- tensor_send_next=None,
- tensor_send_prev=None,
- recv_prev=True,
- recv_next=False)
- if timers is not None:
- timers('forward-recv').stop()
- return input_tensor
- def recv_backward(timers=None):
- """Receive tensor from next rank in pipeline (backward receive)."""
- if mpu.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- if timers is not None:
- timers('backward-recv').start()
- _, output_tensor_grad = _communicate(
- tensor_send_next=None,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=True)
- if timers is not None:
- timers('backward-recv').stop()
- return output_tensor_grad
- def send_forward(output_tensor, timers=None):
- """Send tensor to next rank in pipeline (forward send)."""
- if not mpu.is_pipeline_last_stage():
- if timers is not None:
- timers('forward-send').start()
- _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=False)
- if timers is not None:
- timers('forward-send').stop()
- def send_backward(input_tensor_grad, timers=None):
- """Send tensor to previous rank in pipeline (backward send)."""
- if not mpu.is_pipeline_first_stage():
- if timers is not None:
- timers('backward-send').start()
- _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=False,
- recv_next=False)
- if timers is not None:
- timers('backward-send').stop()
- def send_forward_recv_backward(output_tensor, timers=None):
- """Batched send and recv with next rank in pipeline."""
- if mpu.is_pipeline_last_stage():
- output_tensor_grad = None
- else:
- if timers is not None:
- timers('forward-send-backward-recv').start()
- _, output_tensor_grad = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=False,
- recv_next=True)
- if timers is not None:
- timers('forward-send-backward-recv').stop()
- return output_tensor_grad
- def send_backward_recv_forward(input_tensor_grad, timers=None):
- """Batched send and recv with previous rank in pipeline."""
- if mpu.is_pipeline_first_stage():
- input_tensor = None
- else:
- if timers is not None:
- timers('backward-send-forward-recv').start()
- input_tensor, _ = _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=True,
- recv_next=False)
- if timers is not None:
- timers('backward-send-forward-recv').stop()
- return input_tensor
- def send_forward_recv_forward(output_tensor, recv_prev, timers=None):
- """Batched recv from previous rank and send to next rank in pipeline."""
- if timers is not None:
- timers('forward-send-forward-recv').start()
- input_tensor, _ = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=None,
- recv_prev=recv_prev,
- recv_next=False)
- if timers is not None:
- timers('forward-send-forward-recv').stop()
- return input_tensor
- def send_backward_recv_backward(input_tensor_grad, recv_next, timers=None):
- """Batched recv from next rank and send to previous rank in pipeline."""
- if timers is not None:
- timers('backward-send-backward-recv').start()
- _, output_tensor_grad = _communicate(
- tensor_send_next=None,
- tensor_send_prev=input_tensor_grad,
- recv_prev=False,
- recv_next=recv_next)
- if timers is not None:
- timers('backward-send-backward-recv').stop()
- return output_tensor_grad
- def send_forward_backward_recv_forward_backward(
- output_tensor, input_tensor_grad, recv_prev,
- recv_next, timers=None):
- """Batched send and recv with previous and next ranks in pipeline."""
- if timers is not None:
- timers('forward-backward-send-forward-backward-recv').start()
- input_tensor, output_tensor_grad = _communicate(
- tensor_send_next=output_tensor,
- tensor_send_prev=input_tensor_grad,
- recv_prev=recv_prev,
- recv_next=recv_next)
- if timers is not None:
- timers('forward-backward-send-forward-backward-recv').stop()
- return input_tensor, output_tensor_grad
|