inception_v1.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Contains the definition for inception v1 classification network."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. slim = tf.contrib.slim
  21. trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
  22. def inception_v1_base(inputs,
  23. final_endpoint='Mixed_5c',
  24. scope='InceptionV1'):
  25. """Defines the Inception V1 base architecture.
  26. This architecture is defined in:
  27. Going deeper with convolutions
  28. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  29. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  30. http://arxiv.org/pdf/1409.4842v1.pdf.
  31. Args:
  32. inputs: a tensor of size [batch_size, height, width, channels].
  33. final_endpoint: specifies the endpoint to construct the network up to. It
  34. can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
  35. 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
  36. 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
  37. 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
  38. scope: Optional variable_scope.
  39. Returns:
  40. A dictionary from components of the network to the corresponding activation.
  41. Raises:
  42. ValueError: if final_endpoint is not set to one of the predefined values.
  43. """
  44. end_points = {}
  45. with tf.variable_scope(scope, 'InceptionV1', [inputs]):
  46. with slim.arg_scope(
  47. [slim.conv2d, slim.fully_connected],
  48. weights_initializer=trunc_normal(0.01)):
  49. with slim.arg_scope([slim.conv2d, slim.max_pool2d],
  50. stride=1, padding='SAME'):
  51. end_point = 'Conv2d_1a_7x7'
  52. net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
  53. end_points[end_point] = net
  54. if final_endpoint == end_point: return net, end_points
  55. end_point = 'MaxPool_2a_3x3'
  56. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  57. end_points[end_point] = net
  58. if final_endpoint == end_point: return net, end_points
  59. end_point = 'Conv2d_2b_1x1'
  60. net = slim.conv2d(net, 64, [1, 1], scope=end_point)
  61. end_points[end_point] = net
  62. if final_endpoint == end_point: return net, end_points
  63. end_point = 'Conv2d_2c_3x3'
  64. net = slim.conv2d(net, 192, [3, 3], scope=end_point)
  65. end_points[end_point] = net
  66. if final_endpoint == end_point: return net, end_points
  67. end_point = 'MaxPool_3a_3x3'
  68. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  69. end_points[end_point] = net
  70. if final_endpoint == end_point: return net, end_points
  71. end_point = 'Mixed_3b'
  72. with tf.variable_scope(end_point):
  73. with tf.variable_scope('Branch_0'):
  74. branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
  75. with tf.variable_scope('Branch_1'):
  76. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  77. branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
  78. with tf.variable_scope('Branch_2'):
  79. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  80. branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
  81. with tf.variable_scope('Branch_3'):
  82. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  83. branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
  84. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  85. end_points[end_point] = net
  86. if final_endpoint == end_point: return net, end_points
  87. end_point = 'Mixed_3c'
  88. with tf.variable_scope(end_point):
  89. with tf.variable_scope('Branch_0'):
  90. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  91. with tf.variable_scope('Branch_1'):
  92. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  93. branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
  94. with tf.variable_scope('Branch_2'):
  95. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  96. branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
  97. with tf.variable_scope('Branch_3'):
  98. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  99. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  100. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  101. end_points[end_point] = net
  102. if final_endpoint == end_point: return net, end_points
  103. end_point = 'MaxPool_4a_3x3'
  104. net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
  105. end_points[end_point] = net
  106. if final_endpoint == end_point: return net, end_points
  107. end_point = 'Mixed_4b'
  108. with tf.variable_scope(end_point):
  109. with tf.variable_scope('Branch_0'):
  110. branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  111. with tf.variable_scope('Branch_1'):
  112. branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
  113. branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
  114. with tf.variable_scope('Branch_2'):
  115. branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
  116. branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
  117. with tf.variable_scope('Branch_3'):
  118. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  119. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  120. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  121. end_points[end_point] = net
  122. if final_endpoint == end_point: return net, end_points
  123. end_point = 'Mixed_4c'
  124. with tf.variable_scope(end_point):
  125. with tf.variable_scope('Branch_0'):
  126. branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  127. with tf.variable_scope('Branch_1'):
  128. branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  129. branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
  130. with tf.variable_scope('Branch_2'):
  131. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  132. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  133. with tf.variable_scope('Branch_3'):
  134. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  135. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  136. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  137. end_points[end_point] = net
  138. if final_endpoint == end_point: return net, end_points
  139. end_point = 'Mixed_4d'
  140. with tf.variable_scope(end_point):
  141. with tf.variable_scope('Branch_0'):
  142. branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  143. with tf.variable_scope('Branch_1'):
  144. branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
  145. branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
  146. with tf.variable_scope('Branch_2'):
  147. branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
  148. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  149. with tf.variable_scope('Branch_3'):
  150. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  151. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  152. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  153. end_points[end_point] = net
  154. if final_endpoint == end_point: return net, end_points
  155. end_point = 'Mixed_4e'
  156. with tf.variable_scope(end_point):
  157. with tf.variable_scope('Branch_0'):
  158. branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
  159. with tf.variable_scope('Branch_1'):
  160. branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
  161. branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
  162. with tf.variable_scope('Branch_2'):
  163. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  164. branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
  165. with tf.variable_scope('Branch_3'):
  166. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  167. branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
  168. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  169. end_points[end_point] = net
  170. if final_endpoint == end_point: return net, end_points
  171. end_point = 'Mixed_4f'
  172. with tf.variable_scope(end_point):
  173. with tf.variable_scope('Branch_0'):
  174. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  175. with tf.variable_scope('Branch_1'):
  176. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  177. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  178. with tf.variable_scope('Branch_2'):
  179. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  180. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  181. with tf.variable_scope('Branch_3'):
  182. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  183. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  184. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  185. end_points[end_point] = net
  186. if final_endpoint == end_point: return net, end_points
  187. end_point = 'MaxPool_5a_2x2'
  188. net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point)
  189. end_points[end_point] = net
  190. if final_endpoint == end_point: return net, end_points
  191. end_point = 'Mixed_5b'
  192. with tf.variable_scope(end_point):
  193. with tf.variable_scope('Branch_0'):
  194. branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
  195. with tf.variable_scope('Branch_1'):
  196. branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
  197. branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
  198. with tf.variable_scope('Branch_2'):
  199. branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
  200. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
  201. with tf.variable_scope('Branch_3'):
  202. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  203. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  204. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  205. end_points[end_point] = net
  206. if final_endpoint == end_point: return net, end_points
  207. end_point = 'Mixed_5c'
  208. with tf.variable_scope(end_point):
  209. with tf.variable_scope('Branch_0'):
  210. branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
  211. with tf.variable_scope('Branch_1'):
  212. branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
  213. branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
  214. with tf.variable_scope('Branch_2'):
  215. branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
  216. branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
  217. with tf.variable_scope('Branch_3'):
  218. branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
  219. branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
  220. net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
  221. end_points[end_point] = net
  222. if final_endpoint == end_point: return net, end_points
  223. raise ValueError('Unknown final endpoint %s' % final_endpoint)
  224. def inception_v1(inputs,
  225. num_classes=1000,
  226. is_training=True,
  227. dropout_keep_prob=0.8,
  228. prediction_fn=slim.softmax,
  229. spatial_squeeze=True,
  230. reuse=None,
  231. scope='InceptionV1'):
  232. """Defines the Inception V1 architecture.
  233. This architecture is defined in:
  234. Going deeper with convolutions
  235. Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
  236. Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
  237. http://arxiv.org/pdf/1409.4842v1.pdf.
  238. The default image size used to train this network is 224x224.
  239. Args:
  240. inputs: a tensor of size [batch_size, height, width, channels].
  241. num_classes: number of predicted classes.
  242. is_training: whether is training or not.
  243. dropout_keep_prob: the percentage of activation values that are retained.
  244. prediction_fn: a function to get predictions out of logits.
  245. spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
  246. of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
  247. reuse: whether or not the network and its variables should be reused. To be
  248. able to reuse 'scope' must be given.
  249. scope: Optional variable_scope.
  250. Returns:
  251. logits: the pre-softmax activations, a tensor of size
  252. [batch_size, num_classes]
  253. end_points: a dictionary from components of the network to the corresponding
  254. activation.
  255. """
  256. # Final pooling and prediction
  257. with tf.variable_scope(scope, 'InceptionV1', [inputs, num_classes],
  258. reuse=reuse) as scope:
  259. with slim.arg_scope([slim.batch_norm, slim.dropout],
  260. is_training=is_training):
  261. net, end_points = inception_v1_base(inputs, scope=scope)
  262. with tf.variable_scope('Logits'):
  263. net = slim.avg_pool2d(net, [7, 7], stride=1, scope='MaxPool_0a_7x7')
  264. net = slim.dropout(net,
  265. dropout_keep_prob, scope='Dropout_0b')
  266. logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
  267. normalizer_fn=None, scope='Conv2d_0c_1x1')
  268. if spatial_squeeze:
  269. logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
  270. end_points['Logits'] = logits
  271. end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
  272. return logits, end_points
  273. inception_v1.default_image_size = 224
  274. def inception_v1_arg_scope(weight_decay=0.00004,
  275. use_batch_norm=True):
  276. """Defines the default InceptionV1 arg scope.
  277. Note: Althougth the original paper didn't use batch_norm we found it useful.
  278. Args:
  279. weight_decay: The weight decay to use for regularizing the model.
  280. use_batch_norm: "If `True`, batch_norm is applied after each convolution.
  281. Returns:
  282. An `arg_scope` to use for the inception v3 model.
  283. """
  284. batch_norm_params = {
  285. # Decay for the moving averages.
  286. 'decay': 0.9997,
  287. # epsilon to prevent 0s in variance.
  288. 'epsilon': 0.001,
  289. # collection containing update_ops.
  290. 'updates_collections': tf.GraphKeys.UPDATE_OPS,
  291. }
  292. if use_batch_norm:
  293. normalizer_fn = slim.batch_norm
  294. normalizer_params = batch_norm_params
  295. else:
  296. normalizer_fn = None
  297. normalizer_params = {}
  298. # Set weight_decay for weights in Conv and FC layers.
  299. with slim.arg_scope([slim.conv2d, slim.fully_connected],
  300. weights_regularizer=slim.l2_regularizer(weight_decay)):
  301. with slim.arg_scope(
  302. [slim.conv2d],
  303. weights_initializer=slim.variance_scaling_initializer(),
  304. activation_fn=tf.nn.relu,
  305. normalizer_fn=normalizer_fn,
  306. normalizer_params=normalizer_params) as sc:
  307. return sc