initialize.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron initialization."""
  16. import random
  17. import os
  18. import time
  19. import numpy as np
  20. import torch
  21. from megatron import fused_kernels
  22. from megatron import get_adlr_autoresume
  23. from megatron import get_args
  24. from megatron import get_tensorboard_writer
  25. from megatron import mpu
  26. from megatron.global_vars import set_global_variables
  27. from megatron.mpu import (set_tensor_model_parallel_rank,
  28. set_tensor_model_parallel_world_size)
  29. def initialize_megatron(extra_args_provider=None, args_defaults={},
  30. ignore_unknown_args=False, allow_no_cuda=False):
  31. """Set global variables, initialize distributed, and
  32. set autoresume and random seeds.
  33. `allow_no_cuda` should not be set unless using megatron for cpu only
  34. data processing. In general this arg should not be set unless you know
  35. what you are doing.
  36. Returns a function to finalize distributed env initialization
  37. (optionally, only when args.lazy_mpu_init == True)
  38. """
  39. if not allow_no_cuda:
  40. # Make sure cuda is available.
  41. assert torch.cuda.is_available(), 'Megatron requires CUDA.'
  42. # Parse args, build tokenizer, and set adlr-autoresume,
  43. # tensorboard-writer, and timers.
  44. set_global_variables(extra_args_provider=extra_args_provider,
  45. args_defaults=args_defaults,
  46. ignore_unknown_args=ignore_unknown_args)
  47. # torch.distributed initialization
  48. def finish_mpu_init():
  49. args = get_args()
  50. # Pytorch distributed.
  51. _initialize_distributed()
  52. # Random seeds for reproducibility.
  53. if args.rank == 0:
  54. print('> setting random seeds to {} ...'.format(args.seed))
  55. _set_random_seed(args.seed)
  56. args = get_args()
  57. if args.lazy_mpu_init:
  58. args.use_cpu_initialization=True
  59. # delayed initialization of DDP-related stuff
  60. # We only set basic DDP globals
  61. set_tensor_model_parallel_world_size(args.tensor_model_parallel_size)
  62. # and return function for external DDP manager
  63. # to call when it has DDP initialized
  64. set_tensor_model_parallel_rank(args.rank)
  65. return finish_mpu_init
  66. else:
  67. # Megatron's MPU is the master. Complete initialization right away.
  68. finish_mpu_init()
  69. # Initialize memory buffers.
  70. _initialize_mem_buffs()
  71. # Autoresume.
  72. _init_autoresume()
  73. # Compile dependencies.
  74. _compile_dependencies()
  75. # No continuation function
  76. return None
  77. def _compile_dependencies():
  78. args = get_args()
  79. # =========================
  80. # Compile dataset C++ code.
  81. # =========================
  82. # TODO: move this to ninja
  83. if torch.distributed.get_rank() == 0:
  84. start_time = time.time()
  85. print('> compiling dataset index builder ...')
  86. from megatron.data.dataset_utils import compile_helper
  87. compile_helper()
  88. print('>>> done with dataset index builder. Compilation time: {:.3f} '
  89. 'seconds'.format(time.time() - start_time), flush=True)
  90. # ==================
  91. # Load fused kernels
  92. # ==================
  93. # Custom kernel constraints check.
  94. seq_len = args.seq_length
  95. attn_batch_size = \
  96. (args.num_attention_heads / args.tensor_model_parallel_size) * \
  97. args.micro_batch_size
  98. # Constraints on sequence length and attn_batch_size to enable warp based
  99. # optimization and upper triangular optimization (for causal mask)
  100. custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
  101. seq_len % 4 == 0 and attn_batch_size % 4 == 0
  102. # Print a warning.
  103. if not ((args.fp16 or args.bf16) and
  104. custom_kernel_constraint and
  105. args.masked_softmax_fusion):
  106. if args.rank == 0:
  107. print('WARNING: constraints for invoking optimized'
  108. ' fused softmax kernel are not met. We default'
  109. ' back to unfused kernel invocations.', flush=True)
  110. # Always build on rank zero first.
  111. if torch.distributed.get_rank() == 0:
  112. start_time = time.time()
  113. print('> compiling and loading fused kernels ...', flush=True)
  114. fused_kernels.load(args)
  115. torch.distributed.barrier()
  116. else:
  117. torch.distributed.barrier()
  118. fused_kernels.load(args)
  119. # Simple barrier to make sure all ranks have passed the
  120. # compilation phase successfully before moving on to the
  121. # rest of the program. We think this might ensure that
  122. # the lock is released.
  123. torch.distributed.barrier()
  124. if torch.distributed.get_rank() == 0:
  125. print('>>> done with compiling and loading fused kernels. '
  126. 'Compilation time: {:.3f} seconds'.format(
  127. time.time() - start_time), flush=True)
  128. def _initialize_distributed():
  129. """Initialize torch.distributed and mpu."""
  130. args = get_args()
  131. device_count = torch.cuda.device_count()
  132. if torch.distributed.is_initialized():
  133. if args.rank == 0:
  134. print('torch distributed is already initialized, '
  135. 'skipping initialization ...', flush=True)
  136. args.rank = torch.distributed.get_rank()
  137. args.world_size = torch.distributed.get_world_size()
  138. else:
  139. if args.rank == 0:
  140. print('> initializing torch distributed ...', flush=True)
  141. # Manually set the device ids.
  142. if device_count > 0:
  143. device = args.rank % device_count
  144. if args.local_rank is not None:
  145. assert args.local_rank == device, \
  146. 'expected local-rank to be the same as rank % device-count.'
  147. else:
  148. args.local_rank = device
  149. torch.cuda.set_device(device)
  150. # Call the init process
  151. init_method = 'tcp://'
  152. master_ip = os.getenv('MASTER_ADDR', 'localhost')
  153. master_port = os.getenv('MASTER_PORT', '6000')
  154. init_method += master_ip + ':' + master_port
  155. torch.distributed.init_process_group(
  156. backend=args.distributed_backend,
  157. world_size=args.world_size, rank=args.rank,
  158. init_method=init_method)
  159. # Set the tensor model-parallel, pipeline model-parallel, and
  160. # data-parallel communicators.
  161. if device_count > 0:
  162. if mpu.model_parallel_is_initialized():
  163. print('model parallel is already initialized')
  164. else:
  165. mpu.initialize_model_parallel(args.tensor_model_parallel_size,
  166. args.pipeline_model_parallel_size,
  167. args.virtual_pipeline_model_parallel_size)
  168. def _init_autoresume():
  169. """Set autoresume start time."""
  170. autoresume = get_adlr_autoresume()
  171. if autoresume:
  172. torch.distributed.barrier()
  173. autoresume.init()
  174. torch.distributed.barrier()
  175. def _set_random_seed(seed_):
  176. """Set random seed for reproducability."""
  177. if seed_ is not None and seed_ > 0:
  178. # Ensure that different pipeline MP stages get different seeds.
  179. seed = seed_ + (100 * mpu.get_pipeline_model_parallel_rank())
  180. random.seed(seed)
  181. np.random.seed(seed)
  182. torch.manual_seed(seed)
  183. if torch.cuda.device_count() > 0:
  184. mpu.model_parallel_cuda_manual_seed(seed)
  185. else:
  186. raise ValueError('Seed ({}) should be a positive integer.'.format(seed))
  187. def write_args_to_tensorboard():
  188. """Write arguments to tensorboard."""
  189. args = get_args()
  190. writer = get_tensorboard_writer()
  191. if writer:
  192. for arg in vars(args):
  193. writer.add_text(arg, str(getattr(args, arg)),
  194. global_step=args.iteration)
  195. def _initialize_mem_buffs():
  196. """Initialize manually allocated static memory."""
  197. args = get_args()
  198. # Initialize memory for checkpointed activations.
  199. if args.distribute_checkpointed_activations:
  200. mpu.init_checkpointed_activations_memory_buffer()