123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132 |
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Evaluation utilities."""
- import os
- from functools import partial
- import torch
- from megatron import get_args
- from megatron import print_rank_0, print_rank_last
- from megatron import mpu
- from megatron.schedules import get_forward_backward_func
- from tasks.vision.finetune_utils import build_data_loader
- from tasks.vision.finetune_utils import process_batch
- from torchvision import datasets, transforms
- def accuracy_func_provider():
- """Provide function that calculates accuracies."""
- args = get_args()
- data_path = args.data_path
- crop_size = args.img_dim
- # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
- # Build dataloaders.
- val_data_path = os.path.join(data_path[0], "val")
- normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
- transform_val = transforms.Compose(
- [
- transforms.Resize(crop_size),
- transforms.CenterCrop(crop_size),
- transforms.ToTensor(),
- normalize,
- ]
- )
- dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
- dataloader = build_data_loader(
- dataset,
- args.micro_batch_size,
- num_workers=args.num_workers,
- drop_last=(mpu.get_data_parallel_world_size() > 1),
- )
- def metrics_func(model, epoch):
- print_rank_0("calculating metrics ...")
- correct, total = calculate_correct_answers(model, dataloader, epoch)
- percent = float(correct) * 100.0 / float(total)
- print_rank_last(
- " >> |epoch: {}| overall: correct / total = {} / {} = "
- "{:.4f} %".format(epoch, correct, total, percent)
- )
- return metrics_func
- def calculate_correct_answers(model, dataloader, epoch):
- """Calculate correct over total answers"""
- args = get_args()
- forward_backward_func = get_forward_backward_func()
- for m in model:
- m.eval()
- def loss_func(labels, output_tensor):
- logits = output_tensor
- loss_dict = {}
- # Compute the correct answers.
- predicted = torch.argmax(logits, dim=-1)
- corrects = (predicted == labels).float()
- # Add to the counters.
- loss_dict['total'] = labels.size(0)
- loss_dict['correct'] = corrects.sum().item()
- return 0, loss_dict
- #defined inside to capture output_predictions
- def correct_answers_forward_step(batch, model):
- try:
- batch_ = next(batch)
- except BaseException:
- batch_ = batch
- images, labels = process_batch(batch_)
- # Forward model.
- args = get_args()
- output_tensor = model(images)
- return output_tensor, partial(loss_func, labels)
- with torch.no_grad():
- # For all the batches in the dataset.
- total = 0
- correct = 0
- for _, batch in enumerate(dataloader):
- loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
- optimizer=None, timers=None, forward_only=True)
- for loss_dict in loss_dicts:
- total += loss_dict['total']
- correct += loss_dict['correct']
- for m in model:
- m.train()
- # Reduce.
- if mpu.is_pipeline_last_stage():
- unreduced = torch.cuda.LongTensor([correct, total])
- torch.distributed.all_reduce(unreduced,
- group=mpu.get_data_parallel_group())
- # Print on screen.
- correct_ans = unreduced[0].item()
- total_count = unreduced[1].item()
- return correct_ans, total_count
|