ops_test.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for slim.ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from tensorflow.python.ops import control_flow_ops
  22. from inception.slim import ops
  23. from inception.slim import scopes
  24. from inception.slim import variables
  25. class ConvTest(tf.test.TestCase):
  26. def testCreateConv(self):
  27. height, width = 3, 3
  28. with self.test_session():
  29. images = tf.random_uniform((5, height, width, 3), seed=1)
  30. output = ops.conv2d(images, 32, [3, 3])
  31. self.assertEquals(output.op.name, 'Conv/Relu')
  32. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
  33. def testCreateSquareConv(self):
  34. height, width = 3, 3
  35. with self.test_session():
  36. images = tf.random_uniform((5, height, width, 3), seed=1)
  37. output = ops.conv2d(images, 32, 3)
  38. self.assertEquals(output.op.name, 'Conv/Relu')
  39. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
  40. def testCreateConvWithTensorShape(self):
  41. height, width = 3, 3
  42. with self.test_session():
  43. images = tf.random_uniform((5, height, width, 3), seed=1)
  44. output = ops.conv2d(images, 32, images.get_shape()[1:3])
  45. self.assertEquals(output.op.name, 'Conv/Relu')
  46. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
  47. def testCreateFullyConv(self):
  48. height, width = 6, 6
  49. with self.test_session():
  50. images = tf.random_uniform((5, height, width, 32), seed=1)
  51. output = ops.conv2d(images, 64, images.get_shape()[1:3], padding='VALID')
  52. self.assertEquals(output.op.name, 'Conv/Relu')
  53. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
  54. def testCreateVerticalConv(self):
  55. height, width = 3, 3
  56. with self.test_session():
  57. images = tf.random_uniform((5, height, width, 3), seed=1)
  58. output = ops.conv2d(images, 32, [3, 1])
  59. self.assertEquals(output.op.name, 'Conv/Relu')
  60. self.assertListEqual(output.get_shape().as_list(),
  61. [5, height, width, 32])
  62. def testCreateHorizontalConv(self):
  63. height, width = 3, 3
  64. with self.test_session():
  65. images = tf.random_uniform((5, height, width, 3), seed=1)
  66. output = ops.conv2d(images, 32, [1, 3])
  67. self.assertEquals(output.op.name, 'Conv/Relu')
  68. self.assertListEqual(output.get_shape().as_list(),
  69. [5, height, width, 32])
  70. def testCreateConvWithStride(self):
  71. height, width = 6, 6
  72. with self.test_session():
  73. images = tf.random_uniform((5, height, width, 3), seed=1)
  74. output = ops.conv2d(images, 32, [3, 3], stride=2)
  75. self.assertEquals(output.op.name, 'Conv/Relu')
  76. self.assertListEqual(output.get_shape().as_list(),
  77. [5, height/2, width/2, 32])
  78. def testCreateConvCreatesWeightsAndBiasesVars(self):
  79. height, width = 3, 3
  80. images = tf.random_uniform((5, height, width, 3), seed=1)
  81. with self.test_session():
  82. self.assertFalse(variables.get_variables('conv1/weights'))
  83. self.assertFalse(variables.get_variables('conv1/biases'))
  84. ops.conv2d(images, 32, [3, 3], scope='conv1')
  85. self.assertTrue(variables.get_variables('conv1/weights'))
  86. self.assertTrue(variables.get_variables('conv1/biases'))
  87. def testCreateConvWithScope(self):
  88. height, width = 3, 3
  89. with self.test_session():
  90. images = tf.random_uniform((5, height, width, 3), seed=1)
  91. output = ops.conv2d(images, 32, [3, 3], scope='conv1')
  92. self.assertEquals(output.op.name, 'conv1/Relu')
  93. def testCreateConvWithoutActivation(self):
  94. height, width = 3, 3
  95. with self.test_session():
  96. images = tf.random_uniform((5, height, width, 3), seed=1)
  97. output = ops.conv2d(images, 32, [3, 3], activation=None)
  98. self.assertEquals(output.op.name, 'Conv/BiasAdd')
  99. def testCreateConvValid(self):
  100. height, width = 3, 3
  101. with self.test_session():
  102. images = tf.random_uniform((5, height, width, 3), seed=1)
  103. output = ops.conv2d(images, 32, [3, 3], padding='VALID')
  104. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 32])
  105. def testCreateConvWithWD(self):
  106. height, width = 3, 3
  107. with self.test_session() as sess:
  108. images = tf.random_uniform((5, height, width, 3), seed=1)
  109. ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
  110. wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
  111. self.assertEquals(wd.op.name,
  112. 'Conv/weights/Regularizer/L2Regularizer/value')
  113. sess.run(tf.initialize_all_variables())
  114. self.assertTrue(sess.run(wd) <= 0.01)
  115. def testCreateConvWithoutWD(self):
  116. height, width = 3, 3
  117. with self.test_session():
  118. images = tf.random_uniform((5, height, width, 3), seed=1)
  119. ops.conv2d(images, 32, [3, 3], weight_decay=0)
  120. self.assertEquals(
  121. tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
  122. def testReuseVars(self):
  123. height, width = 3, 3
  124. with self.test_session():
  125. images = tf.random_uniform((5, height, width, 3), seed=1)
  126. ops.conv2d(images, 32, [3, 3], scope='conv1')
  127. self.assertEquals(len(variables.get_variables()), 2)
  128. ops.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
  129. self.assertEquals(len(variables.get_variables()), 2)
  130. def testNonReuseVars(self):
  131. height, width = 3, 3
  132. with self.test_session():
  133. images = tf.random_uniform((5, height, width, 3), seed=1)
  134. ops.conv2d(images, 32, [3, 3])
  135. self.assertEquals(len(variables.get_variables()), 2)
  136. ops.conv2d(images, 32, [3, 3])
  137. self.assertEquals(len(variables.get_variables()), 4)
  138. def testReuseConvWithWD(self):
  139. height, width = 3, 3
  140. with self.test_session():
  141. images = tf.random_uniform((5, height, width, 3), seed=1)
  142. ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
  143. self.assertEquals(len(variables.get_variables()), 2)
  144. self.assertEquals(
  145. len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
  146. ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1',
  147. reuse=True)
  148. self.assertEquals(len(variables.get_variables()), 2)
  149. self.assertEquals(
  150. len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
  151. def testConvWithBatchNorm(self):
  152. height, width = 3, 3
  153. with self.test_session():
  154. images = tf.random_uniform((5, height, width, 32), seed=1)
  155. with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
  156. net = ops.conv2d(images, 32, [3, 3])
  157. net = ops.conv2d(net, 32, [3, 3])
  158. self.assertEquals(len(variables.get_variables()), 8)
  159. self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
  160. self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 3)
  161. def testReuseConvWithBatchNorm(self):
  162. height, width = 3, 3
  163. with self.test_session():
  164. images = tf.random_uniform((5, height, width, 32), seed=1)
  165. with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
  166. net = ops.conv2d(images, 32, [3, 3], scope='Conv')
  167. net = ops.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
  168. self.assertEquals(len(variables.get_variables()), 4)
  169. self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
  170. self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 0)
  171. class FCTest(tf.test.TestCase):
  172. def testCreateFC(self):
  173. height, width = 3, 3
  174. with self.test_session():
  175. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  176. output = ops.fc(inputs, 32)
  177. self.assertEquals(output.op.name, 'FC/Relu')
  178. self.assertListEqual(output.get_shape().as_list(), [5, 32])
  179. def testCreateFCWithScope(self):
  180. height, width = 3, 3
  181. with self.test_session():
  182. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  183. output = ops.fc(inputs, 32, scope='fc1')
  184. self.assertEquals(output.op.name, 'fc1/Relu')
  185. def testCreateFcCreatesWeightsAndBiasesVars(self):
  186. height, width = 3, 3
  187. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  188. with self.test_session():
  189. self.assertFalse(variables.get_variables('fc1/weights'))
  190. self.assertFalse(variables.get_variables('fc1/biases'))
  191. ops.fc(inputs, 32, scope='fc1')
  192. self.assertTrue(variables.get_variables('fc1/weights'))
  193. self.assertTrue(variables.get_variables('fc1/biases'))
  194. def testReuseVars(self):
  195. height, width = 3, 3
  196. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  197. with self.test_session():
  198. ops.fc(inputs, 32, scope='fc1')
  199. self.assertEquals(len(variables.get_variables('fc1')), 2)
  200. ops.fc(inputs, 32, scope='fc1', reuse=True)
  201. self.assertEquals(len(variables.get_variables('fc1')), 2)
  202. def testNonReuseVars(self):
  203. height, width = 3, 3
  204. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  205. with self.test_session():
  206. ops.fc(inputs, 32)
  207. self.assertEquals(len(variables.get_variables('FC')), 2)
  208. ops.fc(inputs, 32)
  209. self.assertEquals(len(variables.get_variables('FC')), 4)
  210. def testCreateFCWithoutActivation(self):
  211. height, width = 3, 3
  212. with self.test_session():
  213. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  214. output = ops.fc(inputs, 32, activation=None)
  215. self.assertEquals(output.op.name, 'FC/xw_plus_b')
  216. def testCreateFCWithWD(self):
  217. height, width = 3, 3
  218. with self.test_session() as sess:
  219. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  220. ops.fc(inputs, 32, weight_decay=0.01)
  221. wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
  222. self.assertEquals(wd.op.name,
  223. 'FC/weights/Regularizer/L2Regularizer/value')
  224. sess.run(tf.initialize_all_variables())
  225. self.assertTrue(sess.run(wd) <= 0.01)
  226. def testCreateFCWithoutWD(self):
  227. height, width = 3, 3
  228. with self.test_session():
  229. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  230. ops.fc(inputs, 32, weight_decay=0)
  231. self.assertEquals(
  232. tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
  233. def testReuseFCWithWD(self):
  234. height, width = 3, 3
  235. with self.test_session():
  236. inputs = tf.random_uniform((5, height * width * 3), seed=1)
  237. ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
  238. self.assertEquals(len(variables.get_variables()), 2)
  239. self.assertEquals(
  240. len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
  241. ops.fc(inputs, 32, weight_decay=0.01, scope='fc', reuse=True)
  242. self.assertEquals(len(variables.get_variables()), 2)
  243. self.assertEquals(
  244. len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
  245. def testFCWithBatchNorm(self):
  246. height, width = 3, 3
  247. with self.test_session():
  248. images = tf.random_uniform((5, height * width * 3), seed=1)
  249. with scopes.arg_scope([ops.fc], batch_norm_params={}):
  250. net = ops.fc(images, 27)
  251. net = ops.fc(net, 27)
  252. self.assertEquals(len(variables.get_variables()), 8)
  253. self.assertEquals(len(variables.get_variables('FC/BatchNorm')), 3)
  254. self.assertEquals(len(variables.get_variables('FC_1/BatchNorm')), 3)
  255. def testReuseFCWithBatchNorm(self):
  256. height, width = 3, 3
  257. with self.test_session():
  258. images = tf.random_uniform((5, height * width * 3), seed=1)
  259. with scopes.arg_scope([ops.fc], batch_norm_params={'decay': 0.9}):
  260. net = ops.fc(images, 27, scope='fc1')
  261. net = ops.fc(net, 27, scope='fc1', reuse=True)
  262. self.assertEquals(len(variables.get_variables()), 4)
  263. self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
  264. class MaxPoolTest(tf.test.TestCase):
  265. def testCreateMaxPool(self):
  266. height, width = 3, 3
  267. with self.test_session():
  268. images = tf.random_uniform((5, height, width, 3), seed=1)
  269. output = ops.max_pool(images, [3, 3])
  270. self.assertEquals(output.op.name, 'MaxPool/MaxPool')
  271. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  272. def testCreateSquareMaxPool(self):
  273. height, width = 3, 3
  274. with self.test_session():
  275. images = tf.random_uniform((5, height, width, 3), seed=1)
  276. output = ops.max_pool(images, 3)
  277. self.assertEquals(output.op.name, 'MaxPool/MaxPool')
  278. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  279. def testCreateMaxPoolWithScope(self):
  280. height, width = 3, 3
  281. with self.test_session():
  282. images = tf.random_uniform((5, height, width, 3), seed=1)
  283. output = ops.max_pool(images, [3, 3], scope='pool1')
  284. self.assertEquals(output.op.name, 'pool1/MaxPool')
  285. def testCreateMaxPoolSAME(self):
  286. height, width = 3, 3
  287. with self.test_session():
  288. images = tf.random_uniform((5, height, width, 3), seed=1)
  289. output = ops.max_pool(images, [3, 3], padding='SAME')
  290. self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
  291. def testCreateMaxPoolStrideSAME(self):
  292. height, width = 3, 3
  293. with self.test_session():
  294. images = tf.random_uniform((5, height, width, 3), seed=1)
  295. output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
  296. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  297. def testGlobalMaxPool(self):
  298. height, width = 3, 3
  299. with self.test_session():
  300. images = tf.random_uniform((5, height, width, 3), seed=1)
  301. output = ops.max_pool(images, images.get_shape()[1:3], stride=1)
  302. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  303. class AvgPoolTest(tf.test.TestCase):
  304. def testCreateAvgPool(self):
  305. height, width = 3, 3
  306. with self.test_session():
  307. images = tf.random_uniform((5, height, width, 3), seed=1)
  308. output = ops.avg_pool(images, [3, 3])
  309. self.assertEquals(output.op.name, 'AvgPool/AvgPool')
  310. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  311. def testCreateSquareAvgPool(self):
  312. height, width = 3, 3
  313. with self.test_session():
  314. images = tf.random_uniform((5, height, width, 3), seed=1)
  315. output = ops.avg_pool(images, 3)
  316. self.assertEquals(output.op.name, 'AvgPool/AvgPool')
  317. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  318. def testCreateAvgPoolWithScope(self):
  319. height, width = 3, 3
  320. with self.test_session():
  321. images = tf.random_uniform((5, height, width, 3), seed=1)
  322. output = ops.avg_pool(images, [3, 3], scope='pool1')
  323. self.assertEquals(output.op.name, 'pool1/AvgPool')
  324. def testCreateAvgPoolSAME(self):
  325. height, width = 3, 3
  326. with self.test_session():
  327. images = tf.random_uniform((5, height, width, 3), seed=1)
  328. output = ops.avg_pool(images, [3, 3], padding='SAME')
  329. self.assertListEqual(output.get_shape().as_list(), [5, 2, 2, 3])
  330. def testCreateAvgPoolStrideSAME(self):
  331. height, width = 3, 3
  332. with self.test_session():
  333. images = tf.random_uniform((5, height, width, 3), seed=1)
  334. output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
  335. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  336. def testGlobalAvgPool(self):
  337. height, width = 3, 3
  338. with self.test_session():
  339. images = tf.random_uniform((5, height, width, 3), seed=1)
  340. output = ops.avg_pool(images, images.get_shape()[1:3], stride=1)
  341. self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
  342. class OneHotEncodingTest(tf.test.TestCase):
  343. def testOneHotEncodingCreate(self):
  344. with self.test_session():
  345. labels = tf.constant([0, 1, 2])
  346. output = ops.one_hot_encoding(labels, num_classes=3)
  347. self.assertEquals(output.op.name, 'OneHotEncoding/SparseToDense')
  348. self.assertListEqual(output.get_shape().as_list(), [3, 3])
  349. def testOneHotEncoding(self):
  350. with self.test_session():
  351. labels = tf.constant([0, 1, 2])
  352. one_hot_labels = tf.constant([[1, 0, 0],
  353. [0, 1, 0],
  354. [0, 0, 1]])
  355. output = ops.one_hot_encoding(labels, num_classes=3)
  356. self.assertAllClose(output.eval(), one_hot_labels.eval())
  357. class DropoutTest(tf.test.TestCase):
  358. def testCreateDropout(self):
  359. height, width = 3, 3
  360. with self.test_session():
  361. images = tf.random_uniform((5, height, width, 3), seed=1)
  362. output = ops.dropout(images)
  363. self.assertEquals(output.op.name, 'Dropout/dropout/mul_1')
  364. output.get_shape().assert_is_compatible_with(images.get_shape())
  365. def testCreateDropoutNoTraining(self):
  366. height, width = 3, 3
  367. with self.test_session():
  368. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  369. output = ops.dropout(images, is_training=False)
  370. self.assertEquals(output, images)
  371. class FlattenTest(tf.test.TestCase):
  372. def testFlatten4D(self):
  373. height, width = 3, 3
  374. with self.test_session():
  375. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  376. output = ops.flatten(images)
  377. self.assertEquals(output.get_shape().num_elements(),
  378. images.get_shape().num_elements())
  379. self.assertEqual(output.get_shape()[0], images.get_shape()[0])
  380. def testFlatten3D(self):
  381. height, width = 3, 3
  382. with self.test_session():
  383. images = tf.random_uniform((5, height, width), seed=1, name='images')
  384. output = ops.flatten(images)
  385. self.assertEquals(output.get_shape().num_elements(),
  386. images.get_shape().num_elements())
  387. self.assertEqual(output.get_shape()[0], images.get_shape()[0])
  388. def testFlattenBatchSize(self):
  389. height, width = 3, 3
  390. with self.test_session() as sess:
  391. images = tf.random_uniform((5, height, width, 3), seed=1, name='images')
  392. inputs = tf.placeholder(tf.int32, (None, height, width, 3))
  393. output = ops.flatten(inputs)
  394. self.assertEquals(output.get_shape().as_list(),
  395. [None, height * width * 3])
  396. output = sess.run(output, {inputs: images.eval()})
  397. self.assertEquals(output.size,
  398. images.get_shape().num_elements())
  399. self.assertEqual(output.shape[0], images.get_shape()[0])
  400. class BatchNormTest(tf.test.TestCase):
  401. def testCreateOp(self):
  402. height, width = 3, 3
  403. with self.test_session():
  404. images = tf.random_uniform((5, height, width, 3), seed=1)
  405. output = ops.batch_norm(images)
  406. self.assertTrue(output.op.name.startswith('BatchNorm/batchnorm'))
  407. self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
  408. def testCreateVariables(self):
  409. height, width = 3, 3
  410. with self.test_session():
  411. images = tf.random_uniform((5, height, width, 3), seed=1)
  412. ops.batch_norm(images)
  413. beta = variables.get_variables_by_name('beta')[0]
  414. self.assertEquals(beta.op.name, 'BatchNorm/beta')
  415. gamma = variables.get_variables_by_name('gamma')
  416. self.assertEquals(gamma, [])
  417. moving_mean = tf.moving_average_variables()[0]
  418. moving_variance = tf.moving_average_variables()[1]
  419. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  420. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  421. def testCreateVariablesWithScale(self):
  422. height, width = 3, 3
  423. with self.test_session():
  424. images = tf.random_uniform((5, height, width, 3), seed=1)
  425. ops.batch_norm(images, scale=True)
  426. beta = variables.get_variables_by_name('beta')[0]
  427. gamma = variables.get_variables_by_name('gamma')[0]
  428. self.assertEquals(beta.op.name, 'BatchNorm/beta')
  429. self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
  430. moving_mean = tf.moving_average_variables()[0]
  431. moving_variance = tf.moving_average_variables()[1]
  432. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  433. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  434. def testCreateVariablesWithoutCenterWithScale(self):
  435. height, width = 3, 3
  436. with self.test_session():
  437. images = tf.random_uniform((5, height, width, 3), seed=1)
  438. ops.batch_norm(images, center=False, scale=True)
  439. beta = variables.get_variables_by_name('beta')
  440. self.assertEquals(beta, [])
  441. gamma = variables.get_variables_by_name('gamma')[0]
  442. self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
  443. moving_mean = tf.moving_average_variables()[0]
  444. moving_variance = tf.moving_average_variables()[1]
  445. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  446. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  447. def testCreateVariablesWithoutCenterWithoutScale(self):
  448. height, width = 3, 3
  449. with self.test_session():
  450. images = tf.random_uniform((5, height, width, 3), seed=1)
  451. ops.batch_norm(images, center=False, scale=False)
  452. beta = variables.get_variables_by_name('beta')
  453. self.assertEquals(beta, [])
  454. gamma = variables.get_variables_by_name('gamma')
  455. self.assertEquals(gamma, [])
  456. moving_mean = tf.moving_average_variables()[0]
  457. moving_variance = tf.moving_average_variables()[1]
  458. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  459. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  460. def testMovingAverageVariables(self):
  461. height, width = 3, 3
  462. with self.test_session():
  463. images = tf.random_uniform((5, height, width, 3), seed=1)
  464. ops.batch_norm(images, scale=True)
  465. moving_mean = tf.moving_average_variables()[0]
  466. moving_variance = tf.moving_average_variables()[1]
  467. self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
  468. self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
  469. def testUpdateOps(self):
  470. height, width = 3, 3
  471. with self.test_session():
  472. images = tf.random_uniform((5, height, width, 3), seed=1)
  473. ops.batch_norm(images)
  474. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  475. update_moving_mean = update_ops[0]
  476. update_moving_variance = update_ops[1]
  477. self.assertEquals(update_moving_mean.op.name,
  478. 'BatchNorm/AssignMovingAvg')
  479. self.assertEquals(update_moving_variance.op.name,
  480. 'BatchNorm/AssignMovingAvg_1')
  481. def testReuseVariables(self):
  482. height, width = 3, 3
  483. with self.test_session():
  484. images = tf.random_uniform((5, height, width, 3), seed=1)
  485. ops.batch_norm(images, scale=True, scope='bn')
  486. ops.batch_norm(images, scale=True, scope='bn', reuse=True)
  487. beta = variables.get_variables_by_name('beta')
  488. gamma = variables.get_variables_by_name('gamma')
  489. self.assertEquals(len(beta), 1)
  490. self.assertEquals(len(gamma), 1)
  491. moving_vars = tf.get_collection('moving_vars')
  492. self.assertEquals(len(moving_vars), 2)
  493. def testReuseUpdateOps(self):
  494. height, width = 3, 3
  495. with self.test_session():
  496. images = tf.random_uniform((5, height, width, 3), seed=1)
  497. ops.batch_norm(images, scope='bn')
  498. self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
  499. ops.batch_norm(images, scope='bn', reuse=True)
  500. self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
  501. def testCreateMovingVars(self):
  502. height, width = 3, 3
  503. with self.test_session():
  504. images = tf.random_uniform((5, height, width, 3), seed=1)
  505. _ = ops.batch_norm(images, moving_vars='moving_vars')
  506. moving_mean = tf.get_collection('moving_vars',
  507. 'BatchNorm/moving_mean')
  508. self.assertEquals(len(moving_mean), 1)
  509. self.assertEquals(moving_mean[0].op.name, 'BatchNorm/moving_mean')
  510. moving_variance = tf.get_collection('moving_vars',
  511. 'BatchNorm/moving_variance')
  512. self.assertEquals(len(moving_variance), 1)
  513. self.assertEquals(moving_variance[0].op.name, 'BatchNorm/moving_variance')
  514. def testComputeMovingVars(self):
  515. height, width = 3, 3
  516. with self.test_session() as sess:
  517. image_shape = (10, height, width, 3)
  518. image_values = np.random.rand(*image_shape)
  519. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  520. expected_var = np.var(image_values, axis=(0, 1, 2))
  521. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  522. output = ops.batch_norm(images, decay=0.1)
  523. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  524. with tf.control_dependencies(update_ops):
  525. barrier = tf.no_op(name='gradient_barrier')
  526. output = control_flow_ops.with_dependencies([barrier], output)
  527. # Initialize all variables
  528. sess.run(tf.initialize_all_variables())
  529. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  530. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  531. mean, variance = sess.run([moving_mean, moving_variance])
  532. # After initialization moving_mean == 0 and moving_variance == 1.
  533. self.assertAllClose(mean, [0] * 3)
  534. self.assertAllClose(variance, [1] * 3)
  535. for _ in range(10):
  536. sess.run([output])
  537. mean = moving_mean.eval()
  538. variance = moving_variance.eval()
  539. # After 10 updates with decay 0.1 moving_mean == expected_mean and
  540. # moving_variance == expected_var.
  541. self.assertAllClose(mean, expected_mean)
  542. self.assertAllClose(variance, expected_var)
  543. def testEvalMovingVars(self):
  544. height, width = 3, 3
  545. with self.test_session() as sess:
  546. image_shape = (10, height, width, 3)
  547. image_values = np.random.rand(*image_shape)
  548. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  549. expected_var = np.var(image_values, axis=(0, 1, 2))
  550. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  551. output = ops.batch_norm(images, decay=0.1, is_training=False)
  552. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  553. with tf.control_dependencies(update_ops):
  554. barrier = tf.no_op(name='gradient_barrier')
  555. output = control_flow_ops.with_dependencies([barrier], output)
  556. # Initialize all variables
  557. sess.run(tf.initialize_all_variables())
  558. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  559. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  560. mean, variance = sess.run([moving_mean, moving_variance])
  561. # After initialization moving_mean == 0 and moving_variance == 1.
  562. self.assertAllClose(mean, [0] * 3)
  563. self.assertAllClose(variance, [1] * 3)
  564. # Simulate assigment from saver restore.
  565. init_assigns = [tf.assign(moving_mean, expected_mean),
  566. tf.assign(moving_variance, expected_var)]
  567. sess.run(init_assigns)
  568. for _ in range(10):
  569. sess.run([output], {images: np.random.rand(*image_shape)})
  570. mean = moving_mean.eval()
  571. variance = moving_variance.eval()
  572. # Although we feed different images, the moving_mean and moving_variance
  573. # shouldn't change.
  574. self.assertAllClose(mean, expected_mean)
  575. self.assertAllClose(variance, expected_var)
  576. def testReuseVars(self):
  577. height, width = 3, 3
  578. with self.test_session() as sess:
  579. image_shape = (10, height, width, 3)
  580. image_values = np.random.rand(*image_shape)
  581. expected_mean = np.mean(image_values, axis=(0, 1, 2))
  582. expected_var = np.var(image_values, axis=(0, 1, 2))
  583. images = tf.constant(image_values, shape=image_shape, dtype=tf.float32)
  584. output = ops.batch_norm(images, decay=0.1, is_training=False)
  585. update_ops = tf.get_collection(ops.UPDATE_OPS_COLLECTION)
  586. with tf.control_dependencies(update_ops):
  587. barrier = tf.no_op(name='gradient_barrier')
  588. output = control_flow_ops.with_dependencies([barrier], output)
  589. # Initialize all variables
  590. sess.run(tf.initialize_all_variables())
  591. moving_mean = variables.get_variables('BatchNorm/moving_mean')[0]
  592. moving_variance = variables.get_variables('BatchNorm/moving_variance')[0]
  593. mean, variance = sess.run([moving_mean, moving_variance])
  594. # After initialization moving_mean == 0 and moving_variance == 1.
  595. self.assertAllClose(mean, [0] * 3)
  596. self.assertAllClose(variance, [1] * 3)
  597. # Simulate assigment from saver restore.
  598. init_assigns = [tf.assign(moving_mean, expected_mean),
  599. tf.assign(moving_variance, expected_var)]
  600. sess.run(init_assigns)
  601. for _ in range(10):
  602. sess.run([output], {images: np.random.rand(*image_shape)})
  603. mean = moving_mean.eval()
  604. variance = moving_variance.eval()
  605. # Although we feed different images, the moving_mean and moving_variance
  606. # shouldn't change.
  607. self.assertAllClose(mean, expected_mean)
  608. self.assertAllClose(variance, expected_var)
  609. if __name__ == '__main__':
  610. tf.test.main()