finetune_utils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Finetune utilities."""
  16. from functools import partial
  17. import sys
  18. import torch
  19. from megatron import get_args, get_num_microbatches
  20. from megatron import print_rank_0
  21. from megatron import get_timers
  22. from megatron import mpu
  23. from megatron.checkpointing import load_checkpoint
  24. from megatron.checkpointing import save_checkpoint
  25. from megatron.training import evaluate_and_print_results
  26. from megatron.training import setup_model_and_optimizer
  27. from megatron.training import train_step
  28. from megatron.training import training_log
  29. from megatron.utils import average_losses_across_data_parallel_group
  30. from megatron.utils import calc_params_l2_norm
  31. from megatron.utils import check_adlr_autoresume_termination
  32. def process_batch(batch):
  33. """Process batch and produce inputs for the model."""
  34. args = get_args()
  35. tokens = batch['text'].long().cuda().contiguous()
  36. types = batch['types'].long().cuda().contiguous()
  37. labels = batch['label'].long().cuda().contiguous()
  38. attention_mask = batch['padding_mask'].float().cuda().contiguous()
  39. if args.fp16:
  40. attention_mask = attention_mask.half()
  41. return tokens, types, labels, attention_mask
  42. def cross_entropy_loss_func(labels, output_tensor):
  43. logits = output_tensor
  44. # Cross-entropy loss.
  45. loss_func = torch.nn.CrossEntropyLoss()
  46. loss = loss_func(logits.contiguous().float(), labels)
  47. # Reduce loss for logging.
  48. averaged_loss = average_losses_across_data_parallel_group([loss])
  49. return loss, {'lm loss': averaged_loss[0]}
  50. def _cross_entropy_forward_step(batch, model):
  51. """Simple forward step with cross-entropy loss."""
  52. timers = get_timers()
  53. # Get the batch.
  54. timers('batch-generator').start()
  55. try:
  56. batch_ = next(batch)
  57. except BaseException:
  58. batch_ = batch
  59. tokens, types, labels, attention_mask = process_batch(batch_)
  60. timers('batch-generator').stop()
  61. # Forward model.
  62. output_tensor = model(tokens, attention_mask, tokentype_ids=types)
  63. return output_tensor, partial(cross_entropy_loss_func, labels)
  64. def build_data_loader(dataset, micro_batch_size, num_workers, drop_last,
  65. task_collate_fn=None):
  66. """Data loader. Note that batch-size is the local (per GPU) batch-size."""
  67. # Sampler.
  68. world_size = mpu.get_data_parallel_world_size()
  69. rank = mpu.get_data_parallel_rank()
  70. sampler = torch.utils.data.distributed.DistributedSampler(
  71. dataset, num_replicas=world_size, rank=rank)
  72. # Data loader. Note that batch size is the per GPU batch size.
  73. data_loader = torch.utils.data.DataLoader(dataset,
  74. batch_size=micro_batch_size,
  75. sampler=sampler,
  76. shuffle=False,
  77. num_workers=num_workers,
  78. drop_last=drop_last,
  79. pin_memory=True,
  80. collate_fn=task_collate_fn)
  81. return data_loader
  82. def _build_infinite_size_dataloader(dataloader):
  83. """Build a looped dataloader with infinite size."""
  84. iterator = dataloader.__iter__()
  85. while True:
  86. try:
  87. yield iterator.__next__()
  88. except StopIteration:
  89. iterator = dataloader.__iter__()
  90. def _build_train_valid_dataloaders(train_dataset, valid_dataset,
  91. task_collate_fn=None):
  92. """Traing and validation dataloaders."""
  93. args = get_args()
  94. print_rank_0('building train and validation dataloaders ...')
  95. # Training dataset.
  96. train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
  97. args.num_workers, not args.keep_last,
  98. task_collate_fn)
  99. # Set the training iterations.
  100. args.train_iters_per_epoch = len(train_dataloader)
  101. args.train_iters = args.epochs * args.train_iters_per_epoch
  102. # Validation dataset. For this dataset, we do not need to set up
  103. # shuffling so we can just use a simple infinite loop.
  104. valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
  105. args.num_workers, not args.keep_last,
  106. task_collate_fn)
  107. valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
  108. # Now that we've built the data loaders, set batch_size arguments
  109. # to the actual batch size the model will see for this dataset.
  110. # This is necessary so pipeline transfers know what size they are
  111. # and the LR schedule, which is based on samples seen, gets set
  112. # correctly.
  113. args.orig_micro_batch_size = args.micro_batch_size
  114. args.orig_global_batch_size = args.global_batch_size
  115. if hasattr(train_dataset, 'sample_multiplier'):
  116. # If our dataset as a sample_multiplier attribute that means
  117. # each "sample" from the dataset actually has multiple samples
  118. # that will collapse into the batch dimension (for example in
  119. # the RACE dataset that has several options), we need to
  120. # account for that when setting the micro batch size.
  121. args.micro_batch_size *= train_dataset.sample_multiplier
  122. args.global_batch_size *= train_dataset.sample_multiplier
  123. return train_dataloader, valid_dataloader
  124. def _train(model, optimizer, lr_scheduler, forward_step,
  125. train_dataloader, valid_dataloader, end_of_epoch_callback):
  126. """Train the model."""
  127. args = get_args()
  128. timers = get_timers()
  129. assert get_num_microbatches() == 1, "finetuning with gradient accumulation doesn't currently work"
  130. # Turn on training mode which enables dropout.
  131. for m in model:
  132. m.train()
  133. # Tracking loss.
  134. losses_dict_sum = {}
  135. # Starting epoch and iteration
  136. start_epoch = args.iteration // args.train_iters_per_epoch
  137. start_iteration = args.iteration % args.train_iters_per_epoch
  138. iteration = args.iteration
  139. # Memory reporting flag.
  140. report_memory_flag = True
  141. # For each remaining epoch
  142. timers('interval-time').start()
  143. for epoch in range(start_epoch, args.epochs):
  144. print_rank_0('working on epoch {} ...'.format(epoch + 1))
  145. # Set the data loader epoch to shuffle the index iterator.
  146. train_dataloader.sampler.set_epoch(args.seed + epoch)
  147. # For all the batches in the dataset.
  148. for iteration_, batch in enumerate(train_dataloader):
  149. # Ignore the iterations before starting value
  150. if iteration_ < start_iteration:
  151. continue
  152. # Set to zero so the next epoch does not skip any batches.
  153. start_iteration = 0
  154. # Train for one step.
  155. out = train_step(forward_step, batch, model, optimizer, lr_scheduler)
  156. losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = out
  157. iteration += 1
  158. # Logging.
  159. params_norm = None
  160. if args.log_params_norm:
  161. params_norm = calc_params_l2_norm(model)
  162. report_memory_flag = training_log(losses_dict, losses_dict_sum,
  163. optimizer.param_groups[0]['lr'],
  164. iteration,
  165. optimizer.get_loss_scale().item(),
  166. report_memory_flag, skipped_iter,
  167. grad_norm, params_norm, num_zeros_in_grad)
  168. # Autoresume
  169. if args.adlr_autoresume and \
  170. (iteration % args.adlr_autoresume_interval == 0):
  171. check_adlr_autoresume_termination(iteration, model,
  172. optimizer, lr_scheduler)
  173. # Checkpointing
  174. saved_checkpoint = False
  175. if args.save and args.save_interval and \
  176. iteration % args.save_interval == 0:
  177. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  178. saved_checkpoint = True
  179. # Evaluation
  180. if args.eval_interval and iteration % args.eval_interval == 0:
  181. prefix = 'iteration {}'.format(iteration)
  182. evaluate_and_print_results(prefix, forward_step,
  183. valid_dataloader, model,
  184. iteration, False)
  185. # Exiting based on iterations
  186. if args.exit_interval and iteration % args.exit_interval == 0:
  187. if not saved_checkpoint:
  188. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  189. torch.distributed.barrier()
  190. print_rank_0('exiting program at iteration {}'.format(iteration))
  191. sys.exit()
  192. # Checkpointing at the end of each epoch.
  193. if args.save:
  194. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  195. # Callback at the end of each epoch.
  196. if end_of_epoch_callback is not None:
  197. end_of_epoch_callback(model, epoch)
  198. def finetune(train_valid_datasets_provider, model_provider,
  199. forward_step=_cross_entropy_forward_step,
  200. end_of_epoch_callback_provider=None,
  201. task_collate_fn=None):
  202. """Main finetune function used across all tasks."""
  203. args = get_args()
  204. timers = get_timers()
  205. assert args.rampup_batch_size is None, \
  206. 'batch size scaling is not supported for finetuning'
  207. # Train and validation data loaders.
  208. timers('train/valid/test dataset/dataloder').start()
  209. if args.epochs > 0:
  210. train_dataset, valid_dataset = train_valid_datasets_provider()
  211. train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
  212. train_dataset, valid_dataset, task_collate_fn)
  213. else:
  214. args.train_iters = 0
  215. timers('train/valid/test dataset/dataloder').stop()
  216. # Build calback function.
  217. timers('callback function').start()
  218. end_of_epoch_callback = None
  219. if end_of_epoch_callback_provider is not None:
  220. end_of_epoch_callback = end_of_epoch_callback_provider()
  221. timers('callback function').stop()
  222. # Build model, optimizer and learning rate scheduler.
  223. timers('model and optimizer').start()
  224. model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  225. timers('model and optimizer').stop()
  226. # If pretrained checkpoint is provided and we have not trained for
  227. # any iteration (i.e., iteration is zero), then load the pretrained
  228. # checkpoint.
  229. timers('pretrained checkpoint').start()
  230. if args.iteration == 0 and args.pretrained_checkpoint is not None:
  231. original_load = args.load
  232. args.load = args.pretrained_checkpoint
  233. original_rng = args.no_load_rng
  234. args.no_load_rng = True
  235. _ = load_checkpoint(model, None, None)
  236. args.load = original_load
  237. args.no_load_rng = original_rng
  238. # This is critical when only model is loaded. We should make sure
  239. # main parameters are also updated.
  240. optimizer.reload_model_params()
  241. timers('pretrained checkpoint').stop()
  242. # Print setup timing.
  243. print_rank_0('done with setups ...')
  244. timers.log(['train/valid/test dataset/dataloder', 'callback function',
  245. 'model and optimizer', 'pretrained checkpoint'])
  246. print_rank_0('training ...')
  247. # Finetune the model.
  248. if args.epochs > 0:
  249. _train(model, optimizer, lr_scheduler, forward_step,
  250. train_dataloader, valid_dataloader, end_of_epoch_callback)
  251. # Or just evaluate.
  252. else:
  253. if end_of_epoch_callback is not None:
  254. print_rank_0('evaluation only mode, setting epoch to -1')
  255. end_of_epoch_callback(model, epoch=-1, output_predictions=True)
  256. print_rank_0('done :-)')