1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """An optimizer that switches between several methods."""
- import tensorflow as tf
- from tensorflow.python.training import optimizer
- class CompositeOptimizer(optimizer.Optimizer):
- """Optimizer that switches between several methods.
- """
- def __init__(self,
- optimizer1,
- optimizer2,
- switch,
- use_locking=False,
- name='Composite'):
- """Construct a new Composite optimizer.
- Args:
- optimizer1: A tf.python.training.optimizer.Optimizer object.
- optimizer2: A tf.python.training.optimizer.Optimizer object.
- switch: A tf.bool Tensor, selecting whether to use the first or the second
- optimizer.
- use_locking: Bool. If True apply use locks to prevent concurrent updates
- to variables.
- name: Optional name prefix for the operations created when applying
- gradients. Defaults to "Composite".
- """
- super(CompositeOptimizer, self).__init__(use_locking, name)
- self._optimizer1 = optimizer1
- self._optimizer2 = optimizer2
- self._switch = switch
- def apply_gradients(self, grads_and_vars, global_step=None, name=None):
- return tf.cond(
- self._switch,
- lambda: self._optimizer1.apply_gradients(grads_and_vars,
- global_step, name),
- lambda: self._optimizer2.apply_gradients(grads_and_vars,
- global_step, name)
- )
- def get_slot(self, var, name):
- slot1 = self._optimizer1.get_slot(var, name)
- slot2 = self._optimizer2.get_slot(var, name)
- if slot1 and slot2:
- raise LookupError('Slot named %s for variable %s populated for both '
- 'optimizers' % (name, var.name))
- return slot1 or slot2
- def get_slot_names(self):
- return sorted(self._optimizer1.get_slot_names() +
- self._optimizer2.get_slot_names())
|