mappings.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import torch
  16. from .initialize import get_tensor_model_parallel_group, get_tensor_model_parallel_world_size, get_tensor_model_parallel_rank
  17. from .utils import split_tensor_along_last_dim
  18. def _reduce(input_):
  19. """All-reduce the the input tensor across model parallel group."""
  20. # Bypass the function if we are using only 1 GPU.
  21. if get_tensor_model_parallel_world_size()==1:
  22. return input_
  23. # All-reduce.
  24. torch.distributed.all_reduce(input_, group=get_tensor_model_parallel_group())
  25. return input_
  26. def _split(input_):
  27. """Split the tensor along its last dimension and keep the
  28. corresponding slice."""
  29. world_size = get_tensor_model_parallel_world_size()
  30. # Bypass the function if we are using only 1 GPU.
  31. if world_size==1:
  32. return input_
  33. # Split along last dimension.
  34. input_list = split_tensor_along_last_dim(input_, world_size)
  35. # Note: torch.split does not create contiguous tensors by default.
  36. rank = get_tensor_model_parallel_rank()
  37. output = input_list[rank].contiguous()
  38. return output
  39. def _gather(input_):
  40. """Gather tensors and concatinate along the last dimension."""
  41. world_size = get_tensor_model_parallel_world_size()
  42. # Bypass the function if we are using only 1 GPU.
  43. if world_size==1:
  44. return input_
  45. # Size and dimension.
  46. last_dim = input_.dim() - 1
  47. rank = get_tensor_model_parallel_rank()
  48. tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
  49. tensor_list[rank] = input_
  50. torch.distributed.all_gather(tensor_list, input_, group=get_tensor_model_parallel_group())
  51. # Note: torch.cat already creates a contiguous tensor.
  52. output = torch.cat(tensor_list, dim=last_dim).contiguous()
  53. return output
  54. class _CopyToModelParallelRegion(torch.autograd.Function):
  55. """Pass the input to the model parallel region."""
  56. @staticmethod
  57. def symbolic(graph, input_):
  58. return input_
  59. @staticmethod
  60. def forward(ctx, input_):
  61. return input_
  62. @staticmethod
  63. def backward(ctx, grad_output):
  64. return _reduce(grad_output)
  65. class _ReduceFromModelParallelRegion(torch.autograd.Function):
  66. """All-reduce the input from the model parallel region."""
  67. @staticmethod
  68. def symbolic(graph, input_):
  69. return _reduce(input_)
  70. @staticmethod
  71. def forward(ctx, input_):
  72. return _reduce(input_)
  73. @staticmethod
  74. def backward(ctx, grad_output):
  75. return grad_output
  76. class _ScatterToModelParallelRegion(torch.autograd.Function):
  77. """Split the input and keep only the corresponding chuck to the rank."""
  78. @staticmethod
  79. def symbolic(graph, input_):
  80. return _split(input_)
  81. @staticmethod
  82. def forward(ctx, input_):
  83. return _split(input_)
  84. @staticmethod
  85. def backward(ctx, grad_output):
  86. return _gather(grad_output)
  87. class _GatherFromModelParallelRegion(torch.autograd.Function):
  88. """Gather the input from model parallel region and concatinate."""
  89. @staticmethod
  90. def symbolic(graph, input_):
  91. return _gather(input_)
  92. @staticmethod
  93. def forward(ctx, input_):
  94. return _gather(input_)
  95. @staticmethod
  96. def backward(ctx, grad_output):
  97. return _split(grad_output)
  98. # -----------------
  99. # Helper functions.
  100. # -----------------
  101. def copy_to_tensor_model_parallel_region(input_):
  102. return _CopyToModelParallelRegion.apply(input_)
  103. def reduce_from_tensor_model_parallel_region(input_):
  104. return _ReduceFromModelParallelRegion.apply(input_)
  105. def scatter_to_tensor_model_parallel_region(input_):
  106. return _ScatterToModelParallelRegion.apply(input_)
  107. def gather_from_tensor_model_parallel_region(input_):
  108. return _GatherFromModelParallelRegion.apply(input_)