composite_optimizer.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """An optimizer that switches between several methods."""
  16. import tensorflow as tf
  17. from tensorflow.python.training import optimizer
  18. class CompositeOptimizer(optimizer.Optimizer):
  19. """Optimizer that switches between several methods.
  20. """
  21. def __init__(self,
  22. optimizer1,
  23. optimizer2,
  24. switch,
  25. use_locking=False,
  26. name='Composite'):
  27. """Construct a new Composite optimizer.
  28. Args:
  29. optimizer1: A tf.python.training.optimizer.Optimizer object.
  30. optimizer2: A tf.python.training.optimizer.Optimizer object.
  31. switch: A tf.bool Tensor, selecting whether to use the first or the second
  32. optimizer.
  33. use_locking: Bool. If True apply use locks to prevent concurrent updates
  34. to variables.
  35. name: Optional name prefix for the operations created when applying
  36. gradients. Defaults to "Composite".
  37. """
  38. super(CompositeOptimizer, self).__init__(use_locking, name)
  39. self._optimizer1 = optimizer1
  40. self._optimizer2 = optimizer2
  41. self._switch = switch
  42. def apply_gradients(self, grads_and_vars, global_step=None, name=None):
  43. return tf.cond(
  44. self._switch,
  45. lambda: self._optimizer1.apply_gradients(grads_and_vars,
  46. global_step, name),
  47. lambda: self._optimizer2.apply_gradients(grads_and_vars,
  48. global_step, name)
  49. )
  50. def get_slot(self, var, name):
  51. slot1 = self._optimizer1.get_slot(var, name)
  52. slot2 = self._optimizer2.get_slot(var, name)
  53. if slot1 and slot2:
  54. raise LookupError('Slot named %s for variable %s populated for both '
  55. 'optimizers' % (name, var.name))
  56. return slot1 or slot2
  57. def get_slot_names(self):
  58. return sorted(self._optimizer1.get_slot_names() +
  59. self._optimizer2.get_slot_names())