gamma_mixture_model.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Gamma mixture model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import sys
  20. import numpy as np
  21. import tensorflow as tf
  22. from ops import rmsprop
  23. from ops import util
  24. st = tf.contrib.bayesflow.stochastic_tensor
  25. distributions = tf.contrib.distributions
  26. def train(optimizer):
  27. """Trains a gamma-normal mixture model.
  28. From http://ajbc.io/resources/bbvi_for_gammas.pdf.
  29. Args:
  30. optimizer: string specifying whether to use manual rmsprop or rmsprop
  31. that works on IndexedSlices
  32. Returns:
  33. Means learned using variational inference with the given optimizer
  34. """
  35. tf.reset_default_graph()
  36. np_dtype = np.float64
  37. np.random.seed(11)
  38. alpha_0 = np.array(0.1, dtype=np_dtype)
  39. mu_0 = np.array(5., dtype=np_dtype)
  40. # number of components
  41. n_components = 12
  42. # number of datapoints
  43. n_data = 100
  44. mu = np.random.gamma(alpha_0, mu_0 / alpha_0, n_components)
  45. x = np.random.normal(mu, 1., (n_data, n_components))
  46. ## set up for inference
  47. # the number of samples to draw for each parameter
  48. n_samples = 40
  49. batch_size = 1
  50. np.random.seed(123232)
  51. tf.set_random_seed(25343)
  52. tf_dtype = tf.float64
  53. inv_softplus = util.inv_softplus
  54. lambda_alpha_var = tf.get_variable(
  55. 'lambda_alpha', shape=[1, n_components], dtype=tf_dtype,
  56. initializer=tf.constant_initializer(value=0.1))
  57. lambda_mu_var = tf.get_variable(
  58. 'lambda_mu', shape=[1, n_components], dtype=tf_dtype,
  59. initializer=tf.constant_initializer(value=0.1))
  60. x_indices = tf.placeholder(shape=[batch_size], dtype=tf.int64)
  61. if optimizer == 'rmsprop_indexed_slices':
  62. lambda_alpha = tf.nn.embedding_lookup(lambda_alpha_var, x_indices)
  63. lambda_mu = tf.nn.embedding_lookup(lambda_mu_var, x_indices)
  64. elif optimizer == 'rmsprop_manual':
  65. lambda_alpha = lambda_alpha_var
  66. lambda_mu = lambda_mu_var
  67. variational = st.StochasticTensor(distributions.Gamma,
  68. alpha=tf.nn.softplus(lambda_alpha),
  69. beta=(tf.nn.softplus(lambda_alpha)
  70. / tf.nn.softplus(lambda_mu)),
  71. dist_value_type=st.SampleValue(n=n_samples),
  72. validate_args=False)
  73. # truncate samples (don't sample zero )
  74. sample_mus = tf.maximum(variational.value(), 1e-300)
  75. # probability of samples given prior
  76. prior = distributions.Gamma(alpha=alpha_0,
  77. beta=alpha_0/mu_0,
  78. validate_args=False)
  79. p = prior.log_pdf(sample_mus)
  80. # probability of samples given variational parameters
  81. q = variational.distribution.log_pdf(sample_mus)
  82. likelihood = distributions.Normal(mu=tf.expand_dims(sample_mus, 1),
  83. sigma=np.array(1., dtype=np_dtype),
  84. validate_args=False)
  85. # probability of observations given samples
  86. x_ph = tf.expand_dims(tf.constant(x, dtype=tf_dtype), 0)
  87. p += tf.reduce_sum(likelihood.log_pdf(x_ph), 2)
  88. elbo = p - q
  89. # run BBVI for a fixed number of iterations
  90. iteration = tf.Variable(0, trainable=False)
  91. increment_iteration = tf.assign(iteration, iteration + 1)
  92. # Robbins-Monro sequence for step size
  93. rho = tf.pow(tf.cast(iteration, tf_dtype) + 1024., -0.7)
  94. if optimizer == 'rmsprop_manual':
  95. # control variates to decrease variance of gradient ;
  96. # one for each variational parameter
  97. g_alpha = tf.pack([tf.gradients(q_sample, lambda_alpha)[0]
  98. for q_sample in tf.unpack(q)])
  99. g_mu = tf.pack([tf.gradients(q_sample, lambda_mu)[0]
  100. for q_sample in tf.unpack(q)])
  101. def cov(a, b):
  102. v = (a - tf.reduce_mean(a, 0)) * (b - tf.reduce_mean(b, 0))
  103. return tf.reduce_mean(v, 0)
  104. _, var_g_alpha = tf.nn.moments(g_alpha, [0])
  105. _, var_g_mu = tf.nn.moments(g_mu, [0])
  106. cov_alpha = cov(g_alpha * (p - q), g_alpha)
  107. cov_mu = cov(g_mu * (p - q), g_mu)
  108. cv_alpha = cov_alpha / var_g_alpha
  109. cv_mu = cov_mu / var_g_mu
  110. ms_mu = tf.Variable(tf.ones_like(g_mu), trainable=False)
  111. ms_alpha = tf.Variable(tf.ones_like(g_alpha), trainable=False)
  112. def update_ms(ms, var):
  113. return tf.assign(ms, 0.9 * ms + 0.1 * tf.reduce_sum(tf.square(var), 0))
  114. update_ms_ops = [update_ms(ms_mu, g_mu), update_ms(ms_alpha, g_alpha)]
  115. # update each variational parameter with smaple average
  116. alpha_step = rho * tf.reduce_mean(g_alpha / tf.sqrt(ms_alpha) *
  117. (p - q - cv_alpha), 0)
  118. update_alpha = tf.assign(lambda_alpha, lambda_alpha + alpha_step)
  119. mu_step = rho * tf.reduce_mean(g_mu / tf.sqrt(ms_mu) * (p - q - cv_mu), 0)
  120. update_mu = tf.assign(lambda_mu, lambda_mu + mu_step)
  121. train_ops = tf.group(update_mu, update_alpha)
  122. elif optimizer == 'rmsprop_indexed_slices':
  123. variable_list = [lambda_mu_var, lambda_alpha_var]
  124. train_ops = rmsprop.maximize_with_control_variate(
  125. rho, elbo, q, variable_list, iteration)
  126. # truncate variational parameters
  127. get_min = lambda var, val: tf.assign(var, tf.maximum(var, inv_softplus(val)))
  128. get_max = lambda var, val: tf.assign(var, tf.minimum(var, inv_softplus(val)))
  129. get_min_ops = [get_min(lambda_alpha_var, 0.005), get_min(lambda_mu_var, 1e-5)]
  130. get_max_ops = [get_max(var, sys.float_info.max)
  131. for var in [lambda_mu_var, lambda_alpha_var]]
  132. truncate_ops = get_min_ops + get_max_ops
  133. with tf.control_dependencies([train_ops]):
  134. train_ops = tf.group(*truncate_ops)
  135. with tf.Session() as sess:
  136. sess.run(tf.initialize_all_variables())
  137. fd = {x_indices: [0]}
  138. print('running variational inference using: %s' % optimizer)
  139. for i in range(100):
  140. if i % 10 == 0:
  141. print('iteration %d\telbo %.3e'
  142. % (i, np.mean(np.sum(elbo.eval(fd), axis=1))))
  143. if optimizer == 'rmsprop_manual':
  144. sess.run(update_ms_ops)
  145. sess.run(train_ops, fd)
  146. sess.run(increment_iteration)
  147. # return the learned variational means
  148. np_mu = sess.run(tf.nn.softplus(lambda_mu), fd)
  149. return np_mu