resnet_v1_test.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.resnet_v1."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from nets import resnet_utils
  22. from nets import resnet_v1
  23. slim = tf.contrib.slim
  24. def create_test_input(batch_size, height, width, channels):
  25. """Create test input tensor.
  26. Args:
  27. batch_size: The number of images per batch or `None` if unknown.
  28. height: The height of each image or `None` if unknown.
  29. width: The width of each image or `None` if unknown.
  30. channels: The number of channels per image or `None` if unknown.
  31. Returns:
  32. Either a placeholder `Tensor` of dimension
  33. [batch_size, height, width, channels] if any of the inputs are `None` or a
  34. constant `Tensor` with the mesh grid values along the spatial dimensions.
  35. """
  36. if None in [batch_size, height, width, channels]:
  37. return tf.placeholder(tf.float32, (batch_size, height, width, channels))
  38. else:
  39. return tf.to_float(
  40. np.tile(
  41. np.reshape(
  42. np.reshape(np.arange(height), [height, 1]) +
  43. np.reshape(np.arange(width), [1, width]),
  44. [1, height, width, 1]),
  45. [batch_size, 1, 1, channels]))
  46. class ResnetUtilsTest(tf.test.TestCase):
  47. def testSubsampleThreeByThree(self):
  48. x = tf.reshape(tf.to_float(tf.range(9)), [1, 3, 3, 1])
  49. x = resnet_utils.subsample(x, 2)
  50. expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
  51. with self.test_session():
  52. self.assertAllClose(x.eval(), expected.eval())
  53. def testSubsampleFourByFour(self):
  54. x = tf.reshape(tf.to_float(tf.range(16)), [1, 4, 4, 1])
  55. x = resnet_utils.subsample(x, 2)
  56. expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
  57. with self.test_session():
  58. self.assertAllClose(x.eval(), expected.eval())
  59. def testConv2DSameEven(self):
  60. n, n2 = 4, 2
  61. # Input image.
  62. x = create_test_input(1, n, n, 1)
  63. # Convolution kernel.
  64. w = create_test_input(1, 3, 3, 1)
  65. w = tf.reshape(w, [3, 3, 1, 1])
  66. tf.get_variable('Conv/weights', initializer=w)
  67. tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
  68. tf.get_variable_scope().reuse_variables()
  69. y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
  70. y1_expected = tf.to_float([[14, 28, 43, 26],
  71. [28, 48, 66, 37],
  72. [43, 66, 84, 46],
  73. [26, 37, 46, 22]])
  74. y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
  75. y2 = resnet_utils.subsample(y1, 2)
  76. y2_expected = tf.to_float([[14, 43],
  77. [43, 84]])
  78. y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
  79. y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
  80. y3_expected = y2_expected
  81. y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
  82. y4_expected = tf.to_float([[48, 37],
  83. [37, 22]])
  84. y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
  85. with self.test_session() as sess:
  86. sess.run(tf.global_variables_initializer())
  87. self.assertAllClose(y1.eval(), y1_expected.eval())
  88. self.assertAllClose(y2.eval(), y2_expected.eval())
  89. self.assertAllClose(y3.eval(), y3_expected.eval())
  90. self.assertAllClose(y4.eval(), y4_expected.eval())
  91. def testConv2DSameOdd(self):
  92. n, n2 = 5, 3
  93. # Input image.
  94. x = create_test_input(1, n, n, 1)
  95. # Convolution kernel.
  96. w = create_test_input(1, 3, 3, 1)
  97. w = tf.reshape(w, [3, 3, 1, 1])
  98. tf.get_variable('Conv/weights', initializer=w)
  99. tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
  100. tf.get_variable_scope().reuse_variables()
  101. y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
  102. y1_expected = tf.to_float([[14, 28, 43, 58, 34],
  103. [28, 48, 66, 84, 46],
  104. [43, 66, 84, 102, 55],
  105. [58, 84, 102, 120, 64],
  106. [34, 46, 55, 64, 30]])
  107. y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
  108. y2 = resnet_utils.subsample(y1, 2)
  109. y2_expected = tf.to_float([[14, 43, 34],
  110. [43, 84, 55],
  111. [34, 55, 30]])
  112. y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
  113. y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
  114. y3_expected = y2_expected
  115. y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
  116. y4_expected = y2_expected
  117. with self.test_session() as sess:
  118. sess.run(tf.global_variables_initializer())
  119. self.assertAllClose(y1.eval(), y1_expected.eval())
  120. self.assertAllClose(y2.eval(), y2_expected.eval())
  121. self.assertAllClose(y3.eval(), y3_expected.eval())
  122. self.assertAllClose(y4.eval(), y4_expected.eval())
  123. def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
  124. """A plain ResNet without extra layers before or after the ResNet blocks."""
  125. with tf.variable_scope(scope, values=[inputs]):
  126. with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
  127. net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
  128. end_points = dict(tf.get_collection('end_points'))
  129. return net, end_points
  130. def testEndPointsV1(self):
  131. """Test the end points of a tiny v1 bottleneck network."""
  132. bottleneck = resnet_v1.bottleneck
  133. blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
  134. resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
  135. inputs = create_test_input(2, 32, 16, 3)
  136. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  137. _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
  138. expected = [
  139. 'tiny/block1/unit_1/bottleneck_v1/shortcut',
  140. 'tiny/block1/unit_1/bottleneck_v1/conv1',
  141. 'tiny/block1/unit_1/bottleneck_v1/conv2',
  142. 'tiny/block1/unit_1/bottleneck_v1/conv3',
  143. 'tiny/block1/unit_2/bottleneck_v1/conv1',
  144. 'tiny/block1/unit_2/bottleneck_v1/conv2',
  145. 'tiny/block1/unit_2/bottleneck_v1/conv3',
  146. 'tiny/block2/unit_1/bottleneck_v1/shortcut',
  147. 'tiny/block2/unit_1/bottleneck_v1/conv1',
  148. 'tiny/block2/unit_1/bottleneck_v1/conv2',
  149. 'tiny/block2/unit_1/bottleneck_v1/conv3',
  150. 'tiny/block2/unit_2/bottleneck_v1/conv1',
  151. 'tiny/block2/unit_2/bottleneck_v1/conv2',
  152. 'tiny/block2/unit_2/bottleneck_v1/conv3']
  153. self.assertItemsEqual(expected, end_points)
  154. def _stack_blocks_nondense(self, net, blocks):
  155. """A simplified ResNet Block stacker without output stride control."""
  156. for block in blocks:
  157. with tf.variable_scope(block.scope, 'block', [net]):
  158. for i, unit in enumerate(block.args):
  159. depth, depth_bottleneck, stride = unit
  160. with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
  161. net = block.unit_fn(net,
  162. depth=depth,
  163. depth_bottleneck=depth_bottleneck,
  164. stride=stride,
  165. rate=1)
  166. return net
  167. def _atrousValues(self, bottleneck):
  168. """Verify the values of dense feature extraction by atrous convolution.
  169. Make sure that dense feature extraction by stack_blocks_dense() followed by
  170. subsampling gives identical results to feature extraction at the nominal
  171. network output stride using the simple self._stack_blocks_nondense() above.
  172. Args:
  173. bottleneck: The bottleneck function.
  174. """
  175. blocks = [
  176. resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
  177. resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
  178. resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
  179. resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
  180. ]
  181. nominal_stride = 8
  182. # Test both odd and even input dimensions.
  183. height = 30
  184. width = 31
  185. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  186. with slim.arg_scope([slim.batch_norm], is_training=False):
  187. for output_stride in [1, 2, 4, 8, None]:
  188. with tf.Graph().as_default():
  189. with self.test_session() as sess:
  190. tf.set_random_seed(0)
  191. inputs = create_test_input(1, height, width, 3)
  192. # Dense feature extraction followed by subsampling.
  193. output = resnet_utils.stack_blocks_dense(inputs,
  194. blocks,
  195. output_stride)
  196. if output_stride is None:
  197. factor = 1
  198. else:
  199. factor = nominal_stride // output_stride
  200. output = resnet_utils.subsample(output, factor)
  201. # Make the two networks use the same weights.
  202. tf.get_variable_scope().reuse_variables()
  203. # Feature extraction at the nominal network rate.
  204. expected = self._stack_blocks_nondense(inputs, blocks)
  205. sess.run(tf.global_variables_initializer())
  206. output, expected = sess.run([output, expected])
  207. self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
  208. def testAtrousValuesBottleneck(self):
  209. self._atrousValues(resnet_v1.bottleneck)
  210. class ResnetCompleteNetworkTest(tf.test.TestCase):
  211. """Tests with complete small ResNet v1 networks."""
  212. def _resnet_small(self,
  213. inputs,
  214. num_classes=None,
  215. is_training=True,
  216. global_pool=True,
  217. output_stride=None,
  218. include_root_block=True,
  219. reuse=None,
  220. scope='resnet_v1_small'):
  221. """A shallow and thin ResNet v1 for faster tests."""
  222. bottleneck = resnet_v1.bottleneck
  223. blocks = [
  224. resnet_utils.Block(
  225. 'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
  226. resnet_utils.Block(
  227. 'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
  228. resnet_utils.Block(
  229. 'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
  230. resnet_utils.Block(
  231. 'block4', bottleneck, [(32, 8, 1)] * 2)]
  232. return resnet_v1.resnet_v1(inputs, blocks, num_classes,
  233. is_training=is_training,
  234. global_pool=global_pool,
  235. output_stride=output_stride,
  236. include_root_block=include_root_block,
  237. reuse=reuse,
  238. scope=scope)
  239. def testClassificationEndPoints(self):
  240. global_pool = True
  241. num_classes = 10
  242. inputs = create_test_input(2, 224, 224, 3)
  243. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  244. logits, end_points = self._resnet_small(inputs, num_classes,
  245. global_pool=global_pool,
  246. scope='resnet')
  247. self.assertTrue(logits.op.name.startswith('resnet/logits'))
  248. self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
  249. self.assertTrue('predictions' in end_points)
  250. self.assertListEqual(end_points['predictions'].get_shape().as_list(),
  251. [2, 1, 1, num_classes])
  252. def testClassificationShapes(self):
  253. global_pool = True
  254. num_classes = 10
  255. inputs = create_test_input(2, 224, 224, 3)
  256. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  257. _, end_points = self._resnet_small(inputs, num_classes,
  258. global_pool=global_pool,
  259. scope='resnet')
  260. endpoint_to_shape = {
  261. 'resnet/block1': [2, 28, 28, 4],
  262. 'resnet/block2': [2, 14, 14, 8],
  263. 'resnet/block3': [2, 7, 7, 16],
  264. 'resnet/block4': [2, 7, 7, 32]}
  265. for endpoint in endpoint_to_shape:
  266. shape = endpoint_to_shape[endpoint]
  267. self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
  268. def testFullyConvolutionalEndpointShapes(self):
  269. global_pool = False
  270. num_classes = 10
  271. inputs = create_test_input(2, 321, 321, 3)
  272. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  273. _, end_points = self._resnet_small(inputs, num_classes,
  274. global_pool=global_pool,
  275. scope='resnet')
  276. endpoint_to_shape = {
  277. 'resnet/block1': [2, 41, 41, 4],
  278. 'resnet/block2': [2, 21, 21, 8],
  279. 'resnet/block3': [2, 11, 11, 16],
  280. 'resnet/block4': [2, 11, 11, 32]}
  281. for endpoint in endpoint_to_shape:
  282. shape = endpoint_to_shape[endpoint]
  283. self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
  284. def testRootlessFullyConvolutionalEndpointShapes(self):
  285. global_pool = False
  286. num_classes = 10
  287. inputs = create_test_input(2, 128, 128, 3)
  288. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  289. _, end_points = self._resnet_small(inputs, num_classes,
  290. global_pool=global_pool,
  291. include_root_block=False,
  292. scope='resnet')
  293. endpoint_to_shape = {
  294. 'resnet/block1': [2, 64, 64, 4],
  295. 'resnet/block2': [2, 32, 32, 8],
  296. 'resnet/block3': [2, 16, 16, 16],
  297. 'resnet/block4': [2, 16, 16, 32]}
  298. for endpoint in endpoint_to_shape:
  299. shape = endpoint_to_shape[endpoint]
  300. self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
  301. def testAtrousFullyConvolutionalEndpointShapes(self):
  302. global_pool = False
  303. num_classes = 10
  304. output_stride = 8
  305. inputs = create_test_input(2, 321, 321, 3)
  306. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  307. _, end_points = self._resnet_small(inputs,
  308. num_classes,
  309. global_pool=global_pool,
  310. output_stride=output_stride,
  311. scope='resnet')
  312. endpoint_to_shape = {
  313. 'resnet/block1': [2, 41, 41, 4],
  314. 'resnet/block2': [2, 41, 41, 8],
  315. 'resnet/block3': [2, 41, 41, 16],
  316. 'resnet/block4': [2, 41, 41, 32]}
  317. for endpoint in endpoint_to_shape:
  318. shape = endpoint_to_shape[endpoint]
  319. self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
  320. def testAtrousFullyConvolutionalValues(self):
  321. """Verify dense feature extraction with atrous convolution."""
  322. nominal_stride = 32
  323. for output_stride in [4, 8, 16, 32, None]:
  324. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  325. with tf.Graph().as_default():
  326. with self.test_session() as sess:
  327. tf.set_random_seed(0)
  328. inputs = create_test_input(2, 81, 81, 3)
  329. # Dense feature extraction followed by subsampling.
  330. output, _ = self._resnet_small(inputs, None, is_training=False,
  331. global_pool=False,
  332. output_stride=output_stride)
  333. if output_stride is None:
  334. factor = 1
  335. else:
  336. factor = nominal_stride // output_stride
  337. output = resnet_utils.subsample(output, factor)
  338. # Make the two networks use the same weights.
  339. tf.get_variable_scope().reuse_variables()
  340. # Feature extraction at the nominal network rate.
  341. expected, _ = self._resnet_small(inputs, None, is_training=False,
  342. global_pool=False)
  343. sess.run(tf.global_variables_initializer())
  344. self.assertAllClose(output.eval(), expected.eval(),
  345. atol=1e-4, rtol=1e-4)
  346. def testUnknownBatchSize(self):
  347. batch = 2
  348. height, width = 65, 65
  349. global_pool = True
  350. num_classes = 10
  351. inputs = create_test_input(None, height, width, 3)
  352. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  353. logits, _ = self._resnet_small(inputs, num_classes,
  354. global_pool=global_pool,
  355. scope='resnet')
  356. self.assertTrue(logits.op.name.startswith('resnet/logits'))
  357. self.assertListEqual(logits.get_shape().as_list(),
  358. [None, 1, 1, num_classes])
  359. images = create_test_input(batch, height, width, 3)
  360. with self.test_session() as sess:
  361. sess.run(tf.global_variables_initializer())
  362. output = sess.run(logits, {inputs: images.eval()})
  363. self.assertEqual(output.shape, (batch, 1, 1, num_classes))
  364. def testFullyConvolutionalUnknownHeightWidth(self):
  365. batch = 2
  366. height, width = 65, 65
  367. global_pool = False
  368. inputs = create_test_input(batch, None, None, 3)
  369. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  370. output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
  371. self.assertListEqual(output.get_shape().as_list(),
  372. [batch, None, None, 32])
  373. images = create_test_input(batch, height, width, 3)
  374. with self.test_session() as sess:
  375. sess.run(tf.global_variables_initializer())
  376. output = sess.run(output, {inputs: images.eval()})
  377. self.assertEqual(output.shape, (batch, 3, 3, 32))
  378. def testAtrousFullyConvolutionalUnknownHeightWidth(self):
  379. batch = 2
  380. height, width = 65, 65
  381. global_pool = False
  382. output_stride = 8
  383. inputs = create_test_input(batch, None, None, 3)
  384. with slim.arg_scope(resnet_utils.resnet_arg_scope()):
  385. output, _ = self._resnet_small(inputs,
  386. None,
  387. global_pool=global_pool,
  388. output_stride=output_stride)
  389. self.assertListEqual(output.get_shape().as_list(),
  390. [batch, None, None, 32])
  391. images = create_test_input(batch, height, width, 3)
  392. with self.test_session() as sess:
  393. sess.run(tf.global_variables_initializer())
  394. output = sess.run(output, {inputs: images.eval()})
  395. self.assertEqual(output.shape, (batch, 9, 9, 32))
  396. if __name__ == '__main__':
  397. tf.test.main()