eval_utils.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Evaluation utilities."""
  16. import os
  17. import time
  18. from functools import partial
  19. import torch
  20. from megatron import get_args
  21. from megatron import print_rank_last, is_last_rank
  22. from megatron import mpu
  23. from megatron.schedules import get_forward_backward_func
  24. from tasks.finetune_utils import build_data_loader
  25. from tasks.finetune_utils import process_batch
  26. def accuracy_func_provider(single_dataset_provider):
  27. """Provide function that calculates accuracies."""
  28. args = get_args()
  29. # Build dataloaders.
  30. datapaths = args.valid_data
  31. dataloaders = []
  32. for datapath in datapaths:
  33. dataset = single_dataset_provider(datapath)
  34. dataloader = build_data_loader(
  35. dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
  36. drop_last=(mpu.get_data_parallel_world_size() > 1))
  37. dataloaders.append((dataset.dataset_name, dataloader))
  38. def metrics_func(model, epoch, output_predictions=False):
  39. print_rank_last('calculating metrics ...')
  40. correct = 0
  41. total = 0
  42. if output_predictions:
  43. assert mpu.get_data_parallel_world_size() == 1
  44. named_predictions = []
  45. names = 'predictions'
  46. for name, dataloader in dataloaders:
  47. output = calculate_correct_answers(name, model, dataloader,
  48. epoch, output_predictions)
  49. if not output_predictions:
  50. correct_ans, total_count = output
  51. else:
  52. correct_ans, total_count, predictions = output
  53. named_predictions.append((name, predictions))
  54. names += '_' + name
  55. correct += correct_ans
  56. total += total_count
  57. if is_last_rank():
  58. percent = float(correct) * 100.0 / float(total)
  59. print(' >> |epoch: {}| overall: correct / total = {} / {} = '
  60. '{:.4f} %'.format(epoch, correct, total, percent))
  61. if output_predictions and is_last_rank():
  62. assert args.load is not None
  63. filename = os.path.join(args.load, names + '.pt')
  64. torch.save(named_predictions, filename)
  65. return metrics_func
  66. def calculate_correct_answers(name, model, dataloader,
  67. epoch, output_predictions):
  68. """Calculate correct over total answers and return prediction if the
  69. `output_predictions` is true."""
  70. args = get_args()
  71. forward_backward_func = get_forward_backward_func()
  72. start_time = time.time()
  73. for m in model:
  74. m.eval()
  75. saved_micro_batch_size = args.micro_batch_size
  76. saved_global_batch_size = args.global_batch_size
  77. ds = dataloader.dataset
  78. if hasattr(ds, 'sample_multiplier'):
  79. # If our dataset as a sample_multiplier attribute that means
  80. # each "sample" from the dataset actually has multiple samples
  81. # that will collapse into the batch dimension (for example in
  82. # the RACE dataset that has several options), we need to
  83. # account for that when setting the micro batch size.
  84. sample_multiplier = ds.sample_multiplier
  85. else:
  86. sample_multiplier = 1
  87. micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
  88. num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
  89. def loss_func(output_predictions, labels, output_tensor):
  90. logits = output_tensor
  91. loss_dict = {}
  92. # Add output predictions.
  93. if output_predictions:
  94. assert False
  95. loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
  96. logits.float()).data.cpu().numpy().tolist()
  97. loss_dict['labels'] = labels.data.cpu().numpy().tolist()
  98. loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
  99. # Compute the correct answers.
  100. predicted = torch.argmax(logits, dim=-1)
  101. corrects = (predicted == labels)
  102. # Add to the counters.
  103. loss_dict['total'] = labels.size(0)
  104. loss_dict['correct'] = corrects.sum().item()
  105. return 0, loss_dict
  106. # defined inside to capture output_predictions
  107. def correct_answers_forward_step(batch, model):
  108. try:
  109. batch_ = next(batch)
  110. except BaseException:
  111. batch_ = batch
  112. tokens, types, labels, attention_mask = process_batch(batch_)
  113. # Forward model.
  114. args = get_args()
  115. output_tensor = model(tokens, attention_mask, tokentype_ids=types)
  116. return output_tensor, partial(loss_func, output_predictions, labels)
  117. with torch.no_grad():
  118. # For all the batches in the dataset.
  119. total = 0
  120. correct = 0
  121. if output_predictions:
  122. # This option is only possible when data parallel size is 1.
  123. assert mpu.get_data_parallel_world_size() == 1
  124. softmaxes = []
  125. labels = []
  126. ids = []
  127. for _, batch in enumerate(dataloader):
  128. # For evaluation only mode we use drop_last = False to get all the
  129. # samples, which means we might not have a full batch, so we
  130. # adjust batch_size here to actual batch size of data
  131. actual_batch_size = len(batch['label'])
  132. # ... applying sample_multiplier if necessary
  133. args.micro_batch_size = actual_batch_size * sample_multiplier
  134. args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
  135. loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
  136. optimizer=None, timers=None, forward_only=True)
  137. for loss_dict in loss_dicts:
  138. if output_predictions:
  139. softmaxes.extend(loss_dict['softmaxes'])
  140. labels.extend(loss_dict['labels'])
  141. ids.extend(loss_dict['ids'])
  142. total += loss_dict['total']
  143. correct += loss_dict['correct']
  144. for m in model:
  145. m.train()
  146. args.micro_batch_size = saved_micro_batch_size
  147. args.global_batch_size = saved_global_batch_size
  148. # Reduce.
  149. if mpu.is_pipeline_last_stage():
  150. unreduced = torch.cuda.LongTensor([correct, total])
  151. torch.distributed.all_reduce(unreduced,
  152. group=mpu.get_data_parallel_group())
  153. # Print on screen.
  154. correct_ans = unreduced[0].item()
  155. total_count = unreduced[1].item()
  156. percent = float(correct_ans) * 100.0 / float(total_count)
  157. elapsed_time = time.time() - start_time
  158. print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
  159. '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
  160. epoch, name, correct_ans, total_count,
  161. percent, elapsed_time))
  162. if output_predictions:
  163. return correct_ans, total_count, (softmaxes, labels, ids)
  164. return correct_ans, total_count
  165. if output_predictions:
  166. return 0, 0, ()
  167. return 0, 0