training_original.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Pretrain utilities."""
  16. from datetime import datetime
  17. import math
  18. import sys
  19. import time
  20. # The earliest we can measure the start time.
  21. _TRAIN_START_TIME = time.time()
  22. import torch
  23. from torch.nn.parallel.distributed import DistributedDataParallel as torchDDP
  24. from megatron import get_args
  25. from megatron import get_timers
  26. from megatron import get_tensorboard_writer
  27. from megatron import get_current_global_batch_size
  28. from megatron import get_num_microbatches
  29. from megatron import is_last_rank
  30. from megatron import update_num_microbatches
  31. from megatron import mpu
  32. from megatron import print_rank_0
  33. from megatron import print_rank_last
  34. from megatron.checkpointing import load_checkpoint
  35. from megatron.checkpointing import save_checkpoint
  36. from megatron.model import Float16Module
  37. from megatron.optimizer import get_megatron_optimizer
  38. from megatron.initialize import initialize_megatron
  39. from megatron.initialize import write_args_to_tensorboard
  40. from megatron.learning_rates import AnnealingLR
  41. from megatron.model import DistributedDataParallel as LocalDDP
  42. from megatron.utils import check_adlr_autoresume_termination
  43. from megatron.utils import unwrap_model
  44. from megatron.data.data_samplers import build_pretraining_data_loader
  45. from megatron.utils import calc_params_l2_norm
  46. from megatron.schedules import forward_backward_no_pipelining
  47. from megatron.schedules import forward_backward_pipelining_without_interleaving
  48. from megatron.schedules import forward_backward_pipelining_with_interleaving
  49. from megatron.utils import report_memory
  50. def print_datetime(string):
  51. """Note that this call will sync across all ranks."""
  52. torch.distributed.barrier()
  53. time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
  54. print_rank_0('[' + string + '] datetime: {} '.format(time_str))
  55. def pretrain(train_valid_test_dataset_provider,
  56. model_provider,
  57. forward_step_func,
  58. extra_args_provider=None,
  59. args_defaults={}):
  60. """Main training program.
  61. This function will run the followings in the order provided:
  62. 1) initialize Megatron.
  63. 2) setup model, optimizer and lr schedule using the model_provider.
  64. 3) call train_val_test_data_provider to get train/val/test datasets.
  65. 4) train the modle using the forward_step_func.
  66. Arguments:
  67. train_valid_test_dataset_provider: a function that takes the size of
  68. train/valid/test dataset and returns `train, valid, test` datasets.
  69. model_provider: a function that returns a vanilla version of the
  70. model. By vanilla we mean a simple model on cpu with no fp16 or ddp.
  71. forward_step_func: a function that takes a `data iterator` and `model`,
  72. and returns a `loss` scalar with a dictionary with key:values being
  73. the info we would like to monitor during training, for example
  74. `lm-loss: value`. We also require that this function add
  75. `batch generator` to the timers class.
  76. extra_args_provider: a function that takes a parser and adds arguments
  77. to it. It is used for programs to add their own arguments.
  78. args_defaults: a dictionary from argument-name to argument-value. It
  79. to set already parse arguments.
  80. """
  81. # Initalize and get arguments, timers, and Tensorboard writer.
  82. initialize_megatron(extra_args_provider=extra_args_provider,
  83. args_defaults=args_defaults)
  84. # Adjust the startup time so it reflects the largest value.
  85. # This will be closer to what scheduler will see (outside of
  86. # image ... launches.
  87. global _TRAIN_START_TIME
  88. start_time_tensor = torch.cuda.FloatTensor([_TRAIN_START_TIME])
  89. torch.distributed.all_reduce(start_time_tensor,
  90. op=torch.distributed.ReduceOp.MIN)
  91. _TRAIN_START_TIME = start_time_tensor.item()
  92. print_rank_0('time to initialize megatron (seconds): {:.3f}'.format(
  93. time.time() - _TRAIN_START_TIME))
  94. print_datetime('after megatron is initialized')
  95. args = get_args()
  96. timers = get_timers()
  97. # Model, optimizer, and learning rate.
  98. timers('model-and-optimizer-setup').start()
  99. model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  100. timers('model-and-optimizer-setup').stop()
  101. print_datetime('after model, optimizer, and learning rate '
  102. 'scheduler are built')
  103. # Data stuff.
  104. timers('train/valid/test-data-iterators-setup').start()
  105. if args.virtual_pipeline_model_parallel_size is not None:
  106. all_data_iterators = [
  107. build_train_valid_test_data_iterators(train_valid_test_dataset_provider)
  108. for _ in range(len(model))
  109. ]
  110. train_data_iterator = [data_iterators[0] for data_iterators in all_data_iterators]
  111. valid_data_iterator = [data_iterators[1] for data_iterators in all_data_iterators]
  112. test_data_iterator = [data_iterators[2] for data_iterators in all_data_iterators]
  113. else:
  114. train_data_iterator, valid_data_iterator, test_data_iterator \
  115. = build_train_valid_test_data_iterators(
  116. train_valid_test_dataset_provider)
  117. timers('train/valid/test-data-iterators-setup').stop()
  118. print_datetime('after dataloaders are built')
  119. # Print setup timing.
  120. print_rank_0('done with setup ...')
  121. timers.log(['model-and-optimizer-setup', 'train/valid/test-data-iterators-setup'])
  122. print_rank_0('training ...')
  123. iteration = 0
  124. if args.do_train and args.train_iters > 0:
  125. iteration = train(forward_step_func,
  126. model, optimizer, lr_scheduler,
  127. train_data_iterator, valid_data_iterator)
  128. print_datetime('after training is done')
  129. if args.do_valid:
  130. prefix = 'the end of training for val data'
  131. evaluate_and_print_results(prefix, forward_step_func,
  132. valid_data_iterator, model,
  133. iteration, False)
  134. if args.save and iteration != 0:
  135. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  136. if args.do_test:
  137. # Run on test data.
  138. prefix = 'the end of training for test data'
  139. evaluate_and_print_results(prefix, forward_step_func,
  140. test_data_iterator, model,
  141. 0, True)
  142. def update_train_iters(args):
  143. # For iteration-based training, we don't need to do anything
  144. if args.train_iters:
  145. return
  146. # Constant batch size with sample-based training.
  147. if args.rampup_batch_size is None:
  148. args.train_iters = args.train_samples // args.global_batch_size
  149. else:
  150. # Sample based training with rampup batch size.
  151. iterations = 0
  152. consumed_samples = 0
  153. # Rampup phase.
  154. while consumed_samples <= int(args.rampup_batch_size[2]):
  155. update_num_microbatches(consumed_samples, consistency_check=False)
  156. consumed_samples += get_current_global_batch_size()
  157. iterations += 1
  158. # Reset
  159. update_num_microbatches(0, consistency_check=False)
  160. # Constant phase
  161. # Note that we throw away any partial last batch.
  162. iterations += (args.train_samples - consumed_samples) // \
  163. args.global_batch_size
  164. args.train_iters = iterations
  165. print_rank_0('setting training iterations to {}'.format(args.train_iters))
  166. def get_model(model_provider_func):
  167. """Build the model."""
  168. args = get_args()
  169. # Build model.
  170. if mpu.get_pipeline_model_parallel_world_size() > 1 and \
  171. args.virtual_pipeline_model_parallel_size is not None:
  172. model = []
  173. for i in range(args.virtual_pipeline_model_parallel_size):
  174. mpu.set_virtual_pipeline_model_parallel_rank(i)
  175. # Set pre_process and post_process only after virtual rank is set.
  176. pre_process = mpu.is_pipeline_first_stage()
  177. post_process = mpu.is_pipeline_last_stage()
  178. this_model = model_provider_func(
  179. pre_process=pre_process,
  180. post_process=post_process
  181. )
  182. model.append(this_model)
  183. else:
  184. pre_process = mpu.is_pipeline_first_stage()
  185. post_process = mpu.is_pipeline_last_stage()
  186. model = model_provider_func(
  187. pre_process=pre_process,
  188. post_process=post_process
  189. )
  190. if not isinstance(model, list):
  191. model = [model]
  192. # Set tensor model parallel attributes if not set.
  193. # Only parameters that are already tensor model parallel have these
  194. # attributes set for them. We should make sure the default attributes
  195. # are set for all params so the optimizer can use them.
  196. for model_module in model:
  197. for param in model_module.parameters():
  198. mpu.set_defaults_if_not_set_tensor_model_parallel_attributes(param)
  199. # Print number of parameters.
  200. if mpu.get_data_parallel_rank() == 0:
  201. print(' > number of parameters on (tensor, pipeline) '
  202. 'model parallel rank ({}, {}): {}'.format(
  203. mpu.get_tensor_model_parallel_rank(),
  204. mpu.get_pipeline_model_parallel_rank(),
  205. sum([sum([p.nelement() for p in model_module.parameters()])
  206. for model_module in model])), flush=True)
  207. # GPU allocation.
  208. for model_module in model:
  209. model_module.cuda(torch.cuda.current_device())
  210. # Fp16 conversion.
  211. if args.fp16 or args.bf16:
  212. model = [Float16Module(model_module, args) for model_module in model]
  213. if args.DDP_impl == 'torch':
  214. i = torch.cuda.current_device()
  215. model = [torchDDP(model_module, device_ids=[i], output_device=i,
  216. process_group=mpu.get_data_parallel_group())
  217. for model_module in model]
  218. return model
  219. if args.DDP_impl == 'local':
  220. model = [LocalDDP(model_module,
  221. args.accumulate_allreduce_grads_in_fp32,
  222. args.use_contiguous_buffers_in_ddp)
  223. for model_module in model]
  224. return model
  225. raise NotImplementedError('Unknown DDP implementation specified: {}. '
  226. 'Exiting.'.format(args.DDP_impl))
  227. def get_learning_rate_scheduler(optimizer):
  228. """Build the learning rate scheduler."""
  229. args = get_args()
  230. # Iteration-based training.
  231. if args.train_iters:
  232. if args.lr_decay_iters is None:
  233. args.lr_decay_iters = args.train_iters
  234. decay_steps = args.lr_decay_iters * args.global_batch_size
  235. if args.lr_warmup_fraction is not None:
  236. warmup_steps = args.lr_warmup_fraction * decay_steps
  237. else:
  238. warmup_steps = args.lr_warmup_iters * args.global_batch_size
  239. # Sample-based training.
  240. elif args.train_samples:
  241. # We need to set training iters for later use. Technically
  242. # we need to adjust the training samples too (due to last
  243. # batch being incomplete) but we leave it as is for now.
  244. update_train_iters(args)
  245. if args.lr_decay_samples is None:
  246. args.lr_decay_samples = args.train_samples
  247. decay_steps = args.lr_decay_samples
  248. if args.lr_warmup_fraction is not None:
  249. warmup_steps = args.lr_warmup_fraction * decay_steps
  250. else:
  251. warmup_steps = args.lr_warmup_samples
  252. else:
  253. raise Exception(
  254. 'either train-iters or train-samples should be provided.')
  255. lr_scheduler = AnnealingLR(
  256. optimizer,
  257. max_lr=args.lr,
  258. min_lr=args.min_lr,
  259. warmup_steps=warmup_steps,
  260. decay_steps=decay_steps,
  261. decay_style=args.lr_decay_style,
  262. use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
  263. override_lr_scheduler=args.override_lr_scheduler)
  264. return lr_scheduler
  265. def setup_model_and_optimizer(model_provider_func):
  266. """Setup model and optimizer."""
  267. args = get_args()
  268. model = get_model(model_provider_func)
  269. unwrapped_model = unwrap_model(model,
  270. (torchDDP, LocalDDP, Float16Module))
  271. optimizer = get_megatron_optimizer(unwrapped_model)
  272. lr_scheduler = get_learning_rate_scheduler(optimizer)
  273. if args.load is not None:
  274. timers = get_timers()
  275. # Extra barrier is added to make sure all ranks report the
  276. # max time.
  277. torch.distributed.barrier()
  278. timers('load-checkpoint').start()
  279. args.iteration = load_checkpoint(model, optimizer, lr_scheduler)
  280. torch.distributed.barrier()
  281. timers('load-checkpoint').stop()
  282. timers.log(['load-checkpoint'])
  283. else:
  284. args.iteration = 0
  285. # We only support local DDP with multiple micro-batches.
  286. if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
  287. assert args.DDP_impl == 'local'
  288. # get model without FP16 and/or TorchDDP wrappers
  289. if args.iteration == 0 and len(unwrapped_model) == 1 \
  290. and hasattr(unwrapped_model[0], 'init_state_dict_from_bert'):
  291. print_rank_0("Initializing ICT from pretrained BERT model")
  292. unwrapped_model[0].init_state_dict_from_bert()
  293. if args.fp16:
  294. optimizer.reload_model_params()
  295. return model, optimizer, lr_scheduler
  296. def train_step(forward_step_func, data_iterator,
  297. model, optimizer, lr_scheduler):
  298. """Single training step."""
  299. args = get_args()
  300. timers = get_timers()
  301. # Set grad to zero.
  302. if args.DDP_impl == 'local' and args.use_contiguous_buffers_in_ddp:
  303. for partition in model:
  304. partition.zero_grad_buffer()
  305. else:
  306. optimizer.zero_grad()
  307. if mpu.get_pipeline_model_parallel_world_size() > 1:
  308. if args.virtual_pipeline_model_parallel_size is not None:
  309. forward_backward_func = forward_backward_pipelining_with_interleaving
  310. assert get_num_microbatches() % args.pipeline_model_parallel_size == 0, \
  311. 'number of microbatches is not divisible by pipeline-parallel ' \
  312. 'size when using interleaved schedule'
  313. else:
  314. forward_backward_func = forward_backward_pipelining_without_interleaving
  315. else:
  316. forward_backward_func = forward_backward_no_pipelining
  317. losses_reduced = forward_backward_func(
  318. forward_step_func, data_iterator, model,
  319. optimizer, timers, forward_only=False)
  320. # All-reduce if needed.
  321. if args.DDP_impl == 'local':
  322. timers('backward-params-all-reduce').start()
  323. for model_module in model:
  324. model_module.allreduce_gradients()
  325. timers('backward-params-all-reduce').stop()
  326. # All-reduce word_embeddings' grad across first and last stages to ensure
  327. # that word_embeddings parameters stay in sync.
  328. # This should only run for models that support pipelined model parallelism
  329. # (BERT and GPT-2).
  330. timers('backward-embedding-all-reduce').start()
  331. if (mpu.is_pipeline_first_stage(ignore_virtual=True) or
  332. mpu.is_pipeline_last_stage(ignore_virtual=True)) and \
  333. mpu.get_pipeline_model_parallel_world_size() > 1:
  334. if mpu.is_pipeline_first_stage(ignore_virtual=True):
  335. unwrapped_model = model[0]
  336. elif mpu.is_pipeline_last_stage(ignore_virtual=True):
  337. unwrapped_model = model[-1]
  338. unwrapped_model = unwrap_model(
  339. unwrapped_model, (torchDDP, LocalDDP, Float16Module))
  340. if unwrapped_model.share_word_embeddings:
  341. word_embeddings_weight = unwrapped_model.word_embeddings_weight()
  342. if args.DDP_impl == 'local':
  343. grad = word_embeddings_weight.main_grad
  344. else:
  345. grad = word_embeddings_weight.grad
  346. torch.distributed.all_reduce(grad, group=mpu.get_embedding_group())
  347. timers('backward-embedding-all-reduce').stop()
  348. # Update parameters.
  349. timers('optimizer').start()
  350. update_successful, grad_norm, num_zeros_in_grad = optimizer.step()
  351. timers('optimizer').stop()
  352. # Update learning rate.
  353. if update_successful:
  354. increment = get_num_microbatches() * \
  355. args.micro_batch_size * \
  356. args.data_parallel_size
  357. lr_scheduler.step(increment=increment)
  358. skipped_iter = 0
  359. else:
  360. skipped_iter = 1
  361. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  362. # Average loss across microbatches.
  363. loss_reduced = {}
  364. for key in losses_reduced[0]:
  365. losses_reduced_for_key = [x[key] for x in losses_reduced]
  366. loss_reduced[key] = sum(losses_reduced_for_key) / len(losses_reduced_for_key)
  367. return loss_reduced, skipped_iter, grad_norm, num_zeros_in_grad
  368. return {}, skipped_iter, grad_norm, num_zeros_in_grad
  369. def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
  370. loss_scale, report_memory_flag, skipped_iter,
  371. grad_norm, params_norm, num_zeros_in_grad):
  372. """Log training information such as losses, timing, ...."""
  373. args = get_args()
  374. timers = get_timers()
  375. writer = get_tensorboard_writer()
  376. # Advanced, skipped, and Nan iterations.
  377. advanced_iters_key = 'advanced iterations'
  378. skipped_iters_key = 'skipped iterations'
  379. nan_iters_key = 'nan iterations'
  380. # Advanced iterations.
  381. if not skipped_iter:
  382. total_loss_dict[advanced_iters_key] = total_loss_dict.get(
  383. advanced_iters_key, 0) + 1
  384. else:
  385. if advanced_iters_key not in total_loss_dict:
  386. total_loss_dict[advanced_iters_key] = 0
  387. # Skipped iterations.
  388. total_loss_dict[skipped_iters_key] = total_loss_dict.get(
  389. skipped_iters_key, 0) + skipped_iter
  390. # Update losses and set nan iterations
  391. got_nan = False
  392. for key in loss_dict:
  393. if not skipped_iter:
  394. total_loss_dict[key] = total_loss_dict.get(
  395. key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
  396. else:
  397. value = loss_dict[key].float().sum().item()
  398. is_nan = value == float('inf') or \
  399. value == -float('inf') or \
  400. value != value
  401. got_nan = got_nan or is_nan
  402. total_loss_dict[nan_iters_key] = total_loss_dict.get(
  403. nan_iters_key, 0) + int(got_nan)
  404. # Logging.
  405. timers_to_log = []
  406. def add_to_logging(name):
  407. if name in timers.timers:
  408. timers_to_log.append(name)
  409. add_to_logging('forward-compute')
  410. add_to_logging('forward-recv')
  411. add_to_logging('forward-send')
  412. add_to_logging('forward-backward-send-forward-backward-recv')
  413. add_to_logging('backward-compute')
  414. add_to_logging('backward-recv')
  415. add_to_logging('backward-send')
  416. add_to_logging('backward-send-forward-recv')
  417. add_to_logging('backward-send-backward-recv')
  418. add_to_logging('backward-params-all-reduce')
  419. add_to_logging('backward-embedding-all-reduce')
  420. add_to_logging('optimizer-copy-to-main-grad')
  421. add_to_logging('optimizer-unscale-and-check-inf')
  422. add_to_logging('optimizer-clip-main-grad')
  423. add_to_logging('optimizer-copy-main-to-model-params')
  424. add_to_logging('optimizer')
  425. add_to_logging('batch-generator')
  426. # Calculate batch size.
  427. batch_size = args.micro_batch_size * args.data_parallel_size * \
  428. get_num_microbatches()
  429. total_iterations = total_loss_dict[advanced_iters_key] + \
  430. total_loss_dict[skipped_iters_key]
  431. # Tensorboard values.
  432. if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
  433. is_last_rank():
  434. if args.log_learning_rate_to_tensorboard:
  435. writer.add_scalar('learning-rate', learning_rate, iteration)
  436. writer.add_scalar('learning-rate vs samples', learning_rate,
  437. args.consumed_train_samples)
  438. if args.log_batch_size_to_tensorboard:
  439. writer.add_scalar('batch-size', batch_size, iteration)
  440. writer.add_scalar('batch-size vs samples', batch_size,
  441. args.consumed_train_samples)
  442. for key in loss_dict:
  443. writer.add_scalar(key , loss_dict[key], iteration)
  444. writer.add_scalar(key + ' vs samples', loss_dict[key],
  445. args.consumed_train_samples)
  446. if args.log_loss_scale_to_tensorboard:
  447. writer.add_scalar('loss-scale', loss_scale, iteration)
  448. writer.add_scalar('loss-scale vs samples', loss_scale,
  449. args.consumed_train_samples)
  450. if grad_norm is not None:
  451. writer.add_scalar('grad-norm', grad_norm, iteration)
  452. writer.add_scalar('grad-norm vs samples', grad_norm,
  453. args.consumed_train_samples)
  454. if num_zeros_in_grad is not None:
  455. writer.add_scalar('num-zeros', num_zeros_in_grad, iteration)
  456. writer.add_scalar('num-zeros vs samples', num_zeros_in_grad,
  457. args.consumed_train_samples)
  458. if params_norm is not None:
  459. writer.add_scalar('params-norm', params_norm, iteration)
  460. writer.add_scalar('params-norm vs samples', params_norm,
  461. args.consumed_train_samples)
  462. if args.log_timers_to_tensorboard:
  463. timers.write(timers_to_log, writer, iteration,
  464. normalizer=total_iterations)
  465. if iteration % args.log_interval == 0:
  466. elapsed_time = timers('interval-time').elapsed()
  467. elapsed_time_per_iteration = elapsed_time / total_iterations
  468. if writer and torch.distributed.get_rank() == 0:
  469. if args.log_timers_to_tensorboard:
  470. writer.add_scalar('iteration-time',
  471. elapsed_time_per_iteration, iteration)
  472. log_string = ' iteration {:8d}/{:8d} |'.format(
  473. iteration, args.train_iters)
  474. log_string += ' consumed samples: {:12d} |'.format(
  475. args.consumed_train_samples)
  476. log_string += ' elapsed time per iteration (ms): {:.1f} |'.format(
  477. elapsed_time_per_iteration * 1000.0)
  478. log_string += ' learning rate: {:.3E} |'.format(learning_rate)
  479. log_string += ' global batch size: {:5d} |'.format(batch_size)
  480. for key in total_loss_dict:
  481. if key not in [advanced_iters_key, skipped_iters_key,
  482. nan_iters_key]:
  483. avg = total_loss_dict[key].item() / \
  484. float(max(1, total_loss_dict[advanced_iters_key]))
  485. if avg > 0.0:
  486. log_string += ' {}: {:.6E} |'.format(key, avg)
  487. total_loss_dict[key] = torch.cuda.FloatTensor([0.0])
  488. log_string += ' loss scale: {:.1f} |'.format(loss_scale)
  489. if grad_norm is not None:
  490. log_string += ' grad norm: {:.3f} |'.format(grad_norm)
  491. if num_zeros_in_grad is not None:
  492. log_string += ' num zeros: {:.1f} |'.format(num_zeros_in_grad)
  493. if params_norm is not None:
  494. log_string += ' params norm: {:.3f} |'.format(params_norm)
  495. log_string += ' number of skipped iterations: {:3d} |'.format(
  496. total_loss_dict[skipped_iters_key])
  497. log_string += ' number of nan iterations: {:3d} |'.format(
  498. total_loss_dict[nan_iters_key])
  499. total_loss_dict[advanced_iters_key] = 0
  500. total_loss_dict[skipped_iters_key] = 0
  501. total_loss_dict[nan_iters_key] = 0
  502. print_rank_last(log_string)
  503. if report_memory_flag and learning_rate > 0.:
  504. # Report memory after optimizer state has been initialized.
  505. report_memory('(after {} iterations)'.format(iteration))
  506. report_memory_flag = False
  507. timers.log(timers_to_log, normalizer=args.log_interval)
  508. return report_memory_flag
  509. def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
  510. timers = get_timers()
  511. # Extra barrier is added to make sure
  512. # all ranks report the max time.
  513. torch.distributed.barrier()
  514. timers('save-checkpoint').start()
  515. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  516. torch.distributed.barrier()
  517. timers('save-checkpoint').stop()
  518. timers.log(['save-checkpoint'])
  519. def train(forward_step_func, model, optimizer, lr_scheduler,
  520. train_data_iterator, valid_data_iterator):
  521. """Train the model function."""
  522. args = get_args()
  523. timers = get_timers()
  524. # Write args to tensorboard
  525. write_args_to_tensorboard()
  526. # Turn on training mode which enables dropout.
  527. for model_module in model:
  528. model_module.train()
  529. # Tracking loss.
  530. total_loss_dict = {}
  531. # Iterations.
  532. iteration = args.iteration
  533. timers('interval-time').start()
  534. print_datetime('before the start of training step')
  535. report_memory_flag = True
  536. while iteration < args.train_iters:
  537. update_num_microbatches(args.consumed_train_samples)
  538. loss_dict, skipped_iter, grad_norm, num_zeros_in_grad = \
  539. train_step(forward_step_func,
  540. train_data_iterator,
  541. model,
  542. optimizer,
  543. lr_scheduler)
  544. iteration += 1
  545. args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
  546. args.micro_batch_size * \
  547. get_num_microbatches()
  548. # Logging.
  549. loss_scale = optimizer.get_loss_scale().item()
  550. params_norm = None
  551. if args.log_params_norm:
  552. params_norm = calc_params_l2_norm(model)
  553. report_memory_flag = training_log(loss_dict, total_loss_dict,
  554. optimizer.param_groups[0]['lr'],
  555. iteration, loss_scale,
  556. report_memory_flag, skipped_iter,
  557. grad_norm, params_norm, num_zeros_in_grad)
  558. # Autoresume
  559. if args.adlr_autoresume and \
  560. (iteration % args.adlr_autoresume_interval == 0):
  561. check_adlr_autoresume_termination(iteration, model, optimizer,
  562. lr_scheduler)
  563. # Evaluation
  564. if args.eval_interval and iteration % args.eval_interval == 0 and \
  565. args.do_valid:
  566. prefix = 'iteration {}'.format(iteration)
  567. evaluate_and_print_results(prefix, forward_step_func,
  568. valid_data_iterator, model,
  569. iteration, False)
  570. # Checkpointing
  571. saved_checkpoint = False
  572. if args.save and args.save_interval and \
  573. iteration % args.save_interval == 0:
  574. save_checkpoint_and_time(iteration, model, optimizer,
  575. lr_scheduler)
  576. saved_checkpoint = True
  577. # Exiting based on duration
  578. if args.exit_duration_in_mins:
  579. train_time = (time.time() - _TRAIN_START_TIME) / 60.0
  580. done_cuda = torch.cuda.IntTensor(
  581. [train_time > args.exit_duration_in_mins])
  582. torch.distributed.all_reduce(
  583. done_cuda, op=torch.distributed.ReduceOp.MAX)
  584. done = done_cuda.item()
  585. if done:
  586. if not saved_checkpoint:
  587. save_checkpoint_and_time(iteration, model, optimizer,
  588. lr_scheduler)
  589. print_datetime('exiting program after {} minutes'.format(train_time))
  590. sys.exit()
  591. # Exiting based on iterations
  592. if args.exit_interval and iteration % args.exit_interval == 0:
  593. if not saved_checkpoint:
  594. save_checkpoint_and_time(iteration, model, optimizer,
  595. lr_scheduler)
  596. torch.distributed.barrier()
  597. print_datetime('exiting program at iteration {}'.format(iteration))
  598. sys.exit()
  599. return iteration
  600. def evaluate(forward_step_func, data_iterator, model, verbose=False):
  601. """Evaluation."""
  602. args = get_args()
  603. # Turn on evaluation mode which disables dropout.
  604. for model_module in model:
  605. model_module.eval()
  606. total_loss_dict = {}
  607. with torch.no_grad():
  608. iteration = 0
  609. while iteration < args.eval_iters:
  610. iteration += 1
  611. if verbose and iteration % args.log_interval == 0:
  612. print_rank_0('Evaluating iter {}/{}'.format(iteration,
  613. args.eval_iters))
  614. if mpu.get_pipeline_model_parallel_world_size() > 1:
  615. if args.virtual_pipeline_model_parallel_size is not None:
  616. forward_backward_func = forward_backward_pipelining_with_interleaving
  617. else:
  618. forward_backward_func = forward_backward_pipelining_without_interleaving
  619. else:
  620. forward_backward_func = forward_backward_no_pipelining
  621. loss_dicts = forward_backward_func(
  622. forward_step_func, data_iterator, model, optimizer=None,
  623. timers=None, forward_only=True)
  624. if mpu.is_pipeline_last_stage(ignore_virtual=True):
  625. # Reduce across processes.
  626. for loss_dict in loss_dicts:
  627. for key in loss_dict:
  628. total_loss_dict[key] = total_loss_dict.get(
  629. key, torch.cuda.FloatTensor([0.0])) + loss_dict[key]
  630. args.consumed_valid_samples += mpu.get_data_parallel_world_size() \
  631. * args.micro_batch_size \
  632. * get_num_microbatches()
  633. # Move model back to the train mode.
  634. for model_module in model:
  635. model_module.train()
  636. for key in total_loss_dict:
  637. total_loss_dict[key] /= args.eval_iters * get_num_microbatches()
  638. return total_loss_dict
  639. def evaluate_and_print_results(prefix, forward_step_func,
  640. data_iterator, model,
  641. iteration, verbose=False):
  642. """Helper function to evaluate and dump results on screen."""
  643. args = get_args()
  644. writer = get_tensorboard_writer()
  645. total_loss_dict = evaluate(forward_step_func, data_iterator, model, verbose)
  646. string = ' validation loss at {} | '.format(prefix)
  647. for key in total_loss_dict:
  648. string += '{} value: {:.6E} | '.format(key, total_loss_dict[key].item())
  649. ppl = math.exp(min(20, total_loss_dict[key].item()))
  650. string += '{} PPL: {:.6E} | '.format(key, ppl)
  651. if writer and is_last_rank():
  652. writer.add_scalar('{} validation'.format(key),
  653. total_loss_dict[key].item(),
  654. iteration)
  655. writer.add_scalar('{} validation vs samples'.format(key),
  656. total_loss_dict[key].item(),
  657. args.consumed_train_samples)
  658. if args.log_validation_ppl_to_tensorboard:
  659. writer.add_scalar('{} validation ppl'.format(key), ppl,
  660. iteration)
  661. writer.add_scalar('{} validation ppl vs samples'.format(key),
  662. ppl, args.consumed_train_samples)
  663. length = len(string) + 1
  664. print_rank_last('-' * length)
  665. print_rank_last(string)
  666. print_rank_last('-' * length)
  667. def cyclic_iter(iter):
  668. while True:
  669. for x in iter:
  670. yield x
  671. def build_train_valid_test_data_iterators(
  672. build_train_valid_test_datasets_provider):
  673. """XXX"""
  674. args = get_args()
  675. (train_dataloader, valid_dataloader, test_dataloader) = (None, None, None)
  676. print_rank_0('> building train, validation, and test datasets ...')
  677. # Backward compatibility, assume fixed batch size.
  678. if args.iteration > 0 and args.consumed_train_samples == 0:
  679. assert args.train_samples is None, \
  680. 'only backward compatiblity support for iteration-based training'
  681. args.consumed_train_samples = args.iteration * args.global_batch_size
  682. if args.iteration > 0 and args.consumed_valid_samples == 0:
  683. assert args.train_samples is None, \
  684. 'only backward compatiblity support for iteration-based training'
  685. args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
  686. args.eval_iters * args.global_batch_size
  687. # Data loader only on rank 0 of each model parallel group.
  688. if mpu.get_tensor_model_parallel_rank() == 0:
  689. # Number of train/valid/test samples.
  690. if args.train_samples:
  691. train_samples = args.train_samples
  692. else:
  693. train_samples = args.train_iters * args.global_batch_size
  694. eval_iters = (args.train_iters // args.eval_interval + 1) * \
  695. args.eval_iters
  696. test_iters = args.eval_iters
  697. train_val_test_num_samples = [train_samples,
  698. eval_iters * args.global_batch_size,
  699. test_iters * args.global_batch_size]
  700. print_rank_0(' > datasets target sizes (minimum size):')
  701. print_rank_0(' train: {}'.format(train_val_test_num_samples[0]))
  702. print_rank_0(' validation: {}'.format(train_val_test_num_samples[1]))
  703. print_rank_0(' test: {}'.format(train_val_test_num_samples[2]))
  704. # Build the datasets.
  705. train_ds, valid_ds, test_ds = build_train_valid_test_datasets_provider(
  706. train_val_test_num_samples)
  707. # Build dataloders.
  708. train_dataloader = build_pretraining_data_loader(
  709. train_ds, args.consumed_train_samples)
  710. valid_dataloader = build_pretraining_data_loader(
  711. valid_ds, args.consumed_valid_samples)
  712. test_dataloader = build_pretraining_data_loader(test_ds, 0)
  713. # Flags to know if we need to do training/validation/testing.
  714. do_train = train_dataloader is not None and args.train_iters > 0
  715. do_valid = valid_dataloader is not None and args.eval_iters > 0
  716. do_test = test_dataloader is not None and args.eval_iters > 0
  717. # Need to broadcast num_tokens and num_type_tokens.
  718. flags = torch.cuda.LongTensor(
  719. [int(do_train), int(do_valid), int(do_test)])
  720. else:
  721. flags = torch.cuda.LongTensor([0, 0, 0])
  722. # Broadcast num tokens.
  723. torch.distributed.broadcast(flags,
  724. mpu.get_tensor_model_parallel_src_rank(),
  725. group=mpu.get_tensor_model_parallel_group())
  726. args.do_train = flags[0].item()
  727. args.do_valid = flags[1].item()
  728. args.do_test = flags[2].item()
  729. # Build iterators.
  730. dl_type = args.dataloader_type
  731. assert dl_type in ['single', 'cyclic']
  732. if train_dataloader is not None:
  733. train_data_iterator = iter(train_dataloader) if dl_type == 'single' \
  734. else iter(cyclic_iter(train_dataloader))
  735. else:
  736. train_data_iterator = None
  737. if valid_dataloader is not None:
  738. valid_data_iterator = iter(valid_dataloader) if dl_type == 'single' \
  739. else iter(cyclic_iter(valid_dataloader))
  740. else:
  741. valid_data_iterator = None
  742. if test_dataloader is not None:
  743. test_data_iterator = iter(test_dataloader) if dl_type == 'single' \
  744. else iter(cyclic_iter(test_dataloader))
  745. else:
  746. test_data_iterator = None
  747. return train_data_iterator, valid_data_iterator, test_data_iterator