| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.import torchfrom .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rankfrom .utils import split_tensor_along_last_dimdef _reduce(input_):    """All-reduce the the input tensor across model parallel group."""    # Bypass the function if we are using only 1 GPU.    if get_tensor_model_parallel_world_size()==1:        return input_    # All-reduce.    torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())    return input_def _split(input_):    """Split the tensor along its last dimension and keep the    corresponding slice."""    world_size = get_tensor_model_parallel_world_size()    # Bypass the function if we are using only 1 GPU.    if world_size==1:        return input_    # Split along last dimension.    input_list = split_tensor_along_last_dim(input_, world_size)    # Note: torch.split does not create contiguous tensors by default.    rank = get_tensor_model_parallel_rank()    output = input_list[rank].contiguous()    return outputdef _gather(input_):    """Gather tensors and concatinate along the last dimension."""    world_size = get_tensor_model_parallel_world_size()    # Bypass the function if we are using only 1 GPU.    if world_size==1:        return input_    # Size and dimension.    last_dim = input_.dim() - 1    rank = get_tensor_model_parallel_rank()    tensor_list = [torch.empty_like(input_) for _ in range(world_size)]    tensor_list[rank] = input_    torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())    # Note: torch.cat already creates a contiguous tensor.    output = torch.cat(tensor_list, dim=last_dim).contiguous()    return outputclass _CopyToModelParallelRegion(torch.autograd.Function):    """Pass the input to the model parallel region."""    @staticmethod    def symbolic(graph, input_):        return input_        @staticmethod    def forward(ctx, input_):        return input_    @staticmethod    def backward(ctx, grad_output):        return _reduce(grad_output)class _ReduceFromModelParallelRegion(torch.autograd.Function):    """All-reduce the input from the model parallel region."""    @staticmethod    def symbolic(graph, input_):        return _reduce(input_)        @staticmethod    def forward(ctx, input_):        return _reduce(input_)    @staticmethod    def backward(ctx, grad_output):        return grad_outputclass _ScatterToModelParallelRegion(torch.autograd.Function):    """Split the input and keep only the corresponding chuck to the rank."""    @staticmethod    def symbolic(graph, input_):        return _split(input_)    @staticmethod    def forward(ctx, input_):        return _split(input_)    @staticmethod    def backward(ctx, grad_output):        return _gather(grad_output)class _GatherFromModelParallelRegion(torch.autograd.Function):    """Gather the input from model parallel region and concatinate."""    @staticmethod    def symbolic(graph, input_):        return _gather(input_)        @staticmethod    def forward(ctx, input_):        return _gather(input_)    @staticmethod    def backward(ctx, grad_output):        return _split(grad_output)# -----------------# Helper functions.# -----------------def copy_to_tensor_model_parallel_region(input_):    return _CopyToModelParallelRegion.apply(input_)def reduce_from_tensor_model_parallel_region(input_):    return _ReduceFromModelParallelRegion.apply(input_)def scatter_to_tensor_model_parallel_region(input_):    return _ScatterToModelParallelRegion.apply(input_)def gather_from_tensor_model_parallel_region(input_):    return _GatherFromModelParallelRegion.apply(input_)
 |