inception_v1.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the definition for inception v1 classification network."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import inception_utils
  21. slim = tf.contrib.slim
  22. trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
  23. def inception_v1_base(inputs,
  24. final_endpoint='Mixed_5c',
  25. scope='InceptionV1'):
  26. """Defines the Inception V1 base architecture.
  27. This architecture is defined in:
  28. Going deeper with convolutions
  29. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  30. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  31. http://arxiv.org/pdf/1409.4842v1.pdf.
  32. Args:
  33. inputs: a tensor of size [batch_size, height, width, channels].
  34. final_endpoint: specifies the endpoint to construct the network up to. It
  35. can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
  36. 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
  37. 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
  38. 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
  39. scope: Optional variable_scope.
  40. Returns:
  41. A dictionary from components of the network to the corresponding activation.
  42. Raises:
  43. ValueError: if final_endpoint is not set to one of the predefined values.
  44. """
  45. end_points = {}
  46. with tf.variable_scope(scope, 'InceptionV1', [inputs]):
  47. with slim.arg_scope(
  48. [slim.conv2d, slim.fully_connected],
  49. weights_initializer=trunc_normal(0.01)):
  50. with slim.arg_scope([slim.conv2d, slim.max_pool2d],
  51. stride=1, padding='SAME'):
  52. end_point = 'Conv2d_1a_7x7'
  53. net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
  54. end_points[end_point] = net
  55. if final_endpoint == end_point: return net, end_points
  56. end_point = 'MaxPool_2a_3x3'
  57. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  58. end_points[end_point] = net
  59. if final_endpoint == end_point: return net, end_points
  60. end_point = 'Conv2d_2b_1x1'
  61. net = slim.conv2d(net, 64, [1, 1], scope=end_point)
  62. end_points[end_point] = net
  63. if final_endpoint == end_point: return net, end_points
  64. end_point = 'Conv2d_2c_3x3'
  65. net = slim.conv2d(net, 192, [3, 3], scope=end_point)
  66. end_points[end_point] = net
  67. if final_endpoint == end_point: return net, end_points
  68. end_point = 'MaxPool_3a_3x3'
  69. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  70. end_points[end_point] = net
  71. if final_endpoint == end_point: return net, end_points
  72. end_point = 'Mixed_3b'
  73. with tf.variable_scope(end_point):
  74. with tf.variable_scope('Branch_0'):
  75. branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
  76. with tf.variable_scope('Branch_1'):
  77. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  78. branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
  79. with tf.variable_scope('Branch_2'):
  80. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  81. branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
  82. with tf.variable_scope('Branch_3'):
  83. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  84. branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
  85. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  86. end_points[end_point] = net
  87. if final_endpoint == end_point: return net, end_points
  88. end_point = 'Mixed_3c'
  89. with tf.variable_scope(end_point):
  90. with tf.variable_scope('Branch_0'):
  91. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  92. with tf.variable_scope('Branch_1'):
  93. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  94. branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
  95. with tf.variable_scope('Branch_2'):
  96. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  97. branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
  98. with tf.variable_scope('Branch_3'):
  99. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  100. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  101. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  102. end_points[end_point] = net
  103. if final_endpoint == end_point: return net, end_points
  104. end_point = 'MaxPool_4a_3x3'
  105. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  106. end_points[end_point] = net
  107. if final_endpoint == end_point: return net, end_points
  108. end_point = 'Mixed_4b'
  109. with tf.variable_scope(end_point):
  110. with tf.variable_scope('Branch_0'):
  111. branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  112. with tf.variable_scope('Branch_1'):
  113. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  114. branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
  115. with tf.variable_scope('Branch_2'):
  116. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  117. branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
  118. with tf.variable_scope('Branch_3'):
  119. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  120. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  121. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  122. end_points[end_point] = net
  123. if final_endpoint == end_point: return net, end_points
  124. end_point = 'Mixed_4c'
  125. with tf.variable_scope(end_point):
  126. with tf.variable_scope('Branch_0'):
  127. branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  128. with tf.variable_scope('Branch_1'):
  129. branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  130. branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
  131. with tf.variable_scope('Branch_2'):
  132. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  133. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  134. with tf.variable_scope('Branch_3'):
  135. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  136. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  137. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  138. end_points[end_point] = net
  139. if final_endpoint == end_point: return net, end_points
  140. end_point = 'Mixed_4d'
  141. with tf.variable_scope(end_point):
  142. with tf.variable_scope('Branch_0'):
  143. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  144. with tf.variable_scope('Branch_1'):
  145. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  146. branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
  147. with tf.variable_scope('Branch_2'):
  148. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  149. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  150. with tf.variable_scope('Branch_3'):
  151. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  152. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  153. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  154. end_points[end_point] = net
  155. if final_endpoint == end_point: return net, end_points
  156. end_point = 'Mixed_4e'
  157. with tf.variable_scope(end_point):
  158. with tf.variable_scope('Branch_0'):
  159. branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  160. with tf.variable_scope('Branch_1'):
  161. branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
  162. branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
  163. with tf.variable_scope('Branch_2'):
  164. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  165. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  166. with tf.variable_scope('Branch_3'):
  167. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  168. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  169. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  170. end_points[end_point] = net
  171. if final_endpoint == end_point: return net, end_points
  172. end_point = 'Mixed_4f'
  173. with tf.variable_scope(end_point):
  174. with tf.variable_scope('Branch_0'):
  175. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  176. with tf.variable_scope('Branch_1'):
  177. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  178. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  179. with tf.variable_scope('Branch_2'):
  180. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  181. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  182. with tf.variable_scope('Branch_3'):
  183. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  184. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  185. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  186. end_points[end_point] = net
  187. if final_endpoint == end_point: return net, end_points
  188. end_point = 'MaxPool_5a_2x2'
  189. net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point)
  190. end_points[end_point] = net
  191. if final_endpoint == end_point: return net, end_points
  192. end_point = 'Mixed_5b'
  193. with tf.variable_scope(end_point):
  194. with tf.variable_scope('Branch_0'):
  195. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  196. with tf.variable_scope('Branch_1'):
  197. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  198. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  199. with tf.variable_scope('Branch_2'):
  200. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  201. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
  202. with tf.variable_scope('Branch_3'):
  203. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  204. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  205. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  206. end_points[end_point] = net
  207. if final_endpoint == end_point: return net, end_points
  208. end_point = 'Mixed_5c'
  209. with tf.variable_scope(end_point):
  210. with tf.variable_scope('Branch_0'):
  211. branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
  212. with tf.variable_scope('Branch_1'):
  213. branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  214. branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
  215. with tf.variable_scope('Branch_2'):
  216. branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
  217. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  218. with tf.variable_scope('Branch_3'):
  219. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  220. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  221. net = tf.concat(axis=3, values=[branch_0, branch_1, branch_2, branch_3])
  222. end_points[end_point] = net
  223. if final_endpoint == end_point: return net, end_points
  224. raise ValueError('Unknown final endpoint %s' % final_endpoint)
  225. def inception_v1(inputs,
  226. num_classes=1000,
  227. is_training=True,
  228. dropout_keep_prob=0.8,
  229. prediction_fn=slim.softmax,
  230. spatial_squeeze=True,
  231. reuse=None,
  232. scope='InceptionV1'):
  233. """Defines the Inception V1 architecture.
  234. This architecture is defined in:
  235. Going deeper with convolutions
  236. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  237. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  238. http://arxiv.org/pdf/1409.4842v1.pdf.
  239. The default image size used to train this network is 224x224.
  240. Args:
  241. inputs: a tensor of size [batch_size, height, width, channels].
  242. num_classes: number of predicted classes.
  243. is_training: whether is training or not.
  244. dropout_keep_prob: the percentage of activation values that are retained.
  245. prediction_fn: a function to get predictions out of logits.
  246. spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
  247. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
  248. reuse: whether or not the network and its variables should be reused. To be
  249. able to reuse 'scope' must be given.
  250. scope: Optional variable_scope.
  251. Returns:
  252. logits: the pre-softmax activations, a tensor of size
  253. [batch_size, num_classes]
  254. end_points: a dictionary from components of the network to the corresponding
  255. activation.
  256. """
  257. # Final pooling and prediction
  258. with tf.variable_scope(scope, 'InceptionV1', [inputs, num_classes],
  259. reuse=reuse) as scope:
  260. with slim.arg_scope([slim.batch_norm, slim.dropout],
  261. is_training=is_training):
  262. net, end_points = inception_v1_base(inputs, scope=scope)
  263. with tf.variable_scope('Logits'):
  264. net = slim.avg_pool2d(net, [7, 7], stride=1, scope='MaxPool_0a_7x7')
  265. net = slim.dropout(net,
  266. dropout_keep_prob, scope='Dropout_0b')
  267. logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
  268. normalizer_fn=None, scope='Conv2d_0c_1x1')
  269. if spatial_squeeze:
  270. logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
  271. end_points['Logits'] = logits
  272. end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  273. return logits, end_points
  274. inception_v1.default_image_size = 224
  275. inception_v1_arg_scope = inception_utils.inception_arg_scope