FullyConvolutionalResnet50.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import cv2
  2. import numpy as np
  3. import tensorflow as tf
  4. from tensorflow.keras import Input
  5. from tensorflow.keras.applications import ResNet50
  6. from tensorflow.keras.applications.resnet import preprocess_input
  7. from tensorflow.keras.layers import (
  8. Activation,
  9. AveragePooling2D,
  10. BatchNormalization,
  11. Conv2D,
  12. MaxPooling2D,
  13. ZeroPadding2D,
  14. )
  15. from tensorflow.python.keras.engine import training
  16. from tensorflow.python.keras.utils import data_utils
  17. from utils import (
  18. BASE_WEIGHTS_PATH,
  19. WEIGHTS_HASHES,
  20. stack1,
  21. )
  22. # setting FC weights to the final convolutional layer
  23. def set_conv_weights(model, feature_extractor):
  24. # get pre-trained ResNet50 FC weights
  25. dense_layer_weights = feature_extractor.layers[-1].get_weights()
  26. weights_list = [
  27. tf.reshape(
  28. dense_layer_weights[0], (1, 1, *dense_layer_weights[0].shape),
  29. ).numpy(),
  30. dense_layer_weights[1],
  31. ]
  32. model.get_layer(name="last_conv").set_weights(weights_list)
  33. def fully_convolutional_resnet50(
  34. input_shape, num_classes=1000, pretrained_resnet=True, use_bias=True,
  35. ):
  36. # init input layer
  37. img_input = Input(shape=input_shape)
  38. # define basic model pipeline
  39. x = ZeroPadding2D(padding=((3, 3), (3, 3)), name="conv1_pad")(img_input)
  40. x = Conv2D(64, 7, strides=2, use_bias=use_bias, name="conv1_conv")(x)
  41. x = BatchNormalization(axis=3, epsilon=1.001e-5, name="conv1_bn")(x)
  42. x = Activation("relu", name="conv1_relu")(x)
  43. x = ZeroPadding2D(padding=((1, 1), (1, 1)), name="pool1_pad")(x)
  44. x = MaxPooling2D(3, strides=2, name="pool1_pool")(x)
  45. # the sequence of stacked residual blocks
  46. x = stack1(x, 64, 3, stride1=1, name="conv2")
  47. x = stack1(x, 128, 4, name="conv3")
  48. x = stack1(x, 256, 6, name="conv4")
  49. x = stack1(x, 512, 3, name="conv5")
  50. # add avg pooling layer after feature extraction layers
  51. x = AveragePooling2D(pool_size=7)(x)
  52. # add final convolutional layer
  53. conv_layer_final = Conv2D(
  54. filters=num_classes, kernel_size=1, use_bias=use_bias, name="last_conv",
  55. )(x)
  56. # configure fully convolutional ResNet50 model
  57. model = training.Model(img_input, x)
  58. # load model weights
  59. if pretrained_resnet:
  60. model_name = "resnet50"
  61. # configure full file name
  62. file_name = model_name + "_weights_tf_dim_ordering_tf_kernels_notop.h5"
  63. # get the file hash from TF WEIGHTS_HASHES
  64. file_hash = WEIGHTS_HASHES[model_name][1]
  65. weights_path = data_utils.get_file(
  66. file_name,
  67. BASE_WEIGHTS_PATH + file_name,
  68. cache_subdir="models",
  69. file_hash=file_hash,
  70. )
  71. model.load_weights(weights_path)
  72. # form final model
  73. model = training.Model(inputs=model.input, outputs=[conv_layer_final])
  74. if pretrained_resnet:
  75. # get model with the dense layer for further FC weights extraction
  76. resnet50_extractor = ResNet50(
  77. include_top=True, weights="imagenet", classes=num_classes,
  78. )
  79. # set ResNet50 FC-layer weights to final convolutional layer
  80. set_conv_weights(model=model, feature_extractor=resnet50_extractor)
  81. return model