inception_model.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Inception-v3 expressed in TensorFlow-Slim.
  16. Usage:
  17. # Parameters for BatchNorm.
  18. batch_norm_params = {
  19. # Decay for the batch_norm moving averages.
  20. 'decay': BATCHNORM_MOVING_AVERAGE_DECAY,
  21. # epsilon to prevent 0s in variance.
  22. 'epsilon': 0.001,
  23. }
  24. # Set weight_decay for weights in Conv and FC layers.
  25. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
  26. with slim.arg_scope([slim.ops.conv2d],
  27. stddev=0.1,
  28. activation=tf.nn.relu,
  29. batch_norm_params=batch_norm_params):
  30. # Force all Variables to reside on the CPU.
  31. with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
  32. logits, endpoints = slim.inception.inception_v3(
  33. images,
  34. dropout_keep_prob=0.8,
  35. num_classes=num_classes,
  36. is_training=for_training,
  37. restore_logits=restore_logits,
  38. scope=scope)
  39. """
  40. from __future__ import absolute_import
  41. from __future__ import division
  42. from __future__ import print_function
  43. import tensorflow as tf
  44. from inception.slim import ops
  45. from inception.slim import scopes
  46. def inception_v3(inputs,
  47. dropout_keep_prob=0.8,
  48. num_classes=1000,
  49. is_training=True,
  50. restore_logits=True,
  51. scope=''):
  52. """Latest Inception from http://arxiv.org/abs/1512.00567.
  53. "Rethinking the Inception Architecture for Computer Vision"
  54. Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens,
  55. Zbigniew Wojna
  56. Args:
  57. inputs: a tensor of size [batch_size, height, width, channels].
  58. dropout_keep_prob: dropout keep_prob.
  59. num_classes: number of predicted classes.
  60. is_training: whether is training or not.
  61. restore_logits: whether or not the logits layers should be restored.
  62. Useful for fine-tuning a model with different num_classes.
  63. scope: Optional scope for op_scope.
  64. Returns:
  65. a list containing 'logits', 'aux_logits' Tensors.
  66. """
  67. # end_points will collect relevant activations for external use, for example
  68. # summaries or losses.
  69. end_points = {}
  70. with tf.op_scope([inputs], scope, 'inception_v3'):
  71. with scopes.arg_scope([ops.conv2d, ops.fc, ops.batch_norm, ops.dropout],
  72. is_training=is_training):
  73. with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
  74. stride=1, padding='VALID'):
  75. # 299 x 299 x 3
  76. end_points['conv0'] = ops.conv2d(inputs, 32, [3, 3], stride=2,
  77. scope='conv0')
  78. # 149 x 149 x 32
  79. end_points['conv1'] = ops.conv2d(end_points['conv0'], 32, [3, 3],
  80. scope='conv1')
  81. # 147 x 147 x 32
  82. end_points['conv2'] = ops.conv2d(end_points['conv1'], 64, [3, 3],
  83. padding='SAME', scope='conv2')
  84. # 147 x 147 x 64
  85. end_points['pool1'] = ops.max_pool(end_points['conv2'], [3, 3],
  86. stride=2, scope='pool1')
  87. # 73 x 73 x 64
  88. end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
  89. scope='conv3')
  90. # 71 x 71 x 80.
  91. end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
  92. scope='conv4')
  93. # 69 x 69 x 192.
  94. end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
  95. stride=2, scope='pool2')
  96. # 35 x 35 x 192.
  97. net = end_points['pool2']
  98. # Inception blocks
  99. with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
  100. stride=1, padding='SAME'):
  101. # mixed: 35 x 35 x 256.
  102. with tf.variable_scope('mixed_35x35x256a'):
  103. with tf.variable_scope('branch1x1'):
  104. branch1x1 = ops.conv2d(net, 64, [1, 1])
  105. with tf.variable_scope('branch5x5'):
  106. branch5x5 = ops.conv2d(net, 48, [1, 1])
  107. branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
  108. with tf.variable_scope('branch3x3dbl'):
  109. branch3x3dbl = ops.conv2d(net, 64, [1, 1])
  110. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  111. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  112. with tf.variable_scope('branch_pool'):
  113. branch_pool = ops.avg_pool(net, [3, 3])
  114. branch_pool = ops.conv2d(branch_pool, 32, [1, 1])
  115. net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
  116. end_points['mixed_35x35x256a'] = net
  117. # mixed_1: 35 x 35 x 288.
  118. with tf.variable_scope('mixed_35x35x288a'):
  119. with tf.variable_scope('branch1x1'):
  120. branch1x1 = ops.conv2d(net, 64, [1, 1])
  121. with tf.variable_scope('branch5x5'):
  122. branch5x5 = ops.conv2d(net, 48, [1, 1])
  123. branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
  124. with tf.variable_scope('branch3x3dbl'):
  125. branch3x3dbl = ops.conv2d(net, 64, [1, 1])
  126. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  127. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  128. with tf.variable_scope('branch_pool'):
  129. branch_pool = ops.avg_pool(net, [3, 3])
  130. branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
  131. net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
  132. end_points['mixed_35x35x288a'] = net
  133. # mixed_2: 35 x 35 x 288.
  134. with tf.variable_scope('mixed_35x35x288b'):
  135. with tf.variable_scope('branch1x1'):
  136. branch1x1 = ops.conv2d(net, 64, [1, 1])
  137. with tf.variable_scope('branch5x5'):
  138. branch5x5 = ops.conv2d(net, 48, [1, 1])
  139. branch5x5 = ops.conv2d(branch5x5, 64, [5, 5])
  140. with tf.variable_scope('branch3x3dbl'):
  141. branch3x3dbl = ops.conv2d(net, 64, [1, 1])
  142. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  143. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  144. with tf.variable_scope('branch_pool'):
  145. branch_pool = ops.avg_pool(net, [3, 3])
  146. branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
  147. net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
  148. end_points['mixed_35x35x288b'] = net
  149. # mixed_3: 17 x 17 x 768.
  150. with tf.variable_scope('mixed_17x17x768a'):
  151. with tf.variable_scope('branch3x3'):
  152. branch3x3 = ops.conv2d(net, 384, [3, 3], stride=2, padding='VALID')
  153. with tf.variable_scope('branch3x3dbl'):
  154. branch3x3dbl = ops.conv2d(net, 64, [1, 1])
  155. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3])
  156. branch3x3dbl = ops.conv2d(branch3x3dbl, 96, [3, 3],
  157. stride=2, padding='VALID')
  158. with tf.variable_scope('branch_pool'):
  159. branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
  160. net = tf.concat(3, [branch3x3, branch3x3dbl, branch_pool])
  161. end_points['mixed_17x17x768a'] = net
  162. # mixed4: 17 x 17 x 768.
  163. with tf.variable_scope('mixed_17x17x768b'):
  164. with tf.variable_scope('branch1x1'):
  165. branch1x1 = ops.conv2d(net, 192, [1, 1])
  166. with tf.variable_scope('branch7x7'):
  167. branch7x7 = ops.conv2d(net, 128, [1, 1])
  168. branch7x7 = ops.conv2d(branch7x7, 128, [1, 7])
  169. branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
  170. with tf.variable_scope('branch7x7dbl'):
  171. branch7x7dbl = ops.conv2d(net, 128, [1, 1])
  172. branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [7, 1])
  173. branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [1, 7])
  174. branch7x7dbl = ops.conv2d(branch7x7dbl, 128, [7, 1])
  175. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
  176. with tf.variable_scope('branch_pool'):
  177. branch_pool = ops.avg_pool(net, [3, 3])
  178. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  179. net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
  180. end_points['mixed_17x17x768b'] = net
  181. # mixed_5: 17 x 17 x 768.
  182. with tf.variable_scope('mixed_17x17x768c'):
  183. with tf.variable_scope('branch1x1'):
  184. branch1x1 = ops.conv2d(net, 192, [1, 1])
  185. with tf.variable_scope('branch7x7'):
  186. branch7x7 = ops.conv2d(net, 160, [1, 1])
  187. branch7x7 = ops.conv2d(branch7x7, 160, [1, 7])
  188. branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
  189. with tf.variable_scope('branch7x7dbl'):
  190. branch7x7dbl = ops.conv2d(net, 160, [1, 1])
  191. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
  192. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [1, 7])
  193. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
  194. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
  195. with tf.variable_scope('branch_pool'):
  196. branch_pool = ops.avg_pool(net, [3, 3])
  197. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  198. net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
  199. end_points['mixed_17x17x768c'] = net
  200. # mixed_6: 17 x 17 x 768.
  201. with tf.variable_scope('mixed_17x17x768d'):
  202. with tf.variable_scope('branch1x1'):
  203. branch1x1 = ops.conv2d(net, 192, [1, 1])
  204. with tf.variable_scope('branch7x7'):
  205. branch7x7 = ops.conv2d(net, 160, [1, 1])
  206. branch7x7 = ops.conv2d(branch7x7, 160, [1, 7])
  207. branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
  208. with tf.variable_scope('branch7x7dbl'):
  209. branch7x7dbl = ops.conv2d(net, 160, [1, 1])
  210. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
  211. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [1, 7])
  212. branch7x7dbl = ops.conv2d(branch7x7dbl, 160, [7, 1])
  213. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
  214. with tf.variable_scope('branch_pool'):
  215. branch_pool = ops.avg_pool(net, [3, 3])
  216. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  217. net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
  218. end_points['mixed_17x17x768d'] = net
  219. # mixed_7: 17 x 17 x 768.
  220. with tf.variable_scope('mixed_17x17x768e'):
  221. with tf.variable_scope('branch1x1'):
  222. branch1x1 = ops.conv2d(net, 192, [1, 1])
  223. with tf.variable_scope('branch7x7'):
  224. branch7x7 = ops.conv2d(net, 192, [1, 1])
  225. branch7x7 = ops.conv2d(branch7x7, 192, [1, 7])
  226. branch7x7 = ops.conv2d(branch7x7, 192, [7, 1])
  227. with tf.variable_scope('branch7x7dbl'):
  228. branch7x7dbl = ops.conv2d(net, 192, [1, 1])
  229. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [7, 1])
  230. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
  231. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [7, 1])
  232. branch7x7dbl = ops.conv2d(branch7x7dbl, 192, [1, 7])
  233. with tf.variable_scope('branch_pool'):
  234. branch_pool = ops.avg_pool(net, [3, 3])
  235. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  236. net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
  237. end_points['mixed_17x17x768e'] = net
  238. # Auxiliary Head logits
  239. aux_logits = tf.identity(end_points['mixed_17x17x768e'])
  240. with tf.variable_scope('aux_logits'):
  241. aux_logits = ops.avg_pool(aux_logits, [5, 5], stride=3,
  242. padding='VALID')
  243. aux_logits = ops.conv2d(aux_logits, 128, [1, 1], scope='proj')
  244. # Shape of feature map before the final layer.
  245. shape = aux_logits.get_shape()
  246. aux_logits = ops.conv2d(aux_logits, 768, shape[1:3], stddev=0.01,
  247. padding='VALID')
  248. aux_logits = ops.flatten(aux_logits)
  249. aux_logits = ops.fc(aux_logits, num_classes, activation=None,
  250. stddev=0.001, restore=restore_logits)
  251. end_points['aux_logits'] = aux_logits
  252. # mixed_8: 17 x 17 x 1280.
  253. with tf.variable_scope('mixed_17x17x1280a'):
  254. with tf.variable_scope('branch3x3'):
  255. branch3x3 = ops.conv2d(net, 192, [1, 1])
  256. branch3x3 = ops.conv2d(branch3x3, 320, [3, 3], stride=2,
  257. padding='VALID')
  258. with tf.variable_scope('branch7x7x3'):
  259. branch7x7x3 = ops.conv2d(net, 192, [1, 1])
  260. branch7x7x3 = ops.conv2d(branch7x7x3, 192, [1, 7])
  261. branch7x7x3 = ops.conv2d(branch7x7x3, 192, [7, 1])
  262. branch7x7x3 = ops.conv2d(branch7x7x3, 192, [3, 3],
  263. stride=2, padding='VALID')
  264. with tf.variable_scope('branch_pool'):
  265. branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
  266. net = tf.concat(3, [branch3x3, branch7x7x3, branch_pool])
  267. end_points['mixed_17x17x1280a'] = net
  268. # mixed_9: 8 x 8 x 2048.
  269. with tf.variable_scope('mixed_8x8x2048a'):
  270. with tf.variable_scope('branch1x1'):
  271. branch1x1 = ops.conv2d(net, 320, [1, 1])
  272. with tf.variable_scope('branch3x3'):
  273. branch3x3 = ops.conv2d(net, 384, [1, 1])
  274. branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
  275. ops.conv2d(branch3x3, 384, [3, 1])])
  276. with tf.variable_scope('branch3x3dbl'):
  277. branch3x3dbl = ops.conv2d(net, 448, [1, 1])
  278. branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
  279. branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
  280. ops.conv2d(branch3x3dbl, 384, [3, 1])])
  281. with tf.variable_scope('branch_pool'):
  282. branch_pool = ops.avg_pool(net, [3, 3])
  283. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  284. net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
  285. end_points['mixed_8x8x2048a'] = net
  286. # mixed_10: 8 x 8 x 2048.
  287. with tf.variable_scope('mixed_8x8x2048b'):
  288. with tf.variable_scope('branch1x1'):
  289. branch1x1 = ops.conv2d(net, 320, [1, 1])
  290. with tf.variable_scope('branch3x3'):
  291. branch3x3 = ops.conv2d(net, 384, [1, 1])
  292. branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
  293. ops.conv2d(branch3x3, 384, [3, 1])])
  294. with tf.variable_scope('branch3x3dbl'):
  295. branch3x3dbl = ops.conv2d(net, 448, [1, 1])
  296. branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
  297. branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
  298. ops.conv2d(branch3x3dbl, 384, [3, 1])])
  299. with tf.variable_scope('branch_pool'):
  300. branch_pool = ops.avg_pool(net, [3, 3])
  301. branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
  302. net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
  303. end_points['mixed_8x8x2048b'] = net
  304. # Final pooling and prediction
  305. with tf.variable_scope('logits'):
  306. shape = net.get_shape()
  307. net = ops.avg_pool(net, shape[1:3], padding='VALID', scope='pool')
  308. # 1 x 1 x 2048
  309. net = ops.dropout(net, dropout_keep_prob, scope='dropout')
  310. net = ops.flatten(net, scope='flatten')
  311. # 2048
  312. logits = ops.fc(net, num_classes, activation=None, scope='logits',
  313. restore=restore_logits)
  314. # 1000
  315. end_points['logits'] = logits
  316. end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
  317. return logits, end_points