eval_utils.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. # coding=utf-8
  2. # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Evaluation utilities."""
  16. import os
  17. from functools import partial
  18. import torch
  19. from megatron import get_args
  20. from megatron import print_rank_0, print_rank_last
  21. from megatron import mpu
  22. from megatron.schedules import get_forward_backward_func
  23. from tasks.vision.finetune_utils import build_data_loader
  24. from tasks.vision.finetune_utils import process_batch
  25. from torchvision import datasets, transforms
  26. def accuracy_func_provider():
  27. """Provide function that calculates accuracies."""
  28. args = get_args()
  29. data_path = args.data_path
  30. crop_size = args.img_dim
  31. # mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  32. # Build dataloaders.
  33. val_data_path = os.path.join(data_path[0], "val")
  34. normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  35. transform_val = transforms.Compose(
  36. [
  37. transforms.Resize(crop_size),
  38. transforms.CenterCrop(crop_size),
  39. transforms.ToTensor(),
  40. normalize,
  41. ]
  42. )
  43. dataset = datasets.ImageFolder(root=val_data_path, transform=transform_val)
  44. dataloader = build_data_loader(
  45. dataset,
  46. args.micro_batch_size,
  47. num_workers=args.num_workers,
  48. drop_last=(mpu.get_data_parallel_world_size() > 1),
  49. )
  50. def metrics_func(model, epoch):
  51. print_rank_0("calculating metrics ...")
  52. correct, total = calculate_correct_answers(model, dataloader, epoch)
  53. percent = float(correct) * 100.0 / float(total)
  54. print_rank_last(
  55. " >> |epoch: {}| overall: correct / total = {} / {} = "
  56. "{:.4f} %".format(epoch, correct, total, percent)
  57. )
  58. return metrics_func
  59. def calculate_correct_answers(model, dataloader, epoch):
  60. """Calculate correct over total answers"""
  61. args = get_args()
  62. forward_backward_func = get_forward_backward_func()
  63. for m in model:
  64. m.eval()
  65. def loss_func(labels, output_tensor):
  66. logits = output_tensor
  67. loss_dict = {}
  68. # Compute the correct answers.
  69. predicted = torch.argmax(logits, dim=-1)
  70. corrects = (predicted == labels).float()
  71. # Add to the counters.
  72. loss_dict['total'] = labels.size(0)
  73. loss_dict['correct'] = corrects.sum().item()
  74. return 0, loss_dict
  75. #defined inside to capture output_predictions
  76. def correct_answers_forward_step(batch, model):
  77. try:
  78. batch_ = next(batch)
  79. except BaseException:
  80. batch_ = batch
  81. images, labels = process_batch(batch_)
  82. # Forward model.
  83. args = get_args()
  84. output_tensor = model(images)
  85. return output_tensor, partial(loss_func, labels)
  86. with torch.no_grad():
  87. # For all the batches in the dataset.
  88. total = 0
  89. correct = 0
  90. for _, batch in enumerate(dataloader):
  91. loss_dicts = forward_backward_func(correct_answers_forward_step, batch, model,
  92. optimizer=None, timers=None, forward_only=True)
  93. for loss_dict in loss_dicts:
  94. total += loss_dict['total']
  95. correct += loss_dict['correct']
  96. for m in model:
  97. m.train()
  98. # Reduce.
  99. if mpu.is_pipeline_last_stage():
  100. unreduced = torch.cuda.LongTensor([correct, total])
  101. torch.distributed.all_reduce(unreduced,
  102. group=mpu.get_data_parallel_group())
  103. # Print on screen.
  104. correct_ans = unreduced[0].item()
  105. total_count = unreduced[1].item()
  106. return correct_ans, total_count