| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.import torchfrom .initialize import get_tensor_model_parallel_groupfrom .initialize import get_tensor_model_parallel_rankfrom .initialize import get_tensor_model_parallel_src_rank_MAX_DATA_DIM = 5def _check_data_types(keys, data, target_dtype):    """Check that all the keys have the same target data type."""    for key in keys:        assert data[key].dtype == target_dtype, '{} has data type {} which '\            'is different than {}'.format(key, data[key].dtype, target_dtype)def _build_key_size_numel_dictionaries(keys, data):    """Build the size on rank 0 and broadcast."""    max_dim = _MAX_DATA_DIM    sizes = [0 for _ in range(max_dim) for _ in keys]    # Pack the sizes on rank zero.    if get_tensor_model_parallel_rank() == 0:        offset = 0        for key in keys:            assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'            size = data[key].size()            for i, s in enumerate(size):                sizes[i + offset] = s            offset += max_dim    # Move to GPU and broadcast.    sizes_cuda = torch.cuda.LongTensor(sizes)    torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),                                group=get_tensor_model_parallel_group())    # Move back to cpu and unpack.    sizes_cpu = sizes_cuda.cpu()    key_size = {}    key_numel = {}    total_numel = 0    offset = 0    for key in keys:        i = 0        size = []        numel = 1        while sizes_cpu[offset + i] > 0:            this_size = sizes_cpu[offset + i]            size.append(this_size)            numel *= this_size            i += 1        key_size[key] = size        key_numel[key] = numel        total_numel += numel        offset += max_dim    return key_size, key_numel, total_numeldef broadcast_data(keys, data, datatype):    """Broadcast data from rank zero of each model parallel group to the    members of the same model parallel group.    Arguments:        keys: list of keys in the data disctionary to be broadcasted        data: data dictionary of string keys and cpu tensor values.        datatype: torch data type of all tensors in data associated                  with keys.    """    # Build (key, size) and (key, number of elements) dictionaries along    # with the total number of elements on all ranks.    key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,                                                                          data)    # Pack on rank zero.    if get_tensor_model_parallel_rank() == 0:        # Check that all keys have the same data type.        _check_data_types(keys, data, datatype)        # Flatten the data associated with the keys        flatten_data = torch.cat(            [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()    else:        flatten_data = torch.empty(total_numel,                                   device=torch.cuda.current_device(),                                   dtype=datatype)    # Broadcast    torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),                                group=get_tensor_model_parallel_group())    # Unpack    output = {}    offset = 0    for key in keys:        size = key_size[key]        numel = key_numel[key]        output[key] = flatten_data.narrow(0, offset, numel).view(size)        offset += numel    return output
 |