123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Evaluation utilities."""
- import os
- import time
- from functools import partial
- import torch
- from megatron import get_args
- from megatron import print_rank_last, is_last_rank
- from megatron import mpu
- from megatron.schedules import get_forward_backward_func
- from tasks.finetune_utils import build_data_loader
- from tasks.finetune_utils import process_batch
- def accuracy_func_provider(single_dataset_provider):
- """Provide function that calculates accuracies."""
- args = get_args()
- # Build dataloaders.
- datapaths = args.valid_data
- dataloaders = []
- for datapath in datapaths:
- dataset = single_dataset_provider(datapath)
- dataloader = build_data_loader(
- dataset, args.orig_micro_batch_size, num_workers=args.num_workers,
- drop_last=(mpu.get_data_parallel_world_size() > 1))
- dataloaders.append((dataset.dataset_name, dataloader))
- def metrics_func(model, epoch, output_predictions=False):
- print_rank_last('calculating metrics ...')
- correct = 0
- total = 0
- if output_predictions:
- assert mpu.get_data_parallel_world_size() == 1
- named_predictions = []
- names = 'predictions'
- for name, dataloader in dataloaders:
- output = calculate_correct_answers(name, model, dataloader,
- epoch, output_predictions)
- if not output_predictions:
- correct_ans, total_count = output
- else:
- correct_ans, total_count, predictions = output
- named_predictions.append((name, predictions))
- names += '_' + name
- correct += correct_ans
- total += total_count
- if is_last_rank():
- percent = float(correct) * 100.0 / float(total)
- print(' >> |epoch: {}| overall: correct / total = {} / {} = '
- '{:.4f} %'.format(epoch, correct, total, percent))
- if output_predictions and is_last_rank():
- assert args.load is not None
- filename = os.path.join(args.load, names + '.pt')
- torch.save(named_predictions, filename)
- return metrics_func
- def calculate_correct_answers(name, model, dataloader,
- epoch, output_predictions):
- """Calculate correct over total answers and return prediction if the
- `output_predictions` is true."""
- args = get_args()
- forward_backward_func = get_forward_backward_func()
- start_time = time.time()
- for m in model:
- m.eval()
- saved_micro_batch_size = args.micro_batch_size
- saved_global_batch_size = args.global_batch_size
- ds = dataloader.dataset
- if hasattr(ds, 'sample_multiplier'):
- # If our dataset as a sample_multiplier attribute that means
- # each "sample" from the dataset actually has multiple samples
- # that will collapse into the batch dimension (for example in
- # the RACE dataset that has several options), we need to
- # account for that when setting the micro batch size.
- sample_multiplier = ds.sample_multiplier
- else:
- sample_multiplier = 1
- micro_batch_size_times_data_parallel = args.orig_micro_batch_size * args.data_parallel_size
- num_micro_batches = args.orig_global_batch_size // micro_batch_size_times_data_parallel
- def loss_func(output_predictions, labels, output_tensor):
- logits = output_tensor
- loss_dict = {}
- # Add output predictions.
- if output_predictions:
- assert False
- loss_dict['softmaxes'] = torch.nn.Softmax(dim=-1)(
- logits.float()).data.cpu().numpy().tolist()
- loss_dict['labels'] = labels.data.cpu().numpy().tolist()
- loss_dict['ids'] = batch['uid'].cpu().numpy().tolist()
- # Compute the correct answers.
- predicted = torch.argmax(logits, dim=-1)
- corrects = (predicted == labels)
- # Add to the counters.
- loss_dict['total'] = labels.size(0)
- loss_dict['correct'] = corrects.sum().item()
- return 0, loss_dict
- # defined inside to capture output_predictions
- def correct_answers_forward_step(batch, model):
- try:
- batch_ = next(batch)
- except BaseException:
- batch_ = batch
- tokens, types, labels, attention_mask = process_batch(batch_)
- # Forward model.
- args = get_args()
- output_tensor = model(tokens, attention_mask, tokentype_ids=types)
- return output_tensor, partial(loss_func, output_predictions, labels)
- with torch.no_grad():
- # For all the batches in the dataset.
- total = 0
- correct = 0
- if output_predictions:
- # This option is only possible when data parallel size is 1.
- assert mpu.get_data_parallel_world_size() == 1
- softmaxes = []
- labels = []
- ids = []
- for _, batch in enumerate(dataloader):
- # For evaluation only mode we use drop_last = False to get all the
- # samples, which means we might not have a full batch, so we
- # adjust batch_size here to actual batch size of data
- actual_batch_size = len(batch['label'])
- # ... applying sample_multiplier if necessary
- args.micro_batch_size = actual_batch_size * sample_multiplier
- args.global_batch_size = actual_batch_size * sample_multiplier * num_micro_batches
- loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
- optimizer=None, timers=None, forward_only=True)
- for loss_dict in loss_dicts:
- if output_predictions:
- softmaxes.extend(loss_dict['softmaxes'])
- labels.extend(loss_dict['labels'])
- ids.extend(loss_dict['ids'])
- total += loss_dict['total']
- correct += loss_dict['correct']
- for m in model:
- m.train()
- args.micro_batch_size = saved_micro_batch_size
- args.global_batch_size = saved_global_batch_size
- # Reduce.
- if mpu.is_pipeline_last_stage():
- unreduced = torch.cuda.LongTensor([correct, total])
- torch.distributed.all_reduce(unreduced,
- group=mpu.get_data_parallel_group())
- # Print on screen.
- correct_ans = unreduced[0].item()
- total_count = unreduced[1].item()
- percent = float(correct_ans) * 100.0 / float(total_count)
- elapsed_time = time.time() - start_time
- print_rank_last(' > |epoch: {}| metrics for {}: correct / total '
- '= {} / {} = {:.4f} %, elapsed time (sec): {:.3f}'.format(
- epoch, name, correct_ans, total_count,
- percent, elapsed_time))
- if output_predictions:
- return correct_ans, total_count, (softmaxes, labels, ids)
- return correct_ans, total_count
- if output_predictions:
- return 0, 0, ()
- return 0, 0
|