ops_test.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from tensorflow.python.ops import control_flow_ops
  22. from inception.slim import losses
  23. from inception.slim import ops
  24. from inception.slim import scopes
  25. from inception.slim import variables
  26. class ConvTest(tf.test.TestCase):
  27. def testCreateConv(self):
  28. height, width = 3, 3
  29. with self.test_session():
  30. images = tf.random_uniform((5, height, width, 3), seed=1)
  31. output = ops.conv2d(images, 32, [3, 3])
  32. self.assertEquals(output.op.name, 'Conv/Relu')
  33. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
  34. def testCreateConvCreatesWeightsAndBiasesVars(self):
  35. height, width = 3, 3
  36. images = tf.random_uniform((5, height, width, 3), seed=1)
  37. with self.test_session():
  38. self.assertFalse(variables.get_variables('conv1/weights'))
  39. self.assertFalse(variables.get_variables('conv1/biases'))
  40. ops.conv2d(images, 32, [3, 3], scope='conv1')
  41. self.assertTrue(variables.get_variables('conv1/weights'))
  42. self.assertTrue(variables.get_variables('conv1/biases'))
  43. def testCreateConvWithScope(self):
  44. height, width = 3, 3
  45. with self.test_session():
  46. images = tf.random_uniform((5, height, width, 3), seed=1)
  47. output = ops.conv2d(images, 32, [3, 3], scope='conv1')
  48. self.assertEquals(output.op.name, 'conv1/Relu')
  49. def testCreateConvWithoutActivation(self):
  50. height, width = 3, 3
  51. with self.test_session():
  52. images = tf.random_uniform((5, height, width, 3), seed=1)
  53. output = ops.conv2d(images, 32, [3, 3], activation=None)
  54. self.assertEquals(output.op.name, 'Conv/BiasAdd')
  55. def testCreateConvValid(self):
  56. height, width = 3, 3
  57. with self.test_session():
  58. images = tf.random_uniform((5, height, width, 3), seed=1)
  59. output = ops.conv2d(images, 32, [3, 3], padding='VALID')
  60. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 32])
  61. def testCreateConvWithWD(self):
  62. height, width = 3, 3
  63. with self.test_session() as sess:
  64. images = tf.random_uniform((5, height, width, 3), seed=1)
  65. ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
  66. wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
  67. self.assertEquals(wd.op.name, 'Conv/weights/Regularizer/L2Loss/value')
  68. sess.run(tf.initialize_all_variables())
  69. self.assertTrue(sess.run(wd) <= 0.01)
  70. def testReuseConvWithWD(self):
  71. height, width = 3, 3
  72. with self.test_session():
  73. images = tf.random_uniform((5, height, width, 3), seed=1)
  74. ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
  75. self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
  76. tf.get_variable_scope().reuse_variables()
  77. ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
  78. self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
  79. def testConvWithBatchNorm(self):
  80. height, width = 3, 3
  81. with self.test_session():
  82. images = tf.random_uniform((5, height, width, 3), seed=1)
  83. with scopes.arg_scope([ops.conv2d], batch_norm_params={}):
  84. net = ops.conv2d(images, 32, [3, 3], scope='conv1')
  85. net = ops.conv2d(net, 32, [3, 3], scope='conv2')
  86. self.assertEquals(len(tf.get_collection('moving_vars')), 4)
  87. self.assertEquals(len(variables.get_variables('conv1/BatchNorm')), 3)
  88. self.assertEquals(len(variables.get_variables('conv2/BatchNorm')), 3)
  89. class FCTest(tf.test.TestCase):
  90. def testCreateFC(self):
  91. height, width = 3, 3
  92. with self.test_session():
  93. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  94. output = ops.fc(inputs, 32)
  95. self.assertEquals(output.op.name, 'FC/Relu')
  96. self.assertListEqual(output.get_shape().as_list(), [5, 32])
  97. def testCreateFCWithScope(self):
  98. height, width = 3, 3
  99. with self.test_session():
  100. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  101. output = ops.fc(inputs, 32, scope='fc1')
  102. self.assertEquals(output.op.name, 'fc1/Relu')
  103. def testCreateFcCreatesWeightsAndBiasesVars(self):
  104. height, width = 3, 3
  105. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  106. with self.test_session():
  107. self.assertFalse(variables.get_variables('fc1/weights'))
  108. self.assertFalse(variables.get_variables('fc1/biases'))
  109. ops.fc(inputs, 32, scope='fc1')
  110. self.assertTrue(variables.get_variables('fc1/weights'))
  111. self.assertTrue(variables.get_variables('fc1/biases'))
  112. def testReuseVars(self):
  113. height, width = 3, 3
  114. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  115. with self.test_session():
  116. ops.fc(inputs, 32, scope='fc1')
  117. self.assertEquals(len(variables.get_variables('fc1')), 2)
  118. tf.get_variable_scope().reuse_variables()
  119. ops.fc(inputs, 32, scope='fc1')
  120. self.assertEquals(len(variables.get_variables('fc1')), 2)
  121. def testNonReuseVars(self):
  122. height, width = 3, 3
  123. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  124. with self.test_session():
  125. ops.fc(inputs, 32)
  126. self.assertEquals(len(variables.get_variables('FC')), 2)
  127. ops.fc(inputs, 32)
  128. self.assertEquals(len(variables.get_variables('FC')), 4)
  129. def testCreateFCWithoutActivation(self):
  130. height, width = 3, 3
  131. with self.test_session():
  132. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  133. output = ops.fc(inputs, 32, activation=None)
  134. self.assertEquals(output.op.name, 'FC/xw_plus_b')
  135. def testCreateFCWithWD(self):
  136. height, width = 3, 3
  137. with self.test_session() as sess:
  138. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  139. ops.fc(inputs, 32, weight_decay=0.01)
  140. wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
  141. self.assertEquals(wd.op.name, 'FC/weights/Regularizer/L2Loss/value')
  142. sess.run(tf.initialize_all_variables())
  143. self.assertTrue(sess.run(wd) <= 0.01)
  144. def testReuseFCWithWD(self):
  145. height, width = 3, 3
  146. with self.test_session():
  147. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  148. ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
  149. self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
  150. tf.get_variable_scope().reuse_variables()
  151. ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
  152. self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
  153. def testFCWithBatchNorm(self):
  154. height, width = 3, 3
  155. with self.test_session():
  156. images = tf.random_uniform((5, height * width * 3), seed=1)
  157. with scopes.arg_scope([ops.fc], batch_norm_params={}):
  158. net = ops.fc(images, 32, scope='fc1')
  159. net = ops.fc(net, 32, scope='fc2')
  160. self.assertEquals(len(tf.get_collection('moving_vars')), 4)
  161. self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
  162. self.assertEquals(len(variables.get_variables('fc2/BatchNorm')), 3)
  163. class MaxPoolTest(tf.test.TestCase):
  164. def testCreateMaxPool(self):
  165. height, width = 3, 3
  166. with self.test_session():
  167. images = tf.random_uniform((5, height, width, 3), seed=1)
  168. output = ops.max_pool(images, [3, 3])
  169. self.assertEquals(output.op.name, 'MaxPool/MaxPool')
  170. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  171. def testCreateMaxPoolWithScope(self):
  172. height, width = 3, 3
  173. with self.test_session():
  174. images = tf.random_uniform((5, height, width, 3), seed=1)
  175. output = ops.max_pool(images, [3, 3], scope='pool1')
  176. self.assertEquals(output.op.name, 'pool1/MaxPool')
  177. def testCreateMaxPoolSAME(self):
  178. height, width = 3, 3
  179. with self.test_session():
  180. images = tf.random_uniform((5, height, width, 3), seed=1)
  181. output = ops.max_pool(images, [3, 3], padding='SAME')
  182. self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
  183. def testCreateMaxPoolStrideSAME(self):
  184. height, width = 3, 3
  185. with self.test_session():
  186. images = tf.random_uniform((5, height, width, 3), seed=1)
  187. output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
  188. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  189. class AvgPoolTest(tf.test.TestCase):
  190. def testCreateAvgPool(self):
  191. height, width = 3, 3
  192. with self.test_session():
  193. images = tf.random_uniform((5, height, width, 3), seed=1)
  194. output = ops.avg_pool(images, [3, 3])
  195. self.assertEquals(output.op.name, 'AvgPool/AvgPool')
  196. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  197. def testCreateAvgPoolWithScope(self):
  198. height, width = 3, 3
  199. with self.test_session():
  200. images = tf.random_uniform((5, height, width, 3), seed=1)
  201. output = ops.avg_pool(images, [3, 3], scope='pool1')
  202. self.assertEquals(output.op.name, 'pool1/AvgPool')
  203. def testCreateAvgPoolSAME(self):
  204. height, width = 3, 3
  205. with self.test_session():
  206. images = tf.random_uniform((5, height, width, 3), seed=1)
  207. output = ops.avg_pool(images, [3, 3], padding='SAME')
  208. self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
  209. def testCreateAvgPoolStrideSAME(self):
  210. height, width = 3, 3
  211. with self.test_session():
  212. images = tf.random_uniform((5, height, width, 3), seed=1)
  213. output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
  214. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  215. class OneHotEncodingTest(tf.test.TestCase):
  216. def testOneHotEncodingCreate(self):
  217. with self.test_session():
  218. labels = tf.constant([0, 1, 2])
  219. output = ops.one_hot_encoding(labels, num_classes=3)
  220. self.assertEquals(output.op.name, 'OneHotEncoding/SparseToDense')
  221. self.assertListEqual(output.get_shape().as_list(), [3, 3])
  222. def testOneHotEncoding(self):
  223. with self.test_session():
  224. labels = tf.constant([0, 1, 2])
  225. one_hot_labels = tf.constant([[1, 0, 0],
  226. [0, 1, 0],
  227. [0, 0, 1]])
  228. output = ops.one_hot_encoding(labels, num_classes=3)
  229. self.assertAllClose(output.eval(), one_hot_labels.eval())
  230. class DropoutTest(tf.test.TestCase):
  231. def testCreateDropout(self):
  232. height, width = 3, 3
  233. with self.test_session():
  234. images = tf.random_uniform((5, height, width, 3), seed=1)
  235. output = ops.dropout(images)
  236. self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
  237. output.get_shape().assert_is_compatible_with(images.get_shape())
  238. def testCreateDropoutNoTraining(self):
  239. height, width = 3, 3
  240. with self.test_session():
  241. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  242. output = ops.dropout(images, is_training=False)
  243. self.assertEquals(output, images)
  244. class FlattenTest(tf.test.TestCase):
  245. def testFlatten4D(self):
  246. height, width = 3, 3
  247. with self.test_session():
  248. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  249. output = ops.flatten(images)
  250. self.assertEquals(output.get_shape().num_elements(),
  251. images.get_shape().num_elements())
  252. self.assertEqual(output.get_shape()[0], images.get_shape()[0])
  253. def testFlatten3D(self):
  254. height, width = 3, 3
  255. with self.test_session():
  256. images = tf.random_uniform((5, height, width), seed=1, name='images')
  257. output = ops.flatten(images)
  258. self.assertEquals(output.get_shape().num_elements(),
  259. images.get_shape().num_elements())
  260. self.assertEqual(output.get_shape()[0], images.get_shape()[0])
  261. def testFlattenBatchSize(self):
  262. height, width = 3, 3
  263. with self.test_session() as sess:
  264. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  265. inputs = tf.placeholder(tf.int32, (None, height, width, 3))
  266. output = ops.flatten(inputs)
  267. self.assertEquals(output.get_shape().as_list(),
  268. [None, height * width * 3])
  269. output = sess.run(output, {inputs: images.eval()})
  270. self.assertEquals(output.size,
  271. images.get_shape().num_elements())
  272. self.assertEqual(output.shape[0], images.get_shape()[0])
  273. class BatchNormTest(tf.test.TestCase):
  274. def testCreateOp(self):
  275. height, width = 3, 3
  276. with self.test_session():
  277. images = tf.random_uniform((5, height, width, 3), seed=1)
  278. output = ops.batch_norm(images)
  279. self.assertTrue(output.op.name.startswith('BatchNorm/batchnorm'))
  280. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  281. def testCreateVariables(self):
  282. height, width = 3, 3
  283. with self.test_session():
  284. images = tf.random_uniform((5, height, width, 3), seed=1)
  285. ops.batch_norm(images, scale=True)
  286. beta = variables.get_variables_by_name('beta')[0]
  287. gamma = variables.get_variables_by_name('gamma')[0]
  288. self.assertEquals(beta.op.name, 'BatchNorm/beta')
  289. self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
  290. moving_mean = tf.get_collection('moving_vars')[0]
  291. moving_variance = tf.get_collection('moving_vars')[1]
  292. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  293. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  294. def testMovingAverageVariables(self):
  295. height, width = 3, 3
  296. with self.test_session():
  297. images = tf.random_uniform((5, height, width, 3), seed=1)
  298. ops.batch_norm(images, scale=True)
  299. moving_mean = tf.moving_average_variables()[0]
  300. moving_variance = tf.moving_average_variables()[1]
  301. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  302. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  303. def testUpdateOps(self):
  304. height, width = 3, 3
  305. with self.test_session():
  306. images = tf.random_uniform((5, height, width, 3), seed=1)
  307. ops.batch_norm(images)
  308. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  309. update_moving_mean = update_ops[0]
  310. update_moving_variance = update_ops[1]
  311. self.assertEquals(update_moving_mean.op.name,
  312. 'BatchNorm/AssignMovingAvg')
  313. self.assertEquals(update_moving_variance.op.name,
  314. 'BatchNorm/AssignMovingAvg_1')
  315. def testReuseVariables(self):
  316. height, width = 3, 3
  317. with self.test_session():
  318. images = tf.random_uniform((5, height, width, 3), seed=1)
  319. ops.batch_norm(images, scale=True, scope='bn')
  320. tf.get_variable_scope().reuse_variables()
  321. ops.batch_norm(images, scale=True, scope='bn')
  322. beta = variables.get_variables_by_name('beta')
  323. gamma = variables.get_variables_by_name('gamma')
  324. self.assertEquals(len(beta), 1)
  325. self.assertEquals(len(gamma), 1)
  326. moving_vars = tf.get_collection('moving_vars')
  327. self.assertEquals(len(moving_vars), 2)
  328. def testReuseUpdateOps(self):
  329. height, width = 3, 3
  330. with self.test_session():
  331. images = tf.random_uniform((5, height, width, 3), seed=1)
  332. ops.batch_norm(images, scope='bn')
  333. self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
  334. tf.get_variable_scope().reuse_variables()
  335. ops.batch_norm(images, scope='bn')
  336. self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
  337. def testCreateMovingVars(self):
  338. height, width = 3, 3
  339. with self.test_session():
  340. images = tf.random_uniform((5, height, width, 3), seed=1)
  341. _ = ops.batch_norm(images, moving_vars='moving_vars')
  342. moving_mean = tf.get_collection('moving_vars',
  343. 'BatchNorm/moving_mean')
  344. self.assertEquals(len(moving_mean), 1)
  345. self.assertEquals(moving_mean[0].op.name, 'BatchNorm/moving_mean')
  346. moving_variance = tf.get_collection('moving_vars',
  347. 'BatchNorm/moving_variance')
  348. self.assertEquals(len(moving_variance), 1)
  349. self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance')
  350. def testComputeMovingVars(self):
  351. height, width = 3, 3
  352. with self.test_session() as sess:
  353. image_shape = (10, height, width, 3)
  354. image_values = np.random.rand(*image_shape)
  355. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  356. expected_var = np.var(image_values, axis=(0, 1, 2))
  357. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  358. output = ops.batch_norm(images, decay=0.1)
  359. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  360. with tf.control_dependencies(update_ops):
  361. barrier = tf.no_op(name='gradient_barrier')
  362. output = control_flow_ops.with_dependencies([barrier], output)
  363. # Initialize all variables
  364. sess.run(tf.initialize_all_variables())
  365. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  366. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  367. mean, variance = sess.run([moving_mean, moving_variance])
  368. # After initialization moving_mean == 0 and moving_variance == 1.
  369. self.assertAllClose(mean, [0] * 3)
  370. self.assertAllClose(variance, [1] * 3)
  371. for _ in range(10):
  372. sess.run([output])
  373. mean = moving_mean.eval()
  374. variance = moving_variance.eval()
  375. # After 10 updates with decay 0.1 moving_mean == expected_mean and
  376. # moving_variance == expected_var.
  377. self.assertAllClose(mean, expected_mean)
  378. self.assertAllClose(variance, expected_var)
  379. def testEvalMovingVars(self):
  380. height, width = 3, 3
  381. with self.test_session() as sess:
  382. image_shape = (10, height, width, 3)
  383. image_values = np.random.rand(*image_shape)
  384. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  385. expected_var = np.var(image_values, axis=(0, 1, 2))
  386. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  387. output = ops.batch_norm(images, decay=0.1, is_training=False)
  388. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  389. with tf.control_dependencies(update_ops):
  390. barrier = tf.no_op(name='gradient_barrier')
  391. output = control_flow_ops.with_dependencies([barrier], output)
  392. # Initialize all variables
  393. sess.run(tf.initialize_all_variables())
  394. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  395. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  396. mean, variance = sess.run([moving_mean, moving_variance])
  397. # After initialization moving_mean == 0 and moving_variance == 1.
  398. self.assertAllClose(mean, [0] * 3)
  399. self.assertAllClose(variance, [1] * 3)
  400. # Simulate assigment from saver restore.
  401. init_assigns = [tf.assign(moving_mean, expected_mean),
  402. tf.assign(moving_variance, expected_var)]
  403. sess.run(init_assigns)
  404. for _ in range(10):
  405. sess.run([output], {images: np.random.rand(*image_shape)})
  406. mean = moving_mean.eval()
  407. variance = moving_variance.eval()
  408. # Although we feed different images, the moving_mean and moving_variance
  409. # shouldn't change.
  410. self.assertAllClose(mean, expected_mean)
  411. self.assertAllClose(variance, expected_var)
  412. def testReuseVars(self):
  413. height, width = 3, 3
  414. with self.test_session() as sess:
  415. image_shape = (10, height, width, 3)
  416. image_values = np.random.rand(*image_shape)
  417. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  418. expected_var = np.var(image_values, axis=(0, 1, 2))
  419. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  420. output = ops.batch_norm(images, decay=0.1, is_training=False)
  421. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  422. with tf.control_dependencies(update_ops):
  423. barrier = tf.no_op(name='gradient_barrier')
  424. output = control_flow_ops.with_dependencies([barrier], output)
  425. # Initialize all variables
  426. sess.run(tf.initialize_all_variables())
  427. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  428. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  429. mean, variance = sess.run([moving_mean, moving_variance])
  430. # After initialization moving_mean == 0 and moving_variance == 1.
  431. self.assertAllClose(mean, [0] * 3)
  432. self.assertAllClose(variance, [1] * 3)
  433. # Simulate assigment from saver restore.
  434. init_assigns = [tf.assign(moving_mean, expected_mean),
  435. tf.assign(moving_variance, expected_var)]
  436. sess.run(init_assigns)
  437. for _ in range(10):
  438. sess.run([output], {images: np.random.rand(*image_shape)})
  439. mean = moving_mean.eval()
  440. variance = moving_variance.eval()
  441. # Although we feed different images, the moving_mean and moving_variance
  442. # shouldn't change.
  443. self.assertAllClose(mean, expected_mean)
  444. self.assertAllClose(variance, expected_var)
  445. if __name__ == '__main__':
  446. tf.test.main()