composite_optimizer_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for CompositeOptimizer."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from tensorflow.python.framework import test_util
  19. from tensorflow.python.platform import googletest
  20. from tensorflow.python.platform import tf_logging as logging
  21. from dragnn.python import composite_optimizer
  22. class MockAdamOptimizer(tf.train.AdamOptimizer):
  23. def __init__(self,
  24. learning_rate=0.001,
  25. beta1=0.9,
  26. beta2=0.999,
  27. epsilon=1e-8,
  28. use_locking=False,
  29. name="Adam"):
  30. super(MockAdamOptimizer, self).__init__(learning_rate, beta1, beta2,
  31. epsilon, use_locking, name)
  32. def _create_slots(self, var_list):
  33. super(MockAdamOptimizer, self)._create_slots(var_list)
  34. for v in var_list:
  35. self._zeros_slot(v, "adam_counter", self._name)
  36. def _apply_dense(self, grad, var):
  37. train_op = super(MockAdamOptimizer, self)._apply_dense(grad, var)
  38. counter = self.get_slot(var, "adam_counter")
  39. return tf.group(train_op, tf.assign_add(counter, [1.0]))
  40. class MockMomentumOptimizer(tf.train.MomentumOptimizer):
  41. def __init__(self,
  42. learning_rate,
  43. momentum,
  44. use_locking=False,
  45. name="Momentum",
  46. use_nesterov=False):
  47. super(MockMomentumOptimizer, self).__init__(learning_rate, momentum,
  48. use_locking, name, use_nesterov)
  49. def _create_slots(self, var_list):
  50. super(MockMomentumOptimizer, self)._create_slots(var_list)
  51. for v in var_list:
  52. self._zeros_slot(v, "momentum_counter", self._name)
  53. def _apply_dense(self, grad, var):
  54. train_op = super(MockMomentumOptimizer, self)._apply_dense(grad, var)
  55. counter = self.get_slot(var, "momentum_counter")
  56. return tf.group(train_op, tf.assign_add(counter, [1.0]))
  57. class CompositeOptimizerTest(test_util.TensorFlowTestCase):
  58. def test_switching(self):
  59. with self.test_session() as sess:
  60. # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
  61. x_data = np.random.rand(100).astype(np.float32)
  62. y_data = x_data * 0.1 + 0.3
  63. # Try to find values for w and b that compute y_data = w * x_data + b
  64. # (We know that w should be 0.1 and b 0.3, but TensorFlow will
  65. # figure that out for us.)
  66. w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
  67. b = tf.Variable(tf.zeros([1]))
  68. y = w * x_data + b
  69. # Minimize the mean squared errors.
  70. loss = tf.reduce_mean(tf.square(y - y_data))
  71. # Set up optimizers.
  72. step = tf.get_variable(
  73. "step",
  74. shape=[],
  75. initializer=tf.zeros_initializer(),
  76. trainable=False,
  77. dtype=tf.int32)
  78. optimizer1 = MockAdamOptimizer(0.05)
  79. optimizer2 = MockMomentumOptimizer(0.05, 0.5)
  80. switch = tf.less(step, 100)
  81. optimizer = composite_optimizer.CompositeOptimizer(optimizer1, optimizer2,
  82. switch)
  83. train_op = optimizer.minimize(loss)
  84. sess.run(tf.global_variables_initializer())
  85. # Fit the line.:
  86. for iteration in range(201):
  87. self.assertEqual(sess.run(switch), iteration < 100)
  88. sess.run(train_op)
  89. sess.run(tf.assign_add(step, 1))
  90. slot_names = optimizer.get_slot_names()
  91. self.assertItemsEqual(
  92. slot_names,
  93. ["m", "v", "momentum", "adam_counter", "momentum_counter"])
  94. adam_counter = sess.run(optimizer.get_slot(w, "adam_counter"))
  95. momentum_counter = sess.run(optimizer.get_slot(w, "momentum_counter"))
  96. self.assertEqual(adam_counter, min(iteration + 1, 100))
  97. self.assertEqual(momentum_counter, max(iteration - 99, 0))
  98. if iteration % 20 == 0:
  99. logging.info("%d %s %d %d", iteration, sess.run([switch, step, w, b]),
  100. adam_counter, momentum_counter)
  101. if __name__ == "__main__":
  102. googletest.main()