| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 | # coding=utf-8# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Megatron initialization."""import randomimport osimport timeimport numpy as npimport torchfrom megatron import fused_kernelsfrom megatron import get_adlr_autoresumefrom megatron import get_argsfrom megatron import get_tensorboard_writerfrom megatron import mpufrom megatron.global_vars import set_global_variablesfrom megatron.mpu import (set_tensor_model_parallel_rank,                          set_tensor_model_parallel_world_size)def initialize_megatron(extra_args_provider=None, args_defaults={},                        ignore_unknown_args=False, allow_no_cuda=False):    """Set global variables, initialize distributed, and    set autoresume and random seeds.    `allow_no_cuda` should not be set unless using megatron for cpu only     data processing. In general this arg should not be set unless you know     what you are doing.    Returns a function to finalize distributed env initialization     (optionally, only when args.lazy_mpu_init == True)    """    if not allow_no_cuda:        # Make sure cuda is available.        assert torch.cuda.is_available(), 'Megatron requires CUDA.'    # Parse args, build tokenizer, and set adlr-autoresume,    # tensorboard-writer, and timers.    set_global_variables(extra_args_provider=extra_args_provider,                         args_defaults=args_defaults,                         ignore_unknown_args=ignore_unknown_args)    # torch.distributed initialization    def finish_mpu_init():        args = get_args()        # Pytorch distributed.        _initialize_distributed()                # Random seeds for reproducibility.        if args.rank == 0:            print('> setting random seeds to {} ...'.format(args.seed))        _set_random_seed(args.seed)    args = get_args()    if  args.lazy_mpu_init:        args.use_cpu_initialization=True        # delayed initialization of DDP-related stuff        # We only set basic DDP globals            set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)        # and return function for external DDP manager        # to call when it has DDP initialized        set_tensor_model_parallel_rank(args.rank)            return finish_mpu_init    else:        # Megatron's MPU is the master. Complete initialization right away.        finish_mpu_init()        # Initialize memory buffers.        _initialize_mem_buffs()                # Autoresume.        _init_autoresume()        # Compile dependencies.        _compile_dependencies()        # No continuation function        return Nonedef _compile_dependencies():    args = get_args()    # =========================    # Compile dataset C++ code.    # =========================    # TODO: move this to ninja    if torch.distributed.get_rank() == 0:        start_time = time.time()        print('> compiling dataset index builder ...')        from megatron.data.dataset_utils import compile_helper        compile_helper()        print('>>> done with dataset index builder. Compilation time: {:.3f} '              'seconds'.format(time.time() - start_time), flush=True)    # ==================    # Load fused kernels    # ==================    # Custom kernel constraints check.    seq_len = args.seq_length    attn_batch_size = \        (args.num_attention_heads / args.tensor_model_parallel_size) * \        args.micro_batch_size    # Constraints on sequence length and attn_batch_size to enable warp based    # optimization and upper triangular optimization (for causal mask)    custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \        seq_len % 4 == 0 and attn_batch_size % 4 == 0    # Print a warning.    if not ((args.fp16 or args.bf16) and            custom_kernel_constraint and            args.masked_softmax_fusion):        if args.rank == 0:            print('WARNING: constraints for invoking optimized'                  ' fused softmax kernel are not met. We default'                  ' back to unfused kernel invocations.', flush=True)        # Always build on rank zero first.    if torch.distributed.get_rank() == 0:        start_time = time.time()        print('> compiling and loading fused kernels ...', flush=True)        fused_kernels.load(args)        torch.distributed.barrier()    else:        torch.distributed.barrier()        fused_kernels.load(args)    # Simple barrier to make sure all ranks have passed the    # compilation phase successfully before moving on to the    # rest of the program. We think this might ensure that    # the lock is released.    torch.distributed.barrier()    if torch.distributed.get_rank() == 0:        print('>>> done with compiling and loading fused kernels. '              'Compilation time: {:.3f} seconds'.format(                  time.time() - start_time), flush=True)def _initialize_distributed():    """Initialize torch.distributed and mpu."""    args = get_args()    device_count = torch.cuda.device_count()    if torch.distributed.is_initialized():        if args.rank == 0:            print('torch distributed is already initialized, '                  'skipping initialization ...', flush=True)        args.rank = torch.distributed.get_rank()        args.world_size = torch.distributed.get_world_size()    else:        if args.rank == 0:            print('> initializing torch distributed ...', flush=True)        # Manually set the device ids.        if device_count > 0:            device = args.rank % device_count            if args.local_rank is not None:                assert args.local_rank == device, \                    'expected local-rank to be the same as rank % device-count.'            else:                args.local_rank = device            torch.cuda.set_device(device)        # Call the init process        init_method = 'tcp://'        master_ip = os.getenv('MASTER_ADDR', 'localhost')        master_port = os.getenv('MASTER_PORT', '6000')        init_method += master_ip + ':' + master_port        torch.distributed.init_process_group(            backend=args.distributed_backend,            world_size=args.world_size, rank=args.rank,            init_method=init_method)    # Set the tensor model-parallel, pipeline model-parallel, and    # data-parallel communicators.    if device_count > 0:        if mpu.model_parallel_is_initialized():            print('model parallel is already initialized')        else:            mpu.initialize_model_parallel(args.tensor_model_parallel_size,                                          args.pipeline_model_parallel_size,                                          args.virtual_pipeline_model_parallel_size)def _init_autoresume():    """Set autoresume start time."""    autoresume = get_adlr_autoresume()    if autoresume:        torch.distributed.barrier()        autoresume.init()        torch.distributed.barrier()def _set_random_seed(seed_):    """Set random seed for reproducability."""    if seed_ is not None and seed_ > 0:        # Ensure that different pipeline MP stages get different seeds.        seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())        random.seed(seed)        np.random.seed(seed)        torch.manual_seed(seed)        if torch.cuda.device_count() > 0:            mpu.model_parallel_cuda_manual_seed(seed)    else:        raise ValueError('Seed ({}) should be a positive integer.'.format(seed))def write_args_to_tensorboard():    """Write arguments to tensorboard."""    args = get_args()    writer = get_tensorboard_writer()    if writer:        for arg in vars(args):            writer.add_text(arg, str(getattr(args, arg)),                            global_step=args.iteration)def _initialize_mem_buffs():    """Initialize manually allocated static memory."""    args = get_args()    # Initialize memory for checkpointed activations.    if args.distribute_checkpointed_activations:        mpu.init_checkpointed_activations_memory_buffer()
 |