123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import argparse
- import os
- import random
- import numpy
- import torch
- import mpu
- class IdentityLayer(torch.nn.Module):
- def __init__(self, size, scale=1.0):
- super(IdentityLayer, self).__init__()
- self.weight = torch.nn.Parameter(scale * torch.randn(size))
- def forward(self):
- return self.weight
- def set_random_seed(seed):
- """Set random seed for reproducability."""
- random.seed(seed)
- numpy.random.seed(seed)
- torch.manual_seed(seed)
- mpu.model_parallel_cuda_manual_seed(seed)
- def initialize_distributed(backend='nccl'):
- """Initialize torch.distributed."""
- # Get local rank in case it is provided.
- parser = argparse.ArgumentParser()
- parser.add_argument('--local_rank', type=int, default=None,
- help='local rank passed from distributed launcher')
- args = parser.parse_args()
- local_rank = args.local_rank
- # Get rank and world size.
- rank = int(os.getenv('RANK', '0'))
- world_size = int(os.getenv("WORLD_SIZE", '1'))
- print('> initializing torch.distributed with local rank: {}, '
- 'rank: {}, world size: {}'.format(local_rank, rank, world_size))
- # Set the device id.
- device = rank % torch.cuda.device_count()
- if local_rank is not None:
- device = local_rank
- torch.cuda.set_device(device)
- # Call the init process.
- init_method = 'tcp://'
- master_ip = os.getenv('MASTER_ADDR', 'localhost')
- master_port = os.getenv('MASTER_PORT', '6000')
- init_method += master_ip + ':' + master_port
- torch.distributed.init_process_group(
- backend=backend,
- world_size=world_size,
- rank=rank,
- init_method=init_method)
- def print_separator(message):
- torch.distributed.barrier()
- filler_len = (78 - len(message)) // 2
- filler = '-' * filler_len
- string = '\n' + filler + ' {} '.format(message) + filler
- if torch.distributed.get_rank() == 0:
- print(string, flush=True)
- torch.distributed.barrier()
|