finetune_utils.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Finetune utilities."""
  16. import torch
  17. import torch.nn.functional as F
  18. from functools import partial
  19. from megatron import get_args
  20. from megatron import print_rank_0
  21. from megatron import get_timers
  22. from megatron import mpu
  23. from megatron.checkpointing import load_checkpoint
  24. from megatron.checkpointing import save_checkpoint
  25. from megatron.training import evaluate_and_print_results
  26. from megatron.training import setup_model_and_optimizer
  27. from megatron.training import train_step
  28. from megatron.training import training_log
  29. from megatron.utils import check_adlr_autoresume_termination
  30. from megatron.utils import average_losses_across_data_parallel_group
  31. def process_batch(batch):
  32. """Process batch and produce inputs for the model."""
  33. images = batch[0].cuda().contiguous()
  34. labels = batch[1].cuda().contiguous()
  35. return images, labels
  36. def cross_entropy_loss_func(labels, output_tensor):
  37. logits = output_tensor
  38. # Cross-entropy loss.
  39. loss = F.cross_entropy(logits.contiguous().float(), labels)
  40. # Reduce loss for logging.
  41. averaged_loss = average_losses_across_data_parallel_group([loss])
  42. return loss, {'lm loss': averaged_loss[0]}
  43. def _cross_entropy_forward_step(batch, model):
  44. """Simple forward step with cross-entropy loss."""
  45. timers = get_timers()
  46. # Get the batch.
  47. timers("batch generator").start()
  48. try:
  49. batch_ = next(batch)
  50. except BaseException:
  51. batch_ = batch
  52. images, labels = process_batch(batch_)
  53. timers("batch generator").stop()
  54. # Forward model.
  55. output_tensor = model(images)
  56. return output_tensor, partial(cross_entropy_loss_func, labels)
  57. def build_data_loader(dataset, micro_batch_size, num_workers, drop_last):
  58. """Data loader. Note that batch-size is the local (per GPU) batch-size."""
  59. # Sampler.
  60. world_size = mpu.get_data_parallel_world_size()
  61. rank = mpu.get_data_parallel_rank()
  62. sampler = torch.utils.data.distributed.DistributedSampler(
  63. dataset, num_replicas=world_size, rank=rank
  64. )
  65. # Data loader. Note that batch size is the per GPU batch size.
  66. data_loader = torch.utils.data.DataLoader(
  67. dataset,
  68. batch_size=micro_batch_size,
  69. sampler=sampler,
  70. shuffle=False,
  71. num_workers=num_workers,
  72. drop_last=drop_last,
  73. pin_memory=True,
  74. )
  75. return data_loader
  76. def _build_infinite_size_dataloader(dataloader):
  77. """Build a looped dataloader with infinite size."""
  78. iterator = dataloader.__iter__()
  79. while True:
  80. try:
  81. yield iterator.__next__()
  82. except StopIteration:
  83. iterator = dataloader.__iter__()
  84. def _build_train_valid_dataloaders(train_dataset, valid_dataset):
  85. """Traing and validation dataloaders."""
  86. args = get_args()
  87. print_rank_0('building train and validation dataloaders ...')
  88. # Training dataset.
  89. train_dataloader = build_data_loader(train_dataset, args.micro_batch_size,
  90. args.num_workers, not args.keep_last)
  91. # Set the training iterations.
  92. args.train_iters_per_epoch = len(train_dataloader)
  93. args.train_iters = args.epochs * args.train_iters_per_epoch
  94. # Validation dataset. For this dataset, we do not need to set up
  95. # shuffling so we can just use a simple infinite loop.
  96. valid_dataloader_ = build_data_loader(valid_dataset, args.micro_batch_size,
  97. args.num_workers, not args.keep_last)
  98. valid_dataloader = _build_infinite_size_dataloader(valid_dataloader_)
  99. # Now that we've built the data loaders, set batch_size arguments
  100. # to the actual batch size the model will see for this dataset.
  101. # This is necessary so pipeline transfers know what size they are
  102. # and the LR schedule, which is based on samples seen, gets set
  103. # correctly.
  104. args.orig_micro_batch_size = args.micro_batch_size
  105. args.orig_global_batch_size = args.global_batch_size
  106. return train_dataloader, valid_dataloader
  107. def _train(
  108. model,
  109. optimizer,
  110. lr_scheduler,
  111. forward_step,
  112. train_dataloader,
  113. valid_dataloader,
  114. end_of_epoch_callback,
  115. ):
  116. """Train the model."""
  117. args = get_args()
  118. timers = get_timers()
  119. # Turn on training mode which enables dropout.
  120. for m in model:
  121. m.train()
  122. # Tracking loss.
  123. losses_dict_sum = {}
  124. # Starting epoch and iteration
  125. start_epoch = args.iteration // args.train_iters_per_epoch
  126. start_iteration = args.iteration % args.train_iters_per_epoch
  127. iteration = args.iteration
  128. # Memory reporting flag.
  129. report_memory_flag = True
  130. # For each remaining epoch
  131. timers("interval-time").start()
  132. for epoch in range(start_epoch, args.epochs):
  133. print_rank_0("working on epoch {} ...".format(epoch + 1))
  134. # Set the data loader epoch to shuffle the index iterator.
  135. train_dataloader.sampler.set_epoch(args.seed + epoch)
  136. # For all the batches in the dataset.
  137. for iteration_, batch in enumerate(train_dataloader):
  138. # Ignore the iterations before starting value
  139. if iteration_ < start_iteration:
  140. continue
  141. # Set to zero so the next epoch does not skip any batches.
  142. start_iteration = 0
  143. # Train for one step.
  144. losses_dict, skipped_iter, grad_norm, num_zeros_in_grad = train_step(
  145. forward_step, batch, model, optimizer, lr_scheduler
  146. )
  147. iteration += 1
  148. # Logging.
  149. params_norm = None
  150. if args.log_params_norm:
  151. params_norm = calc_params_l2_norm(model)
  152. report_memory_flag = training_log(
  153. losses_dict,
  154. losses_dict_sum,
  155. optimizer.param_groups[0]["lr"],
  156. iteration,
  157. optimizer.get_loss_scale().item(),
  158. report_memory_flag,
  159. skipped_iter,
  160. grad_norm,
  161. params_norm,
  162. num_zeros_in_grad
  163. )
  164. # Autoresume
  165. if args.adlr_autoresume and (
  166. iteration % args.adlr_autoresume_interval == 0
  167. ):
  168. check_adlr_autoresume_termination(
  169. iteration, model, optimizer, lr_scheduler
  170. )
  171. # Checkpointing
  172. if (
  173. args.save
  174. and args.save_interval
  175. and iteration % args.save_interval == 0
  176. ):
  177. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  178. # Evaluation
  179. if args.eval_interval and iteration % args.eval_interval == 0:
  180. prefix = "iteration {}".format(iteration)
  181. evaluate_and_print_results(
  182. prefix,
  183. forward_step,
  184. valid_dataloader,
  185. model,
  186. iteration,
  187. False,
  188. )
  189. # Checkpointing at the end of each epoch.
  190. if args.save:
  191. save_checkpoint(iteration, model, optimizer, lr_scheduler)
  192. # Callback at the end of each epoch.
  193. if end_of_epoch_callback is not None:
  194. end_of_epoch_callback(model, epoch)
  195. def finetune(
  196. train_valid_datasets_provider,
  197. model_provider,
  198. forward_step=_cross_entropy_forward_step,
  199. end_of_epoch_callback_provider=None,
  200. ):
  201. """Main finetune function used across all tasks."""
  202. args = get_args()
  203. timers = get_timers()
  204. # Train and validation data loaders.
  205. timers("train/valid/test dataset/dataloder").start()
  206. if args.epochs > 0:
  207. train_dataset, valid_dataset = train_valid_datasets_provider()
  208. train_dataloader, valid_dataloader = _build_train_valid_dataloaders(
  209. train_dataset, valid_dataset
  210. )
  211. timers("train/valid/test dataset/dataloder").stop()
  212. # Build calback function.
  213. timers("callback function").start()
  214. end_of_epoch_callback = None
  215. if end_of_epoch_callback_provider is not None:
  216. end_of_epoch_callback = end_of_epoch_callback_provider()
  217. timers("callback function").stop()
  218. # Build model, optimizer and learning rate scheduler.
  219. timers("model and optimizer").start()
  220. model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
  221. timers("model and optimizer").stop()
  222. # If pretrained checkpoint is provided and we have not trained for
  223. # any iteration (i.e., iteration is zero), then load the pretrained
  224. # checkpoint.
  225. timers("pretrained checkpoint").start()
  226. if args.iteration == 0 and args.pretrained_checkpoint is not None:
  227. original_load = args.load
  228. args.load = args.pretrained_checkpoint
  229. _ = load_checkpoint(model, None, None, strict=False)
  230. args.load = original_load
  231. # This is critical when only model is loaded. We should make sure
  232. # master parameters are also updated.
  233. optimizer.reload_model_params()
  234. timers("pretrained checkpoint").stop()
  235. # Print setup timing.
  236. print_rank_0("done with setups ...")
  237. timers.log(
  238. [
  239. "train/valid/test dataset/dataloder",
  240. "callback function",
  241. "model and optimizer",
  242. "pretrained checkpoint",
  243. ]
  244. )
  245. print_rank_0("training ...")
  246. # Finetune the model.
  247. if args.epochs > 0:
  248. _train(
  249. model,
  250. optimizer,
  251. lr_scheduler,
  252. forward_step,
  253. train_dataloader,
  254. valid_dataloader,
  255. end_of_epoch_callback,
  256. )
  257. # Or just evaluate.
  258. else:
  259. if end_of_epoch_callback is not None:
  260. print_rank_0("evaluation only mode, setting epoch to -1")
  261. end_of_epoch_callback(model, epoch=-1, output_predictions=True)
  262. print_rank_0("done :-)")