test_initialize.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from commons import print_separator
  16. from commons import initialize_distributed
  17. import mpu
  18. import torch
  19. import sys
  20. sys.path.append("../..")
  21. def test_initialize_model_parallel(tensor_model_parallel_size):
  22. if torch.distributed.get_rank() == 0:
  23. print('> testing initialize_model_parallel with size {} ...'.format(
  24. tensor_model_parallel_size))
  25. tensor_model_parallel_size_ = min(tensor_model_parallel_size,
  26. torch.distributed.get_world_size())
  27. assert not mpu.model_parallel_is_initialized()
  28. mpu.initialize_model_parallel(tensor_model_parallel_size_)
  29. assert mpu.model_parallel_is_initialized()
  30. # Checks.
  31. def check(group, world_size, rank):
  32. assert world_size == torch.distributed.get_world_size(group=group)
  33. assert rank == torch.distributed.get_rank(group=group)
  34. # Model parallel.
  35. world_size = tensor_model_parallel_size_
  36. rank = torch.distributed.get_rank() % tensor_model_parallel_size_
  37. assert world_size == mpu.get_tensor_model_parallel_world_size()
  38. assert rank == mpu.get_tensor_model_parallel_rank()
  39. check(mpu.get_tensor_model_parallel_group(), world_size, rank)
  40. # Data parallel.
  41. world_size = torch.distributed.get_world_size() // tensor_model_parallel_size_
  42. rank = torch.distributed.get_rank() // tensor_model_parallel_size
  43. assert world_size == mpu.get_data_parallel_world_size()
  44. assert rank == mpu.get_data_parallel_rank()
  45. check(mpu.get_data_parallel_group(), world_size, rank)
  46. # Reset groups
  47. mpu.destroy_model_parallel()
  48. torch.distributed.barrier()
  49. if torch.distributed.get_rank() == 0:
  50. print('>> passed the test :-)')
  51. def test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size_):
  52. if torch.distributed.get_rank() == 0:
  53. print('> testing get_tensor_model_parallel_src_rank with size {} ...'.format(
  54. tensor_model_parallel_size_))
  55. tensor_model_parallel_size = min(tensor_model_parallel_size_,
  56. torch.distributed.get_world_size())
  57. assert not mpu.model_parallel_is_initialized()
  58. mpu.initialize_model_parallel(tensor_model_parallel_size)
  59. assert mpu.model_parallel_is_initialized()
  60. # Checks
  61. src_rank = torch.distributed.get_rank() - mpu.get_tensor_model_parallel_rank()
  62. assert mpu.get_tensor_model_parallel_src_rank() == src_rank
  63. # Reset groups
  64. mpu.destroy_model_parallel()
  65. torch.distributed.barrier()
  66. if torch.distributed.get_rank() == 0:
  67. print('>> passed the test :-)')
  68. if __name__ == '__main__':
  69. initialize_distributed()
  70. world_size = torch.distributed.get_world_size()
  71. tensor_model_parallel_size = 1
  72. while tensor_model_parallel_size <= world_size:
  73. print_separator('test initialize model parallel')
  74. test_initialize_model_parallel(tensor_model_parallel_size)
  75. print_separator('test model parallel source rank')
  76. test_get_tensor_model_parallel_src_rank(tensor_model_parallel_size)
  77. tensor_model_parallel_size *= 2