receptive_field_backpropagation.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. from collections import namedtuple
  2. import cv2
  3. import numpy as np
  4. import tensorflow as tf
  5. import tensorflow.keras.backend as K
  6. from tensorflow.keras.applications.resnet import preprocess_input
  7. from tensorflow.keras.layers import (
  8. BatchNormalization,
  9. Conv2D,
  10. )
  11. from FullyConvolutionalResnet50 import fully_convolutional_resnet50
  12. Rect = namedtuple("Rect", "x1 y1 x2 y2")
  13. def backprop_receptive_field(
  14. image, predicted_class, scoremap, use_max_activation=False,
  15. ):
  16. model = fully_convolutional_resnet50(
  17. input_shape=(image.shape[-3:]), pretrained_resnet=False,
  18. )
  19. for module in model.layers:
  20. try:
  21. if isinstance(module, Conv2D):
  22. conv_weights = np.full(module.get_weights()[0].shape, 0.005)
  23. if len(module.get_weights()) > 1:
  24. conv_biases = np.full(module.get_weights()[1].shape, 0.0)
  25. module.set_weights([conv_weights, conv_biases])
  26. # cases when use_bias = False
  27. else:
  28. module.set_weights([conv_weights])
  29. if isinstance(module, BatchNormalization):
  30. # weight sequence: gamma, beta, running mean, running variance
  31. bn_weights = [
  32. module.get_weights()[0],
  33. module.get_weights()[1],
  34. np.full(module.get_weights()[2].shape, 0.0),
  35. np.full(module.get_weights()[3].shape, 1.0),
  36. ]
  37. module.set_weights(bn_weights)
  38. except:
  39. pass
  40. input = tf.ones_like(image)
  41. out = model.predict(image)
  42. receptive_field_mask = tf.Variable(tf.zeros_like(out))
  43. if not use_max_activation:
  44. receptive_field_mask[:, :, :, predicted_class].assign(scoremap)
  45. else:
  46. scoremap_max_row_values = tf.math.reduce_max(scoremap, axis=1)
  47. max_row_id = tf.math.argmax(scoremap, axis=1)
  48. max_col_id = tf.math.argmax(scoremap_max_row_values, axis=1).numpy()[0]
  49. max_row_id = max_row_id[0, max_col_id].numpy()
  50. print(
  51. "Coords of the max activation:", max_row_id, max_col_id,
  52. )
  53. # update gradient
  54. receptive_field_mask = tf.tensor_scatter_nd_update(
  55. receptive_field_mask, [(0, max_row_id, max_col_id, 0)], [1],
  56. )
  57. grads = []
  58. with tf.GradientTape() as tf_gradient_tape:
  59. tf_gradient_tape.watch(input)
  60. # get the predictions
  61. preds = model(input)
  62. # apply the mask
  63. pseudo_loss = preds * receptive_field_mask
  64. pseudo_loss = K.mean(pseudo_loss)
  65. # get gradient
  66. grad = tf_gradient_tape.gradient(pseudo_loss, input)
  67. grad = tf.transpose(grad, perm=[0, 3, 1, 2])
  68. grads.append(grad)
  69. return grads[0][0, 0]
  70. def find_rect(activations):
  71. # Dilate and erode the activations to remove grid-like artifacts
  72. kernel = np.ones((5, 5), np.uint8)
  73. activations = cv2.dilate(activations, kernel=kernel)
  74. activations = cv2.erode(activations, kernel=kernel)
  75. # Binarize the activations
  76. _, activations = cv2.threshold(activations, 0.65, 1, type=cv2.THRESH_BINARY)
  77. activations = activations.astype(np.uint8).copy()
  78. # Find the contour of the binary blob
  79. contours, _ = cv2.findContours(
  80. activations, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE,
  81. )
  82. # Find bounding box around the object.
  83. rect = cv2.boundingRect(contours[0])
  84. return Rect(rect[0], rect[1], rect[0] + rect[2], rect[1] + rect[3])
  85. def normalize(activations):
  86. activations = activations - np.min(activations[:])
  87. activations = activations / np.max(activations[:])
  88. return activations
  89. def visualize_activations(image, activations, show_bounding_rect=False):
  90. activations = normalize(activations)
  91. activations_multichannel = np.stack([activations, activations, activations], axis=2)
  92. masked_image = (image * activations_multichannel).astype(np.uint8)
  93. if show_bounding_rect:
  94. rect = find_rect(activations.numpy())
  95. cv2.rectangle(
  96. masked_image,
  97. (rect.x1, rect.y1),
  98. (rect.x2, rect.y2),
  99. color=(0, 0, 255),
  100. thickness=2,
  101. )
  102. return masked_image
  103. def run_resnet_inference(original_image):
  104. # read ImageNet class ids to a list of labels
  105. with open("imagenet_classes.txt") as f:
  106. labels = [line.strip() for line in f.readlines()]
  107. # convert image to the RGB format
  108. image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
  109. # pre-process image
  110. image = preprocess_input(image)
  111. # convert image to NCHW tf.tensor
  112. image = tf.expand_dims(image, 0)
  113. # load resnet50 model with pretrained ImageNet weights
  114. model = fully_convolutional_resnet50(input_shape=(image.shape[-3:]))
  115. # Perform inference.
  116. # Instead of a 1×1000 vector, we will get a
  117. # 1×1000×n×m output ( i.e. a probability map
  118. # of size n × m for each 1000 class,
  119. # where n and m depend on the size of the image).
  120. preds = model.predict(image)
  121. preds = tf.transpose(preds, perm=[0, 3, 1, 2])
  122. preds = tf.nn.softmax(preds, axis=1)
  123. print("Response map shape : ", preds.shape)
  124. # find class with the maximum score in the n × m output map
  125. pred = tf.math.reduce_max(preds, axis=1)
  126. class_idx = tf.math.argmax(preds, axis=1)
  127. row_max = tf.math.reduce_max(pred, axis=1)
  128. row_idx = tf.math.argmax(pred, axis=1)
  129. col_idx = tf.math.argmax(row_max, axis=1)
  130. predicted_class = tf.gather_nd(
  131. class_idx, (0, tf.gather_nd(row_idx, (0, col_idx[0])), col_idx[0]),
  132. )
  133. # print the top predicted class
  134. print("Predicted Class : ", labels[predicted_class], predicted_class)
  135. # find the n × m score map for the predicted class
  136. score_map = tf.expand_dims(preds[0, predicted_class, :, :], 0).numpy()
  137. print("Score Map shape : ", score_map.shape)
  138. # compute the receptive filed for max activation pixel
  139. receptive_field_max_activation = backprop_receptive_field(
  140. image,
  141. scoremap=score_map,
  142. predicted_class=predicted_class,
  143. use_max_activation=True,
  144. )
  145. # compute the receptive filed for the whole image
  146. receptive_field_image = backprop_receptive_field(
  147. image,
  148. scoremap=score_map,
  149. predicted_class=predicted_class,
  150. use_max_activation=False,
  151. )
  152. # resize score map to the original image size
  153. score_map = score_map[0]
  154. score_map = cv2.resize(
  155. score_map, (original_image.shape[1], original_image.shape[0]),
  156. )
  157. # display the images
  158. cv2.imshow("Original Image", original_image)
  159. cv2.imshow(
  160. "Score map: activations and bbox",
  161. visualize_activations(original_image, score_map),
  162. )
  163. cv2.imshow(
  164. "receptive_field_max_activation",
  165. visualize_activations(
  166. original_image, receptive_field_max_activation, show_bounding_rect=True,
  167. ),
  168. )
  169. cv2.imshow(
  170. "receptive_field_the_whole_image",
  171. visualize_activations(
  172. original_image, receptive_field_image, show_bounding_rect=True,
  173. ),
  174. )
  175. cv2.waitKey(0)
  176. def main():
  177. # read the image
  178. image_path = "camel.jpg"
  179. image = cv2.imread(image_path)
  180. # run inference
  181. run_resnet_inference(image)
  182. if __name__ == "__main__":
  183. main()