123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Evaluation utilities."""
- from collections import OrderedDict
- import math
- import numpy as np
- import time
- import torch
- import torch.nn.functional as F
- from torch.utils.data import DataLoader
- from megatron import get_args, print_rank_0
- from megatron import mpu
- from megatron.utils import average_losses_across_data_parallel_group
- from tasks.finetune_utils import build_data_loader
- def task_collate_fn(batch_data):
- # generate batch
- batch_size = len(batch_data)
- tensorized = OrderedDict()
- for d in batch_data:
- for k, v in d.items():
- tensorized.setdefault(k, []).append(v)
- tensorized['query'] = torch.LongTensor(tensorized['query'])
- tensorized['query_mask'] = torch.LongTensor(tensorized['query_mask'])
- tensorized['query_types'] = torch.LongTensor(tensorized['query_types'])
- tensorized['query_pad_mask'] = \
- torch.LongTensor(tensorized['query_pad_mask'])
- tensorized['context'] = torch.LongTensor(tensorized['context'])
- tensorized['context_mask'] = \
- torch.LongTensor(tensorized['context_mask'])
- tensorized['context_types'] = \
- torch.LongTensor(tensorized['context_types'])
- tensorized['context_pad_mask'] = \
- torch.LongTensor(tensorized['context_pad_mask'])
- if 'neg_context' in tensorized:
- tensorized['neg_context'] = \
- torch.LongTensor(np.concatenate(tensorized['neg_context']))
- tensorized['neg_context_mask'] = \
- torch.LongTensor(np.concatenate(tensorized['neg_context_mask']))
- tensorized['neg_context_types'] = \
- torch.LongTensor(np.concatenate(tensorized['neg_context_types']))
- return tensorized
- def process_batch(batch):
- """Process batch and produce inputs for the model."""
- query_tokens = batch['query'].long().cuda()
- query_mask = (batch['query_mask'] < 0.5).cuda()
- query_types = batch['query_types'].long().cuda()
- query_pad_mask = batch['query_pad_mask'].long().cuda()
- context_tokens = batch['context'].long().cuda()
- context_mask = (batch['context_mask'] < 0.5).cuda()
- context_types = batch['context_types'].long().cuda()
- context_pad_mask = batch['context_pad_mask'].long().cuda()
- if 'neg_context' in batch:
- neg_context_tokens = batch['neg_context'].long().cuda()
- neg_context_mask = (batch['neg_context_mask'] < 0.5).cuda()
- neg_context_types = batch['neg_context_types'].long().cuda()
- else:
- neg_context_tokens = None
- neg_context_mask = None
- neg_context_types = None
- reference = batch['reference']
- return query_tokens, query_mask, query_types, query_pad_mask, \
- context_tokens, context_mask, context_types, context_pad_mask, \
- neg_context_tokens, neg_context_mask, neg_context_types, reference
- def accuracy_func_provider(single_dataset_provider, rank0sampler=False):
- """Provide function that calculates accuracies."""
- args = get_args()
- print_rank_0("accuracy_func_provider is CALLED")
- # Build dataloaders
- datapath = args.valid_data
- dataset = single_dataset_provider(datapath)
- drop_last = False
- if mpu.get_data_parallel_world_size() > 1 and not rank0sampler:
- drop_last = True
- print_rank_0(datapath)
- print_rank_0(rank0sampler)
- dataloader = build_data_loader(dataset,
- args.eval_micro_batch_size,
- num_workers=args.num_workers,
- drop_last=drop_last,
- task_collate_fn=task_collate_fn)
- dataloaders = (dataset.dataset_name, dataloader)
- def metrics_func(model, epoch, output_predictions=False):
- print_rank_0('calculating metrics by accuracy func in ORQA...')
- if output_predictions:
- assert rank0sampler
- names = 'predictions'
- name, dataloader = dataloaders
- if args.task == "RET-FINETUNE-NQ":
- start_time = time.time()
- output = retrieval_loss(model, dataloader)
- stats_dict, total = output
- format_string = ""
- for k, v in stats_dict.items():
- format_string += "|{} = {:.2f}".format(k, v / total)
- print_rank_0("epoch:{}{}".format(epoch, format_string))
- print_rank_0("taken time to calcuate metrics {:.3f}".format(\
- time.time() - start_time))
- else:
- raise AssertionError("{} Task not supported".format(args.task))
- return metrics_func
- def retrieval_loss(model, dataloader):
- args = get_args()
- total = 0
- topk_stats_dict = {'top{}_acc'.format(k): 0 for k in \
- args.retriever_report_topk_accuracies}
- stats_dict = dict(rank=0, **topk_stats_dict)
- assert len(model) == 1
- unwrapped_model = model[0]
- unwrapped_model.eval()
- with torch.no_grad():
- # For all the batches in the dataset.
- for batch in dataloader:
- # Run the model forward.
- query_tokens, query_mask, query_types, _, \
- context_tokens, context_mask, context_types, _, \
- neg_context_tokens, neg_context_mask, neg_context_types, \
- reference = process_batch(batch)
- query_logits, context_logits = unwrapped_model(query_tokens,
- query_mask, query_types,
- torch.cat([context_tokens, neg_context_tokens]),
- torch.cat([context_mask, neg_context_mask]),
- torch.cat([context_types, neg_context_types]))
- retrieval_scores = torch.matmul(query_logits,
- torch.transpose(context_logits, 0, 1))
- if args.retriever_score_scaling:
- retrieval_scores = retrieval_scores / \
- math.sqrt(args.hidden_size)
- local_batch_size = query_logits.shape[0]
- labels = torch.arange(local_batch_size).long().cuda()
- softmax_scores = F.softmax(retrieval_scores, dim=1)
- sorted_vals, sorted_indices = torch.topk(softmax_scores,
- k=softmax_scores.shape[1],
- sorted=True)
- def topk_accuracy(k):
- return torch.cuda.FloatTensor(
- [sum([int(labels[i] in sorted_indices[i, :k]) for i in \
- range(local_batch_size)])])
- def get_rank():
- return torch.cuda.FloatTensor(
- [sum([torch.nonzero(labels[i] == sorted_indices[i])[0][0] \
- for i in range(local_batch_size)])])
- topk_accs = [topk_accuracy(k) for k in \
- args.retriever_report_topk_accuracies]
- rank = get_rank()
- losses = average_losses_across_data_parallel_group([rank, \
- *topk_accs])
- # create stats_dict with retrieval loss and all specified
- # top-k accuracies
- topk_acc_dict = {'top{}_acc'.format(k): v * 100 for k, v in \
- zip(args.retriever_report_topk_accuracies, losses[1:])}
- temp_stats_dict = dict(rank=losses[0], **topk_acc_dict)
- for k in stats_dict.keys():
- stats_dict[k] += temp_stats_dict[k]
- total += local_batch_size
- unwrapped_model.train()
- return stats_dict, total
|