variables_test.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.variables."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from inception.slim import scopes
  21. from inception.slim import variables
  22. class VariablesTest(tf.test.TestCase):
  23. def testCreateVariable(self):
  24. with self.test_session():
  25. with tf.variable_scope('A'):
  26. a = variables.variable('a', [5])
  27. self.assertEquals(a.op.name, 'A/a')
  28. self.assertListEqual(a.get_shape().as_list(), [5])
  29. def testGetVariableGivenName(self):
  30. with self.test_session():
  31. with tf.variable_scope('A'):
  32. a = variables.variable('a', [5])
  33. with tf.variable_scope('B'):
  34. b = variables.variable('a', [5])
  35. self.assertEquals('a', variables.get_variable_given_name(a))
  36. self.assertEquals('a', variables.get_variable_given_name(b))
  37. def testGetVariableGivenNameScoped(self):
  38. with self.test_session():
  39. with tf.variable_scope('A'):
  40. a = variables.variable('a', [5])
  41. b = variables.variable('b', [5])
  42. self.assertEquals([a], variables.get_variables_by_name('a'))
  43. self.assertEquals([b], variables.get_variables_by_name('b'))
  44. def testGetVariables(self):
  45. with self.test_session():
  46. with tf.variable_scope('A'):
  47. a = variables.variable('a', [5])
  48. with tf.variable_scope('B'):
  49. b = variables.variable('a', [5])
  50. self.assertEquals([a], variables.get_variables('A'))
  51. self.assertEquals([b], variables.get_variables('B'))
  52. def testGetVariablesSuffix(self):
  53. with self.test_session():
  54. with tf.variable_scope('A'):
  55. a = variables.variable('a', [5])
  56. with tf.variable_scope('A'):
  57. b = variables.variable('b', [5])
  58. self.assertEquals([a], variables.get_variables(suffix='a'))
  59. self.assertEquals([b], variables.get_variables(suffix='b'))
  60. def testGetVariableWithSingleVar(self):
  61. with self.test_session():
  62. with tf.variable_scope('parent'):
  63. a = variables.variable('child', [5])
  64. self.assertEquals(a, variables.get_unique_variable('parent/child'))
  65. def testGetVariableWithDistractors(self):
  66. with self.test_session():
  67. with tf.variable_scope('parent'):
  68. a = variables.variable('child', [5])
  69. with tf.variable_scope('child'):
  70. variables.variable('grandchild1', [7])
  71. variables.variable('grandchild2', [9])
  72. self.assertEquals(a, variables.get_unique_variable('parent/child'))
  73. def testGetVariableThrowsExceptionWithNoMatch(self):
  74. var_name = 'cant_find_me'
  75. with self.test_session():
  76. with self.assertRaises(ValueError):
  77. variables.get_unique_variable(var_name)
  78. def testGetThrowsExceptionWithChildrenButNoMatch(self):
  79. var_name = 'parent/child'
  80. with self.test_session():
  81. with tf.variable_scope(var_name):
  82. variables.variable('grandchild1', [7])
  83. variables.variable('grandchild2', [9])
  84. with self.assertRaises(ValueError):
  85. variables.get_unique_variable(var_name)
  86. def testGetVariablesToRestore(self):
  87. with self.test_session():
  88. with tf.variable_scope('A'):
  89. a = variables.variable('a', [5])
  90. with tf.variable_scope('B'):
  91. b = variables.variable('b', [5])
  92. self.assertListEqual([a, b],
  93. tf.get_collection(variables.VARIABLES_TO_RESTORE))
  94. def testGetVariablesToRestorePartial(self):
  95. with self.test_session():
  96. with tf.variable_scope('A'):
  97. a = variables.variable('a', [5])
  98. with tf.variable_scope('B'):
  99. b = variables.variable('b', [5], restore=False)
  100. self.assertListEqual([a, b], variables.get_variables())
  101. self.assertListEqual([a],
  102. tf.get_collection(variables.VARIABLES_TO_RESTORE))
  103. def testReuseVariable(self):
  104. with self.test_session():
  105. with tf.variable_scope('A'):
  106. a = variables.variable('a', [])
  107. with tf.variable_scope('A', reuse=True):
  108. b = variables.variable('a', [])
  109. self.assertEquals(a, b)
  110. self.assertListEqual([a], variables.get_variables())
  111. def testVariableWithDevice(self):
  112. with self.test_session():
  113. with tf.variable_scope('A'):
  114. a = variables.variable('a', [], device='cpu:0')
  115. b = variables.variable('b', [], device='cpu:1')
  116. self.assertDeviceEqual(a.device, 'cpu:0')
  117. self.assertDeviceEqual(b.device, 'cpu:1')
  118. def testVariableWithDeviceFromScope(self):
  119. with self.test_session():
  120. with tf.device('/cpu:0'):
  121. a = variables.variable('a', [])
  122. b = variables.variable('b', [], device='cpu:1')
  123. self.assertDeviceEqual(a.device, 'cpu:0')
  124. self.assertDeviceEqual(b.device, 'cpu:1')
  125. def testVariableCollection(self):
  126. with self.test_session():
  127. a = variables.variable('a', [], collections='A')
  128. b = variables.variable('b', [], collections='B')
  129. self.assertEquals(a, tf.get_collection('A')[0])
  130. self.assertEquals(b, tf.get_collection('B')[0])
  131. def testVariableCollections(self):
  132. with self.test_session():
  133. a = variables.variable('a', [], collections=['A', 'C'])
  134. b = variables.variable('b', [], collections=['B', 'C'])
  135. self.assertEquals(a, tf.get_collection('A')[0])
  136. self.assertEquals(b, tf.get_collection('B')[0])
  137. def testVariableCollectionsWithArgScope(self):
  138. with self.test_session():
  139. with scopes.arg_scope([variables.variable], collections='A'):
  140. a = variables.variable('a', [])
  141. b = variables.variable('b', [])
  142. self.assertListEqual([a, b], tf.get_collection('A'))
  143. def testVariableCollectionsWithArgScopeNested(self):
  144. with self.test_session():
  145. with scopes.arg_scope([variables.variable], collections='A'):
  146. a = variables.variable('a', [])
  147. with scopes.arg_scope([variables.variable], collections='B'):
  148. b = variables.variable('b', [])
  149. self.assertEquals(a, tf.get_collection('A')[0])
  150. self.assertEquals(b, tf.get_collection('B')[0])
  151. def testVariableCollectionsWithArgScopeNonNested(self):
  152. with self.test_session():
  153. with scopes.arg_scope([variables.variable], collections='A'):
  154. a = variables.variable('a', [])
  155. with scopes.arg_scope([variables.variable], collections='B'):
  156. b = variables.variable('b', [])
  157. variables.variable('c', [])
  158. self.assertListEqual([a], tf.get_collection('A'))
  159. self.assertListEqual([b], tf.get_collection('B'))
  160. def testVariableRestoreWithArgScopeNested(self):
  161. with self.test_session():
  162. with scopes.arg_scope([variables.variable], restore=True):
  163. a = variables.variable('a', [])
  164. with scopes.arg_scope([variables.variable], trainable=False,
  165. collections=['A', 'B']):
  166. b = variables.variable('b', [])
  167. c = variables.variable('c', [])
  168. self.assertListEqual([a, b, c],
  169. tf.get_collection(variables.VARIABLES_TO_RESTORE))
  170. self.assertListEqual([a, c], tf.trainable_variables())
  171. self.assertListEqual([b], tf.get_collection('A'))
  172. self.assertListEqual([b], tf.get_collection('B'))
  173. if __name__ == '__main__':
  174. tf.test.main()