123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import torch
- from .initialize import get_tensor_model_parallel_group
- from .initialize import get_tensor_model_parallel_rank
- from .initialize import get_tensor_model_parallel_src_rank
- _MAX_DATA_DIM = 5
- def _check_data_types(keys, data, target_dtype):
- """Check that all the keys have the same target data type."""
- for key in keys:
- assert data[key].dtype == target_dtype, '{} has data type {} which '\
- 'is different than {}'.format(key, data[key].dtype, target_dtype)
- def _build_key_size_numel_dictionaries(keys, data):
- """Build the size on rank 0 and broadcast."""
- max_dim = _MAX_DATA_DIM
- sizes = [0 for _ in range(max_dim) for _ in keys]
- # Pack the sizes on rank zero.
- if get_tensor_model_parallel_rank() == 0:
- offset = 0
- for key in keys:
- assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
- size = data[key].size()
- for i, s in enumerate(size):
- sizes[i + offset] = s
- offset += max_dim
- # Move to GPU and broadcast.
- sizes_cuda = torch.cuda.LongTensor(sizes)
- torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
- group=get_tensor_model_parallel_group())
- # Move back to cpu and unpack.
- sizes_cpu = sizes_cuda.cpu()
- key_size = {}
- key_numel = {}
- total_numel = 0
- offset = 0
- for key in keys:
- i = 0
- size = []
- numel = 1
- while sizes_cpu[offset + i] > 0:
- this_size = sizes_cpu[offset + i]
- size.append(this_size)
- numel *= this_size
- i += 1
- key_size[key] = size
- key_numel[key] = numel
- total_numel += numel
- offset += max_dim
- return key_size, key_numel, total_numel
- def broadcast_data(keys, data, datatype):
- """Broadcast data from rank zero of each model parallel group to the
- members of the same model parallel group.
- Arguments:
- keys: list of keys in the data disctionary to be broadcasted
- data: data dictionary of string keys and cpu tensor values.
- datatype: torch data type of all tensors in data associated
- with keys.
- """
- # Build (key, size) and (key, number of elements) dictionaries along
- # with the total number of elements on all ranks.
- key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
- data)
- # Pack on rank zero.
- if get_tensor_model_parallel_rank() == 0:
- # Check that all keys have the same data type.
- _check_data_types(keys, data, datatype)
- # Flatten the data associated with the keys
- flatten_data = torch.cat(
- [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
- else:
- flatten_data = torch.empty(total_numel,
- device=torch.cuda.current_device(),
- dtype=datatype)
- # Broadcast
- torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
- group=get_tensor_model_parallel_group())
- # Unpack
- output = {}
- offset = 0
- for key in keys:
- size = key_size[key]
- numel = key_numel[key]
- output[key] = flatten_data.narrow(0, offset, numel).view(size)
- offset += numel
- return output
|