fully_convolution_resnet_no_pooling.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import torch
  2. from torch import nn
  3. from torchvision import models
  4. from torch.hub import load_state_dict_from_url
  5. import time
  6. from tqdm import tqdm
  7. import cv2
  8. import numpy as np
  9. class Model(models.ResNet):
  10. def __init__(self, num_classes=1000, pretrained=False, **kwargs):
  11. super().__init__(block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, **kwargs)
  12. if pretrained:
  13. state_dict = load_state_dict_from_url(models.resnet.model_urls["resnet18"], progress=True)
  14. self.load_state_dict(state_dict)
  15. self.last_conv = torch.nn.Conv2d(in_channels=self.fc.in_features, out_channels=num_classes, kernel_size=1)
  16. self.last_conv.weight.data.copy_(self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
  17. self.last_conv.bias.data.copy_(self.fc.bias.data)
  18. def _forward_impl(self, x):
  19. # See note [TorchScript super()]
  20. x = self.conv1(x)
  21. x = self.bn1(x)
  22. x = self.relu(x)
  23. x = self.maxpool(x)
  24. x = self.layer1(x)
  25. x = self.layer2(x)
  26. x = self.layer3(x)
  27. x = self.layer4(x)
  28. x = self.last_conv(x)
  29. return x
  30. def surgery():
  31. original_image = cv2.imread("dog-basset-hound.jpg", cv2.IMREAD_UNCHANGED)
  32. original_image = cv2.resize(original_image, None, None, fx=1 / 2.0, fy=1 / 2.0)
  33. cv2.imshow("original", original_image)
  34. image = original_image.copy()
  35. model = Model(pretrained=True).eval()
  36. imagenet_means = [0.485, 0.456, 0.406][::-1]
  37. imagenet_stds = [0.229, 0.224, 0.225][::-1]
  38. image = (image / 255.0 - imagenet_means) / imagenet_stds
  39. image = torch.tensor(image.transpose(2, 0, 1), dtype=torch.float32)
  40. image = image.unsqueeze(0)
  41. with torch.no_grad():
  42. preds = model(image)
  43. pred, class_idx = torch.max(preds, dim=1)
  44. row_max, row_idx = torch.max(pred, dim=1)
  45. col_max, col_idx = torch.max(row_max, dim=1)
  46. print('Most confident class: ', class_idx[0, col_idx, row_idx[0, col_idx]].item())
  47. preds = torch.softmax(preds, dim=1)
  48. score_map = preds[0, class_idx[0, col_idx, row_idx[0, col_idx]], :, :].cpu().numpy()
  49. score_map = score_map[0]
  50. score_map = np.expand_dims(score_map, -1)
  51. score_map = np.repeat(score_map, 3, axis=2)
  52. score_map = cv2.resize(score_map, (original_image.shape[1], original_image.shape[0]))
  53. cv2.imshow("activations", (original_image * score_map).astype(np.uint8))
  54. cv2.waitKey(0)
  55. if __name__ == "__main__":
  56. surgery()