variables_test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.variables."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from inception.slim import scopes
  21. from inception.slim import variables
  22. class VariablesTest(tf.test.TestCase):
  23. def testCreateVariable(self):
  24. with self.test_session():
  25. with tf.variable_scope('A'):
  26. a = variables.variable('a', [5])
  27. self.assertEquals(a.op.name, 'A/a')
  28. self.assertListEqual(a.get_shape().as_list(), [5])
  29. def testGetVariables(self):
  30. with self.test_session():
  31. with tf.variable_scope('A'):
  32. a = variables.variable('a', [5])
  33. with tf.variable_scope('B'):
  34. b = variables.variable('a', [5])
  35. self.assertEquals([a, b], variables.get_variables())
  36. self.assertEquals([a], variables.get_variables('A'))
  37. self.assertEquals([b], variables.get_variables('B'))
  38. def testGetVariablesSuffix(self):
  39. with self.test_session():
  40. with tf.variable_scope('A'):
  41. a = variables.variable('a', [5])
  42. with tf.variable_scope('A'):
  43. b = variables.variable('b', [5])
  44. self.assertEquals([a], variables.get_variables(suffix='a'))
  45. self.assertEquals([b], variables.get_variables(suffix='b'))
  46. def testGetVariableWithSingleVar(self):
  47. with self.test_session():
  48. with tf.variable_scope('parent'):
  49. a = variables.variable('child', [5])
  50. self.assertEquals(a, variables.get_unique_variable('parent/child'))
  51. def testGetVariableWithDistractors(self):
  52. with self.test_session():
  53. with tf.variable_scope('parent'):
  54. a = variables.variable('child', [5])
  55. with tf.variable_scope('child'):
  56. variables.variable('grandchild1', [7])
  57. variables.variable('grandchild2', [9])
  58. self.assertEquals(a, variables.get_unique_variable('parent/child'))
  59. def testGetVariableThrowsExceptionWithNoMatch(self):
  60. var_name = 'cant_find_me'
  61. with self.test_session():
  62. with self.assertRaises(ValueError):
  63. variables.get_unique_variable(var_name)
  64. def testGetThrowsExceptionWithChildrenButNoMatch(self):
  65. var_name = 'parent/child'
  66. with self.test_session():
  67. with tf.variable_scope(var_name):
  68. variables.variable('grandchild1', [7])
  69. variables.variable('grandchild2', [9])
  70. with self.assertRaises(ValueError):
  71. variables.get_unique_variable(var_name)
  72. def testGetVariablesToRestore(self):
  73. with self.test_session():
  74. with tf.variable_scope('A'):
  75. a = variables.variable('a', [5])
  76. with tf.variable_scope('B'):
  77. b = variables.variable('a', [5])
  78. self.assertEquals([a, b], variables.get_variables_to_restore())
  79. def testNoneGetVariablesToRestore(self):
  80. with self.test_session():
  81. with tf.variable_scope('A'):
  82. a = variables.variable('a', [5], restore=False)
  83. with tf.variable_scope('B'):
  84. b = variables.variable('a', [5], restore=False)
  85. self.assertEquals([], variables.get_variables_to_restore())
  86. self.assertEquals([a, b], variables.get_variables())
  87. def testGetMixedVariablesToRestore(self):
  88. with self.test_session():
  89. with tf.variable_scope('A'):
  90. a = variables.variable('a', [5])
  91. b = variables.variable('b', [5], restore=False)
  92. with tf.variable_scope('B'):
  93. c = variables.variable('c', [5])
  94. d = variables.variable('d', [5], restore=False)
  95. self.assertEquals([a, b, c, d], variables.get_variables())
  96. self.assertEquals([a, c], variables.get_variables_to_restore())
  97. def testReuseVariable(self):
  98. with self.test_session():
  99. with tf.variable_scope('A'):
  100. a = variables.variable('a', [])
  101. with tf.variable_scope('A', reuse=True):
  102. b = variables.variable('a', [])
  103. self.assertEquals(a, b)
  104. self.assertListEqual([a], variables.get_variables())
  105. def testVariableWithDevice(self):
  106. with self.test_session():
  107. with tf.variable_scope('A'):
  108. a = variables.variable('a', [], device='cpu:0')
  109. b = variables.variable('b', [], device='cpu:1')
  110. self.assertDeviceEqual(a.device, 'cpu:0')
  111. self.assertDeviceEqual(b.device, 'cpu:1')
  112. def testVariableWithDeviceFromScope(self):
  113. with self.test_session():
  114. with tf.device('/cpu:0'):
  115. a = variables.variable('a', [])
  116. b = variables.variable('b', [], device='cpu:1')
  117. self.assertDeviceEqual(a.device, 'cpu:0')
  118. self.assertDeviceEqual(b.device, 'cpu:1')
  119. def testVariableWithDeviceFunction(self):
  120. class DevFn(object):
  121. def __init__(self):
  122. self.counter = -1
  123. def __call__(self, op):
  124. self.counter += 1
  125. return 'cpu:%d' % self.counter
  126. with self.test_session():
  127. with scopes.arg_scope([variables.variable], device=DevFn()):
  128. a = variables.variable('a', [])
  129. b = variables.variable('b', [])
  130. c = variables.variable('c', [], device='cpu:12')
  131. d = variables.variable('d', [])
  132. with tf.device('cpu:99'):
  133. e_init = tf.constant(12)
  134. e = variables.variable('e', initializer=e_init)
  135. self.assertDeviceEqual(a.device, 'cpu:0')
  136. self.assertDeviceEqual(a.initial_value.device, 'cpu:0')
  137. self.assertDeviceEqual(b.device, 'cpu:1')
  138. self.assertDeviceEqual(b.initial_value.device, 'cpu:1')
  139. self.assertDeviceEqual(c.device, 'cpu:12')
  140. self.assertDeviceEqual(c.initial_value.device, 'cpu:12')
  141. self.assertDeviceEqual(d.device, 'cpu:2')
  142. self.assertDeviceEqual(d.initial_value.device, 'cpu:2')
  143. self.assertDeviceEqual(e.device, 'cpu:3')
  144. self.assertDeviceEqual(e.initial_value.device, 'cpu:99')
  145. def testVariableWithReplicaDeviceSetter(self):
  146. with self.test_session():
  147. with tf.device(tf.train.replica_device_setter(ps_tasks=2)):
  148. a = variables.variable('a', [])
  149. b = variables.variable('b', [])
  150. c = variables.variable('c', [], device='cpu:12')
  151. d = variables.variable('d', [])
  152. with tf.device('cpu:99'):
  153. e_init = tf.constant(12)
  154. e = variables.variable('e', initializer=e_init)
  155. # The values below highlight how the replica_device_setter puts initial
  156. # values on the worker job, and how it merges explicit devices.
  157. self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
  158. self.assertDeviceEqual(a.initial_value.device, '/job:worker/cpu:0')
  159. self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
  160. self.assertDeviceEqual(b.initial_value.device, '/job:worker/cpu:0')
  161. self.assertDeviceEqual(c.device, '/job:ps/task:0/cpu:12')
  162. self.assertDeviceEqual(c.initial_value.device, '/job:worker/cpu:12')
  163. self.assertDeviceEqual(d.device, '/job:ps/task:1/cpu:0')
  164. self.assertDeviceEqual(d.initial_value.device, '/job:worker/cpu:0')
  165. self.assertDeviceEqual(e.device, '/job:ps/task:0/cpu:0')
  166. self.assertDeviceEqual(e.initial_value.device, '/job:worker/cpu:99')
  167. def testVariableWithVariableDeviceChooser(self):
  168. with tf.Graph().as_default():
  169. device_fn = variables.VariableDeviceChooser(num_parameter_servers=2)
  170. with scopes.arg_scope([variables.variable], device=device_fn):
  171. a = variables.variable('a', [])
  172. b = variables.variable('b', [])
  173. c = variables.variable('c', [], device='cpu:12')
  174. d = variables.variable('d', [])
  175. with tf.device('cpu:99'):
  176. e_init = tf.constant(12)
  177. e = variables.variable('e', initializer=e_init)
  178. # The values below highlight how the VariableDeviceChooser puts initial
  179. # values on the same device as the variable job.
  180. self.assertDeviceEqual(a.device, '/job:ps/task:0/cpu:0')
  181. self.assertDeviceEqual(a.initial_value.device, a.device)
  182. self.assertDeviceEqual(b.device, '/job:ps/task:1/cpu:0')
  183. self.assertDeviceEqual(b.initial_value.device, b.device)
  184. self.assertDeviceEqual(c.device, '/cpu:12')
  185. self.assertDeviceEqual(c.initial_value.device, c.device)
  186. self.assertDeviceEqual(d.device, '/job:ps/task:0/cpu:0')
  187. self.assertDeviceEqual(d.initial_value.device, d.device)
  188. self.assertDeviceEqual(e.device, '/job:ps/task:1/cpu:0')
  189. self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
  190. def testVariableGPUPlacement(self):
  191. with tf.Graph().as_default():
  192. device_fn = variables.VariableDeviceChooser(placement='gpu:0')
  193. with scopes.arg_scope([variables.variable], device=device_fn):
  194. a = variables.variable('a', [])
  195. b = variables.variable('b', [])
  196. c = variables.variable('c', [], device='cpu:12')
  197. d = variables.variable('d', [])
  198. with tf.device('cpu:99'):
  199. e_init = tf.constant(12)
  200. e = variables.variable('e', initializer=e_init)
  201. # The values below highlight how the VariableDeviceChooser puts initial
  202. # values on the same device as the variable job.
  203. self.assertDeviceEqual(a.device, '/gpu:0')
  204. self.assertDeviceEqual(a.initial_value.device, a.device)
  205. self.assertDeviceEqual(b.device, '/gpu:0')
  206. self.assertDeviceEqual(b.initial_value.device, b.device)
  207. self.assertDeviceEqual(c.device, '/cpu:12')
  208. self.assertDeviceEqual(c.initial_value.device, c.device)
  209. self.assertDeviceEqual(d.device, '/gpu:0')
  210. self.assertDeviceEqual(d.initial_value.device, d.device)
  211. self.assertDeviceEqual(e.device, '/gpu:0')
  212. self.assertDeviceEqual(e.initial_value.device, '/cpu:99')
  213. def testVariableCollection(self):
  214. with self.test_session():
  215. a = variables.variable('a', [], collections='A')
  216. b = variables.variable('b', [], collections='B')
  217. self.assertEquals(a, tf.get_collection('A')[0])
  218. self.assertEquals(b, tf.get_collection('B')[0])
  219. def testVariableCollections(self):
  220. with self.test_session():
  221. a = variables.variable('a', [], collections=['A', 'C'])
  222. b = variables.variable('b', [], collections=['B', 'C'])
  223. self.assertEquals(a, tf.get_collection('A')[0])
  224. self.assertEquals(b, tf.get_collection('B')[0])
  225. def testVariableCollectionsWithArgScope(self):
  226. with self.test_session():
  227. with scopes.arg_scope([variables.variable], collections='A'):
  228. a = variables.variable('a', [])
  229. b = variables.variable('b', [])
  230. self.assertListEqual([a, b], tf.get_collection('A'))
  231. def testVariableCollectionsWithArgScopeNested(self):
  232. with self.test_session():
  233. with scopes.arg_scope([variables.variable], collections='A'):
  234. a = variables.variable('a', [])
  235. with scopes.arg_scope([variables.variable], collections='B'):
  236. b = variables.variable('b', [])
  237. self.assertEquals(a, tf.get_collection('A')[0])
  238. self.assertEquals(b, tf.get_collection('B')[0])
  239. def testVariableCollectionsWithArgScopeNonNested(self):
  240. with self.test_session():
  241. with scopes.arg_scope([variables.variable], collections='A'):
  242. a = variables.variable('a', [])
  243. with scopes.arg_scope([variables.variable], collections='B'):
  244. b = variables.variable('b', [])
  245. variables.variable('c', [])
  246. self.assertListEqual([a], tf.get_collection('A'))
  247. self.assertListEqual([b], tf.get_collection('B'))
  248. def testVariableRestoreWithArgScopeNested(self):
  249. with self.test_session():
  250. with scopes.arg_scope([variables.variable], restore=True):
  251. a = variables.variable('a', [])
  252. with scopes.arg_scope([variables.variable],
  253. trainable=False,
  254. collections=['A', 'B']):
  255. b = variables.variable('b', [])
  256. c = variables.variable('c', [])
  257. self.assertListEqual([a, b, c], variables.get_variables_to_restore())
  258. self.assertListEqual([a, c], tf.trainable_variables())
  259. self.assertListEqual([b], tf.get_collection('A'))
  260. self.assertListEqual([b], tf.get_collection('B'))
  261. class GetVariablesByNameTest(tf.test.TestCase):
  262. def testGetVariableGivenNameScoped(self):
  263. with self.test_session():
  264. with tf.variable_scope('A'):
  265. a = variables.variable('a', [5])
  266. b = variables.variable('b', [5])
  267. self.assertEquals([a], variables.get_variables_by_name('a'))
  268. self.assertEquals([b], variables.get_variables_by_name('b'))
  269. def testGetVariablesByNameReturnsByValueWithScope(self):
  270. with self.test_session():
  271. with tf.variable_scope('A'):
  272. a = variables.variable('a', [5])
  273. matched_variables = variables.get_variables_by_name('a')
  274. # If variables.get_variables_by_name returns the list by reference, the
  275. # following append should persist, and be returned, in subsequent calls
  276. # to variables.get_variables_by_name('a').
  277. matched_variables.append(4)
  278. matched_variables = variables.get_variables_by_name('a')
  279. self.assertEquals([a], matched_variables)
  280. def testGetVariablesByNameReturnsByValueWithoutScope(self):
  281. with self.test_session():
  282. a = variables.variable('a', [5])
  283. matched_variables = variables.get_variables_by_name('a')
  284. # If variables.get_variables_by_name returns the list by reference, the
  285. # following append should persist, and be returned, in subsequent calls
  286. # to variables.get_variables_by_name('a').
  287. matched_variables.append(4)
  288. matched_variables = variables.get_variables_by_name('a')
  289. self.assertEquals([a], matched_variables)
  290. class GlobalStepTest(tf.test.TestCase):
  291. def testStable(self):
  292. with tf.Graph().as_default():
  293. gs = variables.global_step()
  294. gs2 = variables.global_step()
  295. self.assertTrue(gs is gs2)
  296. def testDevice(self):
  297. with tf.Graph().as_default():
  298. with scopes.arg_scope([variables.global_step], device='/gpu:0'):
  299. gs = variables.global_step()
  300. self.assertDeviceEqual(gs.device, '/gpu:0')
  301. def testDeviceFn(self):
  302. class DevFn(object):
  303. def __init__(self):
  304. self.counter = -1
  305. def __call__(self, op):
  306. self.counter += 1
  307. return '/cpu:%d' % self.counter
  308. with tf.Graph().as_default():
  309. with scopes.arg_scope([variables.global_step], device=DevFn()):
  310. gs = variables.global_step()
  311. gs2 = variables.global_step()
  312. self.assertDeviceEqual(gs.device, '/cpu:0')
  313. self.assertEquals(gs, gs2)
  314. self.assertDeviceEqual(gs2.device, '/cpu:0')
  315. def testReplicaDeviceSetter(self):
  316. device_fn = tf.train.replica_device_setter(2)
  317. with tf.Graph().as_default():
  318. with scopes.arg_scope([variables.global_step], device=device_fn):
  319. gs = variables.global_step()
  320. gs2 = variables.global_step()
  321. self.assertEquals(gs, gs2)
  322. self.assertDeviceEqual(gs.device, '/job:ps/task:0')
  323. self.assertDeviceEqual(gs.initial_value.device, '/job:ps/task:0')
  324. self.assertDeviceEqual(gs2.device, '/job:ps/task:0')
  325. self.assertDeviceEqual(gs2.initial_value.device, '/job:ps/task:0')
  326. def testVariableWithVariableDeviceChooser(self):
  327. with tf.Graph().as_default():
  328. device_fn = variables.VariableDeviceChooser()
  329. with scopes.arg_scope([variables.global_step], device=device_fn):
  330. gs = variables.global_step()
  331. gs2 = variables.global_step()
  332. self.assertEquals(gs, gs2)
  333. self.assertDeviceEqual(gs.device, 'cpu:0')
  334. self.assertDeviceEqual(gs.initial_value.device, gs.device)
  335. self.assertDeviceEqual(gs2.device, 'cpu:0')
  336. self.assertDeviceEqual(gs2.initial_value.device, gs2.device)
  337. if __name__ == '__main__':
  338. tf.test.main()