commons.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import random
  18. import numpy
  19. import torch
  20. import mpu
  21. class IdentityLayer(torch.nn.Module):
  22. def __init__(self, size, scale=1.0):
  23. super(IdentityLayer, self).__init__()
  24. self.weight = torch.nn.Parameter(scale * torch.randn(size))
  25. def forward(self):
  26. return self.weight
  27. def set_random_seed(seed):
  28. """Set random seed for reproducability."""
  29. random.seed(seed)
  30. numpy.random.seed(seed)
  31. torch.manual_seed(seed)
  32. mpu.model_parallel_cuda_manual_seed(seed)
  33. def initialize_distributed(backend='nccl'):
  34. """Initialize torch.distributed."""
  35. # Get local rank in case it is provided.
  36. parser = argparse.ArgumentParser()
  37. parser.add_argument('--local_rank', type=int, default=None,
  38. help='local rank passed from distributed launcher')
  39. args = parser.parse_args()
  40. local_rank = args.local_rank
  41. # Get rank and world size.
  42. rank = int(os.getenv('RANK', '0'))
  43. world_size = int(os.getenv("WORLD_SIZE", '1'))
  44. print('> initializing torch.distributed with local rank: {}, '
  45. 'rank: {}, world size: {}'.format(local_rank, rank, world_size))
  46. # Set the device id.
  47. device = rank % torch.cuda.device_count()
  48. if local_rank is not None:
  49. device = local_rank
  50. torch.cuda.set_device(device)
  51. # Call the init process.
  52. init_method = 'tcp://'
  53. master_ip = os.getenv('MASTER_ADDR', 'localhost')
  54. master_port = os.getenv('MASTER_PORT', '6000')
  55. init_method += master_ip + ':' + master_port
  56. torch.distributed.init_process_group(
  57. backend=backend,
  58. world_size=world_size,
  59. rank=rank,
  60. init_method=init_method)
  61. def print_separator(message):
  62. torch.distributed.barrier()
  63. filler_len = (78 - len(message)) // 2
  64. filler = '-' * filler_len
  65. string = '\n' + filler + ' {} '.format(message) + filler
  66. if torch.distributed.get_rank() == 0:
  67. print(string, flush=True)
  68. torch.distributed.barrier()