model_deploy.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Deploy Slim models across multiple clones and replicas.
  16. # TODO(sguada) docstring paragraph by (a) motivating the need for the file and
  17. # (b) defining clones.
  18. # TODO(sguada) describe the high-level components of model deployment.
  19. # E.g. "each model deployment is composed of several parts: a DeploymentConfig,
  20. # which captures A, B and C, an input_fn which loads data.. etc
  21. To easily train a model on multiple GPUs or across multiple machines this
  22. module provides a set of helper functions: `create_clones`,
  23. `optimize_clones` and `deploy`.
  24. Usage:
  25. g = tf.Graph()
  26. # Set up DeploymentConfig
  27. config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
  28. # Create the global step on the device storing the variables.
  29. with tf.device(config.variables_device()):
  30. global_step = slim.create_global_step()
  31. # Define the inputs
  32. with tf.device(config.inputs_device()):
  33. images, labels = LoadData(...)
  34. inputs_queue = slim.data.prefetch_queue((images, labels))
  35. # Define the optimizer.
  36. with tf.device(config.optimizer_device()):
  37. optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
  38. # Define the model including the loss.
  39. def model_fn(inputs_queue):
  40. images, labels = inputs_queue.dequeue()
  41. predictions = CreateNetwork(images)
  42. slim.losses.log_loss(predictions, labels)
  43. model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
  44. optimizer=optimizer)
  45. # Run training.
  46. slim.learning.train(model_dp.train_op, my_log_dir,
  47. summary_op=model_dp.summary_op)
  48. The Clone namedtuple holds together the values associated with each call to
  49. model_fn:
  50. * outputs: The return values of the calls to `model_fn()`.
  51. * scope: The scope used to create the clone.
  52. * device: The device used to create the clone.
  53. DeployedModel namedtuple, holds together the values needed to train multiple
  54. clones:
  55. * train_op: An operation that run the optimizer training op and include
  56. all the update ops created by `model_fn`. Present only if an optimizer
  57. was specified.
  58. * summary_op: An operation that run the summaries created by `model_fn`
  59. and process_gradients.
  60. * total_loss: A `Tensor` that contains the sum of all losses created by
  61. `model_fn` plus the regularization losses.
  62. * clones: List of `Clone` tuples returned by `create_clones()`.
  63. DeploymentConfig parameters:
  64. * num_clones: Number of model clones to deploy in each replica.
  65. * clone_on_cpu: True if clones should be placed on CPU.
  66. * replica_id: Integer. Index of the replica for which the model is
  67. deployed. Usually 0 for the chief replica.
  68. * num_replicas: Number of replicas to use.
  69. * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
  70. * worker_job_name: A name for the worker job.
  71. * ps_job_name: A name for the parameter server job.
  72. TODO(sguada):
  73. - describe side effect to the graph.
  74. - what happens to summaries and update_ops.
  75. - which graph collections are altered.
  76. - write a tutorial on how to use this.
  77. - analyze the possibility of calling deploy more than once.
  78. """
  79. from __future__ import absolute_import
  80. from __future__ import division
  81. from __future__ import print_function
  82. import collections
  83. import tensorflow as tf
  84. from tensorflow.python.ops import control_flow_ops
  85. slim = tf.contrib.slim
  86. __all__ = ['create_clones',
  87. 'deploy',
  88. 'optimize_clones',
  89. 'DeployedModel',
  90. 'DeploymentConfig',
  91. 'Clone',
  92. ]
  93. # Namedtuple used to represent a clone during deployment.
  94. Clone = collections.namedtuple('Clone',
  95. ['outputs', # Whatever model_fn() returned.
  96. 'scope', # The scope used to create it.
  97. 'device', # The device used to create.
  98. ])
  99. # Namedtuple used to represent a DeployedModel, returned by deploy().
  100. DeployedModel = collections.namedtuple('DeployedModel',
  101. ['train_op', # The `train_op`
  102. 'summary_op', # The `summary_op`
  103. 'total_loss', # The loss `Tensor`
  104. 'clones', # A list of `Clones` tuples.
  105. ])
  106. # Default parameters for DeploymentConfig
  107. _deployment_params = {'num_clones': 1,
  108. 'clone_on_cpu': False,
  109. 'replica_id': 0,
  110. 'num_replicas': 1,
  111. 'num_ps_tasks': 0,
  112. 'worker_job_name': 'worker',
  113. 'ps_job_name': 'ps'}
  114. def create_clones(config, model_fn, args=None, kwargs=None):
  115. """Creates multiple clones according to config using a `model_fn`.
  116. The returned values of `model_fn(*args, **kwargs)` are collected along with
  117. the scope and device used to created it in a namedtuple
  118. `Clone(outputs, scope, device)`
  119. Note: it is assumed that any loss created by `model_fn` is collected at
  120. the tf.GraphKeys.LOSSES collection.
  121. To recover the losses, summaries or update_ops created by the clone use:
  122. ```python
  123. losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
  124. summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
  125. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
  126. ```
  127. The deployment options are specified by the config object and support
  128. deploying one or several clones on different GPUs and one or several replicas
  129. of such clones.
  130. The argument `model_fn` is called `config.num_clones` times to create the
  131. model clones as `model_fn(*args, **kwargs)`.
  132. If `config` specifies deployment on multiple replicas then the default
  133. tensorflow device is set appropriatly for each call to `model_fn` and for the
  134. slim variable creation functions: model and global variables will be created
  135. on the `ps` device, the clone operations will be on the `worker` device.
  136. Args:
  137. config: A DeploymentConfig object.
  138. model_fn: A callable. Called as `model_fn(*args, **kwargs)`
  139. args: Optional list of arguments to pass to `model_fn`.
  140. kwargs: Optional list of keyword arguments to pass to `model_fn`.
  141. Returns:
  142. A list of namedtuples `Clone`.
  143. """
  144. clones = []
  145. args = args or []
  146. kwargs = kwargs or {}
  147. with slim.arg_scope([slim.model_variable, slim.variable],
  148. device=config.variables_device()):
  149. # Create clones.
  150. for i in range(0, config.num_clones):
  151. with tf.name_scope(config.clone_scope(i)) as clone_scope:
  152. clone_device = config.clone_device(i)
  153. with tf.device(clone_device):
  154. with tf.variable_scope(tf.get_variable_scope(),
  155. reuse=True if i > 0 else None):
  156. outputs = model_fn(*args, **kwargs)
  157. clones.append(Clone(outputs, clone_scope, clone_device))
  158. return clones
  159. def _gather_clone_loss(clone, num_clones, regularization_losses):
  160. """Gather the loss for a single clone.
  161. Args:
  162. clone: A Clone namedtuple.
  163. num_clones: The number of clones being deployed.
  164. regularization_losses: Possibly empty list of regularization_losses
  165. to add to the clone losses.
  166. Returns:
  167. A tensor for the total loss for the clone. Can be None.
  168. """
  169. # The return value.
  170. sum_loss = None
  171. # Individual components of the loss that will need summaries.
  172. clone_loss = None
  173. regularization_loss = None
  174. # Compute and aggregate losses on the clone device.
  175. with tf.device(clone.device):
  176. all_losses = []
  177. clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
  178. if clone_losses:
  179. clone_loss = tf.add_n(clone_losses, name='clone_loss')
  180. if num_clones > 1:
  181. clone_loss = tf.div(clone_loss, 1.0 * num_clones,
  182. name='scaled_clone_loss')
  183. all_losses.append(clone_loss)
  184. if regularization_losses:
  185. regularization_loss = tf.add_n(regularization_losses,
  186. name='regularization_loss')
  187. all_losses.append(regularization_loss)
  188. if all_losses:
  189. sum_loss = tf.add_n(all_losses)
  190. # Add the summaries out of the clone device block.
  191. if clone_loss is not None:
  192. tf.scalar_summary(clone.scope + '/clone_loss', clone_loss,
  193. name='clone_loss')
  194. if regularization_loss is not None:
  195. tf.scalar_summary('regularization_loss', regularization_loss,
  196. name='regularization_loss')
  197. return sum_loss
  198. def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
  199. **kwargs):
  200. """Compute losses and gradients for a single clone.
  201. Args:
  202. optimizer: A tf.Optimizer object.
  203. clone: A Clone namedtuple.
  204. num_clones: The number of clones being deployed.
  205. regularization_losses: Possibly empty list of regularization_losses
  206. to add to the clone losses.
  207. **kwargs: Dict of kwarg to pass to compute_gradients().
  208. Returns:
  209. A tuple (clone_loss, clone_grads_and_vars).
  210. - clone_loss: A tensor for the total loss for the clone. Can be None.
  211. - clone_grads_and_vars: List of (gradient, variable) for the clone.
  212. Can be empty.
  213. """
  214. sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
  215. clone_grad = None
  216. if sum_loss is not None:
  217. with tf.device(clone.device):
  218. clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
  219. return sum_loss, clone_grad
  220. def optimize_clones(clones, optimizer,
  221. regularization_losses=None,
  222. **kwargs):
  223. """Compute clone losses and gradients for the given list of `Clones`.
  224. Note: The regularization_losses are added to the first clone losses.
  225. Args:
  226. clones: List of `Clones` created by `create_clones()`.
  227. optimizer: An `Optimizer` object.
  228. regularization_losses: Optional list of regularization losses. If None it
  229. will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
  230. exclude them.
  231. **kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
  232. Returns:
  233. A tuple (total_loss, grads_and_vars).
  234. - total_loss: A Tensor containing the average of the clone losses including
  235. the regularization loss.
  236. - grads_and_vars: A List of tuples (gradient, variable) containing the sum
  237. of the gradients for each variable.
  238. """
  239. grads_and_vars = []
  240. clones_losses = []
  241. num_clones = len(clones)
  242. if regularization_losses is None:
  243. regularization_losses = tf.get_collection(
  244. tf.GraphKeys.REGULARIZATION_LOSSES)
  245. for clone in clones:
  246. with tf.name_scope(clone.scope):
  247. clone_loss, clone_grad = _optimize_clone(
  248. optimizer, clone, num_clones, regularization_losses, **kwargs)
  249. if clone_loss is not None:
  250. clones_losses.append(clone_loss)
  251. grads_and_vars.append(clone_grad)
  252. # Only use regularization_losses for the first clone
  253. regularization_losses = None
  254. # Compute the total_loss summing all the clones_losses.
  255. total_loss = tf.add_n(clones_losses, name='total_loss')
  256. # Sum the gradients across clones.
  257. grads_and_vars = _sum_clones_gradients(grads_and_vars)
  258. return total_loss, grads_and_vars
  259. def deploy(config,
  260. model_fn,
  261. args=None,
  262. kwargs=None,
  263. optimizer=None,
  264. summarize_gradients=False):
  265. """Deploys a Slim-constructed model across multiple clones.
  266. The deployment options are specified by the config object and support
  267. deploying one or several clones on different GPUs and one or several replicas
  268. of such clones.
  269. The argument `model_fn` is called `config.num_clones` times to create the
  270. model clones as `model_fn(*args, **kwargs)`.
  271. The optional argument `optimizer` is an `Optimizer` object. If not `None`,
  272. the deployed model is configured for training with that optimizer.
  273. If `config` specifies deployment on multiple replicas then the default
  274. tensorflow device is set appropriatly for each call to `model_fn` and for the
  275. slim variable creation functions: model and global variables will be created
  276. on the `ps` device, the clone operations will be on the `worker` device.
  277. Args:
  278. config: A `DeploymentConfig` object.
  279. model_fn: A callable. Called as `model_fn(*args, **kwargs)`
  280. args: Optional list of arguments to pass to `model_fn`.
  281. kwargs: Optional list of keyword arguments to pass to `model_fn`.
  282. optimizer: Optional `Optimizer` object. If passed the model is deployed
  283. for training with that optimizer.
  284. summarize_gradients: Whether or not add summaries to the gradients.
  285. Returns:
  286. A `DeployedModel` namedtuple.
  287. """
  288. # Gather initial summaries.
  289. summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
  290. # Create Clones.
  291. clones = create_clones(config, model_fn, args, kwargs)
  292. first_clone = clones[0]
  293. # Gather update_ops from the first clone. These contain, for example,
  294. # the updates for the batch_norm variables created by model_fn.
  295. update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope)
  296. train_op = None
  297. total_loss = None
  298. with tf.device(config.optimizer_device()):
  299. if optimizer:
  300. # Place the global step on the device storing the variables.
  301. with tf.device(config.variables_device()):
  302. global_step = slim.get_or_create_global_step()
  303. # Compute the gradients for the clones.
  304. total_loss, clones_gradients = optimize_clones(clones, optimizer)
  305. if clones_gradients:
  306. if summarize_gradients:
  307. # Add summaries to the gradients.
  308. summaries |= set(_add_gradients_summaries(clones_gradients))
  309. # Create gradient updates.
  310. grad_updates = optimizer.apply_gradients(clones_gradients,
  311. global_step=global_step)
  312. update_ops.append(grad_updates)
  313. update_op = tf.group(*update_ops)
  314. train_op = control_flow_ops.with_dependencies([update_op], total_loss,
  315. name='train_op')
  316. else:
  317. clones_losses = []
  318. regularization_losses = tf.get_collection(
  319. tf.GraphKeys.REGULARIZATION_LOSSES)
  320. for clone in clones:
  321. with tf.name_scope(clone.scope):
  322. clone_loss = _gather_clone_loss(clone, len(clones),
  323. regularization_losses)
  324. if clone_loss is not None:
  325. clones_losses.append(clone_loss)
  326. # Only use regularization_losses for the first clone
  327. regularization_losses = None
  328. if clones_losses:
  329. total_loss = tf.add_n(clones_losses, name='total_loss')
  330. # Add the summaries from the first clone. These contain the summaries
  331. # created by model_fn and either optimize_clones() or _gather_clone_loss().
  332. summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
  333. first_clone.scope))
  334. if total_loss is not None:
  335. # Add total_loss to summary.
  336. summaries.add(tf.scalar_summary('total_loss', total_loss,
  337. name='total_loss'))
  338. if summaries:
  339. # Merge all summaries together.
  340. summary_op = tf.merge_summary(list(summaries), name='summary_op')
  341. else:
  342. summary_op = None
  343. return DeployedModel(train_op, summary_op, total_loss, clones)
  344. def _sum_clones_gradients(clone_grads):
  345. """Calculate the sum gradient for each shared variable across all clones.
  346. This function assumes that the clone_grads has been scaled appropriately by
  347. 1 / num_clones.
  348. Args:
  349. clone_grads: A List of List of tuples (gradient, variable), one list per
  350. `Clone`.
  351. Returns:
  352. List of tuples of (gradient, variable) where the gradient has been summed
  353. across all clones.
  354. """
  355. sum_grads = []
  356. for grad_and_vars in zip(*clone_grads):
  357. # Note that each grad_and_vars looks like the following:
  358. # ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN))
  359. grads = []
  360. var = grad_and_vars[0][1]
  361. for g, v in grad_and_vars:
  362. assert v == var
  363. if g is not None:
  364. grads.append(g)
  365. if grads:
  366. if len(grads) > 1:
  367. sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads')
  368. else:
  369. sum_grad = grads[0]
  370. sum_grads.append((sum_grad, var))
  371. return sum_grads
  372. def _add_gradients_summaries(grads_and_vars):
  373. """Add histogram summaries to gradients.
  374. Note: The summaries are also added to the SUMMARIES collection.
  375. Args:
  376. grads_and_vars: A list of gradient to variable pairs (tuples).
  377. Returns:
  378. The _list_ of the added summaries for grads_and_vars.
  379. """
  380. summaries = []
  381. for grad, var in grads_and_vars:
  382. if grad is not None:
  383. if isinstance(grad, tf.IndexedSlices):
  384. grad_values = grad.values
  385. else:
  386. grad_values = grad
  387. summaries.append(tf.histogram_summary(var.op.name + ':gradient',
  388. grad_values))
  389. summaries.append(tf.histogram_summary(var.op.name + ':gradient_norm',
  390. tf.global_norm([grad_values])))
  391. else:
  392. tf.logging.info('Var %s has no gradient', var.op.name)
  393. return summaries
  394. class DeploymentConfig(object):
  395. """Configuration for deploying a model with `deploy()`.
  396. You can pass an instance of this class to `deploy()` to specify exactly
  397. how to deploy the model to build. If you do not pass one, an instance built
  398. from the default deployment_hparams will be used.
  399. """
  400. def __init__(self,
  401. num_clones=1,
  402. clone_on_cpu=False,
  403. replica_id=0,
  404. num_replicas=1,
  405. num_ps_tasks=0,
  406. worker_job_name='worker',
  407. ps_job_name='ps'):
  408. """Create a DeploymentConfig.
  409. The config describes how to deploy a model across multiple clones and
  410. replicas. The model will be replicated `num_clones` times in each replica.
  411. If `clone_on_cpu` is True, each clone will placed on CPU.
  412. If `num_replicas` is 1, the model is deployed via a single process. In that
  413. case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored.
  414. If `num_replicas` is greater than 1, then `worker_device` and `ps_device`
  415. must specify TensorFlow devices for the `worker` and `ps` jobs and
  416. `num_ps_tasks` must be positive.
  417. Args:
  418. num_clones: Number of model clones to deploy in each replica.
  419. clone_on_cpu: If True clones would be placed on CPU.
  420. replica_id: Integer. Index of the replica for which the model is
  421. deployed. Usually 0 for the chief replica.
  422. num_replicas: Number of replicas to use.
  423. num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
  424. worker_job_name: A name for the worker job.
  425. ps_job_name: A name for the parameter server job.
  426. Raises:
  427. ValueError: If the arguments are invalid.
  428. """
  429. if num_replicas > 1:
  430. if num_ps_tasks < 1:
  431. raise ValueError('When using replicas num_ps_tasks must be positive')
  432. if num_replicas > 1 or num_ps_tasks > 0:
  433. if not worker_job_name:
  434. raise ValueError('Must specify worker_job_name when using replicas')
  435. if not ps_job_name:
  436. raise ValueError('Must specify ps_job_name when using parameter server')
  437. if replica_id >= num_replicas:
  438. raise ValueError('replica_id must be less than num_replicas')
  439. self._num_clones = num_clones
  440. self._clone_on_cpu = clone_on_cpu
  441. self._replica_id = replica_id
  442. self._num_replicas = num_replicas
  443. self._num_ps_tasks = num_ps_tasks
  444. self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else ''
  445. self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else ''
  446. @property
  447. def num_clones(self):
  448. return self._num_clones
  449. @property
  450. def clone_on_cpu(self):
  451. return self._clone_on_cpu
  452. @property
  453. def replica_id(self):
  454. return self._replica_id
  455. @property
  456. def num_replicas(self):
  457. return self._num_replicas
  458. @property
  459. def num_ps_tasks(self):
  460. return self._num_ps_tasks
  461. @property
  462. def ps_device(self):
  463. return self._ps_device
  464. @property
  465. def worker_device(self):
  466. return self._worker_device
  467. def caching_device(self):
  468. """Returns the device to use for caching variables.
  469. Variables are cached on the worker CPU when using replicas.
  470. Returns:
  471. A device string or None if the variables do not need to be cached.
  472. """
  473. if self._num_ps_tasks > 0:
  474. return lambda op: op.device
  475. else:
  476. return None
  477. def clone_device(self, clone_index):
  478. """Device used to create the clone and all the ops inside the clone.
  479. Args:
  480. clone_index: Int, representing the clone_index.
  481. Returns:
  482. A value suitable for `tf.device()`.
  483. Raises:
  484. ValueError: if `clone_index` is greater or equal to the number of clones".
  485. """
  486. if clone_index >= self._num_clones:
  487. raise ValueError('clone_index must be less than num_clones')
  488. device = ''
  489. if self._num_ps_tasks > 0:
  490. device += self._worker_device
  491. if self._clone_on_cpu:
  492. device += '/device:CPU:0'
  493. else:
  494. if self._num_clones > 1:
  495. device += '/device:GPU:%d' % clone_index
  496. return device
  497. def clone_scope(self, clone_index):
  498. """Name scope to create the clone.
  499. Args:
  500. clone_index: Int, representing the clone_index.
  501. Returns:
  502. A name_scope suitable for `tf.name_scope()`.
  503. Raises:
  504. ValueError: if `clone_index` is greater or equal to the number of clones".
  505. """
  506. if clone_index >= self._num_clones:
  507. raise ValueError('clone_index must be less than num_clones')
  508. scope = ''
  509. if self._num_clones > 1:
  510. scope = 'clone_%d' % clone_index
  511. return scope
  512. def optimizer_device(self):
  513. """Device to use with the optimizer.
  514. Returns:
  515. A value suitable for `tf.device()`.
  516. """
  517. if self._num_ps_tasks > 0 or self._num_clones > 0:
  518. return self._worker_device + '/device:CPU:0'
  519. else:
  520. return ''
  521. def inputs_device(self):
  522. """Device to use to build the inputs.
  523. Returns:
  524. A value suitable for `tf.device()`.
  525. """
  526. device = ''
  527. if self._num_ps_tasks > 0:
  528. device += self._worker_device
  529. device += '/device:CPU:0'
  530. return device
  531. def variables_device(self):
  532. """Returns the device to use for variables created inside the clone.
  533. Returns:
  534. A value suitable for `tf.device()`.
  535. """
  536. device = ''
  537. if self._num_ps_tasks > 0:
  538. device += self._ps_device
  539. device += '/device:CPU:0'
  540. class _PSDeviceChooser(object):
  541. """Slim device chooser for variables when using PS."""
  542. def __init__(self, device, tasks):
  543. self._device = device
  544. self._tasks = tasks
  545. self._task = 0
  546. def choose(self, op):
  547. if op.device:
  548. return op.device
  549. node_def = op if isinstance(op, tf.NodeDef) else op.node_def
  550. if node_def.op == 'Variable':
  551. t = self._task
  552. self._task = (self._task + 1) % self._tasks
  553. d = '%s/task:%d' % (self._device, t)
  554. return d
  555. else:
  556. return op.device
  557. if not self._num_ps_tasks:
  558. return device
  559. else:
  560. chooser = _PSDeviceChooser(device, self._num_ps_tasks)
  561. return chooser.choose