training.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pretrain utilities."""
  16. from datetime import datetime
  17. import nvtx
  18. import math
  19. import sys
  20. import time
  21. # The earliest we can measure the start time.
  22. _TRAIN_START_TIME = time.time()
  23. import torch
  24. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  25. from megatron import get_args
  26. from megatron import get_timers
  27. from megatron import get_tensorboard_writer
  28. from megatron import get_current_global_batch_size
  29. from megatron import get_num_microbatches
  30. from megatron import is_last_rank
  31. from megatron import update_num_microbatches
  32. from megatron import mpu
  33. from megatron import print_rank_0
  34. from megatron import print_rank_last
  35. from megatron.checkpointing import load_checkpoint
  36. from megatron.checkpointing import save_checkpoint
  37. from megatron.model import Float16Module
  38. from megatron.optimizer import get_megatron_optimizer
  39. from megatron.initialize import initialize_megatron
  40. from megatron.initialize import write_args_to_tensorboard
  41. from megatron.learning_rates import AnnealingLR
  42. from megatron.model import DistributedDataParallel as LocalDDP
  43. from megatron.utils import check_adlr_autoresume_termination
  44. from megatron.utils import unwrap_model
  45. from megatron.data.data_samplers import build_pretraining_data_loader
  46. from megatron.utils import calc_params_l2_norm
  47. from megatron.schedules import forward_backward_no_pipelining
  48. from megatron.schedules import forward_backward_pipelining_without_interleaving
  49. from megatron.schedules import forward_backward_pipelining_with_interleaving
  50. from megatron.utils import report_memory
  51. def print_datetime(string):
  52. """Note that this call will sync across all ranks."""
  53. torch.distributed.barrier()
  54. time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  55. print_rank_0('[' + string + '] datetime: {} '.format(time_str))
  56. def pretrain(train_valid_test_dataset_provider,
  57. model_provider,
  58. forward_step_func,
  59. extra_args_provider=None,
  60. args_defaults={}):
  61. """Main training program.
  62. This function will run the followings in the order provided:
  63. 1) initialize Megatron.
  64. 2) setup model, optimizer and lr schedule using the model_provider.
  65. 3) call train_val_test_data_provider to get train/val/test datasets.
  66. 4) train the modle using the forward_step_func.
  67. Arguments:
  68. train_valid_test_dataset_provider: a function that takes the size of
  69. train/valid/test dataset and returns `train, valid, test` datasets.
  70. model_provider: a function that returns a vanilla version of the
  71. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
  72. forward_step_func: a function that takes a `data iterator` and `model`,
  73. and returns a `loss` scalar with a dictionary with key:values being
  74. the info we would like to monitor during training, for example
  75. `lm-loss: value`. We also require that this function add
  76. `batch generator` to the timers class.
  77. extra_args_provider: a function that takes a parser and adds arguments
  78. to it. It is used for programs to add their own arguments.
  79. args_defaults: a dictionary from argument-name to argument-value. It
  80. to set already parse arguments.
  81. """
  82. with nvtx.annotate("initialize", color="cyan"):
  83. # Initalize and get arguments, timers, and Tensorboard writer.
  84. initialize_megatron(extra_args_provider=extra_args_provider,
  85. args_defaults=args_defaults)
  86. # Adjust the startup time so it reflects the largest value.
  87. # This will be closer to what scheduler will see (outside of
  88. # image ... launches.
  89. global _TRAIN_START_TIME
  90. start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
  91. torch.distributed.all_reduce(start_time_tensor,
  92. op=torch.distributed.ReduceOp.MIN)
  93. _TRAIN_START_TIME = start_time_tensor.item()
  94. print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
  95. time.time() - _TRAIN_START_TIME))
  96. print_datetime('after megatron is initialized')
  97. args = get_args()
  98. timers = get_timers()
  99. # Model, optimizer, and learning rate.
  100. timers('model-and-optimizer-setup').start()
  101. model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  102. timers('model-and-optimizer-setup').stop()
  103. print_datetime('after model, optimizer, and learning rate '
  104. 'scheduler are built')
  105. # Data stuff.
  106. timers('train/valid/test-data-iterators-setup').start()
  107. with nvtx.annotate("data_loading", color="orange"):
  108. if args.virtual_pipeline_model_parallel_size is not None:
  109. all_data_iterators = [
  110. build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
  111. for _ in range(len(model))
  112. ]
  113. train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
  114. valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
  115. test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
  116. else:
  117. train_data_iterator, valid_data_iterator, test_data_iterator \
  118. = build_train_valid_test_data_iterators(
  119. train_valid_test_dataset_provider)
  120. timers('train/valid/test-data-iterators-setup').stop()
  121. print_datetime('after dataloaders are built')
  122. # Print setup timing.
  123. print_rank_0('done with setup ...')
  124. timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
  125. print_rank_0('training ...')
  126. iteration = 0
  127. with nvtx.annotate("training", color="blue"):
  128. if args.do_train and args.train_iters > 0:
  129. iteration = train(forward_step_func,
  130. model, optimizer, lr_scheduler,
  131. train_data_iterator, valid_data_iterator)
  132. print_datetime('after training is done')
  133. if args.do_valid:
  134. prefix = 'the end of training for val data'
  135. evaluate_and_print_results(prefix, forward_step_func,
  136. valid_data_iterator, model,
  137. iteration, False)
  138. with nvtx.annotate("checkpointing", color="yellow"):
  139. if args.save and iteration != 0:
  140. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  141. with nvtx.annotate("do_test", color="darkgreen"):
  142. if args.do_test:
  143. # Run on test data.
  144. prefix = 'the end of training for test data'
  145. evaluate_and_print_results(prefix, forward_step_func,
  146. test_data_iterator, model,
  147. 0, True)
  148. def update_train_iters(args):
  149. # For iteration-based training, we don't need to do anything
  150. if args.train_iters:
  151. return
  152. # Constant batch size with sample-based training.
  153. if args.rampup_batch_size is None:
  154. args.train_iters = args.train_samples // args.global_batch_size
  155. else:
  156. # Sample based training with rampup batch size.
  157. iterations = 0
  158. consumed_samples = 0
  159. # Rampup phase.
  160. while consumed_samples <= int(args.rampup_batch_size[2]):
  161. update_num_microbatches(consumed_samples, consistency_check=False)
  162. consumed_samples += get_current_global_batch_size()
  163. iterations += 1
  164. # Reset
  165. update_num_microbatches(0, consistency_check=False)
  166. # Constant phase
  167. # Note that we throw away any partial last batch.
  168. iterations += (args.train_samples - consumed_samples) // \
  169. args.global_batch_size
  170. args.train_iters = iterations
  171. print_rank_0('setting training iterations to {}'.format(args.train_iters))
  172. def get_model(model_provider_func):
  173. """Build the model."""
  174. args = get_args()
  175. # Build model.
  176. if mpu.get_pipeline_model_parallel_world_size() > 1 and \
  177. args.virtual_pipeline_model_parallel_size is not None:
  178. model = []
  179. for i in range(args.virtual_pipeline_model_parallel_size):
  180. mpu.set_virtual_pipeline_model_parallel_rank(i)
  181. # Set pre_process and post_process only after virtual rank is set.
  182. pre_process = mpu.is_pipeline_first_stage()
  183. post_process = mpu.is_pipeline_last_stage()
  184. this_model = model_provider_func(
  185. pre_process=pre_process,
  186. post_process=post_process
  187. )
  188. model.append(this_model)
  189. else:
  190. pre_process = mpu.is_pipeline_first_stage()
  191. post_process = mpu.is_pipeline_last_stage()
  192. model = model_provider_func(
  193. pre_process=pre_process,
  194. post_process=post_process
  195. )
  196. if not isinstance(model, list):
  197. model = [model]
  198. # Set tensor model parallel attributes if not set.
  199. # Only parameters that are already tensor model parallel have these
  200. # attributes set for them. We should make sure the default attributes
  201. # are set for all params so the optimizer can use them.
  202. for model_module in model:
  203. for param in model_module.parameters():
  204. mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
  205. # Print number of parameters.
  206. if mpu.get_data_parallel_rank() == 0:
  207. print(' > number of parameters on (tensor, pipeline) '
  208. 'model parallel rank ({}, {}): {}'.format(
  209. mpu.get_tensor_model_parallel_rank(),
  210. mpu.get_pipeline_model_parallel_rank(),
  211. sum([sum([p.nelement() for p in model_module.parameters()])
  212. for model_module in model])), flush=True)
  213. # GPU allocation.
  214. for model_module in model:
  215. model_module.cuda(torch.cuda.current_device())
  216. # Fp16 conversion.
  217. if args.fp16 or args.bf16:
  218. model = [Float16Module(model_module, args) for model_module in model]
  219. if args.DDP_impl == 'torch':
  220. i = torch.cuda.current_device()
  221. model = [torchDDP(model_module, device_ids=[i], output_device=i,
  222. process_group=mpu.get_data_parallel_group())
  223. for model_module in model]
  224. return model
  225. if args.DDP_impl == 'local':
  226. model = [LocalDDP(model_module,
  227. args.accumulate_allreduce_grads_in_fp32,
  228. args.use_contiguous_buffers_in_ddp)
  229. for model_module in model]
  230. return model
  231. raise NotImplementedError('Unknown DDP implementation specified: {}. '
  232. 'Exiting.'.format(args.DDP_impl))
  233. def get_learning_rate_scheduler(optimizer):
  234. """Build the learning rate scheduler."""
  235. args = get_args()
  236. # Iteration-based training.
  237. if args.train_iters:
  238. if args.lr_decay_iters is None:
  239. args.lr_decay_iters = args.train_iters
  240. decay_steps = args.lr_decay_iters * args.global_batch_size
  241. if args.lr_warmup_fraction is not None:
  242. warmup_steps = args.lr_warmup_fraction * decay_steps
  243. else:
  244. warmup_steps = args.lr_warmup_iters * args.global_batch_size
  245. # Sample-based training.
  246. elif args.train_samples:
  247. # We need to set training iters for later use. Technically
  248. # we need to adjust the training samples too (due to last
  249. # batch being incomplete) but we leave it as is for now.
  250. update_train_iters(args)
  251. if args.lr_decay_samples is None:
  252. args.lr_decay_samples = args.train_samples
  253. decay_steps = args.lr_decay_samples
  254. if args.lr_warmup_fraction is not None:
  255. warmup_steps = args.lr_warmup_fraction * decay_steps
  256. else:
  257. warmup_steps = args.lr_warmup_samples
  258. else:
  259. raise Exception(
  260. 'either train-iters or train-samples should be provided.')
  261. lr_scheduler = AnnealingLR(
  262. optimizer,
  263. max_lr=args.lr,
  264. min_lr=args.min_lr,
  265. warmup_steps=warmup_steps,
  266. decay_steps=decay_steps,
  267. decay_style=args.lr_decay_style,
  268. use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
  269. override_lr_scheduler=args.override_lr_scheduler)
  270. return lr_scheduler
  271. def setup_model_and_optimizer(model_provider_func):
  272. """Setup model and optimizer."""
  273. args = get_args()
  274. model = get_model(model_provider_func)
  275. unwrapped_model = unwrap_model(model,
  276. (torchDDP, LocalDDP, Float16Module))
  277. optimizer = get_megatron_optimizer(unwrapped_model)
  278. lr_scheduler = get_learning_rate_scheduler(optimizer)
  279. if args.load is not None:
  280. timers = get_timers()
  281. # Extra barrier is added to make sure all ranks report the
  282. # max time.
  283. torch.distributed.barrier()
  284. timers('load-checkpoint').start()
  285. args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
  286. torch.distributed.barrier()
  287. timers('load-checkpoint').stop()
  288. timers.log(['load-checkpoint'])
  289. else:
  290. args.iteration = 0
  291. # We only support local DDP with multiple micro-batches.
  292. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
  293. assert args.DDP_impl == 'local'
  294. # get model without FP16 and/or TorchDDP wrappers
  295. if args.iteration == 0 and len(unwrapped_model) == 1 \
  296. and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
  297. print_rank_0("Initializing ICT from pretrained BERT model")
  298. unwrapped_model[0].init_state_dict_from_bert()
  299. if args.fp16:
  300. optimizer.reload_model_params()
  301. return model, optimizer, lr_scheduler
  302. def train_step(forward_step_func, data_iterator,
  303. model, optimizer, lr_scheduler):
  304. """Single training step."""
  305. args = get_args()
  306. timers = get_timers()
  307. # Set grad to zero.
  308. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
  309. for partition in model:
  310. partition.zero_grad_buffer()
  311. else:
  312. optimizer.zero_grad()
  313. if mpu.get_pipeline_model_parallel_world_size() > 1:
  314. if args.virtual_pipeline_model_parallel_size is not None:
  315. forward_backward_func = forward_backward_pipelining_with_interleaving
  316. assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
  317. 'number of microbatches is not divisible by pipeline-parallel ' \
  318. 'size when using interleaved schedule'
  319. else:
  320. forward_backward_func = forward_backward_pipelining_without_interleaving
  321. else:
  322. forward_backward_func = forward_backward_no_pipelining
  323. losses_reduced = forward_backward_func(
  324. forward_step_func, data_iterator, model,
  325. optimizer, timers, forward_only=False)
  326. # All-reduce if needed.
  327. if args.DDP_impl == 'local':
  328. timers('backward-params-all-reduce').start()
  329. for model_module in model:
  330. model_module.allreduce_gradients()
  331. timers('backward-params-all-reduce').stop()
  332. # All-reduce word_embeddings' grad across first and last stages to ensure
  333. # that word_embeddings parameters stay in sync.
  334. # This should only run for models that support pipelined model parallelism
  335. # (BERT and GPT-2).
  336. timers('backward-embedding-all-reduce').start()
  337. if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
  338. mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
  339. mpu.get_pipeline_model_parallel_world_size() > 1:
  340. if mpu.is_pipeline_first_stage(ignore_virtual=True):
  341. unwrapped_model = model[0]
  342. elif mpu.is_pipeline_last_stage(ignore_virtual=True):
  343. unwrapped_model = model[-1]
  344. unwrapped_model = unwrap_model(
  345. unwrapped_model, (torchDDP, LocalDDP, Float16Module))
  346. if unwrapped_model.share_word_embeddings:
  347. word_embeddings_weight = unwrapped_model.word_embeddings_weight()
  348. if args.DDP_impl == 'local':
  349. grad = word_embeddings_weight.main_grad
  350. else:
  351. grad = word_embeddings_weight.grad
  352. torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
  353. timers('backward-embedding-all-reduce').stop()
  354. # Update parameters.
  355. timers('optimizer').start()
  356. update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
  357. timers('optimizer').stop()
  358. # Update learning rate.
  359. if update_successful:
  360. increment = get_num_microbatches() * \
  361. args.micro_batch_size * \
  362. args.data_parallel_size
  363. lr_scheduler.step(increment=increment)
  364. skipped_iter = 0
  365. else:
  366. skipped_iter = 1
  367. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  368. # Average loss across microbatches.
  369. loss_reduced = {}
  370. for key in losses_reduced[0]:
  371. losses_reduced_for_key = [x[key] for x in losses_reduced]
  372. loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
  373. return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
  374. return {}, skipped_iter, grad_norm, num_zeros_in_grad
  375. def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
  376. loss_scale, report_memory_flag, skipped_iter,
  377. grad_norm, params_norm, num_zeros_in_grad):
  378. """Log training information such as losses, timing, ...."""
  379. args = get_args()
  380. timers = get_timers()
  381. writer = get_tensorboard_writer()
  382. # Advanced, skipped, and Nan iterations.
  383. advanced_iters_key = 'advanced iterations'
  384. skipped_iters_key = 'skipped iterations'
  385. nan_iters_key = 'nan iterations'
  386. # Advanced iterations.
  387. if not skipped_iter:
  388. total_loss_dict[advanced_iters_key] = total_loss_dict.get(
  389. advanced_iters_key, 0) + 1
  390. else:
  391. if advanced_iters_key not in total_loss_dict:
  392. total_loss_dict[advanced_iters_key] = 0
  393. # Skipped iterations.
  394. total_loss_dict[skipped_iters_key] = total_loss_dict.get(
  395. skipped_iters_key, 0) + skipped_iter
  396. # Update losses and set nan iterations
  397. got_nan = False
  398. for key in loss_dict:
  399. if not skipped_iter:
  400. total_loss_dict[key] = total_loss_dict.get(
  401. key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
  402. else:
  403. value = loss_dict[key].float().sum().item()
  404. is_nan = value == float('inf') or \
  405. value == -float('inf') or \
  406. value != value
  407. got_nan = got_nan or is_nan
  408. total_loss_dict[nan_iters_key] = total_loss_dict.get(
  409. nan_iters_key, 0) + int(got_nan)
  410. # Logging.
  411. timers_to_log = []
  412. def add_to_logging(name):
  413. if name in timers.timers:
  414. timers_to_log.append(name)
  415. add_to_logging('forward-compute')
  416. add_to_logging('forward-recv')
  417. add_to_logging('forward-send')
  418. add_to_logging('forward-backward-send-forward-backward-recv')
  419. add_to_logging('backward-compute')
  420. add_to_logging('backward-recv')
  421. add_to_logging('backward-send')
  422. add_to_logging('backward-send-forward-recv')
  423. add_to_logging('backward-send-backward-recv')
  424. add_to_logging('backward-params-all-reduce')
  425. add_to_logging('backward-embedding-all-reduce')
  426. add_to_logging('optimizer-copy-to-main-grad')
  427. add_to_logging('optimizer-unscale-and-check-inf')
  428. add_to_logging('optimizer-clip-main-grad')
  429. add_to_logging('optimizer-copy-main-to-model-params')
  430. add_to_logging('optimizer')
  431. add_to_logging('batch-generator')
  432. # Calculate batch size.
  433. batch_size = args.micro_batch_size * args.data_parallel_size * \
  434. get_num_microbatches()
  435. total_iterations = total_loss_dict[advanced_iters_key] + \
  436. total_loss_dict[skipped_iters_key]
  437. # Tensorboard values.
  438. if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
  439. is_last_rank():
  440. if args.log_learning_rate_to_tensorboard:
  441. writer.add_scalar('learning-rate', learning_rate, iteration)
  442. writer.add_scalar('learning-rate vs samples', learning_rate,
  443. args.consumed_train_samples)
  444. if args.log_batch_size_to_tensorboard:
  445. writer.add_scalar('batch-size', batch_size, iteration)
  446. writer.add_scalar('batch-size vs samples', batch_size,
  447. args.consumed_train_samples)
  448. for key in loss_dict:
  449. writer.add_scalar(key , loss_dict[key], iteration)
  450. writer.add_scalar(key + ' vs samples', loss_dict[key],
  451. args.consumed_train_samples)
  452. if args.log_loss_scale_to_tensorboard:
  453. writer.add_scalar('loss-scale', loss_scale, iteration)
  454. writer.add_scalar('loss-scale vs samples', loss_scale,
  455. args.consumed_train_samples)
  456. if grad_norm is not None:
  457. writer.add_scalar('grad-norm', grad_norm, iteration)
  458. writer.add_scalar('grad-norm vs samples', grad_norm,
  459. args.consumed_train_samples)
  460. if num_zeros_in_grad is not None:
  461. writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
  462. writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
  463. args.consumed_train_samples)
  464. if params_norm is not None:
  465. writer.add_scalar('params-norm', params_norm, iteration)
  466. writer.add_scalar('params-norm vs samples', params_norm,
  467. args.consumed_train_samples)
  468. if args.log_timers_to_tensorboard:
  469. timers.write(timers_to_log, writer, iteration,
  470. normalizer=total_iterations)
  471. if iteration % args.log_interval == 0:
  472. elapsed_time = timers('interval-time').elapsed()
  473. elapsed_time_per_iteration = elapsed_time / total_iterations
  474. if writer and torch.distributed.get_rank() == 0:
  475. if args.log_timers_to_tensorboard:
  476. writer.add_scalar('iteration-time',
  477. elapsed_time_per_iteration, iteration)
  478. log_string = ' iteration {:8d}/{:8d} |'.format(
  479. iteration, args.train_iters)
  480. log_string += ' consumed samples: {:12d} |'.format(
  481. args.consumed_train_samples)
  482. log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
  483. elapsed_time_per_iteration * 1000.0)
  484. log_string += ' learning rate: {:.3E} |'.format(learning_rate)
  485. log_string += ' global batch size: {:5d} |'.format(batch_size)
  486. for key in total_loss_dict:
  487. if key not in [advanced_iters_key, skipped_iters_key,
  488. nan_iters_key]:
  489. avg = total_loss_dict[key].item() / \
  490. float(max(1, total_loss_dict[advanced_iters_key]))
  491. if avg > 0.0:
  492. log_string += ' {}: {:.6E} |'.format(key, avg)
  493. total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
  494. log_string += ' loss scale: {:.1f} |'.format(loss_scale)
  495. if grad_norm is not None:
  496. log_string += ' grad norm: {:.3f} |'.format(grad_norm)
  497. if num_zeros_in_grad is not None:
  498. log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
  499. if params_norm is not None:
  500. log_string += ' params norm: {:.3f} |'.format(params_norm)
  501. log_string += ' number of skipped iterations: {:3d} |'.format(
  502. total_loss_dict[skipped_iters_key])
  503. log_string += ' number of nan iterations: {:3d} |'.format(
  504. total_loss_dict[nan_iters_key])
  505. total_loss_dict[advanced_iters_key] = 0
  506. total_loss_dict[skipped_iters_key] = 0
  507. total_loss_dict[nan_iters_key] = 0
  508. print_rank_last(log_string)
  509. if report_memory_flag and learning_rate > 0.:
  510. # Report memory after optimizer state has been initialized.
  511. report_memory('(after {} iterations)'.format(iteration))
  512. report_memory_flag = False
  513. timers.log(timers_to_log, normalizer=args.log_interval)
  514. return report_memory_flag
  515. def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
  516. timers = get_timers()
  517. # Extra barrier is added to make sure
  518. # all ranks report the max time.
  519. torch.distributed.barrier()
  520. timers('save-checkpoint').start()
  521. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  522. torch.distributed.barrier()
  523. timers('save-checkpoint').stop()
  524. timers.log(['save-checkpoint'])
  525. def train(forward_step_func, model, optimizer, lr_scheduler,
  526. train_data_iterator, valid_data_iterator):
  527. """Train the model function."""
  528. args = get_args()
  529. timers = get_timers()
  530. # Write args to tensorboard
  531. write_args_to_tensorboard()
  532. # Turn on training mode which enables dropout.
  533. for model_module in model:
  534. model_module.train()
  535. # Tracking loss.
  536. total_loss_dict = {}
  537. # Iterations.
  538. iteration = args.iteration
  539. timers('interval-time').start()
  540. print_datetime('before the start of training step')
  541. report_memory_flag = True
  542. while iteration < args.train_iters:
  543. update_num_microbatches(args.consumed_train_samples)
  544. loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
  545. train_step(forward_step_func,
  546. train_data_iterator,
  547. model,
  548. optimizer,
  549. lr_scheduler)
  550. iteration += 1
  551. args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
  552. args.micro_batch_size * \
  553. get_num_microbatches()
  554. # Logging.
  555. loss_scale = optimizer.get_loss_scale().item()
  556. params_norm = None
  557. if args.log_params_norm:
  558. params_norm = calc_params_l2_norm(model)
  559. report_memory_flag = training_log(loss_dict, total_loss_dict,
  560. optimizer.param_groups[0]['lr'],
  561. iteration, loss_scale,
  562. report_memory_flag, skipped_iter,
  563. grad_norm, params_norm, num_zeros_in_grad)
  564. # Autoresume
  565. if args.adlr_autoresume and \
  566. (iteration % args.adlr_autoresume_interval == 0):
  567. check_adlr_autoresume_termination(iteration, model, optimizer,
  568. lr_scheduler)
  569. # Evaluation
  570. if args.eval_interval and iteration % args.eval_interval == 0 and \
  571. args.do_valid:
  572. prefix = 'iteration {}'.format(iteration)
  573. evaluate_and_print_results(prefix, forward_step_func,
  574. valid_data_iterator, model,
  575. iteration, False)
  576. # Checkpointing
  577. saved_checkpoint = False
  578. if args.save and args.save_interval and \
  579. iteration % args.save_interval == 0:
  580. save_checkpoint_and_time(iteration, model, optimizer,
  581. lr_scheduler)
  582. saved_checkpoint = True
  583. # Exiting based on duration
  584. if args.exit_duration_in_mins:
  585. train_time = (time.time() - _TRAIN_START_TIME) / 60.0
  586. done_cuda = torch.cuda.IntTensor(
  587. [train_time > args.exit_duration_in_mins])
  588. torch.distributed.all_reduce(
  589. done_cuda, op=torch.distributed.ReduceOp.MAX)
  590. done = done_cuda.item()
  591. if done:
  592. if not saved_checkpoint:
  593. save_checkpoint_and_time(iteration, model, optimizer,
  594. lr_scheduler)
  595. print_datetime('exiting program after {} minutes'.format(train_time))
  596. sys.exit()
  597. # Exiting based on iterations
  598. if args.exit_interval and iteration % args.exit_interval == 0:
  599. if not saved_checkpoint:
  600. save_checkpoint_and_time(iteration, model, optimizer,
  601. lr_scheduler)
  602. torch.distributed.barrier()
  603. print_datetime('exiting program at iteration {}'.format(iteration))
  604. sys.exit()
  605. return iteration
  606. def evaluate(forward_step_func, data_iterator, model, verbose=False):
  607. """Evaluation."""
  608. args = get_args()
  609. # Turn on evaluation mode which disables dropout.
  610. for model_module in model:
  611. model_module.eval()
  612. total_loss_dict = {}
  613. with torch.no_grad():
  614. iteration = 0
  615. while iteration < args.eval_iters:
  616. iteration += 1
  617. if verbose and iteration % args.log_interval == 0:
  618. print_rank_0('Evaluating iter {}/{}'.format(iteration,
  619. args.eval_iters))
  620. if mpu.get_pipeline_model_parallel_world_size() > 1:
  621. if args.virtual_pipeline_model_parallel_size is not None:
  622. forward_backward_func = forward_backward_pipelining_with_interleaving
  623. else:
  624. forward_backward_func = forward_backward_pipelining_without_interleaving
  625. else:
  626. forward_backward_func = forward_backward_no_pipelining
  627. loss_dicts = forward_backward_func(
  628. forward_step_func, data_iterator, model, optimizer=None,
  629. timers=None, forward_only=True)
  630. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  631. # Reduce across processes.
  632. for loss_dict in loss_dicts:
  633. for key in loss_dict:
  634. total_loss_dict[key] = total_loss_dict.get(
  635. key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
  636. args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
  637. * args.micro_batch_size \
  638. * get_num_microbatches()
  639. # Move model back to the train mode.
  640. for model_module in model:
  641. model_module.train()
  642. for key in total_loss_dict:
  643. total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
  644. return total_loss_dict
  645. def evaluate_and_print_results(prefix, forward_step_func,
  646. data_iterator, model,
  647. iteration, verbose=False):
  648. """Helper function to evaluate and dump results on screen."""
  649. args = get_args()
  650. writer = get_tensorboard_writer()
  651. total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
  652. string = ' validation loss at {} | '.format(prefix)
  653. for key in total_loss_dict:
  654. string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
  655. ppl = math.exp(min(20, total_loss_dict[key].item()))
  656. string += '{} PPL: {:.6E} | '.format(key, ppl)
  657. if writer and is_last_rank():
  658. writer.add_scalar('{} validation'.format(key),
  659. total_loss_dict[key].item(),
  660. iteration)
  661. writer.add_scalar('{} validation vs samples'.format(key),
  662. total_loss_dict[key].item(),
  663. args.consumed_train_samples)
  664. if args.log_validation_ppl_to_tensorboard:
  665. writer.add_scalar('{} validation ppl'.format(key), ppl,
  666. iteration)
  667. writer.add_scalar('{} validation ppl vs samples'.format(key),
  668. ppl, args.consumed_train_samples)
  669. length = len(string) + 1
  670. print_rank_last('-' * length)
  671. print_rank_last(string)
  672. print_rank_last('-' * length)
  673. def cyclic_iter(iter):
  674. while True:
  675. for x in iter:
  676. yield x
  677. def build_train_valid_test_data_iterators(
  678. build_train_valid_test_datasets_provider):
  679. """XXX"""
  680. args = get_args()
  681. (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
  682. print_rank_0('> building train, validation, and test datasets ...')
  683. # Backward compatibility, assume fixed batch size.
  684. if args.iteration > 0 and args.consumed_train_samples == 0:
  685. assert args.train_samples is None, \
  686. 'only backward compatiblity support for iteration-based training'
  687. args.consumed_train_samples = args.iteration * args.global_batch_size
  688. if args.iteration > 0 and args.consumed_valid_samples == 0:
  689. assert args.train_samples is None, \
  690. 'only backward compatiblity support for iteration-based training'
  691. args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
  692. args.eval_iters * args.global_batch_size
  693. # Data loader only on rank 0 of each model parallel group.
  694. if mpu.get_tensor_model_parallel_rank() == 0:
  695. # Number of train/valid/test samples.
  696. if args.train_samples:
  697. train_samples = args.train_samples
  698. else:
  699. train_samples = args.train_iters * args.global_batch_size
  700. eval_iters = (args.train_iters // args.eval_interval + 1) * \
  701. args.eval_iters
  702. test_iters = args.eval_iters
  703. train_val_test_num_samples = [train_samples,
  704. eval_iters * args.global_batch_size,
  705. test_iters * args.global_batch_size]
  706. print_rank_0(' > datasets target sizes (minimum size):')
  707. print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
  708. print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
  709. print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
  710. # Build the datasets.
  711. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
  712. train_val_test_num_samples)
  713. # Build dataloders.
  714. train_dataloader = build_pretraining_data_loader(
  715. train_ds, args.consumed_train_samples)
  716. valid_dataloader = build_pretraining_data_loader(
  717. valid_ds, args.consumed_valid_samples)
  718. test_dataloader = build_pretraining_data_loader(test_ds, 0)
  719. # Flags to know if we need to do training/validation/testing.
  720. do_train = train_dataloader is not None and args.train_iters > 0
  721. do_valid = valid_dataloader is not None and args.eval_iters > 0
  722. do_test = test_dataloader is not None and args.eval_iters > 0
  723. # Need to broadcast num_tokens and num_type_tokens.
  724. flags = torch.cuda.LongTensor(
  725. [int(do_train), int(do_valid), int(do_test)])
  726. else:
  727. flags = torch.cuda.LongTensor([0, 0, 0])
  728. # Broadcast num tokens.
  729. torch.distributed.broadcast(flags,
  730. mpu.get_tensor_model_parallel_src_rank(),
  731. group=mpu.get_tensor_model_parallel_group())
  732. args.do_train = flags[0].item()
  733. args.do_valid = flags[1].item()
  734. args.do_test = flags[2].item()
  735. # Build iterators.
  736. dl_type = args.dataloader_type
  737. assert dl_type in ['single', 'cyclic']
  738. if train_dataloader is not None:
  739. train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
  740. else iter(cyclic_iter(train_dataloader))
  741. else:
  742. train_data_iterator = None
  743. if valid_dataloader is not None:
  744. valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
  745. else iter(cyclic_iter(valid_dataloader))
  746. else:
  747. valid_data_iterator = None
  748. if test_dataloader is not None:
  749. test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
  750. else iter(cyclic_iter(test_dataloader))
  751. else:
  752. test_data_iterator = None
  753. return train_data_iterator, valid_data_iterator, test_data_iterator