collections_test.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for inception."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from inception.slim import slim
  21. def get_variables(scope=None):
  22. return slim.variables.get_variables(scope)
  23. def get_variables_by_name(name):
  24. return slim.variables.get_variables_by_name(name)
  25. class CollectionsTest(tf.test.TestCase):
  26. def testVariables(self):
  27. batch_size = 5
  28. height, width = 299, 299
  29. with self.test_session():
  30. inputs = tf.random_uniform((batch_size, height, width, 3))
  31. with slim.arg_scope([slim.ops.conv2d],
  32. batch_norm_params={'decay': 0.9997}):
  33. slim.inception.inception_v3(inputs)
  34. self.assertEqual(len(get_variables()), 388)
  35. self.assertEqual(len(get_variables_by_name('weights')), 98)
  36. self.assertEqual(len(get_variables_by_name('biases')), 2)
  37. self.assertEqual(len(get_variables_by_name('beta')), 96)
  38. self.assertEqual(len(get_variables_by_name('gamma')), 0)
  39. self.assertEqual(len(get_variables_by_name('moving_mean')), 96)
  40. self.assertEqual(len(get_variables_by_name('moving_variance')), 96)
  41. def testVariablesWithoutBatchNorm(self):
  42. batch_size = 5
  43. height, width = 299, 299
  44. with self.test_session():
  45. inputs = tf.random_uniform((batch_size, height, width, 3))
  46. with slim.arg_scope([slim.ops.conv2d],
  47. batch_norm_params=None):
  48. slim.inception.inception_v3(inputs)
  49. self.assertEqual(len(get_variables()), 196)
  50. self.assertEqual(len(get_variables_by_name('weights')), 98)
  51. self.assertEqual(len(get_variables_by_name('biases')), 98)
  52. self.assertEqual(len(get_variables_by_name('beta')), 0)
  53. self.assertEqual(len(get_variables_by_name('gamma')), 0)
  54. self.assertEqual(len(get_variables_by_name('moving_mean')), 0)
  55. self.assertEqual(len(get_variables_by_name('moving_variance')), 0)
  56. def testVariablesByLayer(self):
  57. batch_size = 5
  58. height, width = 299, 299
  59. with self.test_session():
  60. inputs = tf.random_uniform((batch_size, height, width, 3))
  61. with slim.arg_scope([slim.ops.conv2d],
  62. batch_norm_params={'decay': 0.9997}):
  63. slim.inception.inception_v3(inputs)
  64. self.assertEqual(len(get_variables()), 388)
  65. self.assertEqual(len(get_variables('conv0')), 4)
  66. self.assertEqual(len(get_variables('conv1')), 4)
  67. self.assertEqual(len(get_variables('conv2')), 4)
  68. self.assertEqual(len(get_variables('conv3')), 4)
  69. self.assertEqual(len(get_variables('conv4')), 4)
  70. self.assertEqual(len(get_variables('mixed_35x35x256a')), 28)
  71. self.assertEqual(len(get_variables('mixed_35x35x288a')), 28)
  72. self.assertEqual(len(get_variables('mixed_35x35x288b')), 28)
  73. self.assertEqual(len(get_variables('mixed_17x17x768a')), 16)
  74. self.assertEqual(len(get_variables('mixed_17x17x768b')), 40)
  75. self.assertEqual(len(get_variables('mixed_17x17x768c')), 40)
  76. self.assertEqual(len(get_variables('mixed_17x17x768d')), 40)
  77. self.assertEqual(len(get_variables('mixed_17x17x768e')), 40)
  78. self.assertEqual(len(get_variables('mixed_8x8x2048a')), 36)
  79. self.assertEqual(len(get_variables('mixed_8x8x2048b')), 36)
  80. self.assertEqual(len(get_variables('logits')), 2)
  81. self.assertEqual(len(get_variables('aux_logits')), 10)
  82. def testVariablesToRestore(self):
  83. batch_size = 5
  84. height, width = 299, 299
  85. with self.test_session():
  86. inputs = tf.random_uniform((batch_size, height, width, 3))
  87. with slim.arg_scope([slim.ops.conv2d],
  88. batch_norm_params={'decay': 0.9997}):
  89. slim.inception.inception_v3(inputs)
  90. variables_to_restore = tf.get_collection(
  91. slim.variables.VARIABLES_TO_RESTORE)
  92. self.assertEqual(len(variables_to_restore), 388)
  93. self.assertListEqual(variables_to_restore, get_variables())
  94. def testVariablesToRestoreWithoutLogits(self):
  95. batch_size = 5
  96. height, width = 299, 299
  97. with self.test_session():
  98. inputs = tf.random_uniform((batch_size, height, width, 3))
  99. with slim.arg_scope([slim.ops.conv2d],
  100. batch_norm_params={'decay': 0.9997}):
  101. slim.inception.inception_v3(inputs, restore_logits=False)
  102. variables_to_restore = tf.get_collection(
  103. slim.variables.VARIABLES_TO_RESTORE)
  104. self.assertEqual(len(variables_to_restore), 384)
  105. def testRegularizationLosses(self):
  106. batch_size = 5
  107. height, width = 299, 299
  108. with self.test_session():
  109. inputs = tf.random_uniform((batch_size, height, width, 3))
  110. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
  111. slim.inception.inception_v3(inputs)
  112. losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  113. self.assertEqual(len(losses), len(get_variables_by_name('weights')))
  114. def testTotalLossWithoutRegularization(self):
  115. batch_size = 5
  116. height, width = 299, 299
  117. num_classes = 1001
  118. with self.test_session():
  119. inputs = tf.random_uniform((batch_size, height, width, 3))
  120. dense_labels = tf.random_uniform((batch_size, num_classes))
  121. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0):
  122. logits, end_points = slim.inception.inception_v3(
  123. inputs,
  124. num_classes=num_classes)
  125. # Cross entropy loss for the main softmax prediction.
  126. slim.losses.cross_entropy_loss(logits,
  127. dense_labels,
  128. label_smoothing=0.1,
  129. weight=1.0)
  130. # Cross entropy loss for the auxiliary softmax head.
  131. slim.losses.cross_entropy_loss(end_points['aux_logits'],
  132. dense_labels,
  133. label_smoothing=0.1,
  134. weight=0.4,
  135. scope='aux_loss')
  136. losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
  137. self.assertEqual(len(losses), 2)
  138. def testTotalLossWithRegularization(self):
  139. batch_size = 5
  140. height, width = 299, 299
  141. num_classes = 1000
  142. with self.test_session():
  143. inputs = tf.random_uniform((batch_size, height, width, 3))
  144. dense_labels = tf.random_uniform((batch_size, num_classes))
  145. with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], weight_decay=0.00004):
  146. logits, end_points = slim.inception.inception_v3(inputs, num_classes)
  147. # Cross entropy loss for the main softmax prediction.
  148. slim.losses.cross_entropy_loss(logits,
  149. dense_labels,
  150. label_smoothing=0.1,
  151. weight=1.0)
  152. # Cross entropy loss for the auxiliary softmax head.
  153. slim.losses.cross_entropy_loss(end_points['aux_logits'],
  154. dense_labels,
  155. label_smoothing=0.1,
  156. weight=0.4,
  157. scope='aux_loss')
  158. losses = tf.get_collection(slim.losses.LOSSES_COLLECTION)
  159. self.assertEqual(len(losses), 2)
  160. reg_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  161. self.assertEqual(len(reg_losses), 98)
  162. if __name__ == '__main__':
  163. tf.test.main()