Forráskód Böngészése

Add experimental WIP rdefs (Jaan Altosaar's code).

Eugene Brevdo 9 éve
szülő
commit
dac5d1eed0

+ 183 - 0
experimental/rdefs/python/gamma_mixture_model.py

@@ -0,0 +1,183 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Gamma mixture model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+import numpy as np
+import tensorflow as tf
+from ops import rmsprop
+from ops import util
+
+st = tf.contrib.bayesflow.stochastic_tensor
+distributions = tf.contrib.distributions
+
+
+def train(optimizer):
+  """Trains a gamma-normal mixture model.
+
+  From http://ajbc.io/resources/bbvi_for_gammas.pdf.
+
+  Args:
+    optimizer: string specifying whether to use manual rmsprop or rmsprop
+        that works on IndexedSlices
+
+  Returns:
+    Means learned using variational inference with the given optimizer
+  """
+  tf.reset_default_graph()
+  np_dtype = np.float64
+  np.random.seed(11)
+  alpha_0 = np.array(0.1, dtype=np_dtype)
+  mu_0 = np.array(5., dtype=np_dtype)
+  # number of components
+  n_components = 12
+  # number of datapoints
+  n_data = 100
+  mu = np.random.gamma(alpha_0, mu_0 / alpha_0, n_components)
+  x = np.random.normal(mu, 1., (n_data, n_components))
+  ## set up for inference
+  # the number of samples to draw for each parameter
+  n_samples = 40
+  batch_size = 1
+  np.random.seed(123232)
+  tf.set_random_seed(25343)
+  tf_dtype = tf.float64
+  inv_softplus = util.inv_softplus
+
+  lambda_alpha_var = tf.get_variable(
+      'lambda_alpha', shape=[1, n_components], dtype=tf_dtype,
+      initializer=tf.constant_initializer(value=0.1))
+  lambda_mu_var = tf.get_variable(
+      'lambda_mu', shape=[1, n_components], dtype=tf_dtype,
+      initializer=tf.constant_initializer(value=0.1))
+
+  x_indices = tf.placeholder(shape=[batch_size], dtype=tf.int64)
+
+  if optimizer == 'rmsprop_indexed_slices':
+    lambda_alpha = tf.nn.embedding_lookup(lambda_alpha_var, x_indices)
+    lambda_mu = tf.nn.embedding_lookup(lambda_mu_var, x_indices)
+  elif optimizer == 'rmsprop_manual':
+    lambda_alpha = lambda_alpha_var
+    lambda_mu = lambda_mu_var
+
+  variational = st.StochasticTensor(distributions.Gamma,
+                                    alpha=tf.nn.softplus(lambda_alpha),
+                                    beta=(tf.nn.softplus(lambda_alpha)
+                                          / tf.nn.softplus(lambda_mu)),
+                                    dist_value_type=st.SampleValue(n=n_samples),
+                                    validate_args=False)
+
+  # truncate samples (don't sample zero )
+  sample_mus = tf.maximum(variational.value(), 1e-300)
+
+  # probability of samples given prior
+  prior = distributions.Gamma(alpha=alpha_0,
+                              beta=alpha_0/mu_0,
+                              validate_args=False)
+
+  p = prior.log_pdf(sample_mus)
+
+  # probability of samples given variational parameters
+  q = variational.distribution.log_pdf(sample_mus)
+
+  likelihood = distributions.Normal(mu=tf.expand_dims(sample_mus, 1),
+                                    sigma=np.array(1., dtype=np_dtype),
+                                    validate_args=False)
+
+  # probability of observations given samples
+  x_ph = tf.expand_dims(tf.constant(x, dtype=tf_dtype), 0)
+  p += tf.reduce_sum(likelihood.log_pdf(x_ph), 2)
+
+  elbo = p - q
+
+  # run BBVI for a fixed number of iterations
+  iteration = tf.Variable(0, trainable=False)
+  increment_iteration = tf.assign(iteration, iteration + 1)
+
+  # Robbins-Monro sequence for step size
+  rho = tf.pow(tf.cast(iteration, tf_dtype) + 1024., -0.7)
+
+  if optimizer == 'rmsprop_manual':
+    # control variates to decrease variance of gradient ;
+    # one for each variational parameter
+    g_alpha = tf.pack([tf.gradients(q_sample, lambda_alpha)[0]
+                       for q_sample in tf.unpack(q)])
+    g_mu = tf.pack([tf.gradients(q_sample, lambda_mu)[0]
+                    for q_sample in tf.unpack(q)])
+
+    def cov(a, b):
+      v = (a - tf.reduce_mean(a, 0)) * (b - tf.reduce_mean(b, 0))
+      return tf.reduce_mean(v, 0)
+
+    _, var_g_alpha = tf.nn.moments(g_alpha, [0])
+    _, var_g_mu = tf.nn.moments(g_mu, [0])
+
+    cov_alpha = cov(g_alpha * (p - q), g_alpha)
+    cov_mu = cov(g_mu * (p - q), g_mu)
+
+    cv_alpha = cov_alpha / var_g_alpha
+    cv_mu = cov_mu / var_g_mu
+
+    ms_mu = tf.Variable(tf.ones_like(g_mu), trainable=False)
+    ms_alpha = tf.Variable(tf.ones_like(g_alpha), trainable=False)
+    def update_ms(ms, var):
+      return tf.assign(ms, 0.9 * ms + 0.1 * tf.reduce_sum(tf.square(var), 0))
+
+    update_ms_ops = [update_ms(ms_mu, g_mu), update_ms(ms_alpha, g_alpha)]
+
+    # update each variational parameter with smaple average
+    alpha_step = rho * tf.reduce_mean(g_alpha / tf.sqrt(ms_alpha) *
+                                      (p - q - cv_alpha), 0)
+    update_alpha = tf.assign(lambda_alpha, lambda_alpha + alpha_step)
+    mu_step = rho * tf.reduce_mean(g_mu / tf.sqrt(ms_mu) * (p - q - cv_mu), 0)
+    update_mu = tf.assign(lambda_mu, lambda_mu + mu_step)
+    train_ops = tf.group(update_mu, update_alpha)
+  elif optimizer == 'rmsprop_indexed_slices':
+    variable_list = [lambda_mu_var, lambda_alpha_var]
+    train_ops = rmsprop.maximize_with_control_variate(
+        rho, elbo, q, variable_list, iteration)
+
+  # truncate variational parameters
+  get_min = lambda var, val: tf.assign(var, tf.maximum(var, inv_softplus(val)))
+  get_max = lambda var, val: tf.assign(var, tf.minimum(var, inv_softplus(val)))
+
+  get_min_ops = [get_min(lambda_alpha_var, 0.005), get_min(lambda_mu_var, 1e-5)]
+  get_max_ops = [get_max(var, sys.float_info.max)
+                 for var in [lambda_mu_var, lambda_alpha_var]]
+
+  truncate_ops = get_min_ops + get_max_ops
+
+  with tf.control_dependencies([train_ops]):
+    train_ops = tf.group(*truncate_ops)
+
+  with tf.Session() as sess:
+    sess.run(tf.initialize_all_variables())
+    fd = {x_indices: [0]}
+    print('running variational inference using: %s' % optimizer)
+    for i in range(100):
+      if i % 10 == 0:
+        print('iteration %d\telbo %.3e'
+              % (i, np.mean(np.sum(elbo.eval(fd), axis=1))))
+      if optimizer == 'rmsprop_manual':
+        sess.run(update_ms_ops)
+      sess.run(train_ops, fd)
+      sess.run(increment_iteration)
+    # return the learned variational means
+    np_mu = sess.run(tf.nn.softplus(lambda_mu), fd)
+  return np_mu

+ 41 - 0
experimental/rdefs/python/gamma_normal_rmsprop_test.py

@@ -0,0 +1,41 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rmsprop."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import tensorflow as tf
+
+import gamma_mixture_model
+
+
+FLAGS = tf.flags.FLAGS
+
+
+class RMSPropTest(tf.test.TestCase):
+
+  def testGammaMixtureModel(self):
+    """Test a Gamma Mixture model.
+    """
+    mu_manual = gamma_mixture_model.train('rmsprop_manual')
+    mu_indexed_slices = gamma_mixture_model.train('rmsprop_indexed_slices')
+    self.assertAllClose(mu_indexed_slices, mu_manual, rtol=2e-1, atol=2e-1)
+
+if __name__ == '__main__':
+  tf.test.main()

+ 18 - 0
experimental/rdefs/python/ops/__init__.py

@@ -0,0 +1,18 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function

+ 56 - 0
experimental/rdefs/python/ops/inference.py

@@ -0,0 +1,56 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class for variational inference."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+sg = tf.contrib.bayesflow.stochastic_graph
+distributions = tf.contrib.distributions
+
+
+class VariationalInference(object):
+  """VariationalInference class."""
+
+  def __init__(self, model, variational, data):
+    """Initializes the VariationalInference class.
+
+    Args:
+      model: the probability model. an object with a log_prob and sample method.
+      variational: the variational family for the model. an object with
+          log_prob and sampling methods.
+      data: the observations we use to fit the model.
+    """
+    self.model = model
+    self.variational = variational
+    self.data = data
+
+  def build_graph(self):
+    """Builds the graph for variational inference."""
+    q_samples = self.variational.sample
+    log_p = self.model.log_prob(q_samples, self.data['x'])
+    log_q = self.variational.log_prob(q_samples)
+    elbo = log_p - log_q
+    if elbo.get_shape().ndims > 1:
+      # first dimension is samples, second is batch_size
+      self.scalar_elbo = tf.reduce_mean(tf.reduce_mean(elbo, 0), 0)
+    else:
+      self.scalar_elbo = tf.reduce_sum(elbo, 0)
+    self.elbo = elbo
+    self.log_p = log_p
+    self.log_q = log_q

+ 654 - 0
experimental/rdefs/python/ops/model_factory.py

@@ -0,0 +1,654 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes for models and variational distributions for recurrent DEFs.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import numpy as np
+import tensorflow as tf
+
+from ops import util
+
+st = tf.contrib.bayesflow.stochastic_tensor
+
+distributions = tf.contrib.distributions
+
+
+class NormalNormalRDEF(object):
+  """Class for a recurrent DEF with normal latent variables and normal weights.
+  """
+
+  def __init__(self, n_timesteps, batch_size, p_w_mu_sigma, p_w_sigma_sigma,
+               p_z_sigma, fixed_p_z_sigma, z_dim, dtype):
+    """Initializes the NormalNormalRDEF class.
+
+    Args:
+      n_timesteps: int. number of timesteps
+      batch_size: int. batch size
+      p_w_mu_sigma: float. prior for the weights for the mean of the latent
+          variables
+      p_w_sigma_sigma: float. prior for the weights for the variance of the
+          latent variables
+      p_z_sigma: floating point prior for the latent variables
+      fixed_p_z_sigma: bool. whether the prior variance is learned
+      z_dim: int. dimension of each latent variable
+      dtype: dtype
+    """
+
+    self.n_timesteps = n_timesteps
+    self.batch_size = batch_size
+    self.p_w_mu_sigma = p_w_mu_sigma
+    self.p_w_sigma_sigma = p_w_sigma_sigma
+    self.p_z_sigma = p_z_sigma
+    self.fixed_p_z_sigma = fixed_p_z_sigma
+    self.z_dim = z_dim
+    self.dtype = dtype
+
+  def log_prob(self, params, x):
+    """Returns the log joint. log p(x | z, w)p(z)p(w); [batch_size].
+
+    Args:
+      params: dict. dictionary of samples of the latent variables.
+      x: tensor. minibatch of examples
+
+    Returns:
+      The log joint of the NormalNormalRDEF probability model.
+    """
+    z_1 = params['z_1']
+    w_1_mu = params['w_1_mu']
+    w_1_sigma = params['w_1_sigma']
+    log_p_x_zw, p = util.build_bernoulli_log_likelihood(
+        params, x, self.batch_size)
+    self.p_x_zw_bernoulli_p = p
+
+    log_p_z, log_p_w_mu, log_p_w_sigma = self.build_recurrent_layer(
+        z_1, w_1_mu, w_1_sigma)
+    return log_p_x_zw + log_p_z + log_p_w_mu + log_p_w_sigma
+
+  def build_recurrent_layer(self, z, w_mu, w_sigma):
+    """Creates a gaussian layer of the recurrent DEF.
+
+    Args:
+      z: sampled gaussian latent variables [batch_size, n_timesteps, z_dim]
+      w_mu: sampled gaussian stochastic weights [z_dim, z_dim]
+      w_sigma: sampled gaussian stochastic weights for stddev
+          [z_dim, z_dim]
+
+    Returns:
+      log_p_z: log prior of latent variables evaluated at the samples z.
+      log_p_w_mu: log density of the weights evaluated at the sampled weights w.
+      log_p_w_sigma: log density of weights for stddev.
+    """
+    # the prior for the weights p(w) has two parts: p(w_mu) and p(w_sigma)
+    # prior for the weights for the mean parameter
+    p_w_mu = distributions.Normal(
+        mu=0., sigma=self.p_w_mu_sigma, validate_args=False)
+    log_p_w_mu = tf.reduce_sum(p_w_mu.log_pdf(w_mu))
+
+    if self.fixed_p_z_sigma:
+      log_p_w_sigma = 0.0
+    else:
+      # prior for the weights for the standard deviation
+      p_w_sigma = distributions.Normal(mu=0., sigma=self.p_w_sigma_sigma,
+                                       validate_args=False)
+      log_p_w_sigma = tf.reduce_sum(p_w_sigma.log_pdf(w_sigma))
+
+    # need this for indexing npy-style
+    z = z.value()
+
+    # the prior for the latent variable at the first timestep is just 0, 1
+    z_t0 = z[:, 0, :]
+    p_z_t0 = distributions.Normal(
+        mu=0., sigma=self.p_z_sigma, validate_args=False)
+    log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 1)
+
+    # the prior for subsequent timesteps is off by one
+    mu = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
+                         tf.pack([w_mu] * self.batch_size))
+    if self.fixed_p_z_sigma:
+      sigma = self.p_z_sigma
+    else:
+      wz = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
+                           tf.pack([w_sigma] * self.batch_size))
+      sigma = tf.maximum(tf.nn.softplus(wz), 1e-5)
+    p_z_t1_to_end = distributions.Normal(mu=mu, sigma=sigma,
+                                         validate_args=False)
+    log_p_z_t1_to_end = tf.reduce_sum(
+        p_z_t1_to_end.log_pdf(z[:, 1:, :]), [1, 2])
+    log_p_z = log_p_z_t0 + log_p_z_t1_to_end
+    return log_p_z, log_p_w_mu, log_p_w_sigma
+
+  def recurrent_layer_sample(self, w_mu, w_sigma, n_samples_latents):
+    """Sample from the model, with learned latent weights.
+
+    Args:
+      w_mu: latent weights for the mean parameter. [z_dim, z_dim]
+      w_sigma: latent weights for the standard deviation. [z_dim, z_dim]
+      n_samples_latents: how many samples of latent variables
+
+    Returns:
+      z: samples from the generative process.
+    """
+    p_z_t0 = distributions.Normal(
+        mu=0., sigma=self.p_z_sigma, validate_args=False)
+    z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
+    z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
+    def sample_timestep(z_t_prev, w_mu, w_sigma):
+      mu_t = tf.matmul(z_t_prev, w_mu)
+      if self.fixed_p_z_sigma:
+        sigma_t = self.p_z_sigma
+      else:
+        wz_t = tf.matmul(z_t_prev, w_sigma)
+        sigma_t = tf.maximum(tf.nn.softplus(wz_t), 1e-5)
+      p_z_t = distributions.Normal(mu=mu_t, sigma=sigma_t, validate_args=False)
+      if self.z_dim == 1:
+        return p_z_t.sample_n(n=1)[0, :, :]
+      else:
+        return tf.squeeze(p_z_t.sample_n(n=1))
+    z_list = [z_t0]
+    for _ in range(self.n_timesteps - 1):
+      z_t = sample_timestep(z_list[-1], w_mu, w_sigma)
+      z_list.append(z_t)
+    z = tf.pack(z_list)  # [n_timesteps, n_samples_latents, z_dim]
+    z = tf.transpose(z, perm=[1, 0, 2])  # [n_samples, n_timesteps, z_dim]
+    return z
+
+  def likelihood_sample(self, params, z_1, n_samples):
+    return util.bernoulli_likelihood_sample(params, z_1, n_samples)
+
+
+class NormalNormalRDEFVariational(object):
+  """Creates the variational family for the recurrent DEF model.
+
+    Variational family:
+      q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
+        [n_examples, n_timesteps, z_dim]
+      q_w_1_mu: gaussian approximate posterior q(w_1) for mean weights of first
+        (recurrent) layer [z_dim, z_dim]
+      q_w_1_sigma: gaussian approximate posterior q(w_1) for std weights, first
+        (recurrent) layer [z_dim, z_dim]
+      q_w_0: gaussian approximate posterior q(w_0) for weights of observation
+        layer [z_dim, timestep_dim]
+  """
+
+  def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
+               timestep_dim, init_sigma_q_w_mu, init_sigma_q_z,
+               init_sigma_q_w_sigma, fixed_p_z_sigma, fixed_q_z_sigma,
+               fixed_q_w_mu_sigma, fixed_q_w_sigma_sigma,
+               fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, dtype):
+    """Initializes the variational family for the NormalNormalRDEF.
+
+    Args:
+      x_indexes: tensor. indices of the datapoints.
+      n_examples: int. number of examples in the dataset.
+      n_timesteps: int. number of timesteps in each datapoint.
+      z_dim: int. dimension of latent variables.
+      timestep_dim: int. dimension of each timestep.
+      init_sigma_q_w_mu: float. initial variance for weights for the means of
+          the latent variables.
+      init_sigma_q_z: float. initial variance for the variational distribution
+          for the latent variables.
+      init_sigma_q_w_sigma: float. initial variance for the weights for the
+          variance of the latent variables.
+      fixed_p_z_sigma: bool. whether to keep the prior over latents fixed.
+      fixed_q_z_sigma: bool. whether to train the variance of the variational
+          distributions for the latents.
+      fixed_q_w_mu_sigma: bool. whether to train the variance of the weights for
+          the latent variables.
+      fixed_q_w_sigma_sigma: bool. whether to train the variance of the weights
+          for the variance of the latent variables.
+      fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
+          the observations.
+      init_sigma_q_w_0_sigma: float. initial variance for the observation
+          weights.
+      dtype: dtype
+    """
+    self.x_indexes = x_indexes
+    self.n_examples = n_examples
+    self.n_timesteps = n_timesteps
+    self.z_dim = z_dim
+    self.timestep_dim = timestep_dim
+    self.init_sigma_q_z = init_sigma_q_z
+    self.init_sigma_q_w_mu = init_sigma_q_w_mu
+    self.init_sigma_q_w_sigma = init_sigma_q_w_sigma
+    self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
+    self.fixed_p_z_sigma = fixed_p_z_sigma
+    self.fixed_q_z_sigma = fixed_q_z_sigma
+    self.fixed_q_w_mu_sigma = fixed_q_w_mu_sigma
+    self.fixed_q_w_sigma_sigma = fixed_q_w_sigma_sigma
+    self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
+    self.dtype = dtype
+    self.build_graph()
+
+  @property
+  def sample(self):
+    """Returns a dict of samples of the latent variables."""
+    return self.params
+
+  def build_graph(self):
+    """Builds the graph for the variational family for the NormalNormalRDEF."""
+    with tf.variable_scope('q_z_1'):
+      z_1 = util.build_gaussian(
+          [self.n_examples, self.n_timesteps, self.z_dim],
+          init_mu=0., init_sigma=self.init_sigma_q_z, x_indexes=self.x_indexes,
+          fixed_sigma=self.fixed_q_z_sigma, place_on_cpu=True, dtype=self.dtype)
+
+    with tf.variable_scope('q_w_1_mu'):
+      # half of the weights are for the mean, half for the variance
+      w_1_mu = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
+                                   init_sigma=self.init_sigma_q_w_mu,
+                                   fixed_sigma=self.fixed_q_w_mu_sigma,
+                                   dtype=self.dtype)
+
+    if self.fixed_p_z_sigma:
+      w_1_sigma = None
+    else:
+      with tf.variable_scope('q_w_1_sigma'):
+        w_1_sigma = util.build_gaussian(
+            [self.z_dim, self.z_dim],
+            init_mu=0., init_sigma=self.init_sigma_q_w_sigma,
+            fixed_sigma=self.fixed_q_w_sigma_sigma,
+            dtype=self.dtype)
+
+    with tf.variable_scope('q_w_0'):
+      w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
+                                init_sigma=self.init_sigma_q_w_0_sigma,
+                                fixed_sigma=self.fixed_q_w_0_sigma,
+                                dtype=self.dtype)
+
+      self.params = {'w_0': w_0, 'w_1_mu': w_1_mu, 'w_1_sigma': w_1_sigma,
+                     'z_1': z_1}
+
+  def log_prob(self, q_samples):
+    """Get the log joint of variational family: log(q(z, w_mu, w_sigma, w_0)).
+
+    Args:
+      q_samples: dict. samples of latent variables
+
+    Returns:
+      log_prob: tensor log-probability summed over dimensions of the variables
+    """
+    w_0 = q_samples['w_0']
+    z_1 = q_samples['z_1']
+    w_1_mu = q_samples['w_1_mu']
+    w_1_sigma = q_samples['w_1_sigma']
+
+    log_prob = 0.
+    # preserve the minibatch dimension [0]
+    log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1), [1, 2])
+    # w_1, w_0 are global, so reduce_sum across all dims
+    log_prob += tf.reduce_sum(w_1_mu.distribution.log_pdf(w_1_mu))
+    log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0))
+    if not self.fixed_p_z_sigma:
+      log_prob += tf.reduce_sum(w_1_sigma.distribution.log_pdf(w_1_sigma))
+    return log_prob
+
+
+class GammaNormalRDEF(object):
+  """Class for a recurrent DEF with normal latent variables and normal weights.
+  """
+
+  def __init__(self, n_timesteps, batch_size, p_w_shape_sigma, p_w_mean_sigma,
+               p_z_shape, p_z_mean, fixed_p_z_mean, z_dim, n_samples_latents,
+               use_bias_observations, dtype):
+    """Initializes the NormalNormalRDEF class.
+
+    Args:
+      n_timesteps: int. number of timesteps
+      batch_size: int. batch size
+      p_w_shape_sigma: float. prior for the weights for the mean of the latent
+          variables
+      p_w_mean_sigma: float. prior for the weights for the shape of the
+          latent variables
+      p_z_shape: float. prior for shape.
+      p_z_mean: floating point prior for the latent variables
+      fixed_p_z_mean: bool. whether the prior mean is learned
+      z_dim: int. dimension of each latent variable
+      n_samples_latents: number of samples of latent variables
+      use_bias_observations: whether to use bias terms
+      dtype: dtype
+    """
+
+    self.n_timesteps = n_timesteps
+    self.batch_size = batch_size
+    self.p_w_shape_sigma = p_w_shape_sigma
+    self.p_w_mean_sigma = p_w_mean_sigma
+    self.p_z_shape = p_z_shape
+    self.p_z_mean = p_z_mean
+    self.fixed_p_z_mean = fixed_p_z_mean
+    self.z_dim = z_dim
+    self.n_samples_latents = n_samples_latents
+    self.use_bias_observations = use_bias_observations
+    self.use_bias_latents = False
+    self.dtype = dtype
+
+  def log_prob(self, params, x):
+    """Returns the log joint. log p(x | z, w)p(z)log p(w); [batch_size].
+
+    Args:
+      params: dict. dictionary of samples of the latent variables.
+      x: tensor. minibatch of examples
+
+    Returns:
+      The log joint of the GammaNormalRDEF probability model.
+    """
+    z_1 = params['z_1']
+    w_1_mean = params['w_1_mean']
+    w_1_shape = params['w_1_shape']
+    log_p_x_zw, p = util.build_bernoulli_log_likelihood(
+        params, x, self.batch_size, n_samples_latents=self.n_samples_latents,
+        use_bias_observations=self.use_bias_observations)
+    self.p_x_zw_bernoulli_p = p
+
+    log_p_z, log_p_w_shape, log_p_w_mean = self.build_recurrent_layer(
+        z_1, w_1_shape, w_1_mean)
+    return log_p_x_zw + log_p_z + log_p_w_shape + log_p_w_mean
+
+  def build_recurrent_layer(self, z, w_shape, w_mean):
+    """Creates a gaussian layer of the recurrent DEF.
+
+    Args:
+      z: sampled gamma latent variables,
+          shape [n_samples_latents, batch_size, n_timesteps, z_dim]
+      w_shape: single sample of gaussian stochastic weights for shape,
+          shape [z_dim, z_dim]
+      w_mean: single sample of gaussian stochastic weights for mean,
+          shape [z_dim, z_dim]
+
+    Returns:
+      log_p_z: log prior of latent variables evaluated at the samples z.
+      log_p_w_shape: log density of the weights evaluated at the sampled weights
+      log_p_w_mean: log density of weights for stddev.
+    """
+    # the prior for the weights p(w) has two parts: p(w_shape) and p(w_mean)
+    # prior for the weights for the mean parameter
+    cast = lambda x: np.array(x, self.dtype)
+    p_w_shape = distributions.Normal(mu=cast(0.),
+                                     sigma=cast(self.p_w_shape_sigma),
+                                     validate_args=False)
+    log_p_w_shape = tf.reduce_sum(p_w_shape.log_pdf(w_shape))
+
+    if self.fixed_p_z_mean:
+      log_p_w_mean = 0.0
+    else:
+      # prior for the weights for the standard deviation
+      p_w_mean = distributions.Normal(mu=cast(0.),
+                                      sigma=cast(self.p_w_mean_sigma),
+                                      validate_args=False)
+      log_p_w_mean = tf.reduce_sum(p_w_mean.log_pdf(w_mean))
+
+    # need this for indexing npy-style
+    z = z.value()
+
+    # the prior for the latent variable at the first timestep is just 0, 1
+    z_t0 = z[:, :, 0, :]
+    # alpha is shape, beta is inverse scale. we set the scale to be the mean
+    # over the shape, so beta = shape / mean.
+    p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
+                                 beta=cast(self.p_z_shape / self.p_z_mean),
+                                 validate_args=False)
+    log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 2)
+
+    # the prior for subsequent timesteps is off by one
+    shape = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
+                            tf.pack([tf.pack([w_shape] * self.batch_size)]
+                                    * self.n_samples_latents))
+    shape = util.clip_shape(shape)
+
+    if self.fixed_p_z_mean:
+      mean = self.p_z_mean
+    else:
+      wz = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
+                           tf.pack([tf.pack([w_mean] * self.batch_size)]
+                                   * self.n_samples_latents))
+      mean = tf.nn.softplus(wz)
+      mean = util.clip_mean(mean)
+    p_z_t1_to_end = distributions.Gamma(alpha=shape,
+                                        beta=shape / mean,
+                                        validate_args=False)
+    log_p_z_t1_to_end = tf.reduce_sum(
+        p_z_t1_to_end.log_pdf(z[:, :, 1:, :]), [2, 3])
+    log_p_z = log_p_z_t0 + log_p_z_t1_to_end
+    return log_p_z, log_p_w_shape, log_p_w_mean
+
+  def recurrent_layer_sample(self, w_shape, w_mean, n_samples_latents,
+                             b_shape=None, b_mean=None):
+    """Sample from the model, with learned latent weights.
+
+    Args:
+      w_shape: latent weights for the mean parameter. [z_dim, z_dim]
+      w_mean: latent weights for the standard deviation. [z_dim, z_dim]
+      n_samples_latents: how many samples
+      b_shape: bias for shape parameters
+      b_mean: bias for mean parameters
+
+    Returns:
+      z: samples from the generative process.
+    """
+    cast = lambda x: np.array(x, self.dtype)
+    p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
+                                 beta=cast(self.p_z_shape / self.p_z_mean),
+                                 validate_args=False)
+    z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
+
+    z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
+
+    def sample_timestep(z_t_prev, w_shape, w_mean, b_shape=b_shape,
+                        b_mean=b_mean):
+      """Sample a single timestep.
+
+      Args:
+        z_t_prev: previous timestep latent variable,
+            shape [n_samples_latents, z_dim]
+        w_shape: latent weights for shape param, shape [z_dim, z_dim]
+        w_mean: latent weights for mean param, shape [z_dim, z_dim]
+        b_shape: bias for shape parameters
+        b_mean: bias for mean parameters
+
+      Returns:
+        z_t: A sample of a latent variable for all timesteps
+      """
+      wz_t = tf.matmul(z_t_prev, w_shape)
+      if self.use_bias_latents:
+        wz_t += b_shape
+      shape_t = tf.nn.softplus(wz_t)
+      shape_t = util.clip_shape(shape_t)
+      if self.fixed_p_z_mean:
+        mean_t = self.p_z_mean
+      else:
+        wz_t = tf.matmul(z_t_prev, w_mean)
+        if self.use_bias_latents:
+          wz_t += b_mean
+        mean_t = tf.nn.softplus(wz_t)
+        mean_t = util.clip_mean(mean_t)
+      p_z_t = distributions.Gamma(alpha=shape_t,
+                                  beta=shape_t / mean_t,
+                                  validate_args=False)
+      z_t = p_z_t.sample_n(n=1)[0, :, :]
+      return z_t
+    z_list = [z_t0]
+    for _ in range(self.n_timesteps - 1):
+      z_t = sample_timestep(z_list[-1], w_shape, w_mean)
+      z_list.append(z_t)
+    # pack into shape [n_timesteps, n_samples_latents, z_dim]
+    z = tf.pack(z_list)
+    # transpose into [n_samples_latents, n_timesteps, z_dim]
+    z = tf.transpose(z, perm=[1, 0, 2])
+    return z
+
+  def likelihood_sample(self, params, z_1, n_samples):
+    return util.bernoulli_likelihood_sample(
+        params, z_1, n_samples,
+        use_bias_observations=self.use_bias_observations)
+
+
+class GammaNormalRDEFVariational(object):
+  """Creates the variational family for the recurrent DEF model.
+
+    Variational family:
+      q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
+        [n_examples, n_timesteps, z_dim]
+      q_w_1_shape: gaussian approximate posterior q(w_1) for mean weights of
+        (recurrent) layer [z_dim, z_dim]
+      q_w_1_mean: gaussian approximate posterior q(w_1) for std weights, first
+        (recurrent) layer [z_dim, z_dim]
+      q_w_0: gaussian approximate posterior q(w_0) for weights of observation
+        layer [z_dim, timestep_dim]
+  """
+
+  def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
+               timestep_dim, init_sigma_q_w_shape, init_shape_q_z,
+               init_mean_q_z,
+               init_sigma_q_w_mean, fixed_p_z_mean, fixed_q_z_mean,
+               fixed_q_w_shape_sigma, fixed_q_w_mean_sigma,
+               fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, n_samples_latents,
+               use_bias_observations,
+               dtype):
+    """Initializes the variational family for the NormalNormalRDEF.
+
+    Args:
+      x_indexes: tensor. indices of the datapoints.
+      n_examples: int. number of examples in the dataset.
+      n_timesteps: int. number of timesteps in each datapoint.
+      z_dim: int. dimension of latent variables.
+      timestep_dim: int. dimension of each timestep.
+      init_sigma_q_w_shape: float. initial variance for weights for the means of
+          the latent variables.
+      init_shape_q_z: float. initial variance for the variational distribution
+          for the latent variables.
+      init_mean_q_z: float. initial mean for latent variables variational.
+      init_sigma_q_w_mean: float. initial variance for the weights for the
+          variance of the latent variables.
+      fixed_p_z_mean: bool. whether to keep the prior over latents fixed.
+      fixed_q_z_mean: bool. whether to train the variance of the variational
+          distributions for the latents.
+      fixed_q_w_shape_sigma: bool. whether to train the variance of the weights
+          the latent variables.
+      fixed_q_w_mean_sigma: bool. whether to train the variance of the weights
+          for the variance of the latent variables.
+      fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
+          the observations.
+      init_sigma_q_w_0_sigma: float. initial variance for the observation
+          weights.
+      n_samples_latents: number of samples of latent variables to draw
+      use_bias_observations: whether to use bias terms
+      dtype: dtype
+    """
+    self.x_indexes = x_indexes
+    self.n_examples = n_examples
+    self.n_timesteps = n_timesteps
+    self.z_dim = z_dim
+    self.timestep_dim = timestep_dim
+    self.init_mean_q_z = init_mean_q_z
+    self.init_shape_q_z = init_shape_q_z
+    self.init_sigma_q_w_shape = init_sigma_q_w_shape
+    self.init_sigma_q_w_mean = init_sigma_q_w_mean
+    self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
+    self.fixed_p_z_mean = fixed_p_z_mean
+    self.fixed_q_z_mean = fixed_q_z_mean
+    self.fixed_q_w_shape_sigma = fixed_q_w_shape_sigma
+    self.fixed_q_w_mean_sigma = fixed_q_w_mean_sigma
+    self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
+    self.n_samples_latents = n_samples_latents
+    self.use_bias_observations = use_bias_observations
+    self.dtype = dtype
+    with tf.variable_scope('variational'):
+      self.build_graph()
+
+  @property
+  def sample(self):
+    """Returns a dict of samples of the latent variables."""
+    return self.params
+
+  @property
+  def trainable_variables(self):
+    return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'variational')
+
+  def build_graph(self):
+    """Builds the graph for the variational family for the NormalNormalRDEF."""
+    with tf.variable_scope('q_z_1'):
+      z_1 = util.build_gamma(
+          [self.n_examples, self.n_timesteps, self.z_dim],
+          init_shape=self.init_shape_q_z,
+          init_mean=self.init_mean_q_z,
+          x_indexes=self.x_indexes,
+          fixed_mean=self.fixed_q_z_mean,
+          place_on_cpu=False,
+          n_samples=self.n_samples_latents,
+          dtype=self.dtype)
+
+    with tf.variable_scope('q_w_1_shape'):
+      # half of the weights are for the mean, half for the variance
+      w_1_shape = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
+                                      init_sigma=self.init_sigma_q_w_shape,
+                                      fixed_sigma=self.fixed_q_w_shape_sigma,
+                                      dtype=self.dtype)
+
+    if self.fixed_p_z_mean:
+      w_1_mean = None
+    else:
+      with tf.variable_scope('q_w_1_mean'):
+        w_1_mean = util.build_gaussian(
+            [self.z_dim, self.z_dim],
+            init_mu=0., init_sigma=self.init_sigma_q_w_mean,
+            fixed_sigma=self.fixed_q_w_mean_sigma,
+            dtype=self.dtype)
+
+    with tf.variable_scope('q_w_0'):
+      w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
+                                init_sigma=self.init_sigma_q_w_0_sigma,
+                                fixed_sigma=self.fixed_q_w_0_sigma,
+                                dtype=self.dtype)
+
+    self.params = {'w_0': w_0, 'w_1_shape': w_1_shape, 'w_1_mean': w_1_mean,
+                   'z_1': z_1}
+
+    if self.use_bias_observations:
+      # b_0 = tf.get_variable(
+      #     'b_0', [self.timestep_dim], self.dtype, tf.zeros_initializer,
+      #     collections=[tf.GraphKeys.VARIABLES, 'reparam_variables'])
+      b_0 = util.build_gaussian([self.timestep_dim], init_mu=0.,
+                                init_sigma=0.01, fixed_sigma=False,
+                                dtype=self.dtype)
+      self.params.update({'b_0': b_0})
+
+  def log_prob(self, q_samples):
+    """Get the log joint of variational family: log(q(z, w_shape, w_mean, w_0)).
+
+    Args:
+      q_samples: dict. samples of latent variables.
+
+    Returns:
+      log_prob: tensor log-probability summed over dimensions of the variables
+    """
+    w_0 = q_samples['w_0']
+    z_1 = q_samples['z_1']
+    w_1_shape = q_samples['w_1_shape']
+    w_1_mean = q_samples['w_1_mean']
+
+    log_prob = 0.
+    # preserve the sample and minibatch dimensions [0, 1]
+    log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1.value()), [2, 3])
+    # w_1, w_0 are global, so reduce_sum across all dims
+    log_prob += tf.reduce_sum(w_1_shape.distribution.log_pdf(w_1_shape.value()))
+    log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0.value()))
+    if not self.fixed_p_z_mean:
+      log_prob += tf.reduce_sum(w_1_mean.distribution.log_pdf(w_1_mean.value()))
+    return log_prob

+ 172 - 0
experimental/rdefs/python/ops/rmsprop.py

@@ -0,0 +1,172 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RMSProp for score function gradients and IndexedSlices."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import sys
+
+import tensorflow as tf
+
+
+def _gradients_per_example(loss, variable):
+  """Returns per-example gradients.
+
+  Args:
+    loss: A [n_samples, batch_size] shape tensor
+    variable: A variable to optimize of shape var_shape
+
+  Returns:
+    grad: A tensor of shape [n_samples, *var_shape]
+  """
+  grad_list = [tf.gradients(loss_sample, variable)[0] for loss_sample in
+               tf.unpack(loss)]
+  if isinstance(grad_list[0], tf.IndexedSlices):
+    grad = tf.pack([g.values for g in grad_list])
+    grad = tf.IndexedSlices(values=grad, indices=grad_list[0].indices)
+  else:
+    grad = tf.pack(grad_list)
+  return grad
+
+
+def _cov(a, b):
+  """Calculates covariance between a and b."""
+  v = (a - tf.reduce_mean(a, 0)) * (b - tf.reduce_mean(b, 0))
+  return tf.reduce_mean(v, 0)
+
+
+def _var(a):
+  """Returns the variance across the sample dimension."""
+  _, var = tf.nn.moments(a, [0])
+  return var
+
+
+def _update_mean_square(mean_square, variable):
+  """Update mean square for a variable."""
+  if isinstance(variable, tf.IndexedSlices):
+    square_sum = tf.reduce_sum(tf.square(variable.values), 0)
+    mean_square_lookup = tf.nn.embedding_lookup(mean_square, variable.indices)
+    moving_mean_square = 0.9 * mean_square_lookup + 0.1 * square_sum
+    return tf.scatter_update(mean_square, variable.indices, moving_mean_square)
+  else:
+    square_sum = tf.reduce_sum(tf.square(variable), 0)
+    moving_mean_square = 0.9 * mean_square + 0.1 * square_sum
+  return tf.assign(mean_square, square_sum)
+
+
+def _get_mean_square(variable):
+  with tf.variable_scope('optimizer_state'):
+    mean_square = tf.get_variable(name=variable.name[:-2],
+                                  shape=variable.get_shape(),
+                                  initializer=tf.ones_initializer,
+                                  dtype=variable.dtype.base_dtype)
+  return mean_square
+
+
+def _control_variate(grad, learning_signal):
+  if isinstance(grad, tf.IndexedSlices):
+    grad = grad.values
+  cov = _cov(grad * learning_signal, grad)
+  var = _var(grad)
+  return cov / var
+
+
+def _rmsprop_maximize(learning_rate, learning_signal, log_prob, variable,
+                      clip_min=None, clip_max=None):
+  """Builds rmsprop maximization ops for a single variable."""
+  grad = _gradients_per_example(log_prob, variable)
+  if learning_signal.get_shape().ndims == 2:
+    # if we have multiple samples of latent variables, need to broadcast
+    # grad of shape [n_samples_latents, batch_size, n_timesteps, z_dim]
+    # with learning_signal of shape [n_samples_latents, batch_size]:
+    learning_signal = tf.expand_dims(tf.expand_dims(learning_signal, 2), 2)
+  control_variate = _control_variate(grad, learning_signal)
+  mean_square = _get_mean_square(variable)
+  update_mean_square = _update_mean_square(mean_square, grad)
+  variance_reduced_learning_sig = learning_signal - control_variate
+  update_name = variable.name[:-2] + '/score_function_grad_estimator'
+  if isinstance(grad, tf.IndexedSlices):
+    mean_square_lookup = tf.nn.embedding_lookup(mean_square, grad.indices)
+    mean_square_lookup = tf.expand_dims(mean_square_lookup, 0)
+
+    update_per_sample = (grad.values / tf.sqrt(mean_square_lookup)
+                         * variance_reduced_learning_sig)
+    update = tf.reduce_mean(update_per_sample, 0, name=update_name)
+    step = learning_rate * update
+    if clip_min is None and clip_max is None:
+      apply_step = tf.scatter_add(variable, grad.indices, step)
+    else:
+      var_lookup = tf.nn.embedding_lookup(variable, grad.indices)
+      new_var = var_lookup + step
+      new_var_clipped = tf.clip_by_value(
+          new_var, clip_value_min=clip_min, clip_value_max=clip_max)
+      apply_step = tf.scatter_update(variable, grad.indices, new_var)
+  else:
+    update_per_sample = (grad / tf.sqrt(mean_square)
+                         * variance_reduced_learning_sig)
+    update = tf.reduce_mean(update_per_sample, 0,
+                            name=update_name)
+    step = learning_rate * update
+    if clip_min is None and clip_max is None:
+      apply_step = tf.assign(variable, variable + step)
+    else:
+      new_var = variable + step
+      new_var_clipped = tf.clip_by_value(
+          new_var, clip_value_min=clip_min, clip_value_max=clip_max)
+      apply_step = tf.assign(variable, new_var_clipped)
+  # add to collection for keeping track of stats
+  tf.add_to_collection('non_reparam_variable_grads', update)
+  with tf.control_dependencies([update_mean_square]):
+    train_op = tf.group(apply_step)
+  return train_op
+
+
+def maximize_with_control_variate(learning_rate, learning_signal, log_prob,
+                                  variable_list, global_step=None):
+  """Build a covariance control variate with rmsprop updates.
+
+  Args:
+    learning_rate: Step size
+    learning_signal: Usually the ELBO; the bound we optimize
+        Shape [n_samples, batch_size]
+    log_prob: log probability of samples of latent variables
+    variable_list: List of variables
+    global_step: Global step
+
+  Returns:
+    train_op: Group of operations that apply an RMSProp update with the
+        control variate
+  """
+  train_ops = []
+  for variable in variable_list:
+    clip_max, clip_min = (None, None)
+    if 'shape_softplus_inv' in variable.name:
+      clip_max = sys.float_info.max
+      clip_min = 5e-3
+    elif 'mean_softplus_inv' in variable.name:
+      clip_max = sys.float_info.max
+      clip_min = 1e-5
+    train_ops.append(_rmsprop_maximize(
+        learning_rate, learning_signal, log_prob, variable, clip_max=clip_max,
+        clip_min=clip_min))
+  if global_step is not None:
+    increment_global_step = tf.assign(global_step, global_step + 1)
+    with tf.control_dependencies(train_ops):
+      train_op = tf.group(increment_global_step)
+  else:
+    train_op = tf.group(*train_ops)
+  return train_op

+ 54 - 0
experimental/rdefs/python/ops/tf_lib.py

@@ -0,0 +1,54 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for working with TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import ast
+
+
+class HParams(object):
+
+  """Creates an object for passing around hyperparameter values.
+
+  Use the parse method to overwrite the default hyperparameters with values
+  passed in as a string representation of a Python dictionary mapping
+  hyperparameters to values.
+  Ex.
+  hparams = tf_lib.HParams(batch_size=128, hidden_size=256)
+  hparams.parse('{"hidden_size":512}')
+  assert hparams.batch_size == 128
+  assert hparams.hidden_size == 512
+  """
+
+  def __init__(self, **init_hparams):
+    object.__setattr__(self, 'keyvals', init_hparams)
+
+  def __getattr__(self, key):
+    return self.keyvals.get(key)
+
+  def __setattr__(self, key, value):
+    """Returns None if key does not exist."""
+    self.keyvals[key] = value
+
+  def parse(self, string):
+    new_hparams = ast.literal_eval(string)
+    return HParams(**dict(self.keyvals, **new_hparams))
+
+  def values(self):
+    return self.keyvals

+ 414 - 0
experimental/rdefs/python/ops/util.py

@@ -0,0 +1,414 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import sys
+
+import h5py
+import numpy as np
+import tensorflow as tf
+
+st = tf.contrib.bayesflow.stochastic_tensor
+distributions = tf.contrib.distributions
+
+
+def provide_tfrecords_data(path, split_name, batch_size, n_timesteps,
+                           timestep_dim):
+  """Provides batches of MNIST digits.
+
+  Args:
+    path: String specifying location of tf.records files.
+    split_name: string. name of the split.
+    batch_size: int. batch size.
+    n_timesteps: int. number of timesteps.
+    timestep_dim: int. dimension of each timestep.
+
+  Returns:
+    labels: minibatch tensor of the indices of each datapoint.
+    images: minibatch tensor of images.
+  """
+  # Load the data:
+  image, label = read_and_decode_single_example(
+      os.path.join(path, 'binarized_mnist_{}.tfrecords'.format(split_name)))
+
+  # Preprocess the images.
+  image = tf.reshape(image, [28, 28])
+  if n_timesteps < 28:
+    image = image[0:n_timesteps, :]
+  if timestep_dim < 28:
+    image = image[:, 0:timestep_dim]
+  image = tf.expand_dims(image, 2)
+
+  # Creates a QueueRunner for the pre-fetching operation.
+  images, labels = tf.train.batch(
+      [image, label],
+      batch_size=batch_size,
+      num_threads=15,
+      capacity=batch_size * 5000)
+
+  return labels, images
+
+
+def read_and_decode_single_example(filename):
+  """Read and decode a single example.
+
+  Args:
+    filename: str. path to a tf.records file.
+
+  Returns:
+    image: tensor. a single image.
+    label: tensor. the index for the image.
+  """
+  # first construct a queue containing a list of filenames.
+  # this lets a user split up there dataset in multiple files to keep
+  # size down
+  filename_queue = tf.train.string_input_producer([filename],
+                                                  num_epochs=None)
+  # Unlike the TFRecordWriter, the TFRecordReader is symbolic
+  reader = tf.TFRecordReader()
+  # One can read a single serialized example from a filename
+  # serialized_example is a Tensor of type string.
+  _, serialized_example = reader.read(filename_queue)
+  # The serialized example is converted back to actual values.
+  # One needs to describe the format of the objects to be returned
+  features = tf.parse_single_example(
+      serialized_example,
+      features={
+          # We know the length of both fields. If not the
+          # tf.VarLenFeature could be used
+          'image': tf.FixedLenFeature([784], tf.float32),
+          'label': tf.FixedLenFeature([], tf.int64)
+      })
+  # now return the converted data
+  image = features['image']
+  label = features['label']
+  return image, label
+
+
+def provide_hdf5_data(path, split_name, n_examples, batch_size, n_timesteps,
+                      timestep_dim, dataset):
+  """Provides batches of MNIST digits.
+
+  Args:
+   path: str. path to the  dataset.
+   split_name: string. name of the split.
+   n_examples: int. number of examples to serve from the dataset.
+   batch_size: int. batch size.
+   n_timesteps: int. number of timesteps.
+   timestep_dim: int. dimension of each timestep.
+   dataset: String specifying dataset.
+
+  Returns:
+    data_iterator: a generator of minibatches.
+  """
+  if dataset == 'alternating':
+    data_list = []
+    start_zeros = np.vstack([np.zeros(timestep_dim) if t % 2 == 0 else
+                             np.ones(timestep_dim) for t in range(n_timesteps)])
+    start_ones = np.roll(start_zeros, 1, axis=0)
+    start_zeros = start_zeros.flatten()
+    start_ones = start_ones.flatten()
+    data_list = [start_zeros if n % 2 == 0 else
+                 start_ones for n in range(n_examples)]
+    data = np.vstack(data_list)
+  elif dataset == 'MNIST':
+    f = h5py.File(path, 'r')
+    if split_name == 'train_and_valid':
+      train = f['train'][:]
+      valid = f['valid'][:]
+      data = np.vstack([train, valid])
+    else:
+      data = f[split_name][:]
+
+    data = data[0:n_examples]
+
+  # create indexes for the data points.
+  indexed_data = zip(range(len(data)), np.split(data, len(data)))
+  def data_iterator():
+    """Generate minibatches of examples from the dataset."""
+    batch_idx = 0
+    while True:
+      # shuffle data
+      idxs = np.arange(0, len(data))
+      np.random.shuffle(idxs)
+      shuf_data = [indexed_data[idx] for idx in idxs]
+      for batch_idx in range(0, len(data), batch_size):
+        indexed_images_batch = shuf_data[batch_idx:batch_idx+batch_size]
+        indexes, images_batch = zip(*indexed_images_batch)
+        images_batch = np.vstack(images_batch)
+        if timestep_dim == 784:
+          images_batch = images_batch.reshape(
+              (batch_size, 1, 784, 1))
+        else:
+          if dataset == 'alternating':
+            images_batch = images_batch.reshape(
+                (batch_size, n_timesteps, timestep_dim, 1))
+          else:
+            images_batch = images_batch.reshape(
+                (batch_size, 28, 28, 1))[:, :n_timesteps, :timestep_dim]
+        yield indexes, images_batch
+
+  return data_iterator()
+
+
+def inv_softplus(x):
+  """Inverse softplus."""
+  return np.log(np.exp(x) - 1.)
+
+
+def softplus(x):
+  """Softplus."""
+  return np.log(np.exp(x) + 1.)
+
+
+def build_gamma(shape, init_shape=1., init_mean=1., x_indexes=None,
+                fixed_mean=False, place_on_cpu=False, n_samples=1,
+                dtype='float64'):
+  """Builds a Gaussian DistributionTensor.
+
+  Truncation: we truncate shape and mean parameters because gamma sampling is
+  numerically unstable. Reference: http://ajbc.io/resources/bbvi_for_gammas.pdf
+
+  Args:
+    shape: list. shape of the distribution.
+    init_shape: float. initial shape
+    init_mean: float. initial standard deviation
+    x_indexes: tensor. integer placeholder for mean-field parameters
+    fixed_mean: bool. whether to learn mean
+    place_on_cpu: bool. whether to place the op on cpu.
+    n_samples: number of samples
+    dtype: dtype
+
+  Returns:
+    A Gaussian DistributionTensor of the specified shape, with variables for
+    mean and standard deviation safely parametrized to avoid over/underflow.
+  """
+  if place_on_cpu:
+    with tf.device('/cpu:0'):
+      shape_softplus_inv = tf.get_variable(
+          'shape_softplus_inv', shape, dtype, tf.constant_initializer(
+              inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
+                                                      'non_reparam_variables'])
+  else:
+    shape_softplus_inv = tf.get_variable(
+        'shape_softplus_inv', shape, dtype, tf.constant_initializer(
+            inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
+                                                    'non_reparam_variables'])
+  if fixed_mean:
+    mean_softplus_inv = None
+  else:
+    mean_softplus_arg = tf.constant_initializer(inv_softplus(init_mean))
+    if place_on_cpu:
+      with tf.device('/cpu:0'):
+        mean_softplus_inv = tf.get_variable(
+            'mean_softplus_inv', shape, dtype, mean_softplus_arg)
+    else:
+      mean_softplus_inv = tf.get_variable('mean_softplus_inv', shape,
+                                          dtype, mean_softplus_arg,
+                                          collections=[tf.GraphKeys.VARIABLES,
+                                                       'non_reparam_variables'])
+
+  if x_indexes is not None:
+    shape_softplus_inv_batch = tf.nn.embedding_lookup(
+        shape_softplus_inv, x_indexes)
+    if not fixed_mean:
+      mean_softplus_inv_batch = tf.nn.embedding_lookup(
+          mean_softplus_inv, x_indexes)
+  else:
+    shape_softplus_inv_batch, mean_softplus_inv_batch = (shape_softplus_inv,
+                                                         mean_softplus_inv)
+  shape_batch = tf.nn.softplus(shape_softplus_inv_batch)
+
+  if fixed_mean:
+    mean_batch = tf.constant(init_mean)
+  else:
+    mean_batch = tf.nn.softplus(mean_softplus_inv_batch)
+
+  with st.value_type(st.SampleValue(n=n_samples)):
+    dist = st.StochasticTensor(distributions.Gamma,
+                               alpha=shape_batch,
+                               beta=shape_batch / mean_batch,
+                               validate_args=False)
+  return dist
+
+
+def truncate(max_or_min, var, val):
+  """Truncate variable to a max or min value."""
+  if max_or_min == 'max':
+    tf_fn = tf.minimum
+  elif max_or_min == 'min':
+    tf_fn = tf.maximum
+
+  if isinstance(var, tf.IndexedSlices):
+    assign_op = tf.assign(var.values, tf_fn(var.values, inv_softplus(val)))
+  else:
+    assign_op = tf.assign(var, tf_fn(var, inv_softplus(val)))
+  return assign_op
+
+
+def build_gaussian(shape, init_mu=0., init_sigma=1.0, x_indexes=None,
+                   fixed_sigma=False, place_on_cpu=False, dtype='float64'):
+  """Builds a Gaussian DistributionTensor.
+
+  Args:
+    shape: list. shape of the distribution.
+    init_mu: float. initial mean
+    init_sigma: float. initial standard deviation
+    x_indexes: tensor. integer placeholder for mean-field parameters
+    fixed_sigma: bool. whether to learn sigma
+    place_on_cpu: bool. whether to place the op on cpu.
+    dtype: dtpe
+
+  Returns:
+    A Gaussian DistributionTensor of the specified shape, with variables for
+    mean and standard deviation safely parametrized to avoid over/underflow.
+  """
+  if place_on_cpu:
+    with tf.device('/cpu:0'):
+      mu = tf.get_variable(
+          'mu', shape, dtype, tf.random_normal_initializer(
+              mean=init_mu, stddev=0.1))
+  else:
+    mu = tf.get_variable('mu', shape, dtype,
+                         tf.random_normal_initializer(mean=init_mu, stddev=0.1),
+                         collections=[tf.GraphKeys.VARIABLES,
+                                      'reparam_variables'])
+  if fixed_sigma:
+    sigma_softplus_inv = None
+  else:
+    sigma_softplus_arg = tf.truncated_normal_initializer(
+        mean=inv_softplus(init_sigma), stddev=0.1)
+    if place_on_cpu:
+      with tf.device('/cpu:0'):
+        sigma_softplus_inv = tf.get_variable(
+            'sigma_softplus_inv', shape, dtype, sigma_softplus_arg)
+    else:
+      sigma_softplus_inv = tf.get_variable('sigma_softplus_inv', shape,
+                                           dtype, sigma_softplus_arg,
+                                           collections=[tf.GraphKeys.VARIABLES,
+                                                        'reparam_variables'])
+
+  if x_indexes is not None:
+    mu_batch = tf.nn.embedding_lookup(mu, x_indexes)
+    if not fixed_sigma:
+      sigma_softplus_inv_batch = tf.nn.embedding_lookup(
+          sigma_softplus_inv, x_indexes)
+  else:
+    mu_batch, sigma_softplus_inv_batch = mu, sigma_softplus_inv
+
+  if fixed_sigma:
+    sigma_batch = np.array(init_sigma, dtype)
+  else:
+    sigma_batch = tf.maximum(tf.nn.softplus(sigma_softplus_inv_batch), 1e-5)
+
+  dist = st.StochasticTensor(distributions.Normal, mu=mu_batch,
+                             sigma=sigma_batch, validate_args=False)
+  return dist
+
+
+def get_np_dtype(tensor):
+  """Returns the numpy dtype."""
+  return np.float32 if 'float32' in str(tensor.dtype) else np.float64
+
+
+def build_bernoulli_log_likelihood(params, x, batch_size,
+                                   n_samples_latents=1,
+                                   use_bias_observations=False):
+  """Builds the likelihood given stochastic latents and weights.
+
+  Args:
+    params: dict that contains:
+        z_1 tensor. sampled latent variables
+          [n_samples_latents] + [batch_size, n_timesteps, z_dim]
+        w_0 tensor. sampled stochastic weights [z_dim, timestep_dim]
+        b_0 optional tensor. biases [timestep_dim]
+    x: tensor. minibatch of examples
+    batch_size: integer number of minibatch examples.
+    n_samples_latents: number of samples of latent variables
+    use_bias_observations: use bias
+
+  Returns:
+    likelihood: the bernoulli likelihood distribution of the data.
+        [n_samples, batch_size, n_timesteps, timestep_dim]
+  """
+  z_1 = params['z_1']
+  w_0 = params['w_0']
+  if use_bias_observations:
+    b_0 = params['b_0']
+  if n_samples_latents > 1:
+    wz = tf.batch_matmul(z_1, tf.pack([tf.pack([w_0] * batch_size)]
+                                      * n_samples_latents))
+    if use_bias_observations:
+      wz += b_0
+    logits = tf.expand_dims(wz, 4)
+    dims_to_reduce = [2, 3, 4]
+  else:
+    wz = tf.batch_matmul(z_1, tf.pack([w_0] * batch_size))
+    if use_bias_observations:
+      wz += b_0
+    logits = tf.expand_dims(wz, 3)
+    dims_to_reduce = [1, 2, 3]
+  p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
+  log_p_x_zw = tf.reduce_sum(p_x_zw.log_pmf(x), dims_to_reduce)
+  print('log_p_x_zw', log_p_x_zw.get_shape())
+  print('logits', logits.get_shape())
+  print('z_1', z_1.value().get_shape())
+  return log_p_x_zw, p_x_zw.p
+
+
+def clip_mean(mean):
+  """Clip mean parameter of gamma."""
+  return tf.clip_by_value(mean, clip_value_max=sys.float_info.max,
+                          clip_value_min=1e-5)
+
+
+def clip_shape(shape):
+  """Clip shape parameter of gamma."""
+  return tf.clip_by_value(shape, clip_value_max=sys.float_info.max,
+                          clip_value_min=5e-3)
+
+
+def bernoulli_likelihood_sample(params, z_1, n_samples,
+                                use_bias_observations=False):
+  """Sample from the model likelihood.
+
+  Args:
+    params: dict that contains
+        w_0 tensor. sample of latent weights
+        b_0 optional tensor. bias
+    z_1: tensor. sample of latent variables
+    n_samples: int. number of samples to draw
+    use_bias_observations: use bias
+
+  Returns:
+    A tensor sample from the model likelihood.
+  """
+  w_0 = params['w_0']
+  if isinstance(z_1, st.StochasticTensor):
+    z_1 = z_1.value()
+  if z_1.get_shape().ndims == 4:
+    z_1 = z_1[0, :, :, :]
+  wz = tf.batch_matmul(z_1, tf.pack([w_0] * n_samples))
+  if use_bias_observations:
+    wz += params['b_0']
+  logits = tf.expand_dims(wz, 3)
+  p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
+  return tf.cast(p_x_zw.sample_n(n=1)[0, :, :, :, :], logits.dtype)

+ 102 - 0
experimental/rdefs/python/save_mnist_tf_records.py

@@ -0,0 +1,102 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Save MNIST into tf.records format."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import urllib
+
+import numpy as np
+import tensorflow as tf
+# from https://github.com/yburda/iwae/blob/master/datasets.py
+
+DATASETS_DIR = '/tmp/BinarizedMNIST'
+if not os.path.exists(DATASETS_DIR):
+  os.makedirs(DATASETS_DIR)
+subdatasets = ['train', 'valid', 'test']
+for subdataset in subdatasets:
+  filename = 'binarized_mnist_{}.amat'.format(subdataset)
+  url = 'http://www.cs.toronto.edu/~larocheh/public/datasets/binarized_mnist/binarized_mnist_{}.amat'.format(subdataset)  # pylint: disable=line-too-long
+  local_filename = os.path.join(DATASETS_DIR, filename)
+  if not os.path.exists(local_filename):
+    urllib.urlretrieve(url, local_filename)
+
+
+def binarized_mnist_fixed_binarization():
+  """parse .mat file and get numpy array of MNIST."""
+  def lines_to_np_array(lines):
+    return np.array([[int(i) for i in line.split()] for line in lines])
+  with open(os.path.join(DATASETS_DIR, 'binarized_mnist_train.amat')) as f:
+    lines = f.readlines()
+  train_data = lines_to_np_array(lines).astype('float32')
+  with open(os.path.join(DATASETS_DIR, 'binarized_mnist_valid.amat')) as f:
+    lines = f.readlines()
+  validation_data = lines_to_np_array(lines).astype('float32')
+  with open(os.path.join(DATASETS_DIR, 'binarized_mnist_test.amat')) as f:
+    lines = f.readlines()
+  test_data = lines_to_np_array(lines).astype('float32')
+  return train_data, validation_data, test_data
+
+train, validation, test = binarized_mnist_fixed_binarization()
+
+train_and_validation = np.vstack([train, validation])
+
+data_dict = {'train': train, 'valid': validation, 'test': test,
+             'train_and_valid': np.vstack([train, validation])}
+
+
+def serialize_array_with_label(array, path):
+  writer = tf.python_io.TFRecordWriter(path)
+  indices = range(array.shape[0])
+  # one MUST randomly shuffle data before putting it into one of these
+  # formats. Without this, one cannot make use of tensorflow's great
+  # out of core shuffling.
+  np.random.shuffle(indices)
+  # iterate over each example
+  for example_idx in indices:
+    features = array[example_idx]
+
+    # construct the Example proto boject
+    example = tf.train.Example(
+        # Example contains a Features proto object
+        features=tf.train.Features(
+            # Features contains a map of string to Feature proto objects
+            feature={
+                # A Feature contains one of either a int64_list,
+                # float_list, or bytes_list
+                'image': tf.train.Feature(
+                    float_list=tf.train.FloatList(
+                        value=features.astype('float'))),
+                'label': tf.train.Feature(
+                    int64_list=tf.train.Int64List(value=[example_idx]))
+            }
+        )
+    )
+    # use the proto object to serialize the example to a string
+    serialized = example.SerializeToString()
+    # write the serialized object to disk
+    writer.write(serialized)
+
+subdatasets = ['train', 'valid', 'test', 'train_and_valid']
+
+for subdataset in subdatasets:
+  print 'serializing %s' % subdataset
+  file_name = os.path.join(
+      DATASETS_DIR, 'binarized_mnist_{}_labeled.tfrecords'.format(subdataset))
+  if not os.path.exists(file_name):
+    serialize_array_with_label(data_dict[subdataset], file_name)

+ 355 - 0
experimental/rdefs/python/train_gamma_normal_def.py

@@ -0,0 +1,355 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trains a recurrent DEF with gamma latent variables and gaussian weights.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+
+
+import numpy as np
+from scipy.misc import imsave
+import tensorflow as tf
+
+from ops import inference
+from ops import model_factory
+from ops import rmsprop
+from ops import tf_lib
+from ops import util
+
+flags = tf.flags
+flags.DEFINE_string('master', 'local',
+                    'BNS name of the TensorFlow master to use.')
+flags.DEFINE_string('logdir', '/tmp/write_logs',
+                    'Directory where to write event logs.')
+flags.DEFINE_integer('seed', 41312, 'Random seed for TensorFlow and Numpy')
+flags.DEFINE_boolean('delete_logdir', True, 'Whether to clear the log dir.')
+flags.DEFINE_string('trials_root_dir',
+                    '/tmp/logs',
+                    'Directory where to write event logs.')
+flags.DEFINE_integer(
+    'save_summaries_secs', 10,
+    'The frequency with which summaries are saved, in seconds.')
+flags.DEFINE_integer('save_interval_secs', 10,
+                     'The frequency with which the model is saved, in seconds.')
+flags.DEFINE_integer('max_steps', 200000,
+                     'The maximum number of gradient steps.')
+flags.DEFINE_integer('print_stats_every', 100, 'print stats every')
+flags.DEFINE_integer(
+    'ps_tasks', 0,
+    'The number of parameter servers. If the value is 0, then the parameters '
+    'are handled locally by the worker.')
+flags.DEFINE_integer(
+    'task', 0,
+    'The Task ID. This value is used when training with multiple workers to '
+    'identify each worker.')
+flags.DEFINE_string('trainer', 'supervisor', 'slim/local/supervisor')
+flags.DEFINE_integer('samples_to_save', 1, 'number of samples to save')
+flags.DEFINE_boolean('check_nans', False, 'add ops to check for nans.')
+flags.DEFINE_string('data_path',
+                    '/readahead/256M/cns/in-d/home/jaana/binarized_mnist_new',
+                    'Where to read the data from.')
+
+FLAGS = flags.FLAGS
+
+sg = tf.contrib.bayesflow.stochastic_graph
+distributions = tf.contrib.distributions
+
+
+def run_training(hparams, train_dir, max_steps, tuner, container='',
+                 trainer='supervisor'):
+  """Trains a Gaussian Recurrent DEF.
+
+  Args:
+    hparams: A tf.HParams object with hyperparameters for training.
+    train_dir: Where to store events files and checkpoints.
+    max_steps: Integer number of steps to train.
+    tuner: An instance of a vizier tuner.
+    container: String specifying container for resource sharing.
+    trainer: Train locally by loading an hdf5 file or with Supervisor.
+
+  Returns:
+    sess: Optionally, the session for training.
+    vi: Optionally, VariationalInference object that has been trained.
+
+  Raises:
+    ValueError: if ELBO is nan.
+  """
+  hps = hparams
+  tf.set_random_seed(FLAGS.seed)
+  np.random.seed(FLAGS.seed)
+  g = tf.Graph()
+  if FLAGS.ps_tasks > 0:
+    device_fn = tf.ReplicaDeviceSetter(FLAGS.ps_tasks)
+  else:
+    device_fn = None
+  with g.as_default(), g.device(device_fn), tf.container(container):
+    if trainer == 'local':
+      x_indexes = tf.placeholder(tf.int32, [hps.batch_size])
+      x = tf.placeholder(tf.float32,
+                         [hps.batch_size, hps.n_timesteps, hps.timestep_dim, 1])
+      data_iterator = util.provide_hdf5_data(
+          FLAGS.data_path,
+          'train',
+          hps.n_examples,
+          hps.batch_size,
+          hps.n_timesteps,
+          hps.timestep_dim,
+          hps.dataset)
+    else:
+      x_indexes, x = util.provide_tfrecords_data(
+          FLAGS.data_path,
+          'train_labeled',
+          hps.batch_size,
+          hps.n_timesteps,
+          hps.timestep_dim)
+
+    data = {'x': x, 'x_indexes': x_indexes}
+
+    model = model_factory.GammaNormalRDEF(
+        n_timesteps=hps.n_timesteps,
+        batch_size=hps.batch_size,
+        p_z_shape=hps.p_z_shape,
+        p_z_mean=hps.p_z_mean,
+        p_w_mean_sigma=hps.p_w_mean_sigma,
+        fixed_p_z_mean=hps.fixed_p_z_mean,
+        p_w_shape_sigma=hps.p_w_shape_sigma,
+        z_dim=hps.z_dim,
+        use_bias_observations=hps.use_bias_observations,
+        n_samples_latents=hps.n_samples_latents,
+        dtype=hps.dtype)
+
+    variational = model_factory.GammaNormalRDEFVariational(
+        x_indexes=x_indexes,
+        n_examples=hps.n_examples,
+        n_timesteps=hps.n_timesteps,
+        z_dim=hps.z_dim,
+        timestep_dim=hps.timestep_dim,
+        init_shape_q_z=hps.init_shape_q_z,
+        init_mean_q_z=hps.init_mean_q_z,
+        init_sigma_q_w_mean=hps.p_w_mean_sigma * hps.init_q_sigma_scale,
+        init_sigma_q_w_shape=hps.p_w_shape_sigma * hps.init_q_sigma_scale,
+        init_sigma_q_w_0_sigma=hps.p_w_shape_sigma * hps.init_q_sigma_scale,
+        fixed_p_z_mean=hps.fixed_p_z_mean,
+        fixed_q_z_mean=hps.fixed_q_z_mean,
+        fixed_q_w_mean_sigma=hps.fixed_q_w_mean_sigma,
+        fixed_q_w_shape_sigma=hps.fixed_q_w_shape_sigma,
+        fixed_q_w_0_sigma=hps.fixed_q_w_0_sigma,
+        n_samples_latents=hps.n_samples_latents,
+        use_bias_observations=hps.use_bias_observations,
+        dtype=hps.dtype)
+
+    vi = inference.VariationalInference(model, variational, data)
+    vi.build_graph()
+
+    # Build prior and posterior predictive samples
+    z_1_prior_sample = model.recurrent_layer_sample(
+        variational.sample['w_1_shape'], variational.sample['w_1_mean'],
+        hps.batch_size)
+    prior_predictive = model.likelihood_sample(
+        variational.sample, z_1_prior_sample, hps.batch_size)
+    posterior_predictive = model.likelihood_sample(
+        variational.sample, variational.sample['z_1'], hps.batch_size)
+
+    # Build summaries.
+    float32 = lambda x: tf.cast(x, tf.float32)
+    tf.image_summary('prior_predictive',
+                     float32(prior_predictive),
+                     max_images=10)
+    tf.image_summary('posterior_predictive',
+                     float32(posterior_predictive),
+                     max_images=10)
+    tf.scalar_summary('ELBO', vi.scalar_elbo / hps.batch_size)
+    tf.scalar_summary('log_p', tf.reduce_mean(vi.log_p))
+    tf.scalar_summary('log_q', tf.reduce_mean(vi.log_q))
+
+    global_step = tf.contrib.framework.get_or_create_global_step()
+
+    # Specify optimization scheme.
+    optimizer = tf.train.AdamOptimizer(learning_rate=hps.learning_rate)
+    if hps.control_variate == 'none':
+      train_op = optimizer.minimize(-vi.surrogate_elbo, global_step=global_step)
+    elif hps.control_variate == 'covariance':
+      train_non_reparam = rmsprop.maximize_with_control_variate(
+          learning_rate=hps.learning_rate,
+          learning_signal=vi.elbo,
+          log_prob=vi.log_q,
+          variable_list=tf.get_collection('non_reparam_variables'),
+          global_step=global_step)
+      grad_tensors = [v.values if 'embedding_lookup' in v.name else v
+                      for v in tf.get_collection('non_reparam_variable_grads')]
+
+      train_reparam = optimizer.minimize(
+          -tf.reduce_mean(vi.elbo, 0),  # optimize the mean across samples
+          var_list=tf.get_collection('reparam_variables'))
+      train_op = tf.group(train_reparam, train_non_reparam)
+
+    if trainer == 'supervisor':
+      global_step = tf.contrib.framework.get_or_create_global_step()
+      train_op = optimizer.minimize(-vi.elbo, global_step=global_step)
+      summary_op = tf.merge_all_summaries()
+      saver = tf.train.Saver()
+      sv = tf.Supervisor(
+          logdir=train_dir,
+          is_chief=(FLAGS.task == 0),
+          saver=saver,
+          summary_op=summary_op,
+          global_step=global_step,
+          save_summaries_secs=FLAGS.save_summaries_secs,
+          save_model_secs=FLAGS.save_summaries_secs,
+          recovery_wait_secs=5)
+      sess = sv.PrepareSession(FLAGS.master)
+      sv.StartQueueRunners(sess)
+      local_step = 0
+      while not sv.ShouldStop():
+        _, np_elbo, np_global_step = sess.run(
+            [train_op, vi.elbo, global_step])
+        if tuner is not None:
+          if np.isnan(np_elbo):
+            tuner.report_done(infeasible=True, infeasible_reason='ELBO is nan')
+            should_stop = True
+          else:
+            should_stop = tuner.report_measure(float(np_elbo),
+                                               global_step=np_global_step)
+            if should_stop:
+              tuner.report_done()
+              sv.RequestStop()
+        if np_global_step >= max_steps:
+          break
+        if local_step % FLAGS.print_stats_every == 0:
+          print 'step %d: %g' % (np_global_step - 1, np_elbo / hps.batch_size)
+        local_step += 1
+      sv.Stop()
+      sess.close()
+    elif trainer == 'local':
+      sess = tf.InteractiveSession()
+      sess.run(tf.initialize_all_variables())
+      t0 = time.time()
+      if tf.gfile.Exists(train_dir):
+        tf.gfile.DeleteRecursively(train_dir)
+        tf.gfile.MakeDirs(train_dir)
+      else:
+        tf.gfile.MakeDirs(train_dir)
+      for i in range(max_steps):
+        indexes, images = data_iterator.next()
+        feed_dict = {x_indexes: indexes, x: images}
+        if i % FLAGS.print_stats_every == 0:
+          _, np_prior_predictive, np_posterior_predictive = sess.run(
+              [train_op, prior_predictive, posterior_predictive],
+              feed_dict)
+          print 'prior_predictive', np_prior_predictive.flatten()
+          print 'posterior_predictive', np_posterior_predictive.flatten()
+          print 'data', images.flatten()
+          examples_per_s = (hps.batch_size * FLAGS.print_stats_every /
+                            (time.time() - t0))
+          q_z = variational.params['z_1'].distribution
+          alpha = q_z.alpha
+          beta = q_z.beta
+          mean = alpha / beta
+          grad_list = []
+          elbo_list = []
+          for k in range(100):
+            elbo_list.append(vi.elbo.eval(feed_dict))
+            grads = sess.run(grad_tensors, feed_dict)
+            grad_list.append(grads)
+          np_elbo = np.mean(np.vstack([np.sum(v, axis=1) for v in elbo_list]))
+          if np.isnan(np_elbo):
+            raise ValueError('ELBO is NaN. Please keep trying!')
+          grads_per_var = [np.stack(
+              [g_sample[var_idx] for g_sample in grad_list])
+                           for var_idx in range(
+                               len(tf.get_collection(
+                                   'non_reparam_variable_grads')))]
+          grads_per_timestep = [np.split(g, hps.n_timesteps, axis=2)
+                                for g in grads_per_var]
+          grads_per_timestep_per_dim = [[np.split(g, hps.z_dim, axis=3) for g in
+                                         g_list] for g_list
+                                        in grads_per_timestep]
+          grads_per_timestep_per_dim = [sum(g_list, []) for g_list in
+                                        grads_per_timestep_per_dim]
+          print 'variance of gradients for each variable: '
+          for var_idx, var in enumerate(
+              tf.get_collection('non_reparam_variable_grads')):
+            print 'variable: %s' % var.name
+            var = [np.var(g, axis=0) for g in
+                   grads_per_timestep_per_dim[var_idx]]
+            print 'variance is: ', np.stack(var).flatten()
+          print 'alpha ', alpha.eval(feed_dict).flatten()
+          print 'mean ', mean.eval(feed_dict).flatten()
+          print 'bernoulli p ', np.mean(
+              vi.model.p_x_zw_bernoulli_p.eval(feed_dict), axis=0).flatten()
+          t0 = time.time()
+          print 'iter %d\telbo: %.3e\texamples/s: %.3f' % (
+              i, np_elbo, examples_per_s)
+          for k in range(hps.samples_to_save):
+            im_name = 'i_%d_k_%d_' % (i, k)
+            prior_name = im_name + 'prior_predictive.jpg'
+            posterior_name = im_name + 'posterior_predictive.jpg'
+            imsave(os.path.join(train_dir, prior_name),
+                   np_prior_predictive[k, :, :, 0])
+            imsave(os.path.join(train_dir, posterior_name),
+                   np_posterior_predictive[k, :, :, 0])
+        else:
+          _ = sess.run(train_op, feed_dict)
+      return vi, sess
+
+
+def main(unused_argv):
+  """Trains a gaussian def locally."""
+  if tf.gfile.Exists(FLAGS.logdir) and FLAGS.delete_logdir:
+    tf.gfile.DeleteRecursively(FLAGS.logdir)
+  tf.gfile.MakeDirs(FLAGS.logdir)
+  # The HParams are commented in training_params.py
+  try:
+    hparams = tf.HParams
+  except AttributeError:
+    hparams = tf_lib.HParams
+  hparams = hparams(
+      dataset='alternating',
+      z_dim=1,
+      timestep_dim=1,
+      n_timesteps=2,
+      batch_size=1,
+      n_examples=1,
+      samples_to_save=1,
+      learning_rate=0.01,
+      momentum=0.0,
+      n_samples_latents=100,
+      p_z_shape=0.1,
+      p_z_mean=1.,
+      p_w_mean_sigma=5.,
+      p_w_shape_sigma=5.,
+      init_q_sigma_scale=0.1,
+      use_bias_observations=True,
+      init_q_z_scale=1.,
+      init_shape_q_z=util.softplus(0.1),
+      init_mean_q_z=util.softplus(0.01),
+      fixed_p_z_mean=False,
+      fixed_q_z_mean=False,
+      fixed_q_z_shape=False,
+      fixed_q_w_mean_sigma=False,
+      fixed_q_w_shape_sigma=False,
+      fixed_q_w_0_sigma=False,
+      dtype='float64',
+      control_variate='covariance')
+  run_training(hparams, FLAGS.logdir, FLAGS.max_steps, None, trainer='local')
+
+if __name__ == '__main__':
+  tf.app.run()

+ 70 - 0
experimental/rdefs/python/train_normal_normal_def.py

@@ -0,0 +1,70 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Wraps train_normal_normal_def."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import tensorflow as tf
+
+import train_normal_normal_def_lib
+from ops import tf_lib
+
+
+FLAGS = tf.flags.FLAGS
+
+
+def main(unused_argv):
+  """Trains a gaussian def locally."""
+  if tf.gfile.Exists(FLAGS.logdir) and FLAGS.delete_logdir:
+    tf.gfile.DeleteRecursively(FLAGS.logdir)
+    tf.gfile.MakeDirs(FLAGS.logdir)
+  # The HParams are commented in training_params.py
+  try:
+    hparams = tf.HParams
+  except AttributeError:
+    hparams = tf_lib.HParams
+  hparams = hparams(
+      dataset='MNIST',
+      z_dim=200,
+      timestep_dim=28,
+      n_timesteps=28,
+      batch_size=50,
+      samples_to_save=10,
+      learning_rate=0.1,
+      n_examples=50,
+      momentum=0.0,
+      p_z_sigma=1.,
+      p_w_mu_sigma=1.,
+      p_w_sigma_sigma=1.,
+      init_q_sigma_scale=0.1,
+      fixed_p_z_sigma=True,
+      fixed_q_z_sigma=False,
+      fixed_q_w_mu_sigma=False,
+      fixed_q_w_sigma_sigma=False,
+      fixed_q_w_0_sigma=False,
+      dtype='float32')
+  tf.logging.info('Starting experiment in %s with params %s', FLAGS.logdir,
+                  hparams)
+  train_normal_normal_def_lib.run_training(
+      hparams, FLAGS.logdir, FLAGS.max_steps, None,
+      trainer=FLAGS.trainer)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 258 - 0
experimental/rdefs/python/train_normal_normal_def_lib.py

@@ -0,0 +1,258 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trains a recurrent DEF with gaussian latent variables and gaussian weights.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import time
+
+
+
+import numpy as np
+from scipy.misc import imsave
+import tensorflow as tf
+
+from ops import inference
+from ops import model_factory
+from ops import util
+
+slim = tf.contrib.slim
+
+flags = tf.flags
+flags.DEFINE_string('master', 'local',
+                    'BNS name of the TensorFlow master to use.')
+flags.DEFINE_string('logdir', '/tmp/logs',
+                    'Directory where to write event logs.')
+flags.DEFINE_boolean('delete_logdir', True, 'Whether to clear the log dir.')
+flags.DEFINE_string('trials_root_dir',
+                    '/tmp/logs',
+                    'Directory where to write event logs.')
+flags.DEFINE_integer(
+    'save_summaries_secs', 10,
+    'The frequency with which summaries are saved, in seconds.')
+flags.DEFINE_integer('save_interval_secs', 10,
+                     'The frequency with which the model is saved, in seconds.')
+flags.DEFINE_integer('max_steps', 200000,
+                     'The maximum number of gradient steps.')
+flags.DEFINE_integer('print_stats_every', 100, 'print stats every')
+flags.DEFINE_integer(
+    'ps_tasks', 0,
+    'The number of parameter servers. If the value is 0, then the parameters '
+    'are handled locally by the worker.')
+flags.DEFINE_integer(
+    'task', 0,
+    'The Task ID. This value is used when training with multiple workers to '
+    'identify each worker.')
+flags.DEFINE_string('trainer', 'supervisor', 'slim/local/supervisor')
+flags.DEFINE_integer('samples_to_save', 1, 'number of samples to save')
+flags.DEFINE_boolean('check_nans', False, 'add ops to check for nans.')
+flags.DEFINE_string('data_path',
+                    '/readahead/256M/cns/in-d/home/jaana/binarized_mnist_new',
+                    'Where to read the data from.')
+
+FLAGS = flags.FLAGS
+
+
+def run_training(hparams, train_dir, max_steps, container='',
+                 trainer='supervisor', reporter_fn=None):
+  """Trains a Gaussian Recurrent DEF.
+
+  Args:
+    hparams: A tf.HParams object with hyperparameters for training.
+    train_dir: Where to store events files and checkpoints.
+    max_steps: Integer number of steps to train.
+    container: String specifying container for resource sharing.
+    trainer: Train locally by loading an hdf5 file or with Supervisor.
+    reporter_fn: Optional reporter function
+
+  Returns:
+    sess: Optionally, the session for training.
+    vi: Optionally, VariationalInference object that has been trained.
+  """
+  hps = hparams
+  tf.set_random_seed(4235)
+  np.random.seed(4234)
+  g = tf.Graph()
+  if FLAGS.ps_tasks > 0:
+    device_fn = tf.train.replica_device_setter(FLAGS.ps_tasks)
+  else:
+    device_fn = None
+  with g.as_default(), g.device(device_fn), tf.container(container):
+    if trainer == 'local':
+      x_indexes = tf.placeholder(tf.int32, [None])
+      x = tf.placeholder(tf.float32,
+                         [None, hps.n_timesteps, hps.timestep_dim, 1])
+      data_iterator = util.provide_hdf5_data(
+          FLAGS.data_path,
+          'train',
+          hps.n_examples,
+          hps.batch_size,
+          hps.n_timesteps,
+          hps.timestep_dim,
+          hps.dataset)
+    else:
+      x_indexes, x = util.provide_tfrecords_data(
+          FLAGS.data_path,
+          'train_labeled',
+          hps.batch_size,
+          hps.n_timesteps,
+          hps.timestep_dim)
+
+    data = {'x': x, 'x_indexes': x_indexes}
+
+    model = model_factory.NormalNormalRDEF(
+        n_timesteps=hps.n_timesteps,
+        batch_size=hps.batch_size,
+        p_z_sigma=hps.p_z_sigma,
+        p_w_mu_sigma=hps.p_w_mu_sigma,
+        fixed_p_z_sigma=hps.fixed_p_z_sigma,
+        p_w_sigma_sigma=hps.p_w_sigma_sigma,
+        z_dim=hps.z_dim,
+        dtype=hps.dtype)
+
+    variational = model_factory.NormalNormalRDEFVariational(
+        x_indexes=x_indexes,
+        n_examples=hps.n_examples,
+        n_timesteps=hps.n_timesteps,
+        z_dim=hps.z_dim,
+        timestep_dim=hps.timestep_dim,
+        init_sigma_q_z=hps.p_z_sigma * hps.init_q_sigma_scale,
+        init_sigma_q_w_mu=hps.p_w_mu_sigma * hps.init_q_sigma_scale,
+        init_sigma_q_w_sigma=hps.p_w_sigma_sigma * hps.init_q_sigma_scale,
+        init_sigma_q_w_0_sigma=hps.p_w_sigma_sigma * hps.init_q_sigma_scale,
+        fixed_p_z_sigma=hps.fixed_p_z_sigma,
+        fixed_q_z_sigma=hps.fixed_q_z_sigma,
+        fixed_q_w_mu_sigma=hps.fixed_q_w_mu_sigma,
+        fixed_q_w_sigma_sigma=hps.fixed_q_w_sigma_sigma,
+        fixed_q_w_0_sigma=hps.fixed_q_w_0_sigma,
+        dtype=hps.dtype)
+
+    vi = inference.VariationalInference(model, variational, data)
+
+    # Build graph for variational inference.
+    vi.build_graph()
+
+    # Build prior and posterior predictive samples
+    z_1_prior_sample = model.recurrent_layer_sample(
+        variational.sample['w_1_mu'], variational.sample['w_1_sigma'],
+        hps.batch_size)
+    prior_predictive = model.likelihood_sample(
+        variational.sample, z_1_prior_sample, hps.batch_size)
+    posterior_predictive = model.likelihood_sample(
+        variational.sample, variational.sample['z_1'], hps.batch_size)
+
+    # Build summaries.
+    tf.image_summary('prior_predictive', prior_predictive, max_images=10)
+    tf.image_summary('posterior_predictive', posterior_predictive,
+                     max_images=10)
+    tf.scalar_summary('ELBO', vi.scalar_elbo / hps.batch_size)
+    tf.scalar_summary('log_p', tf.reduce_mean(vi.log_p))
+    tf.scalar_summary('log_q', tf.reduce_mean(vi.log_q))
+
+    # Total loss is the negative ELBO (we maximize the evidence lower bound).
+    total_loss = -vi.elbo
+
+    if FLAGS.check_nans:
+      checks = tf.add_check_numerics_ops()
+      total_loss = tf.control_flow_ops.with_dependencies([checks], total_loss)
+
+    # Specify optimization scheme.
+    optimizer = tf.train.AdamOptimizer(learning_rate=hps.learning_rate)
+
+    # Run training.
+    if trainer == 'slim':
+      train_op = slim.learning.create_train_op(total_loss, optimizer)
+      slim.learning.train(
+          train_op=train_op,
+          logdir=train_dir,
+          master=FLAGS.master,
+          is_chief=FLAGS.task == 0,
+          number_of_steps=max_steps,
+          save_summaries_secs=FLAGS.save_summaries_secs,
+          save_interval_secs=FLAGS.save_interval_secs)
+    elif trainer == 'supervisor':
+      global_step = tf.contrib.framework.get_or_create_global_step()
+      train_op = optimizer.minimize(total_loss, global_step=global_step)
+      summary_op = tf.merge_all_summaries()
+      saver = tf.train.Saver()
+      sv = tf.train.Supervisor(
+          logdir=train_dir,
+          is_chief=(FLAGS.task == 0),
+          saver=saver,
+          summary_op=summary_op,
+          global_step=global_step,
+          save_summaries_secs=FLAGS.save_summaries_secs,
+          save_model_secs=FLAGS.save_summaries_secs,
+          recovery_wait_secs=5)
+      sess = sv.PrepareSession(FLAGS.master)
+      sv.StartQueueRunners(sess)
+      local_step = 0
+      while not sv.ShouldStop():
+        _, np_elbo, np_global_step = sess.run(
+            [train_op, vi.elbo, global_step])
+        np_elbo = np.mean(np_elbo)
+        if reporter_fn:
+          should_stop = reporter_fn(np_elbo, np_global_step)
+          if should_stop:
+            sv.RequestStop()
+        if np_global_step >= max_steps:
+          break
+        if local_step % FLAGS.print_stats_every == 0:
+          print 'step %d: %g' % (np_global_step - 1, np_elbo / hps.batch_size)
+        local_step += 1
+      sv.Stop()
+      sess.close()
+    elif trainer == 'local':
+      global_step = tf.contrib.framework.get_or_create_global_step()
+      train_op = tf.contrib.layers.optimize_loss(
+          total_loss,
+          global_step,
+          hps.learning_rate,
+          'Adam')
+      sess = tf.InteractiveSession()
+      sess.run(tf.initialize_all_variables())
+      t0 = time.time()
+      if tf.gfile.Exists(train_dir):
+        tf.gfile.DeleteRecursively(train_dir)
+        tf.gfile.MakeDirs(train_dir)
+      else:
+        tf.gfile.MakeDirs(train_dir)
+      for i in range(max_steps):
+        indexes, images = data_iterator.next()
+        feed_dict = {x_indexes: indexes, x: images}
+        if i % FLAGS.print_stats_every == 0:
+          np_elbo, _, np_prior_predictive, np_posterior_predictive = sess.run(
+              [vi.elbo, train_op, prior_predictive, posterior_predictive],
+              feed_dict)
+          examples_per_s = (hps.batch_size * FLAGS.print_stats_every /
+                            (time.time() - t0))
+          t0 = time.time()
+          print 'iter %d\telbo: %.3e\texamples/s: %.3f' % (
+              i, (np.mean(np_elbo) / hps.batch_size), examples_per_s)
+          for k in range(hps.samples_to_save):
+            im_name = 'i_%d_k_%d_' % (i, k)
+            prior_name = im_name + 'prior_predictive.jpg'
+            posterior_name = im_name + 'posterior_predictive.jpg'
+            imsave(os.path.join(train_dir, prior_name),
+                   np_prior_predictive[k, :, :, 0])
+            imsave(os.path.join(train_dir, posterior_name),
+                   np_posterior_predictive[k, :, :, 0])
+        else:
+          _ = sess.run(train_op, feed_dict)
+      return vi, sess

+ 81 - 0
experimental/rdefs/python/train_normal_normal_def_test.py

@@ -0,0 +1,81 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for train_normal_normal_def."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import numpy as np
+import tensorflow as tf
+
+import train_normal_normal_def
+
+FLAGS = tf.flags.FLAGS
+
+
+class TrainGaussianDefTest(tf.test.TestCase):
+
+  def testBernoulliDegenerateSolution(self):
+    """Test whether we recover bernoulli parameter 0 if we feed in zeros.
+
+    In this case, the model is a gaussian-bernoulli factor model, with one
+    time step (i.e. no recurrence)
+    """
+    tf.set_random_seed(1322423)
+    np.random.seed(1423234)
+    hparams = tf.HParams(
+        dataset='alternating',
+        z_dim=1,
+        timestep_dim=1,
+        n_timesteps=1,
+        batch_size=1,
+        samples_to_save=1,
+        learning_rate=0.05,
+        n_examples=1,
+        momentum=0.0,
+        p_z_sigma=1.,
+        p_w_mu_sigma=1.,
+        p_w_sigma_sigma=1.,
+        init_q_sigma_scale=0.1,
+        fixed_p_z_sigma=True,
+        fixed_q_z_sigma=False,
+        fixed_q_w_mu_sigma=False,
+        fixed_q_w_sigma_sigma=False,
+        fixed_q_w_0_sigma=False,
+        dtype='float32',
+        trainer='local')
+
+    tmp_dir = tf.test.get_temp_dir()
+
+    vi, sess = train_normal_normal_def.run_training(
+        hparams, tmp_dir, 6000, None, trainer='local')
+
+    zero = np.array(0.)
+    zero = np.reshape(zero, (1, 1, 1, 1))
+    p_list = []
+    for _ in range(100):
+      bernoulli_p = sess.run(vi.model.p_x_zw_bernoulli_p,
+                             {vi.data['x']: zero,
+                              vi.data['x_indexes']: np.reshape(zero, (1,))})
+      p_list.append(bernoulli_p)
+
+    mean_p = np.mean(p_list)
+    self.assertAllClose(mean_p, 0., rtol=1e-1, atol=1e-1)
+
+if __name__ == '__main__':
+  tf.test.main()

+ 94 - 0
experimental/rdefs/python/training_params.py

@@ -0,0 +1,94 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines hyperparameters for training the def model.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+import tensorflow as tf
+
+from ops import tf_lib
+
+
+flags = tf.flags
+
+# Dataset options
+flags.DEFINE_enum('dataset', 'MNIST', ['MNIST', 'alternating'],
+                  'Dataset to use. mnist or synthetic bernoulli data')
+flags.DEFINE_integer('n_timesteps', 28, 'Number of timesteps per example')
+flags.DEFINE_integer('timestep_dim', 28, 'Dimensionality of each timestep')
+flags.DEFINE_integer('n_examples', 50000, 'Number of examples to use from the '
+                     'dataset.')
+
+# Model options
+flags.DEFINE_integer('z_dim', 2, 'Latent dimensionality')
+flags.DEFINE_float('p_z_sigma', 1., 'Prior variance for latent variables')
+flags.DEFINE_float('p_w_mu_sigma', 1., 'Prior variance for weights for mean')
+flags.DEFINE_float('p_w_sigma_sigma', 1., 'Prior variance for weights for '
+                   'standard deviation')
+flags.DEFINE_boolean('fixed_p_z_sigma', True, 'Whether to have the variance '
+                     'depend recurrently across timesteps')
+
+# Variational family options
+flags.DEFINE_float('init_q_sigma_scale', 0.1, 'Factor by which to scale prior'
+                   ' variances to use as initialization for variational stddev')
+flags.DEFINE_boolean('fixed_q_z_sigma', False, 'Whether to learn variational '
+                     'variance parameters for latents')
+flags.DEFINE_boolean('fixed_q_w_mu_sigma', False, 'Whether to learn variational'
+                     'variance parameters for weights for mean')
+flags.DEFINE_boolean('fixed_q_w_sigma_sigma', False, 'Whether to learn '
+                     'variance parameters for weights for variance')
+flags.DEFINE_boolean('fixed_q_w_0_sigma', False, 'Whether to learn '
+                     'variance parameters for weights for observations')
+
+# Training options
+flags.DEFINE_enum('optimizer', 'Adam', ['Adam', 'RMSProp', 'SGD', 'Adagrad'],
+                  'Optimizer to use')
+flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate')
+flags.DEFINE_float('momentum', 0., 'Momentum for optimizer')
+flags.DEFINE_integer('batch_size', 10, 'Batch size')
+
+FLAGS = tf.flags.FLAGS
+
+
+def h_params():
+  """Returns hyperparameters defaulting to the corresponding flag values."""
+  try:
+    hparams = tf.HParams
+  except AttributeError:
+    hparams = tf_lib.HParams
+  return hparams(
+      dataset=FLAGS.dataset,
+      z_dim=FLAGS.z_dim,
+      timestep_dim=FLAGS.timestep_dim,
+      n_timesteps=FLAGS.n_timesteps,
+      batch_size=FLAGS.batch_size,
+      learning_rate=FLAGS.learning_rate,
+      n_examples=FLAGS.n_examples,
+      momentum=FLAGS.momentum,
+      p_z_sigma=FLAGS.p_z_sigma,
+      p_w_mu_sigma=FLAGS.p_w_mu_sigma,
+      p_w_sigma_sigma=FLAGS.p_w_sigma_sigma,
+      init_q_sigma_scale=FLAGS.init_q_sigma_scale,
+      fixed_p_z_sigma=FLAGS.fixed_p_z_sigma,
+      fixed_q_z_sigma=FLAGS.fixed_q_z_sigma,
+      fixed_q_w_mu_sigma=FLAGS.fixed_q_w_mu_sigma,
+      fixed_q_w_sigma_sigma=FLAGS.fixed_q_w_sigma_sigma,
+      fixed_q_w_0_sigma=FLAGS.fixed_q_w_0_sigma,
+      dtype='float32')