123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
- from .utils import split_tensor_along_last_dim
- def _reduce(input_):
- """All-reduce the the input tensor across model parallel group."""
- # Bypass the function if we are using only 1 GPU.
- if get_tensor_model_parallel_world_size()==1:
- return input_
- # All-reduce.
- torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
- return input_
- def _split(input_):
- """Split the tensor along its last dimension and keep the
- corresponding slice."""
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size==1:
- return input_
- # Split along last dimension.
- input_list = split_tensor_along_last_dim(input_, world_size)
- # Note: torch.split does not create contiguous tensors by default.
- rank = get_tensor_model_parallel_rank()
- output = input_list[rank].contiguous()
- return output
- def _gather(input_):
- """Gather tensors and concatinate along the last dimension."""
- world_size = get_tensor_model_parallel_world_size()
- # Bypass the function if we are using only 1 GPU.
- if world_size==1:
- return input_
- # Size and dimension.
- last_dim = input_.dim() - 1
- rank = get_tensor_model_parallel_rank()
- tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
- tensor_list[rank] = input_
- torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
- # Note: torch.cat already creates a contiguous tensor.
- output = torch.cat(tensor_list, dim=last_dim).contiguous()
- return output
- class _CopyToModelParallelRegion(torch.autograd.Function):
- """Pass the input to the model parallel region."""
- @staticmethod
- def symbolic(graph, input_):
- return input_
-
- @staticmethod
- def forward(ctx, input_):
- return input_
- @staticmethod
- def backward(ctx, grad_output):
- return _reduce(grad_output)
- class _ReduceFromModelParallelRegion(torch.autograd.Function):
- """All-reduce the input from the model parallel region."""
- @staticmethod
- def symbolic(graph, input_):
- return _reduce(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _reduce(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return grad_output
- class _ScatterToModelParallelRegion(torch.autograd.Function):
- """Split the input and keep only the corresponding chuck to the rank."""
- @staticmethod
- def symbolic(graph, input_):
- return _split(input_)
- @staticmethod
- def forward(ctx, input_):
- return _split(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return _gather(grad_output)
- class _GatherFromModelParallelRegion(torch.autograd.Function):
- """Gather the input from model parallel region and concatinate."""
- @staticmethod
- def symbolic(graph, input_):
- return _gather(input_)
-
- @staticmethod
- def forward(ctx, input_):
- return _gather(input_)
- @staticmethod
- def backward(ctx, grad_output):
- return _split(grad_output)
- # -----------------
- # Helper functions.
- # -----------------
- def copy_to_tensor_model_parallel_region(input_):
- return _CopyToModelParallelRegion.apply(input_)
- def reduce_from_tensor_model_parallel_region(input_):
- return _ReduceFromModelParallelRegion.apply(input_)
- def scatter_to_tensor_model_parallel_region(input_):
- return _ScatterToModelParallelRegion.apply(input_)
- def gather_from_tensor_model_parallel_region(input_):
- return _GatherFromModelParallelRegion.apply(input_)
|