data.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from .initialize import get_tensor_model_parallel_group
  17. from .initialize import get_tensor_model_parallel_rank
  18. from .initialize import get_tensor_model_parallel_src_rank
  19. _MAX_DATA_DIM = 5
  20. def _check_data_types(keys, data, target_dtype):
  21. """Check that all the keys have the same target data type."""
  22. for key in keys:
  23. assert data[key].dtype == target_dtype, '{} has data type {} which '\
  24. 'is different than {}'.format(key, data[key].dtype, target_dtype)
  25. def _build_key_size_numel_dictionaries(keys, data):
  26. """Build the size on rank 0 and broadcast."""
  27. max_dim = _MAX_DATA_DIM
  28. sizes = [0 for _ in range(max_dim) for _ in keys]
  29. # Pack the sizes on rank zero.
  30. if get_tensor_model_parallel_rank() == 0:
  31. offset = 0
  32. for key in keys:
  33. assert data[key].dim() < max_dim, 'you should increase MAX_DATA_DIM'
  34. size = data[key].size()
  35. for i, s in enumerate(size):
  36. sizes[i + offset] = s
  37. offset += max_dim
  38. # Move to GPU and broadcast.
  39. sizes_cuda = torch.cuda.LongTensor(sizes)
  40. torch.distributed.broadcast(sizes_cuda, get_tensor_model_parallel_src_rank(),
  41. group=get_tensor_model_parallel_group())
  42. # Move back to cpu and unpack.
  43. sizes_cpu = sizes_cuda.cpu()
  44. key_size = {}
  45. key_numel = {}
  46. total_numel = 0
  47. offset = 0
  48. for key in keys:
  49. i = 0
  50. size = []
  51. numel = 1
  52. while sizes_cpu[offset + i] > 0:
  53. this_size = sizes_cpu[offset + i]
  54. size.append(this_size)
  55. numel *= this_size
  56. i += 1
  57. key_size[key] = size
  58. key_numel[key] = numel
  59. total_numel += numel
  60. offset += max_dim
  61. return key_size, key_numel, total_numel
  62. def broadcast_data(keys, data, datatype):
  63. """Broadcast data from rank zero of each model parallel group to the
  64. members of the same model parallel group.
  65. Arguments:
  66. keys: list of keys in the data disctionary to be broadcasted
  67. data: data dictionary of string keys and cpu tensor values.
  68. datatype: torch data type of all tensors in data associated
  69. with keys.
  70. """
  71. # Build (key, size) and (key, number of elements) dictionaries along
  72. # with the total number of elements on all ranks.
  73. key_size, key_numel, total_numel = _build_key_size_numel_dictionaries(keys,
  74. data)
  75. # Pack on rank zero.
  76. if get_tensor_model_parallel_rank() == 0:
  77. # Check that all keys have the same data type.
  78. _check_data_types(keys, data, datatype)
  79. # Flatten the data associated with the keys
  80. flatten_data = torch.cat(
  81. [data[key].contiguous().view(-1) for key in keys], dim=0).cuda()
  82. else:
  83. flatten_data = torch.empty(total_numel,
  84. device=torch.cuda.current_device(),
  85. dtype=datatype)
  86. # Broadcast
  87. torch.distributed.broadcast(flatten_data, get_tensor_model_parallel_src_rank(),
  88. group=get_tensor_model_parallel_group())
  89. # Unpack
  90. output = {}
  91. offset = 0
  92. for key in keys:
  93. size = key_size[key]
  94. numel = key_numel[key]
  95. output[key] = flatten_data.narrow(0, offset, numel).view(size)
  96. offset += numel
  97. return output