test_random.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from commons import print_separator
  16. from commons import initialize_distributed
  17. import mpu
  18. import torch
  19. import sys
  20. sys.path.append("../..")
  21. def test_set_cuda_rng_state(tensor_model_parallel_size):
  22. if torch.distributed.get_rank() == 0:
  23. print('> testing set_rng_state with size {} ...'.
  24. format(tensor_model_parallel_size))
  25. mpu.initialize_model_parallel(tensor_model_parallel_size)
  26. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  27. size = 123
  28. seed = 1234
  29. torch.cuda.manual_seed(1234)
  30. tensor = torch.cuda.FloatTensor(size)
  31. # Get the state
  32. rng_state = torch.cuda.get_rng_state()
  33. rng_state_copy = rng_state.clone()
  34. # Do some stuff.
  35. for _ in range(5):
  36. torch.randn(size, out=tensor)
  37. result_1 = tensor.clone()
  38. assert rng_state.sub(rng_state_copy).max() == 0
  39. assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
  40. # State should be different.
  41. new_rng_state = torch.cuda.get_rng_state()
  42. max_diff = new_rng_state.sub(rng_state).max()
  43. print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
  44. format(torch.distributed.get_rank(), max_diff))
  45. assert max_diff > 0
  46. # Reset the rng state and do the same stuff.
  47. mpu.random._set_cuda_rng_state(rng_state)
  48. for _ in range(5):
  49. torch.randn(size, out=tensor)
  50. mpu.random._set_cuda_rng_state(rng_state)
  51. for _ in range(5):
  52. torch.randn(size, out=tensor)
  53. result_2 = tensor.clone()
  54. # Results should be the same
  55. error = result_2.sub(result_1).abs().max()
  56. print(' max error in generated tensors (should be zero) on '
  57. 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
  58. assert error < 1.0e-6
  59. # Input state should have remained intact.
  60. error = rng_state.sub(rng_state_copy).max()
  61. print(' max error in rng state (should be zero) on global rank {}: {}'.
  62. format(torch.distributed.get_rank(), error))
  63. assert error == 0
  64. # Reset groups
  65. mpu.destroy_model_parallel()
  66. torch.distributed.barrier()
  67. if torch.distributed.get_rank() == 0:
  68. print('>> passed the test :-)')
  69. def test_cuda_rng_tracker(tensor_model_parallel_size):
  70. if torch.distributed.get_rank() == 0:
  71. print('> testing cuda rng tracker with size {} ...'.
  72. format(tensor_model_parallel_size))
  73. mpu.initialize_model_parallel(tensor_model_parallel_size)
  74. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  75. seed_1 = 1234
  76. seed_2 = 4321
  77. size = [12, 21]
  78. tensor = torch.cuda.FloatTensor(size)
  79. # Set to seed_1 and generate two tensors.
  80. torch.cuda.manual_seed(seed_1)
  81. torch.randn(size, out=tensor)
  82. target_11 = tensor.clone()
  83. torch.randn(size, out=tensor)
  84. target_12 = tensor.clone()
  85. # Set to seed_2 and generate two tensors.
  86. torch.cuda.manual_seed(seed_2)
  87. torch.randn(size, out=tensor)
  88. target_21 = tensor.clone()
  89. torch.randn(size, out=tensor)
  90. target_22 = tensor.clone()
  91. # Now if we interleave seed_1 and seed_2,
  92. # we should still get the same tensors
  93. torch.cuda.manual_seed(seed_1)
  94. mpu.get_cuda_rng_tracker().add('test', seed_2)
  95. torch.randn(size, out=tensor)
  96. result_11 = tensor.clone()
  97. with mpu.get_cuda_rng_tracker().fork('test'):
  98. torch.randn(size, out=tensor)
  99. result_21 = tensor.clone()
  100. torch.randn(size, out=tensor)
  101. result_12 = tensor.clone()
  102. with mpu.get_cuda_rng_tracker().fork('test'):
  103. torch.randn(size, out=tensor)
  104. result_22 = tensor.clone()
  105. diff = result_11.sub(result_21).abs().max()
  106. diff = min(diff, result_12.sub(result_22).abs().max())
  107. print(' max diff in generated tensors (should be non-zero) on '
  108. 'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
  109. assert diff > 1.0e-6
  110. error = max(result_11.sub(target_11).abs().max(),
  111. result_12.sub(target_12).abs().max())
  112. error = max(error, result_21.sub(target_21).abs().max())
  113. error = max(error, result_22.sub(target_22).abs().max())
  114. print(' max error in generated tensors (should be zero) on '
  115. 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
  116. assert error < 1.0e-6
  117. # Reset the tracker
  118. mpu.get_cuda_rng_tracker().reset()
  119. # Reset groups
  120. mpu.destroy_model_parallel()
  121. torch.distributed.barrier()
  122. if torch.distributed.get_rank() == 0:
  123. print('>> passed the test :-)')
  124. def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
  125. if torch.distributed.get_rank() == 0:
  126. print('> testing model parallel cuda manual seed with size {} ...'.
  127. format(tensor_model_parallel_size))
  128. mpu.initialize_model_parallel(tensor_model_parallel_size)
  129. tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
  130. mpu.model_parallel_cuda_manual_seed(12345)
  131. assert torch.cuda.initial_seed() == 12345
  132. with mpu.get_cuda_rng_tracker().fork():
  133. assert torch.cuda.initial_seed() == (12345 + 2718 +
  134. mpu.get_tensor_model_parallel_rank())
  135. # Reset the tracker
  136. mpu.get_cuda_rng_tracker().reset()
  137. # Reset groups
  138. mpu.destroy_model_parallel()
  139. torch.distributed.barrier()
  140. if torch.distributed.get_rank() == 0:
  141. print('>> passed the test :-)')
  142. if __name__ == '__main__':
  143. initialize_distributed()
  144. world_size = torch.distributed.get_world_size()
  145. tensor_model_parallel_size = 1
  146. while tensor_model_parallel_size <= world_size:
  147. print_separator('test set rng state')
  148. test_set_cuda_rng_state(tensor_model_parallel_size)
  149. tensor_model_parallel_size *= 2
  150. tensor_model_parallel_size = 1
  151. while tensor_model_parallel_size <= world_size:
  152. print_separator('test cuda rng tracker')
  153. test_cuda_rng_tracker(tensor_model_parallel_size)
  154. tensor_model_parallel_size *= 2
  155. tensor_model_parallel_size = 1
  156. while tensor_model_parallel_size <= world_size:
  157. print_separator('test model parallel cuda manual seed')
  158. test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
  159. tensor_model_parallel_size *= 2