123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from commons import print_separator
- from commons import initialize_distributed
- import mpu
- import torch
- import sys
- sys.path.append("../..")
- def test_set_cuda_rng_state(tensor_model_parallel_size):
- if torch.distributed.get_rank() == 0:
- print('> testing set_rng_state with size {} ...'.
- format(tensor_model_parallel_size))
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
- size = 123
- seed = 1234
- torch.cuda.manual_seed(1234)
- tensor = torch.cuda.FloatTensor(size)
- # Get the state
- rng_state = torch.cuda.get_rng_state()
- rng_state_copy = rng_state.clone()
- # Do some stuff.
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_1 = tensor.clone()
- assert rng_state.sub(rng_state_copy).max() == 0
- assert torch.cuda.get_rng_state().sub(rng_state_copy).max() > 0
- # State should be different.
- new_rng_state = torch.cuda.get_rng_state()
- max_diff = new_rng_state.sub(rng_state).max()
- print(' max diff in rng state (should be non-zero) on global rank {}: {}'.
- format(torch.distributed.get_rank(), max_diff))
- assert max_diff > 0
- # Reset the rng state and do the same stuff.
- mpu.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- mpu.random._set_cuda_rng_state(rng_state)
- for _ in range(5):
- torch.randn(size, out=tensor)
- result_2 = tensor.clone()
- # Results should be the same
- error = result_2.sub(result_1).abs().max()
- print(' max error in generated tensors (should be zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
- # Input state should have remained intact.
- error = rng_state.sub(rng_state_copy).max()
- print(' max error in rng state (should be zero) on global rank {}: {}'.
- format(torch.distributed.get_rank(), error))
- assert error == 0
- # Reset groups
- mpu.destroy_model_parallel()
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
- def test_cuda_rng_tracker(tensor_model_parallel_size):
- if torch.distributed.get_rank() == 0:
- print('> testing cuda rng tracker with size {} ...'.
- format(tensor_model_parallel_size))
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
- seed_1 = 1234
- seed_2 = 4321
- size = [12, 21]
- tensor = torch.cuda.FloatTensor(size)
- # Set to seed_1 and generate two tensors.
- torch.cuda.manual_seed(seed_1)
- torch.randn(size, out=tensor)
- target_11 = tensor.clone()
- torch.randn(size, out=tensor)
- target_12 = tensor.clone()
- # Set to seed_2 and generate two tensors.
- torch.cuda.manual_seed(seed_2)
- torch.randn(size, out=tensor)
- target_21 = tensor.clone()
- torch.randn(size, out=tensor)
- target_22 = tensor.clone()
- # Now if we interleave seed_1 and seed_2,
- # we should still get the same tensors
- torch.cuda.manual_seed(seed_1)
- mpu.get_cuda_rng_tracker().add('test', seed_2)
- torch.randn(size, out=tensor)
- result_11 = tensor.clone()
- with mpu.get_cuda_rng_tracker().fork('test'):
- torch.randn(size, out=tensor)
- result_21 = tensor.clone()
- torch.randn(size, out=tensor)
- result_12 = tensor.clone()
- with mpu.get_cuda_rng_tracker().fork('test'):
- torch.randn(size, out=tensor)
- result_22 = tensor.clone()
- diff = result_11.sub(result_21).abs().max()
- diff = min(diff, result_12.sub(result_22).abs().max())
- print(' max diff in generated tensors (should be non-zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), diff))
- assert diff > 1.0e-6
- error = max(result_11.sub(target_11).abs().max(),
- result_12.sub(target_12).abs().max())
- error = max(error, result_21.sub(target_21).abs().max())
- error = max(error, result_22.sub(target_22).abs().max())
- print(' max error in generated tensors (should be zero) on '
- 'global rank {}: {}'.format(torch.distributed.get_rank(), error))
- assert error < 1.0e-6
- # Reset the tracker
- mpu.get_cuda_rng_tracker().reset()
- # Reset groups
- mpu.destroy_model_parallel()
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
- def test_model_parallel_cuda_manual_seed(tensor_model_parallel_size):
- if torch.distributed.get_rank() == 0:
- print('> testing model parallel cuda manual seed with size {} ...'.
- format(tensor_model_parallel_size))
- mpu.initialize_model_parallel(tensor_model_parallel_size)
- tensor_model_parallel_size = mpu.get_tensor_model_parallel_world_size()
- mpu.model_parallel_cuda_manual_seed(12345)
- assert torch.cuda.initial_seed() == 12345
- with mpu.get_cuda_rng_tracker().fork():
- assert torch.cuda.initial_seed() == (12345 + 2718 +
- mpu.get_tensor_model_parallel_rank())
- # Reset the tracker
- mpu.get_cuda_rng_tracker().reset()
- # Reset groups
- mpu.destroy_model_parallel()
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>> passed the test :-)')
- if __name__ == '__main__':
- initialize_distributed()
- world_size = torch.distributed.get_world_size()
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test set rng state')
- test_set_cuda_rng_state(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test cuda rng tracker')
- test_cuda_rng_tracker(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
- tensor_model_parallel_size = 1
- while tensor_model_parallel_size <= world_size:
- print_separator('test model parallel cuda manual seed')
- test_model_parallel_cuda_manual_seed(tensor_model_parallel_size)
- tensor_model_parallel_size *= 2
|