| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# Parts of the code here are adapted from PyTorch# repo: https://github.com/pytorch/pytorchimport mathimport torchimport torch.nn.functional as Fimport torch.nn.init as initfrom torch.nn.parameter import Parameterfrom .initialize import get_tensor_model_parallel_rankfrom .initialize import get_tensor_model_parallel_world_sizefrom .mappings import copy_to_tensor_model_parallel_regionfrom .mappings import gather_from_tensor_model_parallel_regionfrom .mappings import reduce_from_tensor_model_parallel_regionfrom .mappings import scatter_to_tensor_model_parallel_regionfrom .random import get_cuda_rng_trackerfrom .utils import dividefrom .utils import split_tensor_along_last_dimfrom .utils import VocabUtilityfrom megatron import get_args_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False,                                      'partition_dim': -1,                                      'partition_stride': 1}def param_is_not_tensor_parallel_duplicate(param):    return (hasattr(param, 'tensor_model_parallel') and            param.tensor_model_parallel) or (                get_tensor_model_parallel_rank() == 0)def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride):    # Make sure the attributes are not set.    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:        assert not hasattr(tensor, attribute)    # Set the attributes.    setattr(tensor, 'tensor_model_parallel', is_parallel)    setattr(tensor, 'partition_dim', dim)    setattr(tensor, 'partition_stride', stride)def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor):    def maybe_set(attribute, value):        if not hasattr(tensor, attribute):            setattr(tensor, attribute, value)    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:        maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute])def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor):    def maybe_copy(attribute):        if hasattr(source_tensor, attribute):            setattr(destination_tensor, attribute,                    getattr(source_tensor, attribute))    for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS:        maybe_copy(attribute)def _initialize_affine_weight_gpu(weight, init_method,                                  partition_dim, stride=1):    """Initialize affine weight for model parallel on GPU."""    set_tensor_model_parallel_attributes(tensor=weight,                                         is_parallel=True,                                         dim=partition_dim,                                         stride=stride)    with get_cuda_rng_tracker().fork():        init_method(weight)def _initialize_affine_weight_cpu(weight, output_size, input_size,                                  per_partition_size, partition_dim,                                  init_method, stride=1,                                  return_master_weight=False):    """Initialize affine weight for model parallel.    Build the master weight on all processes and scatter    the relevant chunk."""    set_tensor_model_parallel_attributes(tensor=weight,                                         is_parallel=True,                                         dim=partition_dim,                                         stride=stride)    # Initialize master weight    master_weight = torch.empty(output_size, input_size,                                dtype=torch.float,                                requires_grad=False)    init_method(master_weight)    args = get_args()    master_weight = master_weight.to(dtype=args.params_dtype)    # Split and copy    per_partition_per_stride_size = divide(per_partition_size, stride)    weight_list = torch.split(master_weight, per_partition_per_stride_size,                              dim=partition_dim)    rank = get_tensor_model_parallel_rank()    world_size = get_tensor_model_parallel_world_size()    my_weight_list = weight_list[rank::world_size]    with torch.no_grad():        torch.cat(my_weight_list, dim=partition_dim, out=weight)    if return_master_weight:        return master_weight    return Noneclass VocabParallelEmbedding(torch.nn.Module):    """Embedding parallelized in the vocabulary dimension.    This is mainly adapted from torch.nn.Embedding and all the default    values are kept.    Arguments:        num_embeddings: vocabulary size.        embedding_dim: size of hidden state.        init_method: method to initialize weights.    """    def __init__(self, num_embeddings, embedding_dim,                 init_method=init.xavier_normal_):        super(VocabParallelEmbedding, self).__init__()        # Keep the input dimensions.        self.num_embeddings = num_embeddings        self.embedding_dim = embedding_dim        # Set the detauls for compatibility.        self.padding_idx = None        self.max_norm = None        self.norm_type = 2.        self.scale_grad_by_freq = False        self.sparse = False        self._weight = None        self.tensor_model_parallel_size = get_tensor_model_parallel_world_size()        # Divide the weight matrix along the vocaburaly dimension.        self.vocab_start_index, self.vocab_end_index = \            VocabUtility.vocab_range_from_global_vocab_size(                self.num_embeddings, get_tensor_model_parallel_rank(),                self.tensor_model_parallel_size)        self.num_embeddings_per_partition = self.vocab_end_index - \            self.vocab_start_index        # Allocate weights and initialize.        args = get_args()        if args.use_cpu_initialization:            self.weight = Parameter(torch.empty(                self.num_embeddings_per_partition, self.embedding_dim,                dtype=args.params_dtype))            _initialize_affine_weight_cpu(                self.weight, self.num_embeddings, self.embedding_dim,                self.num_embeddings_per_partition, 0, init_method)        else:            self.weight = Parameter(torch.empty(                self.num_embeddings_per_partition, self.embedding_dim,                device=torch.cuda.current_device(), dtype=args.params_dtype))            _initialize_affine_weight_gpu(self.weight, init_method,                                          partition_dim=0, stride=1)    def forward(self, input_):        if self.tensor_model_parallel_size > 1:            # Build the mask.            input_mask = (input_ < self.vocab_start_index) | \                         (input_ >= self.vocab_end_index)            # Mask the input.            masked_input = input_.clone() - self.vocab_start_index            masked_input[input_mask] = 0        else:            masked_input = input_            # Get the embeddings.        output_parallel = F.embedding(masked_input, self.weight,                                      self.padding_idx, self.max_norm,                                      self.norm_type, self.scale_grad_by_freq,                                      self.sparse)        # Mask the output embedding.        if self.tensor_model_parallel_size > 1:            output_parallel[input_mask, :] = 0.0        # Reduce across all the model parallel GPUs.        output = reduce_from_tensor_model_parallel_region(output_parallel)        return outputclass ColumnParallelLinear(torch.nn.Module):    """Linear layer with column parallelism.    The linear layer is defined as Y = XA + b. A is parallelized along    its second dimension as A = [A_1, ..., A_p].    Arguments:        input_size: first dimension of matrix A.        output_size: second dimension of matrix A.        bias: If true, add bias        gather_output: If true, call all-gether on output and make Y avaiable                       to all GPUs, otherwise, every GPU will have its output                       which is Y_i = XA_i        init_method: method to initialize weights. Note that bias is always set                     to zero.        stride: For the strided linear layers.        keep_master_weight_for_test: This was added for testing and should be                                     set to False. It returns the master weights                                     used for initialization.        skip_bias_add: This was added to enable performance optimations where bias                       can be fused with other elementwise operations. we skip                        adding bias but instead return it.    """    def __init__(self, input_size, output_size, bias=True, gather_output=True,                 init_method=init.xavier_normal_, stride=1,                 keep_master_weight_for_test=False,                 skip_bias_add=False):        super(ColumnParallelLinear, self).__init__()        # Keep input parameters        self.input_size = input_size        self.output_size = output_size        self.gather_output = gather_output        # Divide the weight matrix along the last dimension.        world_size = get_tensor_model_parallel_world_size()        self.output_size_per_partition = divide(output_size, world_size)        self.skip_bias_add = skip_bias_add        # Parameters.        # Note: torch.nn.functional.linear performs XA^T + b and as a result        # we allocate the transpose.        # Initialize weight.        args = get_args()        if args.use_cpu_initialization:            self.weight = Parameter(torch.empty(self.output_size_per_partition,                                                self.input_size,                                                dtype=args.params_dtype))            self.master_weight = _initialize_affine_weight_cpu(                self.weight, self.output_size, self.input_size,                self.output_size_per_partition, 0, init_method,                stride=stride, return_master_weight=keep_master_weight_for_test)        else:            self.weight = Parameter(torch.empty(                self.output_size_per_partition, self.input_size,                device=torch.cuda.current_device(), dtype=args.params_dtype))            _initialize_affine_weight_gpu(self.weight, init_method,                                          partition_dim=0, stride=stride)                    if bias:            if args.use_cpu_initialization:                self.bias = Parameter(torch.empty(                    self.output_size_per_partition, dtype=args.params_dtype))            else:                self.bias = Parameter(torch.empty(                    self.output_size_per_partition,                    device=torch.cuda.current_device(),                    dtype=args.params_dtype))            set_tensor_model_parallel_attributes(self.bias, True, 0, stride)            # Always initialize bias to zero.            with torch.no_grad():                self.bias.zero_()        else:            self.register_parameter('bias', None)    def forward(self, input_):        # Set up backprop all-reduce.        input_parallel = copy_to_tensor_model_parallel_region(input_)        # Matrix multiply.        bias = self.bias if not self.skip_bias_add else None        output_parallel = F.linear(input_parallel, self.weight, bias)        if self.gather_output:            # All-gather across the partitions.            output = gather_from_tensor_model_parallel_region(output_parallel)        else:            output = output_parallel         output_bias = self.bias if self.skip_bias_add else None        return output, output_biasclass RowParallelLinear(torch.nn.Module):    """Linear layer with row parallelism.    The linear layer is defined as Y = XA + b. A is parallelized along    its first dimension and X along its second dimension as:               -   -              | A_1 |              | .   |          A = | .   |        X = [X_1, ..., X_p]              | .   |              | A_p |               -   -    Arguments:        input_size: first dimension of matrix A.        output_size: second dimension of matrix A.        bias: If true, add bias. Note that bias is not parallelized.        input_is_parallel: If true, we assume that the input is already                           split across the GPUs and we do not split                           again.        init_method: method to initialize weights. Note that bias is always set                     to zero.        stride: For the strided linear layers.        keep_master_weight_for_test: This was added for testing and should be                                     set to False. It returns the master weights                                     used for initialization.        skip_bias_add: This was added to enable performance optimations where bias                       can be fused with other elementwise operations. we skip                        adding bias but instead return it.    """    def __init__(self, input_size, output_size, bias=True,                 input_is_parallel=False,                 init_method=init.xavier_normal_, stride=1,                 keep_master_weight_for_test=False,                 skip_bias_add=False):        super(RowParallelLinear, self).__init__()        # Keep input parameters        self.input_size = input_size        self.output_size = output_size        self.input_is_parallel = input_is_parallel        # Divide the weight matrix along the last dimension.        world_size = get_tensor_model_parallel_world_size()        self.input_size_per_partition = divide(input_size, world_size)        self.skip_bias_add = skip_bias_add        # Parameters.        # Note: torch.nn.functional.linear performs XA^T + b and as a result        # we allocate the transpose.        # Initialize weight.        args = get_args()        if args.use_cpu_initialization:            self.weight = Parameter(torch.empty(self.output_size,                                                self.input_size_per_partition,                                                dtype=args.params_dtype))            self.master_weight = _initialize_affine_weight_cpu(                self.weight, self.output_size, self.input_size,                self.input_size_per_partition, 1, init_method,                stride=stride, return_master_weight=keep_master_weight_for_test)        else:            self.weight = Parameter(torch.empty(                self.output_size, self.input_size_per_partition,                device=torch.cuda.current_device(), dtype=args.params_dtype))            _initialize_affine_weight_gpu(self.weight, init_method,                                          partition_dim=1, stride=stride)        if bias:            if args.use_cpu_initialization:                self.bias = Parameter(torch.empty(self.output_size,                                                  dtype=args.params_dtype))            else:                self.bias = Parameter(torch.empty(                    self.output_size, device=torch.cuda.current_device(),                    dtype=args.params_dtype))            # Always initialize bias to zero.            with torch.no_grad():                self.bias.zero_()        else:            self.register_parameter('bias', None)    def forward(self, input_):        # Set up backprop all-reduce.        if self.input_is_parallel:            input_parallel = input_        else:            input_parallel = scatter_to_tensor_model_parallel_region(input_)        # Matrix multiply.        output_parallel = F.linear(input_parallel, self.weight)        # All-reduce across all the partitions.        output_ = reduce_from_tensor_model_parallel_region(output_parallel)        if not self.skip_bias_add:            output = output_ + self.bias if self.bias is not None else output_            output_bias = None        else:            output = output_            output_bias = self.bias        return output, output_bias
 |