composite_optimizer_test.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. """Tests for CompositeOptimizer.
  2. """
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.python.framework import test_util
  6. from tensorflow.python.platform import googletest
  7. from tensorflow.python.platform import tf_logging as logging
  8. from dragnn.python import composite_optimizer
  9. class MockAdamOptimizer(tf.train.AdamOptimizer):
  10. def __init__(self,
  11. learning_rate=0.001,
  12. beta1=0.9,
  13. beta2=0.999,
  14. epsilon=1e-8,
  15. use_locking=False,
  16. name="Adam"):
  17. super(MockAdamOptimizer, self).__init__(learning_rate, beta1, beta2,
  18. epsilon, use_locking, name)
  19. def _create_slots(self, var_list):
  20. super(MockAdamOptimizer, self)._create_slots(var_list)
  21. for v in var_list:
  22. self._zeros_slot(v, "adam_counter", self._name)
  23. def _apply_dense(self, grad, var):
  24. train_op = super(MockAdamOptimizer, self)._apply_dense(grad, var)
  25. counter = self.get_slot(var, "adam_counter")
  26. return tf.group(train_op, tf.assign_add(counter, [1.0]))
  27. class MockMomentumOptimizer(tf.train.MomentumOptimizer):
  28. def __init__(self,
  29. learning_rate,
  30. momentum,
  31. use_locking=False,
  32. name="Momentum",
  33. use_nesterov=False):
  34. super(MockMomentumOptimizer, self).__init__(learning_rate, momentum,
  35. use_locking, name, use_nesterov)
  36. def _create_slots(self, var_list):
  37. super(MockMomentumOptimizer, self)._create_slots(var_list)
  38. for v in var_list:
  39. self._zeros_slot(v, "momentum_counter", self._name)
  40. def _apply_dense(self, grad, var):
  41. train_op = super(MockMomentumOptimizer, self)._apply_dense(grad, var)
  42. counter = self.get_slot(var, "momentum_counter")
  43. return tf.group(train_op, tf.assign_add(counter, [1.0]))
  44. class CompositeOptimizerTest(test_util.TensorFlowTestCase):
  45. def test_switching(self):
  46. with self.test_session() as sess:
  47. # Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
  48. x_data = np.random.rand(100).astype(np.float32)
  49. y_data = x_data * 0.1 + 0.3
  50. # Try to find values for w and b that compute y_data = w * x_data + b
  51. # (We know that w should be 0.1 and b 0.3, but TensorFlow will
  52. # figure that out for us.)
  53. w = tf.Variable(tf.random_uniform([1], -1.0, 1.0))
  54. b = tf.Variable(tf.zeros([1]))
  55. y = w * x_data + b
  56. # Minimize the mean squared errors.
  57. loss = tf.reduce_mean(tf.square(y - y_data))
  58. # Set up optimizers.
  59. step = tf.get_variable(
  60. "step",
  61. shape=[],
  62. initializer=tf.zeros_initializer(),
  63. trainable=False,
  64. dtype=tf.int32)
  65. optimizer1 = MockAdamOptimizer(0.05)
  66. optimizer2 = MockMomentumOptimizer(0.05, 0.5)
  67. switch = tf.less(step, 100)
  68. optimizer = composite_optimizer.CompositeOptimizer(optimizer1, optimizer2,
  69. switch)
  70. train_op = optimizer.minimize(loss)
  71. sess.run(tf.global_variables_initializer())
  72. # Fit the line.:
  73. for iteration in range(201):
  74. self.assertEqual(sess.run(switch), iteration < 100)
  75. sess.run(train_op)
  76. sess.run(tf.assign_add(step, 1))
  77. slot_names = optimizer.get_slot_names()
  78. self.assertItemsEqual(
  79. slot_names,
  80. ["m", "v", "momentum", "adam_counter", "momentum_counter"])
  81. adam_counter = sess.run(optimizer.get_slot(w, "adam_counter"))
  82. momentum_counter = sess.run(optimizer.get_slot(w, "momentum_counter"))
  83. self.assertEqual(adam_counter, min(iteration + 1, 100))
  84. self.assertEqual(momentum_counter, max(iteration - 99, 0))
  85. if iteration % 20 == 0:
  86. logging.info("%d %s %d %d", iteration, sess.run([switch, step, w, b]),
  87. adam_counter, momentum_counter)
  88. if __name__ == "__main__":
  89. googletest.main()