model_deploy_test.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for model_deploy."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import tensorflow as tf
  21. from slim.models import model_deploy
  22. slim = tf.contrib.slim
  23. class DeploymentConfigTest(tf.test.TestCase):
  24. def testDefaults(self):
  25. deploy_config = model_deploy.DeploymentConfig()
  26. self.assertEqual(slim.get_variables(), [])
  27. self.assertEqual(deploy_config.caching_device(), None)
  28. self.assertDeviceEqual(deploy_config.clone_device(0), '')
  29. self.assertEqual(deploy_config.clone_scope(0), '')
  30. self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
  31. self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
  32. self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
  33. def testCPUonly(self):
  34. deploy_config = model_deploy.DeploymentConfig(clone_on_cpu=True)
  35. self.assertEqual(deploy_config.caching_device(), None)
  36. self.assertDeviceEqual(deploy_config.clone_device(0), 'CPU:0')
  37. self.assertEqual(deploy_config.clone_scope(0), '')
  38. self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
  39. self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
  40. self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
  41. def testMultiGPU(self):
  42. deploy_config = model_deploy.DeploymentConfig(num_clones=2)
  43. self.assertEqual(deploy_config.caching_device(), None)
  44. self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0')
  45. self.assertDeviceEqual(deploy_config.clone_device(1), 'GPU:1')
  46. self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
  47. self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
  48. self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
  49. self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
  50. self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
  51. def testPS(self):
  52. deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1)
  53. self.assertDeviceEqual(deploy_config.clone_device(0),
  54. '/job:worker')
  55. self.assertEqual(deploy_config.clone_scope(0), '')
  56. self.assertDeviceEqual(deploy_config.optimizer_device(),
  57. '/job:worker/device:CPU:0')
  58. self.assertDeviceEqual(deploy_config.inputs_device(),
  59. '/job:worker/device:CPU:0')
  60. with tf.device(deploy_config.variables_device()):
  61. a = tf.Variable(0)
  62. b = tf.Variable(0)
  63. c = tf.no_op()
  64. d = slim.variable('a', [],
  65. caching_device=deploy_config.caching_device())
  66. self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
  67. self.assertDeviceEqual(a.device, a.value().device)
  68. self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0')
  69. self.assertDeviceEqual(b.device, b.value().device)
  70. self.assertDeviceEqual(c.device, '')
  71. self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
  72. self.assertDeviceEqual(d.value().device, '')
  73. def testMultiGPUPS(self):
  74. deploy_config = model_deploy.DeploymentConfig(num_clones=2, num_ps_tasks=1)
  75. self.assertEqual(deploy_config.caching_device()(tf.no_op()), '')
  76. self.assertDeviceEqual(deploy_config.clone_device(0),
  77. '/job:worker/device:GPU:0')
  78. self.assertDeviceEqual(deploy_config.clone_device(1),
  79. '/job:worker/device:GPU:1')
  80. self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
  81. self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
  82. self.assertDeviceEqual(deploy_config.optimizer_device(),
  83. '/job:worker/device:CPU:0')
  84. self.assertDeviceEqual(deploy_config.inputs_device(),
  85. '/job:worker/device:CPU:0')
  86. def testReplicasPS(self):
  87. deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
  88. num_ps_tasks=2)
  89. self.assertDeviceEqual(deploy_config.clone_device(0),
  90. '/job:worker')
  91. self.assertEqual(deploy_config.clone_scope(0), '')
  92. self.assertDeviceEqual(deploy_config.optimizer_device(),
  93. '/job:worker/device:CPU:0')
  94. self.assertDeviceEqual(deploy_config.inputs_device(),
  95. '/job:worker/device:CPU:0')
  96. def testReplicasMultiGPUPS(self):
  97. deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
  98. num_clones=2,
  99. num_ps_tasks=2)
  100. self.assertDeviceEqual(deploy_config.clone_device(0),
  101. '/job:worker/device:GPU:0')
  102. self.assertDeviceEqual(deploy_config.clone_device(1),
  103. '/job:worker/device:GPU:1')
  104. self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
  105. self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
  106. self.assertDeviceEqual(deploy_config.optimizer_device(),
  107. '/job:worker/device:CPU:0')
  108. self.assertDeviceEqual(deploy_config.inputs_device(),
  109. '/job:worker/device:CPU:0')
  110. def testVariablesPS(self):
  111. deploy_config = model_deploy.DeploymentConfig(num_ps_tasks=2)
  112. with tf.device(deploy_config.variables_device()):
  113. a = tf.Variable(0)
  114. b = tf.Variable(0)
  115. c = tf.no_op()
  116. d = slim.variable('a', [],
  117. caching_device=deploy_config.caching_device())
  118. self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
  119. self.assertDeviceEqual(a.device, a.value().device)
  120. self.assertDeviceEqual(b.device, '/job:ps/task:1/device:CPU:0')
  121. self.assertDeviceEqual(b.device, b.value().device)
  122. self.assertDeviceEqual(c.device, '')
  123. self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
  124. self.assertDeviceEqual(d.value().device, '')
  125. def LogisticClassifier(inputs, labels, scope=None, reuse=None):
  126. with tf.variable_scope(scope, 'LogisticClassifier', [inputs, labels],
  127. reuse=reuse):
  128. predictions = slim.fully_connected(inputs, 1, activation_fn=tf.sigmoid,
  129. scope='fully_connected')
  130. slim.losses.log_loss(predictions, labels)
  131. return predictions
  132. def BatchNormClassifier(inputs, labels, scope=None, reuse=None):
  133. with tf.variable_scope(scope, 'BatchNormClassifier', [inputs, labels],
  134. reuse=reuse):
  135. inputs = slim.batch_norm(inputs, decay=0.1)
  136. predictions = slim.fully_connected(inputs, 1,
  137. activation_fn=tf.sigmoid,
  138. scope='fully_connected')
  139. slim.losses.log_loss(predictions, labels)
  140. return predictions
  141. class CreatecloneTest(tf.test.TestCase):
  142. def setUp(self):
  143. # Create an easy training set:
  144. np.random.seed(0)
  145. self._inputs = np.zeros((16, 4))
  146. self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
  147. self._logdir = self.get_temp_dir()
  148. for i in range(16):
  149. j = int(2 * self._labels[i] + np.random.randint(0, 2))
  150. self._inputs[i, j] = 1
  151. def testCreateLogisticClassifier(self):
  152. g = tf.Graph()
  153. with g.as_default():
  154. tf.set_random_seed(0)
  155. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  156. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  157. model_fn = LogisticClassifier
  158. clone_args = (tf_inputs, tf_labels)
  159. deploy_config = model_deploy.DeploymentConfig(num_clones=1)
  160. self.assertEqual(slim.get_variables(), [])
  161. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  162. clone = clones[0]
  163. self.assertEqual(len(slim.get_variables()), 2)
  164. for v in slim.get_variables():
  165. self.assertDeviceEqual(v.device, 'CPU:0')
  166. self.assertDeviceEqual(v.value().device, 'CPU:0')
  167. self.assertEqual(clone.outputs.op.name,
  168. 'LogisticClassifier/fully_connected/Sigmoid')
  169. self.assertEqual(clone.scope, '')
  170. self.assertDeviceEqual(clone.device, '')
  171. self.assertEqual(len(slim.losses.get_losses()), 1)
  172. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  173. self.assertEqual(update_ops, [])
  174. def testCreateSingleclone(self):
  175. g = tf.Graph()
  176. with g.as_default():
  177. tf.set_random_seed(0)
  178. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  179. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  180. model_fn = BatchNormClassifier
  181. clone_args = (tf_inputs, tf_labels)
  182. deploy_config = model_deploy.DeploymentConfig(num_clones=1)
  183. self.assertEqual(slim.get_variables(), [])
  184. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  185. clone = clones[0]
  186. self.assertEqual(len(slim.get_variables()), 5)
  187. for v in slim.get_variables():
  188. self.assertDeviceEqual(v.device, 'CPU:0')
  189. self.assertDeviceEqual(v.value().device, 'CPU:0')
  190. self.assertEqual(clone.outputs.op.name,
  191. 'BatchNormClassifier/fully_connected/Sigmoid')
  192. self.assertEqual(clone.scope, '')
  193. self.assertDeviceEqual(clone.device, '')
  194. self.assertEqual(len(slim.losses.get_losses()), 1)
  195. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  196. self.assertEqual(len(update_ops), 2)
  197. def testCreateMulticlone(self):
  198. g = tf.Graph()
  199. with g.as_default():
  200. tf.set_random_seed(0)
  201. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  202. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  203. model_fn = BatchNormClassifier
  204. clone_args = (tf_inputs, tf_labels)
  205. num_clones = 4
  206. deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
  207. self.assertEqual(slim.get_variables(), [])
  208. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  209. self.assertEqual(len(slim.get_variables()), 5)
  210. for v in slim.get_variables():
  211. self.assertDeviceEqual(v.device, 'CPU:0')
  212. self.assertDeviceEqual(v.value().device, 'CPU:0')
  213. self.assertEqual(len(clones), num_clones)
  214. for i, clone in enumerate(clones):
  215. self.assertEqual(
  216. clone.outputs.op.name,
  217. 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
  218. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
  219. self.assertEqual(len(update_ops), 2)
  220. self.assertEqual(clone.scope, 'clone_%d/' % i)
  221. self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
  222. def testCreateOnecloneWithPS(self):
  223. g = tf.Graph()
  224. with g.as_default():
  225. tf.set_random_seed(0)
  226. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  227. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  228. model_fn = BatchNormClassifier
  229. clone_args = (tf_inputs, tf_labels)
  230. deploy_config = model_deploy.DeploymentConfig(num_clones=1,
  231. num_ps_tasks=1)
  232. self.assertEqual(slim.get_variables(), [])
  233. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  234. self.assertEqual(len(clones), 1)
  235. clone = clones[0]
  236. self.assertEqual(clone.outputs.op.name,
  237. 'BatchNormClassifier/fully_connected/Sigmoid')
  238. self.assertDeviceEqual(clone.device, '/job:worker')
  239. self.assertEqual(clone.scope, '')
  240. self.assertEqual(len(slim.get_variables()), 5)
  241. for v in slim.get_variables():
  242. self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
  243. self.assertDeviceEqual(v.device, v.value().device)
  244. def testCreateMulticloneWithPS(self):
  245. g = tf.Graph()
  246. with g.as_default():
  247. tf.set_random_seed(0)
  248. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  249. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  250. model_fn = BatchNormClassifier
  251. clone_args = (tf_inputs, tf_labels)
  252. deploy_config = model_deploy.DeploymentConfig(num_clones=2,
  253. num_ps_tasks=2)
  254. self.assertEqual(slim.get_variables(), [])
  255. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  256. self.assertEqual(len(slim.get_variables()), 5)
  257. for i, v in enumerate(slim.get_variables()):
  258. t = i % 2
  259. self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t)
  260. self.assertDeviceEqual(v.device, v.value().device)
  261. self.assertEqual(len(clones), 2)
  262. for i, clone in enumerate(clones):
  263. self.assertEqual(
  264. clone.outputs.op.name,
  265. 'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
  266. self.assertEqual(clone.scope, 'clone_%d/' % i)
  267. self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
  268. class OptimizeclonesTest(tf.test.TestCase):
  269. def setUp(self):
  270. # Create an easy training set:
  271. np.random.seed(0)
  272. self._inputs = np.zeros((16, 4))
  273. self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
  274. self._logdir = self.get_temp_dir()
  275. for i in range(16):
  276. j = int(2 * self._labels[i] + np.random.randint(0, 2))
  277. self._inputs[i, j] = 1
  278. def testCreateLogisticClassifier(self):
  279. g = tf.Graph()
  280. with g.as_default():
  281. tf.set_random_seed(0)
  282. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  283. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  284. model_fn = LogisticClassifier
  285. clone_args = (tf_inputs, tf_labels)
  286. deploy_config = model_deploy.DeploymentConfig(num_clones=1)
  287. self.assertEqual(slim.get_variables(), [])
  288. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  289. self.assertEqual(len(slim.get_variables()), 2)
  290. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  291. self.assertEqual(update_ops, [])
  292. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  293. total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
  294. optimizer)
  295. self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
  296. self.assertEqual(total_loss.op.name, 'total_loss')
  297. for g, v in grads_and_vars:
  298. self.assertDeviceEqual(g.device, '')
  299. self.assertDeviceEqual(v.device, 'CPU:0')
  300. def testCreateSingleclone(self):
  301. g = tf.Graph()
  302. with g.as_default():
  303. tf.set_random_seed(0)
  304. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  305. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  306. model_fn = BatchNormClassifier
  307. clone_args = (tf_inputs, tf_labels)
  308. deploy_config = model_deploy.DeploymentConfig(num_clones=1)
  309. self.assertEqual(slim.get_variables(), [])
  310. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  311. self.assertEqual(len(slim.get_variables()), 5)
  312. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  313. self.assertEqual(len(update_ops), 2)
  314. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  315. total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
  316. optimizer)
  317. self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
  318. self.assertEqual(total_loss.op.name, 'total_loss')
  319. for g, v in grads_and_vars:
  320. self.assertDeviceEqual(g.device, '')
  321. self.assertDeviceEqual(v.device, 'CPU:0')
  322. def testCreateMulticlone(self):
  323. g = tf.Graph()
  324. with g.as_default():
  325. tf.set_random_seed(0)
  326. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  327. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  328. model_fn = BatchNormClassifier
  329. clone_args = (tf_inputs, tf_labels)
  330. num_clones = 4
  331. deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
  332. self.assertEqual(slim.get_variables(), [])
  333. clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
  334. self.assertEqual(len(slim.get_variables()), 5)
  335. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  336. self.assertEqual(len(update_ops), num_clones * 2)
  337. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  338. total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
  339. optimizer)
  340. self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
  341. self.assertEqual(total_loss.op.name, 'total_loss')
  342. for g, v in grads_and_vars:
  343. self.assertDeviceEqual(g.device, '')
  344. self.assertDeviceEqual(v.device, 'CPU:0')
  345. def testCreateMulticloneCPU(self):
  346. g = tf.Graph()
  347. with g.as_default():
  348. tf.set_random_seed(0)
  349. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  350. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  351. model_fn = BatchNormClassifier
  352. model_args = (tf_inputs, tf_labels)
  353. num_clones = 4
  354. deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones,
  355. clone_on_cpu=True)
  356. self.assertEqual(slim.get_variables(), [])
  357. clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
  358. self.assertEqual(len(slim.get_variables()), 5)
  359. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  360. self.assertEqual(len(update_ops), num_clones * 2)
  361. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  362. total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
  363. optimizer)
  364. self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
  365. self.assertEqual(total_loss.op.name, 'total_loss')
  366. for g, v in grads_and_vars:
  367. self.assertDeviceEqual(g.device, '')
  368. self.assertDeviceEqual(v.device, 'CPU:0')
  369. def testCreateOnecloneWithPS(self):
  370. g = tf.Graph()
  371. with g.as_default():
  372. tf.set_random_seed(0)
  373. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  374. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  375. model_fn = BatchNormClassifier
  376. model_args = (tf_inputs, tf_labels)
  377. deploy_config = model_deploy.DeploymentConfig(num_clones=1,
  378. num_ps_tasks=1)
  379. self.assertEqual(slim.get_variables(), [])
  380. clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
  381. self.assertEqual(len(slim.get_variables()), 5)
  382. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  383. self.assertEqual(len(update_ops), 2)
  384. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  385. total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
  386. optimizer)
  387. self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
  388. self.assertEqual(total_loss.op.name, 'total_loss')
  389. for g, v in grads_and_vars:
  390. self.assertDeviceEqual(g.device, '/job:worker')
  391. self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
  392. class DeployTest(tf.test.TestCase):
  393. def setUp(self):
  394. # Create an easy training set:
  395. np.random.seed(0)
  396. self._inputs = np.zeros((16, 4))
  397. self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
  398. self._logdir = self.get_temp_dir()
  399. for i in range(16):
  400. j = int(2 * self._labels[i] + np.random.randint(0, 2))
  401. self._inputs[i, j] = 1
  402. def testLocalTrainOp(self):
  403. g = tf.Graph()
  404. with g.as_default():
  405. tf.set_random_seed(0)
  406. tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
  407. tf_labels = tf.constant(self._labels, dtype=tf.float32)
  408. model_fn = BatchNormClassifier
  409. model_args = (tf_inputs, tf_labels)
  410. deploy_config = model_deploy.DeploymentConfig(num_clones=2,
  411. clone_on_cpu=True)
  412. optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
  413. self.assertEqual(slim.get_variables(), [])
  414. model = model_deploy.deploy(deploy_config, model_fn, model_args,
  415. optimizer=optimizer)
  416. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  417. self.assertEqual(len(update_ops), 4)
  418. self.assertEqual(len(model.clones), 2)
  419. self.assertEqual(model.total_loss.op.name, 'total_loss')
  420. self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op')
  421. self.assertEqual(model.train_op.op.name, 'train_op')
  422. with tf.Session() as sess:
  423. sess.run(tf.initialize_all_variables())
  424. moving_mean = tf.contrib.framework.get_variables_by_name(
  425. 'moving_mean')[0]
  426. moving_variance = tf.contrib.framework.get_variables_by_name(
  427. 'moving_variance')[0]
  428. initial_loss = sess.run(model.total_loss)
  429. initial_mean, initial_variance = sess.run([moving_mean,
  430. moving_variance])
  431. self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0])
  432. self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0])
  433. for _ in range(10):
  434. sess.run(model.train_op)
  435. final_loss = sess.run(model.total_loss)
  436. self.assertLess(final_loss, initial_loss / 10.0)
  437. final_mean, final_variance = sess.run([moving_mean,
  438. moving_variance])
  439. self.assertAllClose(final_mean, [0.125, 0.25, 0.375, 0.25])
  440. self.assertAllClose(final_variance, [0.109375, 0.1875,
  441. 0.234375, 0.1875])
  442. def testNoSummariesOnGPU(self):
  443. with tf.Graph().as_default():
  444. deploy_config = model_deploy.DeploymentConfig(num_clones=2)
  445. # clone function creates a fully_connected layer with a regularizer loss.
  446. def ModelFn():
  447. inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
  448. reg = tf.contrib.layers.l2_regularizer(0.001)
  449. tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
  450. model = model_deploy.deploy(
  451. deploy_config, ModelFn,
  452. optimizer=tf.train.GradientDescentOptimizer(1.0))
  453. # The model summary op should have a few summary inputs and all of them
  454. # should be on the CPU.
  455. self.assertTrue(model.summary_op.op.inputs)
  456. for inp in model.summary_op.op.inputs:
  457. self.assertEqual('/device:CPU:0', inp.device)
  458. def testNoSummariesOnGPUForEvals(self):
  459. with tf.Graph().as_default():
  460. deploy_config = model_deploy.DeploymentConfig(num_clones=2)
  461. # clone function creates a fully_connected layer with a regularizer loss.
  462. def ModelFn():
  463. inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
  464. reg = tf.contrib.layers.l2_regularizer(0.001)
  465. tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
  466. # No optimizer here, it's an eval.
  467. model = model_deploy.deploy(deploy_config, ModelFn)
  468. # The model summary op should have a few summary inputs and all of them
  469. # should be on the CPU.
  470. self.assertTrue(model.summary_op.op.inputs)
  471. for inp in model.summary_op.op.inputs:
  472. self.assertEqual('/device:CPU:0', inp.device)
  473. if __name__ == '__main__':
  474. tf.test.main()