eval_utils.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Evaluation utilities."""
  16. from collections import OrderedDict
  17. import math
  18. import numpy as np
  19. import time
  20. import torch
  21. import torch.nn.functional as F
  22. from torch.utils.data import DataLoader
  23. from megatron import get_args, print_rank_0
  24. from megatron import mpu
  25. from megatron.utils import average_losses_across_data_parallel_group
  26. from tasks.finetune_utils import build_data_loader
  27. def task_collate_fn(batch_data):
  28. # generate batch
  29. batch_size = len(batch_data)
  30. tensorized = OrderedDict()
  31. for d in batch_data:
  32. for k, v in d.items():
  33. tensorized.setdefault(k, []).append(v)
  34. tensorized['query'] = torch.LongTensor(tensorized['query'])
  35. tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
  36. tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
  37. tensorized['query_pad_mask'] = \
  38. torch.LongTensor(tensorized['query_pad_mask'])
  39. tensorized['context'] = torch.LongTensor(tensorized['context'])
  40. tensorized['context_mask'] = \
  41. torch.LongTensor(tensorized['context_mask'])
  42. tensorized['context_types'] = \
  43. torch.LongTensor(tensorized['context_types'])
  44. tensorized['context_pad_mask'] = \
  45. torch.LongTensor(tensorized['context_pad_mask'])
  46. if 'neg_context' in tensorized:
  47. tensorized['neg_context'] = \
  48. torch.LongTensor(np.concatenate(tensorized['neg_context']))
  49. tensorized['neg_context_mask'] = \
  50. torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
  51. tensorized['neg_context_types'] = \
  52. torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
  53. return tensorized
  54. def process_batch(batch):
  55. """Process batch and produce inputs for the model."""
  56. query_tokens = batch['query'].long().cuda()
  57. query_mask = (batch['query_mask'] < 0.5).cuda()
  58. query_types = batch['query_types'].long().cuda()
  59. query_pad_mask = batch['query_pad_mask'].long().cuda()
  60. context_tokens = batch['context'].long().cuda()
  61. context_mask = (batch['context_mask'] < 0.5).cuda()
  62. context_types = batch['context_types'].long().cuda()
  63. context_pad_mask = batch['context_pad_mask'].long().cuda()
  64. if 'neg_context' in batch:
  65. neg_context_tokens = batch['neg_context'].long().cuda()
  66. neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
  67. neg_context_types = batch['neg_context_types'].long().cuda()
  68. else:
  69. neg_context_tokens = None
  70. neg_context_mask = None
  71. neg_context_types = None
  72. reference = batch['reference']
  73. return query_tokens, query_mask, query_types, query_pad_mask, \
  74. context_tokens, context_mask, context_types, context_pad_mask, \
  75. neg_context_tokens, neg_context_mask, neg_context_types, reference
  76. def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
  77. """Provide function that calculates accuracies."""
  78. args = get_args()
  79. print_rank_0("accuracy_func_provider is CALLED")
  80. # Build dataloaders
  81. datapath = args.valid_data
  82. dataset = single_dataset_provider(datapath)
  83. drop_last = False
  84. if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
  85. drop_last = True
  86. print_rank_0(datapath)
  87. print_rank_0(rank0sampler)
  88. dataloader = build_data_loader(dataset,
  89. args.eval_micro_batch_size,
  90. num_workers=args.num_workers,
  91. drop_last=drop_last,
  92. task_collate_fn=task_collate_fn)
  93. dataloaders = (dataset.dataset_name, dataloader)
  94. def metrics_func(model, epoch, output_predictions=False):
  95. print_rank_0('calculating metrics by accuracy func in ORQA...')
  96. if output_predictions:
  97. assert rank0sampler
  98. names = 'predictions'
  99. name, dataloader = dataloaders
  100. if args.task == "RET-FINETUNE-NQ":
  101. start_time = time.time()
  102. output = retrieval_loss(model, dataloader)
  103. stats_dict, total = output
  104. format_string = ""
  105. for k, v in stats_dict.items():
  106. format_string += "|{} = {:.2f}".format(k, v / total)
  107. print_rank_0("epoch:{}{}".format(epoch, format_string))
  108. print_rank_0("taken time to calcuate metrics {:.3f}".format(\
  109. time.time() - start_time))
  110. else:
  111. raise AssertionError("{} Task not supported".format(args.task))
  112. return metrics_func
  113. def retrieval_loss(model, dataloader):
  114. args = get_args()
  115. total = 0
  116. topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
  117. args.retriever_report_topk_accuracies}
  118. stats_dict = dict(rank=0, **topk_stats_dict)
  119. assert len(model) == 1
  120. unwrapped_model = model[0]
  121. unwrapped_model.eval()
  122. with torch.no_grad():
  123. # For all the batches in the dataset.
  124. for batch in dataloader:
  125. # Run the model forward.
  126. query_tokens, query_mask, query_types, _, \
  127. context_tokens, context_mask, context_types, _, \
  128. neg_context_tokens, neg_context_mask, neg_context_types, \
  129. reference = process_batch(batch)
  130. query_logits, context_logits = unwrapped_model(query_tokens,
  131. query_mask, query_types,
  132. torch.cat([context_tokens, neg_context_tokens]),
  133. torch.cat([context_mask, neg_context_mask]),
  134. torch.cat([context_types, neg_context_types]))
  135. retrieval_scores = torch.matmul(query_logits,
  136. torch.transpose(context_logits, 0, 1))
  137. if args.retriever_score_scaling:
  138. retrieval_scores = retrieval_scores / \
  139. math.sqrt(args.hidden_size)
  140. local_batch_size = query_logits.shape[0]
  141. labels = torch.arange(local_batch_size).long().cuda()
  142. softmax_scores = F.softmax(retrieval_scores, dim=1)
  143. sorted_vals, sorted_indices = torch.topk(softmax_scores,
  144. k=softmax_scores.shape[1],
  145. sorted=True)
  146. def topk_accuracy(k):
  147. return torch.cuda.FloatTensor(
  148. [sum([int(labels[i] in sorted_indices[i, :k]) for i in \
  149. range(local_batch_size)])])
  150. def get_rank():
  151. return torch.cuda.FloatTensor(
  152. [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
  153. for i in range(local_batch_size)])])
  154. topk_accs = [topk_accuracy(k) for k in \
  155. args.retriever_report_topk_accuracies]
  156. rank = get_rank()
  157. losses = average_losses_across_data_parallel_group([rank, \
  158. *topk_accs])
  159. # create stats_dict with retrieval loss and all specified
  160. # top-k accuracies
  161. topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
  162. zip(args.retriever_report_topk_accuracies, losses[1:])}
  163. temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
  164. for k in stats_dict.keys():
  165. stats_dict[k] += temp_stats_dict[k]
  166. total += local_batch_size
  167. unwrapped_model.train()
  168. return stats_dict, total