inception_v3_test.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for nets.inception_v1."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from nets import inception
  22. slim = tf.contrib.slim
  23. class InceptionV3Test(tf.test.TestCase):
  24. def testBuildClassificationNetwork(self):
  25. batch_size = 5
  26. height, width = 299, 299
  27. num_classes = 1000
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, end_points = inception.inception_v3(inputs, num_classes)
  30. self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
  31. self.assertListEqual(logits.get_shape().as_list(),
  32. [batch_size, num_classes])
  33. self.assertTrue('Predictions' in end_points)
  34. self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
  35. [batch_size, num_classes])
  36. def testBuildBaseNetwork(self):
  37. batch_size = 5
  38. height, width = 299, 299
  39. inputs = tf.random_uniform((batch_size, height, width, 3))
  40. final_endpoint, end_points = inception.inception_v3_base(inputs)
  41. self.assertTrue(final_endpoint.op.name.startswith(
  42. 'InceptionV3/Mixed_7c'))
  43. self.assertListEqual(final_endpoint.get_shape().as_list(),
  44. [batch_size, 8, 8, 2048])
  45. expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
  46. 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
  47. 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
  48. 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
  49. 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
  50. self.assertItemsEqual(end_points.keys(), expected_endpoints)
  51. def testBuildOnlyUptoFinalEndpoint(self):
  52. batch_size = 5
  53. height, width = 299, 299
  54. endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
  55. 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
  56. 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
  57. 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
  58. 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
  59. for index, endpoint in enumerate(endpoints):
  60. with tf.Graph().as_default():
  61. inputs = tf.random_uniform((batch_size, height, width, 3))
  62. out_tensor, end_points = inception.inception_v3_base(
  63. inputs, final_endpoint=endpoint)
  64. self.assertTrue(out_tensor.op.name.startswith(
  65. 'InceptionV3/' + endpoint))
  66. self.assertItemsEqual(endpoints[:index+1], end_points)
  67. def testBuildAndCheckAllEndPointsUptoMixed7c(self):
  68. batch_size = 5
  69. height, width = 299, 299
  70. inputs = tf.random_uniform((batch_size, height, width, 3))
  71. _, end_points = inception.inception_v3_base(
  72. inputs, final_endpoint='Mixed_7c')
  73. endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
  74. 'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
  75. 'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
  76. 'MaxPool_3a_3x3': [batch_size, 73, 73, 64],
  77. 'Conv2d_3b_1x1': [batch_size, 73, 73, 80],
  78. 'Conv2d_4a_3x3': [batch_size, 71, 71, 192],
  79. 'MaxPool_5a_3x3': [batch_size, 35, 35, 192],
  80. 'Mixed_5b': [batch_size, 35, 35, 256],
  81. 'Mixed_5c': [batch_size, 35, 35, 288],
  82. 'Mixed_5d': [batch_size, 35, 35, 288],
  83. 'Mixed_6a': [batch_size, 17, 17, 768],
  84. 'Mixed_6b': [batch_size, 17, 17, 768],
  85. 'Mixed_6c': [batch_size, 17, 17, 768],
  86. 'Mixed_6d': [batch_size, 17, 17, 768],
  87. 'Mixed_6e': [batch_size, 17, 17, 768],
  88. 'Mixed_7a': [batch_size, 8, 8, 1280],
  89. 'Mixed_7b': [batch_size, 8, 8, 2048],
  90. 'Mixed_7c': [batch_size, 8, 8, 2048]}
  91. self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
  92. for endpoint_name in endpoints_shapes:
  93. expected_shape = endpoints_shapes[endpoint_name]
  94. self.assertTrue(endpoint_name in end_points)
  95. self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
  96. expected_shape)
  97. def testModelHasExpectedNumberOfParameters(self):
  98. batch_size = 5
  99. height, width = 299, 299
  100. inputs = tf.random_uniform((batch_size, height, width, 3))
  101. with slim.arg_scope(inception.inception_v3_arg_scope()):
  102. inception.inception_v3_base(inputs)
  103. total_params, _ = slim.model_analyzer.analyze_vars(
  104. slim.get_model_variables())
  105. self.assertAlmostEqual(21802784, total_params)
  106. def testBuildEndPoints(self):
  107. batch_size = 5
  108. height, width = 299, 299
  109. num_classes = 1000
  110. inputs = tf.random_uniform((batch_size, height, width, 3))
  111. _, end_points = inception.inception_v3(inputs, num_classes)
  112. self.assertTrue('Logits' in end_points)
  113. logits = end_points['Logits']
  114. self.assertListEqual(logits.get_shape().as_list(),
  115. [batch_size, num_classes])
  116. self.assertTrue('AuxLogits' in end_points)
  117. aux_logits = end_points['AuxLogits']
  118. self.assertListEqual(aux_logits.get_shape().as_list(),
  119. [batch_size, num_classes])
  120. self.assertTrue('Mixed_7c' in end_points)
  121. pre_pool = end_points['Mixed_7c']
  122. self.assertListEqual(pre_pool.get_shape().as_list(),
  123. [batch_size, 8, 8, 2048])
  124. self.assertTrue('PreLogits' in end_points)
  125. pre_logits = end_points['PreLogits']
  126. self.assertListEqual(pre_logits.get_shape().as_list(),
  127. [batch_size, 1, 1, 2048])
  128. def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
  129. batch_size = 5
  130. height, width = 299, 299
  131. num_classes = 1000
  132. inputs = tf.random_uniform((batch_size, height, width, 3))
  133. _, end_points = inception.inception_v3(inputs, num_classes)
  134. endpoint_keys = [key for key in end_points.keys()
  135. if key.startswith('Mixed') or key.startswith('Conv')]
  136. _, end_points_with_multiplier = inception.inception_v3(
  137. inputs, num_classes, scope='depth_multiplied_net',
  138. depth_multiplier=0.5)
  139. for key in endpoint_keys:
  140. original_depth = end_points[key].get_shape().as_list()[3]
  141. new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
  142. self.assertEqual(0.5 * original_depth, new_depth)
  143. def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
  144. batch_size = 5
  145. height, width = 299, 299
  146. num_classes = 1000
  147. inputs = tf.random_uniform((batch_size, height, width, 3))
  148. _, end_points = inception.inception_v3(inputs, num_classes)
  149. endpoint_keys = [key for key in end_points.keys()
  150. if key.startswith('Mixed') or key.startswith('Conv')]
  151. _, end_points_with_multiplier = inception.inception_v3(
  152. inputs, num_classes, scope='depth_multiplied_net',
  153. depth_multiplier=2.0)
  154. for key in endpoint_keys:
  155. original_depth = end_points[key].get_shape().as_list()[3]
  156. new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
  157. self.assertEqual(2.0 * original_depth, new_depth)
  158. def testRaiseValueErrorWithInvalidDepthMultiplier(self):
  159. batch_size = 5
  160. height, width = 299, 299
  161. num_classes = 1000
  162. inputs = tf.random_uniform((batch_size, height, width, 3))
  163. with self.assertRaises(ValueError):
  164. _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1)
  165. with self.assertRaises(ValueError):
  166. _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0)
  167. def testHalfSizeImages(self):
  168. batch_size = 5
  169. height, width = 150, 150
  170. num_classes = 1000
  171. inputs = tf.random_uniform((batch_size, height, width, 3))
  172. logits, end_points = inception.inception_v3(inputs, num_classes)
  173. self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
  174. self.assertListEqual(logits.get_shape().as_list(),
  175. [batch_size, num_classes])
  176. pre_pool = end_points['Mixed_7c']
  177. self.assertListEqual(pre_pool.get_shape().as_list(),
  178. [batch_size, 3, 3, 2048])
  179. def testUnknownImageShape(self):
  180. tf.reset_default_graph()
  181. batch_size = 2
  182. height, width = 299, 299
  183. num_classes = 1000
  184. input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
  185. with self.test_session() as sess:
  186. inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
  187. logits, end_points = inception.inception_v3(inputs, num_classes)
  188. self.assertListEqual(logits.get_shape().as_list(),
  189. [batch_size, num_classes])
  190. pre_pool = end_points['Mixed_7c']
  191. feed_dict = {inputs: input_np}
  192. tf.initialize_all_variables().run()
  193. pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
  194. self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
  195. def testUnknowBatchSize(self):
  196. batch_size = 1
  197. height, width = 299, 299
  198. num_classes = 1000
  199. inputs = tf.placeholder(tf.float32, (None, height, width, 3))
  200. logits, _ = inception.inception_v3(inputs, num_classes)
  201. self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
  202. self.assertListEqual(logits.get_shape().as_list(),
  203. [None, num_classes])
  204. images = tf.random_uniform((batch_size, height, width, 3))
  205. with self.test_session() as sess:
  206. sess.run(tf.initialize_all_variables())
  207. output = sess.run(logits, {inputs: images.eval()})
  208. self.assertEquals(output.shape, (batch_size, num_classes))
  209. def testEvaluation(self):
  210. batch_size = 2
  211. height, width = 299, 299
  212. num_classes = 1000
  213. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  214. logits, _ = inception.inception_v3(eval_inputs, num_classes,
  215. is_training=False)
  216. predictions = tf.argmax(logits, 1)
  217. with self.test_session() as sess:
  218. sess.run(tf.initialize_all_variables())
  219. output = sess.run(predictions)
  220. self.assertEquals(output.shape, (batch_size,))
  221. def testTrainEvalWithReuse(self):
  222. train_batch_size = 5
  223. eval_batch_size = 2
  224. height, width = 150, 150
  225. num_classes = 1000
  226. train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
  227. inception.inception_v3(train_inputs, num_classes)
  228. eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
  229. logits, _ = inception.inception_v3(eval_inputs, num_classes,
  230. is_training=False, reuse=True)
  231. predictions = tf.argmax(logits, 1)
  232. with self.test_session() as sess:
  233. sess.run(tf.initialize_all_variables())
  234. output = sess.run(predictions)
  235. self.assertEquals(output.shape, (eval_batch_size,))
  236. def testLogitsNotSqueezed(self):
  237. num_classes = 25
  238. images = tf.random_uniform([1, 299, 299, 3])
  239. logits, _ = inception.inception_v3(images,
  240. num_classes=num_classes,
  241. spatial_squeeze=False)
  242. with self.test_session() as sess:
  243. tf.initialize_all_variables().run()
  244. logits_out = sess.run(logits)
  245. self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
  246. if __name__ == '__main__':
  247. tf.test.main()