FullyConvolutionalResnet18.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import torch
  2. import torch.nn as nn
  3. from torchvision import models
  4. from torch.hub import load_state_dict_from_url
  5. from PIL import Image
  6. import cv2
  7. import numpy as np
  8. from matplotlib import pyplot as plt
  9. from torchvision import transforms
  10. from torchsummary import summary
  11. # Define the architecture by modifying resnet.
  12. # Original code is here
  13. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  14. class FullyConvolutionalResnet18(models.ResNet):
  15. def __init__(self, num_classes=1000, pretrained=False, **kwargs):
  16. # Start with standard resnet18 defined here
  17. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py
  18. super().__init__(block=models.resnet.BasicBlock, layers=[2, 2, 2, 2], num_classes=num_classes, **kwargs)
  19. if pretrained:
  20. state_dict = load_state_dict_from_url(models.resnet.model_urls["resnet18"], progress=True)
  21. self.load_state_dict(state_dict)
  22. # Replace AdaptiveAvgPool2d with standard AvgPool2d
  23. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L153-L154
  24. self.avgpool = nn.AvgPool2d((7, 7))
  25. # Add final Convolution Layer.
  26. self.last_conv = torch.nn.Conv2d(in_channels=self.fc.in_features, out_channels=num_classes, kernel_size=1)
  27. self.last_conv.weight.data.copy_(self.fc.weight.data.view(*self.fc.weight.data.shape, 1, 1))
  28. self.last_conv.bias.data.copy_(self.fc.bias.data)
  29. # Reimplementing forward pass.
  30. # Replacing the following code
  31. # https://github.com/pytorch/vision/blob/b2e95657cd5f389e3973212ba7ddbdcc751a7878/torchvision/models/resnet.py#L197-L213
  32. def _forward_impl(self, x):
  33. # Standard forward for resnet18
  34. x = self.conv1(x)
  35. x = self.bn1(x)
  36. x = self.relu(x)
  37. x = self.maxpool(x)
  38. x = self.layer1(x)
  39. x = self.layer2(x)
  40. x = self.layer3(x)
  41. x = self.layer4(x)
  42. x = self.avgpool(x)
  43. # Notice, there is no forward pass
  44. # through the original fully connected layer.
  45. # Instead, we forward pass through the last conv layer
  46. x = self.last_conv(x)
  47. return x
  48. if __name__ == "__main__":
  49. # Read ImageNet class id to name mapping
  50. with open('imagenet_classes.txt') as f:
  51. labels = [line.strip() for line in f.readlines()]
  52. # Read image
  53. original_image = cv2.imread('camel.jpg')
  54. # Convert original image to RGB format
  55. image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
  56. # Transform input image
  57. # 1. Convert to Tensor
  58. # 2. Subtract mean
  59. # 3. Divide by standard deviation
  60. transform = transforms.Compose([
  61. transforms.ToTensor(), #Convert image to tensor.
  62. transforms.Normalize(
  63. mean=[0.485, 0.456, 0.406], # Subtract mean
  64. std=[0.229, 0.224, 0.225] # Divide by standard deviation
  65. )])
  66. image = transform(image)
  67. image = image.unsqueeze(0)
  68. # Load modified resnet18 model with pretrained ImageNet weights
  69. model = FullyConvolutionalResnet18(pretrained=True).eval()
  70. with torch.no_grad():
  71. # Perform inference.
  72. # Instead of a 1x1000 vector, we will get a
  73. # 1x1000xnxm output ( i.e. a probabibility map
  74. # of size n x m for each 1000 class,
  75. # where n and m depend on the size of the image.)
  76. preds = model(image)
  77. preds = torch.softmax(preds, dim=1)
  78. print('Response map shape : ', preds.shape)
  79. # Find the class with the maximum score in the n x m output map
  80. pred, class_idx = torch.max(preds, dim=1)
  81. print(class_idx)
  82. row_max, row_idx = torch.max(pred, dim=1)
  83. col_max, col_idx = torch.max(row_max, dim=1)
  84. predicted_class = class_idx[0, row_idx[0, col_idx], col_idx]
  85. # Print top predicted class
  86. print('Predicted Class : ', labels[predicted_class], predicted_class)
  87. # Find the n x m score map for the predicted class
  88. score_map = preds[0, predicted_class, :, :].cpu().numpy()
  89. score_map = score_map[0]
  90. # Resize score map to the original image size
  91. score_map = cv2.resize(score_map, (original_image.shape[1], original_image.shape[0]))
  92. # Binarize score map
  93. _, score_map_for_contours = cv2.threshold(score_map, 0.25, 1, type=cv2.THRESH_BINARY)
  94. score_map_for_contours = score_map_for_contours.astype(np.uint8).copy()
  95. # Find the countour of the binary blob
  96. contours, _ = cv2.findContours(score_map_for_contours, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE)
  97. # Find bounding box around the object.
  98. rect = cv2.boundingRect(contours[0])
  99. # Apply score map as a mask to original image
  100. score_map = score_map - np.min(score_map[:])
  101. score_map = score_map / np.max(score_map[:])
  102. score_map = cv2.cvtColor(score_map, cv2.COLOR_GRAY2BGR)
  103. masked_image = (original_image * score_map).astype(np.uint8)
  104. # Display bounding box
  105. cv2.rectangle(masked_image, rect[:2], (rect[0] + rect[2], rect[1] + rect[3]), (0, 0, 255), 2)
  106. # Display images
  107. cv2.imshow("Original Image", original_image)
  108. cv2.imshow("scaled_score_map", score_map)
  109. cv2.imshow("activations_and_bbox", masked_image)
  110. cv2.waitKey(0)