inception_resnet_v2.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the definition of the Inception Resnet V2 architecture.
  16. As described in http://arxiv.org/abs/1602.07261.
  17. Inception-v4, Inception-ResNet and the Impact of Residual Connections
  18. on Learning
  19. Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
  20. """
  21. from __future__ import absolute_import
  22. from __future__ import division
  23. from __future__ import print_function
  24. import tensorflow as tf
  25. slim = tf.contrib.slim
  26. def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  27. """Builds the 35x35 resnet block."""
  28. with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
  29. with tf.variable_scope('Branch_0'):
  30. tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
  31. with tf.variable_scope('Branch_1'):
  32. tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
  33. tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
  34. with tf.variable_scope('Branch_2'):
  35. tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
  36. tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
  37. tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
  38. mixed = tf.concat(3, [tower_conv, tower_conv1_1, tower_conv2_2])
  39. up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
  40. activation_fn=None, scope='Conv2d_1x1')
  41. net += scale * up
  42. if activation_fn:
  43. net = activation_fn(net)
  44. return net
  45. def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  46. """Builds the 17x17 resnet block."""
  47. with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
  48. with tf.variable_scope('Branch_0'):
  49. tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
  50. with tf.variable_scope('Branch_1'):
  51. tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
  52. tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
  53. scope='Conv2d_0b_1x7')
  54. tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
  55. scope='Conv2d_0c_7x1')
  56. mixed = tf.concat(3, [tower_conv, tower_conv1_2])
  57. up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
  58. activation_fn=None, scope='Conv2d_1x1')
  59. net += scale * up
  60. if activation_fn:
  61. net = activation_fn(net)
  62. return net
  63. def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
  64. """Builds the 8x8 resnet block."""
  65. with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
  66. with tf.variable_scope('Branch_0'):
  67. tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
  68. with tf.variable_scope('Branch_1'):
  69. tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
  70. tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
  71. scope='Conv2d_0b_1x3')
  72. tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
  73. scope='Conv2d_0c_3x1')
  74. mixed = tf.concat(3, [tower_conv, tower_conv1_2])
  75. up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
  76. activation_fn=None, scope='Conv2d_1x1')
  77. net += scale * up
  78. if activation_fn:
  79. net = activation_fn(net)
  80. return net
  81. def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
  82. dropout_keep_prob=0.8,
  83. reuse=None,
  84. scope='InceptionResnetV2'):
  85. """Creates the Inception Resnet V2 model.
  86. Args:
  87. inputs: a 4-D tensor of size [batch_size, height, width, 3].
  88. num_classes: number of predicted classes.
  89. is_training: whether is training or not.
  90. dropout_keep_prob: float, the fraction to keep before final layer.
  91. reuse: whether or not the network and its variables should be reused. To be
  92. able to reuse 'scope' must be given.
  93. scope: Optional variable_scope.
  94. Returns:
  95. logits: the logits outputs of the model.
  96. end_points: the set of end_points from the inception model.
  97. """
  98. end_points = {}
  99. with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):
  100. with slim.arg_scope([slim.batch_norm, slim.dropout],
  101. is_training=is_training):
  102. with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
  103. stride=1, padding='SAME'):
  104. # 149 x 149 x 32
  105. net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
  106. scope='Conv2d_1a_3x3')
  107. end_points['Conv2d_1a_3x3'] = net
  108. # 147 x 147 x 32
  109. net = slim.conv2d(net, 32, 3, padding='VALID',
  110. scope='Conv2d_2a_3x3')
  111. end_points['Conv2d_2a_3x3'] = net
  112. # 147 x 147 x 64
  113. net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
  114. end_points['Conv2d_2b_3x3'] = net
  115. # 73 x 73 x 64
  116. net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
  117. scope='MaxPool_3a_3x3')
  118. end_points['MaxPool_3a_3x3'] = net
  119. # 73 x 73 x 80
  120. net = slim.conv2d(net, 80, 1, padding='VALID',
  121. scope='Conv2d_3b_1x1')
  122. end_points['Conv2d_3b_1x1'] = net
  123. # 71 x 71 x 192
  124. net = slim.conv2d(net, 192, 3, padding='VALID',
  125. scope='Conv2d_4a_3x3')
  126. end_points['Conv2d_4a_3x3'] = net
  127. # 35 x 35 x 192
  128. net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
  129. scope='MaxPool_5a_3x3')
  130. end_points['MaxPool_5a_3x3'] = net
  131. # 35 x 35 x 320
  132. with tf.variable_scope('Mixed_5b'):
  133. with tf.variable_scope('Branch_0'):
  134. tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
  135. with tf.variable_scope('Branch_1'):
  136. tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
  137. tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
  138. scope='Conv2d_0b_5x5')
  139. with tf.variable_scope('Branch_2'):
  140. tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
  141. tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
  142. scope='Conv2d_0b_3x3')
  143. tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
  144. scope='Conv2d_0c_3x3')
  145. with tf.variable_scope('Branch_3'):
  146. tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
  147. scope='AvgPool_0a_3x3')
  148. tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
  149. scope='Conv2d_0b_1x1')
  150. net = tf.concat(3, [tower_conv, tower_conv1_1,
  151. tower_conv2_2, tower_pool_1])
  152. end_points['Mixed_5b'] = net
  153. net = slim.repeat(net, 10, block35, scale=0.17)
  154. # 17 x 17 x 1024
  155. with tf.variable_scope('Mixed_6a'):
  156. with tf.variable_scope('Branch_0'):
  157. tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
  158. scope='Conv2d_1a_3x3')
  159. with tf.variable_scope('Branch_1'):
  160. tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
  161. tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
  162. scope='Conv2d_0b_3x3')
  163. tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
  164. stride=2, padding='VALID',
  165. scope='Conv2d_1a_3x3')
  166. with tf.variable_scope('Branch_2'):
  167. tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
  168. scope='MaxPool_1a_3x3')
  169. net = tf.concat(3, [tower_conv, tower_conv1_2, tower_pool])
  170. end_points['Mixed_6a'] = net
  171. net = slim.repeat(net, 20, block17, scale=0.10)
  172. # Auxillary tower
  173. with tf.variable_scope('AuxLogits'):
  174. aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',
  175. scope='Conv2d_1a_3x3')
  176. aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
  177. aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
  178. padding='VALID', scope='Conv2d_2a_5x5')
  179. aux = slim.flatten(aux)
  180. aux = slim.fully_connected(aux, num_classes, activation_fn=None,
  181. scope='Logits')
  182. end_points['AuxLogits'] = aux
  183. with tf.variable_scope('Mixed_7a'):
  184. with tf.variable_scope('Branch_0'):
  185. tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
  186. tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
  187. padding='VALID', scope='Conv2d_1a_3x3')
  188. with tf.variable_scope('Branch_1'):
  189. tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
  190. tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
  191. padding='VALID', scope='Conv2d_1a_3x3')
  192. with tf.variable_scope('Branch_2'):
  193. tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
  194. tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
  195. scope='Conv2d_0b_3x3')
  196. tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
  197. padding='VALID', scope='Conv2d_1a_3x3')
  198. with tf.variable_scope('Branch_3'):
  199. tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
  200. scope='MaxPool_1a_3x3')
  201. net = tf.concat(3, [tower_conv_1, tower_conv1_1,
  202. tower_conv2_2, tower_pool])
  203. end_points['Mixed_7a'] = net
  204. net = slim.repeat(net, 9, block8, scale=0.20)
  205. net = block8(net, activation_fn=None)
  206. net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
  207. end_points['Conv2d_7b_1x1'] = net
  208. with tf.variable_scope('Logits'):
  209. end_points['PrePool'] = net
  210. net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
  211. scope='AvgPool_1a_8x8')
  212. net = slim.flatten(net)
  213. net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
  214. scope='Dropout')
  215. end_points['PreLogitsFlatten'] = net
  216. logits = slim.fully_connected(net, num_classes, activation_fn=None,
  217. scope='Logits')
  218. end_points['Logits'] = logits
  219. end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
  220. return logits, end_points
  221. inception_resnet_v2.default_image_size = 299
  222. def inception_resnet_v2_arg_scope(weight_decay=0.00004,
  223. batch_norm_decay=0.9997,
  224. batch_norm_epsilon=0.001):
  225. """Yields the scope with the default parameters for inception_resnet_v2.
  226. Args:
  227. weight_decay: the weight decay for weights variables.
  228. batch_norm_decay: decay for the moving average of batch_norm momentums.
  229. batch_norm_epsilon: small float added to variance to avoid dividing by zero.
  230. Returns:
  231. a arg_scope with the parameters needed for inception_resnet_v2.
  232. """
  233. # Set weight_decay for weights in conv2d and fully_connected layers.
  234. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  235. weights_regularizer=slim.l2_regularizer(weight_decay),
  236. biases_regularizer=slim.l2_regularizer(weight_decay)):
  237. batch_norm_params = {
  238. 'decay': batch_norm_decay,
  239. 'epsilon': batch_norm_epsilon,
  240. }
  241. # Set activation_fn and parameters for batch_norm.
  242. with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
  243. normalizer_fn=slim.batch_norm,
  244. normalizer_params=batch_norm_params) as scope:
  245. return scope