composite_optimizer.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. """An optimizer that switches between several methods."""
  2. import tensorflow as tf
  3. from tensorflow.python.training import optimizer
  4. class CompositeOptimizer(optimizer.Optimizer):
  5. """Optimizer that switches between several methods.
  6. """
  7. def __init__(self,
  8. optimizer1,
  9. optimizer2,
  10. switch,
  11. use_locking=False,
  12. name='Composite'):
  13. """Construct a new Composite optimizer.
  14. Args:
  15. optimizer1: A tf.python.training.optimizer.Optimizer object.
  16. optimizer2: A tf.python.training.optimizer.Optimizer object.
  17. switch: A tf.bool Tensor, selecting whether to use the first or the second
  18. optimizer.
  19. use_locking: Bool. If True apply use locks to prevent concurrent updates
  20. to variables.
  21. name: Optional name prefix for the operations created when applying
  22. gradients. Defaults to "Composite".
  23. """
  24. super(CompositeOptimizer, self).__init__(use_locking, name)
  25. self._optimizer1 = optimizer1
  26. self._optimizer2 = optimizer2
  27. self._switch = switch
  28. def apply_gradients(self, grads_and_vars, global_step=None, name=None):
  29. return tf.cond(
  30. self._switch,
  31. lambda: self._optimizer1.apply_gradients(grads_and_vars,
  32. global_step, name),
  33. lambda: self._optimizer2.apply_gradients(grads_and_vars,
  34. global_step, name)
  35. )
  36. def get_slot(self, var, name):
  37. slot1 = self._optimizer1.get_slot(var, name)
  38. slot2 = self._optimizer2.get_slot(var, name)
  39. if slot1 and slot2:
  40. raise LookupError('Slot named %s for variable %s populated for both '
  41. 'optimizers' % (name, var.name))
  42. return slot1 or slot2
  43. def get_slot_names(self):
  44. return sorted(self._optimizer1.get_slot_names() +
  45. self._optimizer2.get_slot_names())