alexnet_test.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.nets.alexnet."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from nets import alexnet
  21. slim = tf.contrib.slim
  22. class AlexnetV2Test(tf.test.TestCase):
  23. def testBuild(self):
  24. batch_size = 5
  25. height, width = 224, 224
  26. num_classes = 1000
  27. with self.test_session():
  28. inputs = tf.random_uniform((batch_size, height, width, 3))
  29. logits, _ = alexnet.alexnet_v2(inputs, num_classes)
  30. self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
  31. self.assertListEqual(logits.get_shape().as_list(),
  32. [batch_size, num_classes])
  33. def testFullyConvolutional(self):
  34. batch_size = 1
  35. height, width = 300, 400
  36. num_classes = 1000
  37. with self.test_session():
  38. inputs = tf.random_uniform((batch_size, height, width, 3))
  39. logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
  40. self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
  41. self.assertListEqual(logits.get_shape().as_list(),
  42. [batch_size, 4, 7, num_classes])
  43. def testEndPoints(self):
  44. batch_size = 5
  45. height, width = 224, 224
  46. num_classes = 1000
  47. with self.test_session():
  48. inputs = tf.random_uniform((batch_size, height, width, 3))
  49. _, end_points = alexnet.alexnet_v2(inputs, num_classes)
  50. expected_names = ['alexnet_v2/conv1',
  51. 'alexnet_v2/pool1',
  52. 'alexnet_v2/conv2',
  53. 'alexnet_v2/pool2',
  54. 'alexnet_v2/conv3',
  55. 'alexnet_v2/conv4',
  56. 'alexnet_v2/conv5',
  57. 'alexnet_v2/pool5',
  58. 'alexnet_v2/fc6',
  59. 'alexnet_v2/fc7',
  60. 'alexnet_v2/fc8'
  61. ]
  62. self.assertSetEqual(set(end_points.keys()), set(expected_names))
  63. def testModelVariables(self):
  64. batch_size = 5
  65. height, width = 224, 224
  66. num_classes = 1000
  67. with self.test_session():
  68. inputs = tf.random_uniform((batch_size, height, width, 3))
  69. alexnet.alexnet_v2(inputs, num_classes)
  70. expected_names = ['alexnet_v2/conv1/weights',
  71. 'alexnet_v2/conv1/biases',
  72. 'alexnet_v2/conv2/weights',
  73. 'alexnet_v2/conv2/biases',
  74. 'alexnet_v2/conv3/weights',
  75. 'alexnet_v2/conv3/biases',
  76. 'alexnet_v2/conv4/weights',
  77. 'alexnet_v2/conv4/biases',
  78. 'alexnet_v2/conv5/weights',
  79. 'alexnet_v2/conv5/biases',
  80. 'alexnet_v2/fc6/weights',
  81. 'alexnet_v2/fc6/biases',
  82. 'alexnet_v2/fc7/weights',
  83. 'alexnet_v2/fc7/biases',
  84. 'alexnet_v2/fc8/weights',
  85. 'alexnet_v2/fc8/biases',
  86. ]
  87. model_variables = [v.op.name for v in slim.get_model_variables()]
  88. self.assertSetEqual(set(model_variables), set(expected_names))
  89. def testEvaluation(self):
  90. batch_size = 2
  91. height, width = 224, 224
  92. num_classes = 1000
  93. with self.test_session():
  94. eval_inputs = tf.random_uniform((batch_size, height, width, 3))
  95. logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
  96. self.assertListEqual(logits.get_shape().as_list(),
  97. [batch_size, num_classes])
  98. predictions = tf.argmax(logits, 1)
  99. self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
  100. def testTrainEvalWithReuse(self):
  101. train_batch_size = 2
  102. eval_batch_size = 1
  103. train_height, train_width = 224, 224
  104. eval_height, eval_width = 300, 400
  105. num_classes = 1000
  106. with self.test_session():
  107. train_inputs = tf.random_uniform(
  108. (train_batch_size, train_height, train_width, 3))
  109. logits, _ = alexnet.alexnet_v2(train_inputs)
  110. self.assertListEqual(logits.get_shape().as_list(),
  111. [train_batch_size, num_classes])
  112. tf.get_variable_scope().reuse_variables()
  113. eval_inputs = tf.random_uniform(
  114. (eval_batch_size, eval_height, eval_width, 3))
  115. logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
  116. spatial_squeeze=False)
  117. self.assertListEqual(logits.get_shape().as_list(),
  118. [eval_batch_size, 4, 7, num_classes])
  119. logits = tf.reduce_mean(logits, [1, 2])
  120. predictions = tf.argmax(logits, 1)
  121. self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
  122. def testForward(self):
  123. batch_size = 1
  124. height, width = 224, 224
  125. with self.test_session() as sess:
  126. inputs = tf.random_uniform((batch_size, height, width, 3))
  127. logits, _ = alexnet.alexnet_v2(inputs)
  128. sess.run(tf.global_variables_initializer())
  129. output = sess.run(logits)
  130. self.assertTrue(output.any())
  131. if __name__ == '__main__':
  132. tf.test.main()