global_vars.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Megatron global variables."""
  16. import os
  17. import sys
  18. import time
  19. import torch
  20. from megatron.tokenizer import build_tokenizer
  21. from .arguments import parse_args
  22. from .microbatches import build_num_microbatches_calculator
  23. _GLOBAL_ARGS = None
  24. _GLOBAL_NUM_MICROBATCHES_CALCULATOR = None
  25. _GLOBAL_TOKENIZER = None
  26. _GLOBAL_TENSORBOARD_WRITER = None
  27. _GLOBAL_ADLR_AUTORESUME = None
  28. _GLOBAL_TIMERS = None
  29. def get_args():
  30. """Return arguments."""
  31. _ensure_var_is_initialized(_GLOBAL_ARGS, 'args')
  32. return _GLOBAL_ARGS
  33. def get_num_microbatches():
  34. return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get()
  35. def get_current_global_batch_size():
  36. return _GLOBAL_NUM_MICROBATCHES_CALCULATOR.get_current_global_batch_size()
  37. def update_num_microbatches(consumed_samples, consistency_check=True):
  38. _GLOBAL_NUM_MICROBATCHES_CALCULATOR.update(consumed_samples,
  39. consistency_check)
  40. def get_tokenizer():
  41. """Return tokenizer."""
  42. _ensure_var_is_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
  43. return _GLOBAL_TOKENIZER
  44. def get_tensorboard_writer():
  45. """Return tensorboard writer. It can be None so no need
  46. to check if it is initialized."""
  47. return _GLOBAL_TENSORBOARD_WRITER
  48. def get_adlr_autoresume():
  49. """ADLR autoresume object. It can be None so no need
  50. to check if it is initialized."""
  51. return _GLOBAL_ADLR_AUTORESUME
  52. def get_timers():
  53. """Return timers."""
  54. _ensure_var_is_initialized(_GLOBAL_TIMERS, 'timers')
  55. return _GLOBAL_TIMERS
  56. def set_global_variables(extra_args_provider=None, args_defaults={},
  57. ignore_unknown_args=False):
  58. """Set args, tokenizer, tensorboard-writer, adlr-autoresume, and timers."""
  59. args = _parse_args(extra_args_provider=extra_args_provider,
  60. defaults=args_defaults,
  61. ignore_unknown_args=ignore_unknown_args)
  62. _build_num_microbatches_calculator(args)
  63. if args.vocab_file:
  64. _ = _build_tokenizer(args)
  65. _set_tensorboard_writer(args)
  66. _set_adlr_autoresume(args)
  67. _set_timers()
  68. def _parse_args(extra_args_provider=None, defaults={},
  69. ignore_unknown_args=False):
  70. """Parse entire arguments."""
  71. global _GLOBAL_ARGS
  72. _ensure_var_is_not_initialized(_GLOBAL_ARGS, 'args')
  73. _GLOBAL_ARGS = parse_args(extra_args_provider=extra_args_provider,
  74. defaults=defaults,
  75. ignore_unknown_args=ignore_unknown_args)
  76. return _GLOBAL_ARGS
  77. def _build_num_microbatches_calculator(args):
  78. global _GLOBAL_NUM_MICROBATCHES_CALCULATOR
  79. _ensure_var_is_not_initialized(_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
  80. 'num microbatches calculator')
  81. _GLOBAL_NUM_MICROBATCHES_CALCULATOR = build_num_microbatches_calculator(
  82. args)
  83. def _build_tokenizer(args):
  84. """Initialize tokenizer."""
  85. global _GLOBAL_TOKENIZER
  86. _ensure_var_is_not_initialized(_GLOBAL_TOKENIZER, 'tokenizer')
  87. _GLOBAL_TOKENIZER = build_tokenizer(args)
  88. return _GLOBAL_TOKENIZER
  89. def rebuild_tokenizer(args):
  90. global _GLOBAL_TOKENIZER
  91. _GLOBAL_TOKENIZER = None
  92. return _build_tokenizer(args)
  93. def _set_tensorboard_writer(args):
  94. """Set tensorboard writer."""
  95. global _GLOBAL_TENSORBOARD_WRITER
  96. _ensure_var_is_not_initialized(_GLOBAL_TENSORBOARD_WRITER,
  97. 'tensorboard writer')
  98. if hasattr(args, 'tensorboard_dir') and \
  99. args.tensorboard_dir and args.rank == (args.world_size - 1):
  100. try:
  101. from torch.utils.tensorboard import SummaryWriter
  102. print('> setting tensorboard ...')
  103. _GLOBAL_TENSORBOARD_WRITER = SummaryWriter(
  104. log_dir=args.tensorboard_dir,
  105. max_queue=args.tensorboard_queue_size)
  106. except ModuleNotFoundError:
  107. print('WARNING: TensorBoard writing requested but is not '
  108. 'available (are you using PyTorch 1.1.0 or later?), '
  109. 'no TensorBoard logs will be written.', flush=True)
  110. def _set_adlr_autoresume(args):
  111. """Initialize ADLR autoresume."""
  112. global _GLOBAL_ADLR_AUTORESUME
  113. _ensure_var_is_not_initialized(_GLOBAL_ADLR_AUTORESUME, 'adlr autoresume')
  114. if args.adlr_autoresume:
  115. if args.rank == 0:
  116. print('enabling autoresume ...', flush=True)
  117. sys.path.append(os.environ.get('SUBMIT_SCRIPTS', '.'))
  118. try:
  119. from userlib.auto_resume import AutoResume
  120. except BaseException:
  121. print('ADLR autoresume is not available, exiting ...')
  122. sys.exit()
  123. _GLOBAL_ADLR_AUTORESUME = AutoResume
  124. def _set_timers():
  125. """Initialize timers."""
  126. global _GLOBAL_TIMERS
  127. _ensure_var_is_not_initialized(_GLOBAL_TIMERS, 'timers')
  128. _GLOBAL_TIMERS = Timers()
  129. def _ensure_var_is_initialized(var, name):
  130. """Make sure the input variable is not None."""
  131. assert var is not None, '{} is not initialized.'.format(name)
  132. def _ensure_var_is_not_initialized(var, name):
  133. """Make sure the input variable is not None."""
  134. assert var is None, '{} is already initialized.'.format(name)
  135. class _Timer:
  136. """Timer."""
  137. def __init__(self, name):
  138. self.name_ = name
  139. self.elapsed_ = 0.0
  140. self.started_ = False
  141. self.start_time = time.time()
  142. def start(self):
  143. """Start the timer."""
  144. assert not self.started_, 'timer has already been started'
  145. torch.cuda.synchronize()
  146. self.start_time = time.time()
  147. self.started_ = True
  148. def stop(self):
  149. """Stop the timer."""
  150. assert self.started_, 'timer is not started'
  151. torch.cuda.synchronize()
  152. self.elapsed_ += (time.time() - self.start_time)
  153. self.started_ = False
  154. def reset(self):
  155. """Reset timer."""
  156. self.elapsed_ = 0.0
  157. self.started_ = False
  158. def elapsed(self, reset=True):
  159. """Calculate the elapsed time."""
  160. started_ = self.started_
  161. # If the timing in progress, end it first.
  162. if self.started_:
  163. self.stop()
  164. # Get the elapsed time.
  165. elapsed_ = self.elapsed_
  166. # Reset the elapsed time
  167. if reset:
  168. self.reset()
  169. # If timing was in progress, set it back.
  170. if started_:
  171. self.start()
  172. return elapsed_
  173. class Timers:
  174. """Group of timers."""
  175. def __init__(self):
  176. self.timers = {}
  177. def __call__(self, name):
  178. if name not in self.timers:
  179. self.timers[name] = _Timer(name)
  180. return self.timers[name]
  181. def write(self, names, writer, iteration, normalizer=1.0, reset=False):
  182. """Write timers to a tensorboard writer"""
  183. # currently when using add_scalars,
  184. # torch.utils.add_scalars makes each timer its own run, which
  185. # polutes the runs list, so we just add each as a scalar
  186. assert normalizer > 0.0
  187. for name in names:
  188. value = self.timers[name].elapsed(reset=reset) / normalizer
  189. writer.add_scalar(name + '-time', value, iteration)
  190. def log(self, names, normalizer=1.0, reset=True):
  191. """Log a group of timers."""
  192. assert normalizer > 0.0
  193. string = 'time (ms)'
  194. for name in names:
  195. elapsed_time = self.timers[name].elapsed(
  196. reset=reset) * 1000.0 / normalizer
  197. string += ' | {}: {:.2f}'.format(name, elapsed_time)
  198. if torch.distributed.is_initialized():
  199. if torch.distributed.get_rank() == (
  200. torch.distributed.get_world_size() - 1):
  201. print(string, flush=True)
  202. else:
  203. print(string, flush=True)