123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Megatron initialization."""
- import random
- import os
- import time
- import numpy as np
- import torch
- from megatron import fused_kernels
- from megatron import get_adlr_autoresume
- from megatron import get_args
- from megatron import get_tensorboard_writer
- from megatron import mpu
- from megatron.global_vars import set_global_variables
- from megatron.mpu import (set_tensor_model_parallel_rank,
- set_tensor_model_parallel_world_size)
- def initialize_megatron(extra_args_provider=None, args_defaults={},
- ignore_unknown_args=False, allow_no_cuda=False):
- """Set global variables, initialize distributed, and
- set autoresume and random seeds.
- `allow_no_cuda` should not be set unless using megatron for cpu only
- data processing. In general this arg should not be set unless you know
- what you are doing.
- Returns a function to finalize distributed env initialization
- (optionally, only when args.lazy_mpu_init == True)
- """
- if not allow_no_cuda:
- # Make sure cuda is available.
- assert torch.cuda.is_available(), 'Megatron requires CUDA.'
- # Parse args, build tokenizer, and set adlr-autoresume,
- # tensorboard-writer, and timers.
- set_global_variables(extra_args_provider=extra_args_provider,
- args_defaults=args_defaults,
- ignore_unknown_args=ignore_unknown_args)
- # torch.distributed initialization
- def finish_mpu_init():
- args = get_args()
- # Pytorch distributed.
- _initialize_distributed()
-
- # Random seeds for reproducibility.
- if args.rank == 0:
- print('> setting random seeds to {} ...'.format(args.seed))
- _set_random_seed(args.seed)
- args = get_args()
- if args.lazy_mpu_init:
- args.use_cpu_initialization=True
- # delayed initialization of DDP-related stuff
- # We only set basic DDP globals
- set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
- # and return function for external DDP manager
- # to call when it has DDP initialized
- set_tensor_model_parallel_rank(args.rank)
- return finish_mpu_init
- else:
- # Megatron's MPU is the master. Complete initialization right away.
- finish_mpu_init()
- # Initialize memory buffers.
- _initialize_mem_buffs()
-
- # Autoresume.
- _init_autoresume()
- # Compile dependencies.
- _compile_dependencies()
- # No continuation function
- return None
- def _compile_dependencies():
- args = get_args()
- # =========================
- # Compile dataset C++ code.
- # =========================
- # TODO: move this to ninja
- if torch.distributed.get_rank() == 0:
- start_time = time.time()
- print('> compiling dataset index builder ...')
- from megatron.data.dataset_utils import compile_helper
- compile_helper()
- print('>>> done with dataset index builder. Compilation time: {:.3f} '
- 'seconds'.format(time.time() - start_time), flush=True)
- # ==================
- # Load fused kernels
- # ==================
- # Custom kernel constraints check.
- seq_len = args.seq_length
- attn_batch_size = \
- (args.num_attention_heads / args.tensor_model_parallel_size) * \
- args.micro_batch_size
- # Constraints on sequence length and attn_batch_size to enable warp based
- # optimization and upper triangular optimization (for causal mask)
- custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
- seq_len % 4 == 0 and attn_batch_size % 4 == 0
- # Print a warning.
- if not ((args.fp16 or args.bf16) and
- custom_kernel_constraint and
- args.masked_softmax_fusion):
- if args.rank == 0:
- print('WARNING: constraints for invoking optimized'
- ' fused softmax kernel are not met. We default'
- ' back to unfused kernel invocations.', flush=True)
-
- # Always build on rank zero first.
- if torch.distributed.get_rank() == 0:
- start_time = time.time()
- print('> compiling and loading fused kernels ...', flush=True)
- fused_kernels.load(args)
- torch.distributed.barrier()
- else:
- torch.distributed.barrier()
- fused_kernels.load(args)
- # Simple barrier to make sure all ranks have passed the
- # compilation phase successfully before moving on to the
- # rest of the program. We think this might ensure that
- # the lock is released.
- torch.distributed.barrier()
- if torch.distributed.get_rank() == 0:
- print('>>> done with compiling and loading fused kernels. '
- 'Compilation time: {:.3f} seconds'.format(
- time.time() - start_time), flush=True)
- def _initialize_distributed():
- """Initialize torch.distributed and mpu."""
- args = get_args()
- device_count = torch.cuda.device_count()
- if torch.distributed.is_initialized():
- if args.rank == 0:
- print('torch distributed is already initialized, '
- 'skipping initialization ...', flush=True)
- args.rank = torch.distributed.get_rank()
- args.world_size = torch.distributed.get_world_size()
- else:
- if args.rank == 0:
- print('> initializing torch distributed ...', flush=True)
- # Manually set the device ids.
- if device_count > 0:
- device = args.rank % device_count
- if args.local_rank is not None:
- assert args.local_rank == device, \
- 'expected local-rank to be the same as rank % device-count.'
- else:
- args.local_rank = device
- torch.cuda.set_device(device)
- # Call the init process
- init_method = 'tcp://'
- master_ip = os.getenv('MASTER_ADDR', 'localhost')
- master_port = os.getenv('MASTER_PORT', '6000')
- init_method += master_ip + ':' + master_port
- torch.distributed.init_process_group(
- backend=args.distributed_backend,
- world_size=args.world_size, rank=args.rank,
- init_method=init_method)
- # Set the tensor model-parallel, pipeline model-parallel, and
- # data-parallel communicators.
- if device_count > 0:
- if mpu.model_parallel_is_initialized():
- print('model parallel is already initialized')
- else:
- mpu.initialize_model_parallel(args.tensor_model_parallel_size,
- args.pipeline_model_parallel_size,
- args.virtual_pipeline_model_parallel_size)
- def _init_autoresume():
- """Set autoresume start time."""
- autoresume = get_adlr_autoresume()
- if autoresume:
- torch.distributed.barrier()
- autoresume.init()
- torch.distributed.barrier()
- def _set_random_seed(seed_):
- """Set random seed for reproducability."""
- if seed_ is not None and seed_ > 0:
- # Ensure that different pipeline MP stages get different seeds.
- seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
- random.seed(seed)
- np.random.seed(seed)
- torch.manual_seed(seed)
- if torch.cuda.device_count() > 0:
- mpu.model_parallel_cuda_manual_seed(seed)
- else:
- raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
- def write_args_to_tensorboard():
- """Write arguments to tensorboard."""
- args = get_args()
- writer = get_tensorboard_writer()
- if writer:
- for arg in vars(args):
- writer.add_text(arg, str(getattr(args, arg)),
- global_step=args.iteration)
- def _initialize_mem_buffs():
- """Initialize manually allocated static memory."""
- args = get_args()
- # Initialize memory for checkpointed activations.
- if args.distribute_checkpointed_activations:
- mpu.init_checkpointed_activations_memory_buffer()
|