utils.py 2.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. from tensorflow.keras.layers import (
  2. Activation,
  3. Add,
  4. BatchNormalization,
  5. Conv2D,
  6. )
  7. # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L32
  8. BASE_WEIGHTS_PATH = (
  9. "https://storage.googleapis.com/tensorflow/keras-applications/resnet/"
  10. )
  11. WEIGHTS_HASHES = {
  12. "resnet50": "4d473c1dd8becc155b73f8504c6f6626",
  13. }
  14. # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L262
  15. def stack1(x, filters, blocks, stride1=2, name=None):
  16. """
  17. A set of stacked residual blocks.
  18. Arguments:
  19. x: input tensor.
  20. filters: integer, filters of the bottleneck layer in a block.
  21. blocks: integer, blocks in the stacked blocks.
  22. stride1: default 2, stride of the first layer in the first block.
  23. name: string, stack label.
  24. Returns:
  25. Output tensor for the stacked blocks.
  26. """
  27. x = block1(x, filters, stride=stride1, name=name + "_block1")
  28. for i in range(2, blocks + 1):
  29. x = block1(x, filters, conv_shortcut=False, name=name + "_block" + str(i))
  30. return x
  31. # https://github.com/tensorflow/tensorflow/blob/2b96f3662bd776e277f86997659e61046b56c315/tensorflow/python/keras/applications/resnet.py#L217
  32. def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):
  33. """
  34. A residual block.
  35. Arguments:
  36. x: input tensor.
  37. filters: integer, filters of the bottleneck layer.
  38. kernel_size: default 3, kernel size of the bottleneck layer.
  39. stride: default 1, stride of the first layer.
  40. conv_shortcut: default True, use convolution shortcut if True,
  41. otherwise identity shortcut.
  42. name: string, block label.
  43. Returns:
  44. Output tensor for the residual block.
  45. """
  46. # channels_last format
  47. bn_axis = 3
  48. if conv_shortcut:
  49. shortcut = Conv2D(4 * filters, 1, strides=stride, name=name + "_0_conv")(x)
  50. shortcut = BatchNormalization(
  51. axis=bn_axis, epsilon=1.001e-5, name=name + "_0_bn",
  52. )(shortcut)
  53. else:
  54. shortcut = x
  55. x = Conv2D(filters, 1, strides=stride, name=name + "_1_conv")(x)
  56. x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_1_bn")(x)
  57. x = Activation("relu", name=name + "_1_relu")(x)
  58. x = Conv2D(filters, kernel_size, padding="SAME", name=name + "_2_conv")(x)
  59. x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_2_bn")(x)
  60. x = Activation("relu", name=name + "_2_relu")(x)
  61. x = Conv2D(4 * filters, 1, name=name + "_3_conv")(x)
  62. x = BatchNormalization(axis=bn_axis, epsilon=1.001e-5, name=name + "_3_bn")(x)
  63. x = Add(name=name + "_add")([shortcut, x])
  64. x = Activation("relu", name=name + "_out")(x)
  65. return x