rmsprop.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """RMSProp for score function gradients and IndexedSlices."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import sys
  20. import tensorflow as tf
  21. def _gradients_per_example(loss, variable):
  22. """Returns per-example gradients.
  23. Args:
  24. loss: A [n_samples, batch_size] shape tensor
  25. variable: A variable to optimize of shape var_shape
  26. Returns:
  27. grad: A tensor of shape [n_samples, *var_shape]
  28. """
  29. grad_list = [tf.gradients(loss_sample, variable)[0] for loss_sample in
  30. tf.unpack(loss)]
  31. if isinstance(grad_list[0], tf.IndexedSlices):
  32. grad = tf.pack([g.values for g in grad_list])
  33. grad = tf.IndexedSlices(values=grad, indices=grad_list[0].indices)
  34. else:
  35. grad = tf.pack(grad_list)
  36. return grad
  37. def _cov(a, b):
  38. """Calculates covariance between a and b."""
  39. v = (a - tf.reduce_mean(a, 0)) * (b - tf.reduce_mean(b, 0))
  40. return tf.reduce_mean(v, 0)
  41. def _var(a):
  42. """Returns the variance across the sample dimension."""
  43. _, var = tf.nn.moments(a, [0])
  44. return var
  45. def _update_mean_square(mean_square, variable):
  46. """Update mean square for a variable."""
  47. if isinstance(variable, tf.IndexedSlices):
  48. square_sum = tf.reduce_sum(tf.square(variable.values), 0)
  49. mean_square_lookup = tf.nn.embedding_lookup(mean_square, variable.indices)
  50. moving_mean_square = 0.9 * mean_square_lookup + 0.1 * square_sum
  51. return tf.scatter_update(mean_square, variable.indices, moving_mean_square)
  52. else:
  53. square_sum = tf.reduce_sum(tf.square(variable), 0)
  54. moving_mean_square = 0.9 * mean_square + 0.1 * square_sum
  55. return tf.assign(mean_square, square_sum)
  56. def _get_mean_square(variable):
  57. with tf.variable_scope('optimizer_state'):
  58. mean_square = tf.get_variable(name=variable.name[:-2],
  59. shape=variable.get_shape(),
  60. initializer=tf.ones_initializer,
  61. dtype=variable.dtype.base_dtype)
  62. return mean_square
  63. def _control_variate(grad, learning_signal):
  64. if isinstance(grad, tf.IndexedSlices):
  65. grad = grad.values
  66. cov = _cov(grad * learning_signal, grad)
  67. var = _var(grad)
  68. return cov / var
  69. def _rmsprop_maximize(learning_rate, learning_signal, log_prob, variable,
  70. clip_min=None, clip_max=None):
  71. """Builds rmsprop maximization ops for a single variable."""
  72. grad = _gradients_per_example(log_prob, variable)
  73. if learning_signal.get_shape().ndims == 2:
  74. # if we have multiple samples of latent variables, need to broadcast
  75. # grad of shape [n_samples_latents, batch_size, n_timesteps, z_dim]
  76. # with learning_signal of shape [n_samples_latents, batch_size]:
  77. learning_signal = tf.expand_dims(tf.expand_dims(learning_signal, 2), 2)
  78. control_variate = _control_variate(grad, learning_signal)
  79. mean_square = _get_mean_square(variable)
  80. update_mean_square = _update_mean_square(mean_square, grad)
  81. variance_reduced_learning_sig = learning_signal - control_variate
  82. update_name = variable.name[:-2] + '/score_function_grad_estimator'
  83. if isinstance(grad, tf.IndexedSlices):
  84. mean_square_lookup = tf.nn.embedding_lookup(mean_square, grad.indices)
  85. mean_square_lookup = tf.expand_dims(mean_square_lookup, 0)
  86. update_per_sample = (grad.values / tf.sqrt(mean_square_lookup)
  87. * variance_reduced_learning_sig)
  88. update = tf.reduce_mean(update_per_sample, 0, name=update_name)
  89. step = learning_rate * update
  90. if clip_min is None and clip_max is None:
  91. apply_step = tf.scatter_add(variable, grad.indices, step)
  92. else:
  93. var_lookup = tf.nn.embedding_lookup(variable, grad.indices)
  94. new_var = var_lookup + step
  95. new_var_clipped = tf.clip_by_value(
  96. new_var, clip_value_min=clip_min, clip_value_max=clip_max)
  97. apply_step = tf.scatter_update(variable, grad.indices, new_var)
  98. else:
  99. update_per_sample = (grad / tf.sqrt(mean_square)
  100. * variance_reduced_learning_sig)
  101. update = tf.reduce_mean(update_per_sample, 0,
  102. name=update_name)
  103. step = learning_rate * update
  104. if clip_min is None and clip_max is None:
  105. apply_step = tf.assign(variable, variable + step)
  106. else:
  107. new_var = variable + step
  108. new_var_clipped = tf.clip_by_value(
  109. new_var, clip_value_min=clip_min, clip_value_max=clip_max)
  110. apply_step = tf.assign(variable, new_var_clipped)
  111. # add to collection for keeping track of stats
  112. tf.add_to_collection('non_reparam_variable_grads', update)
  113. with tf.control_dependencies([update_mean_square]):
  114. train_op = tf.group(apply_step)
  115. return train_op
  116. def maximize_with_control_variate(learning_rate, learning_signal, log_prob,
  117. variable_list, global_step=None):
  118. """Build a covariance control variate with rmsprop updates.
  119. Args:
  120. learning_rate: Step size
  121. learning_signal: Usually the ELBO; the bound we optimize
  122. Shape [n_samples, batch_size]
  123. log_prob: log probability of samples of latent variables
  124. variable_list: List of variables
  125. global_step: Global step
  126. Returns:
  127. train_op: Group of operations that apply an RMSProp update with the
  128. control variate
  129. """
  130. train_ops = []
  131. for variable in variable_list:
  132. clip_max, clip_min = (None, None)
  133. if 'shape_softplus_inv' in variable.name:
  134. clip_max = sys.float_info.max
  135. clip_min = 5e-3
  136. elif 'mean_softplus_inv' in variable.name:
  137. clip_max = sys.float_info.max
  138. clip_min = 1e-5
  139. train_ops.append(_rmsprop_maximize(
  140. learning_rate, learning_signal, log_prob, variable, clip_max=clip_max,
  141. clip_min=clip_min))
  142. if global_step is not None:
  143. increment_global_step = tf.assign(global_step, global_step + 1)
  144. with tf.control_dependencies(train_ops):
  145. train_op = tf.group(increment_global_step)
  146. else:
  147. train_op = tf.group(*train_ops)
  148. return train_op