|
@@ -0,0 +1,654 @@
|
|
|
+# Copyright 2016 Google Inc. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+# ==============================================================================
|
|
|
+"""Classes for models and variational distributions for recurrent DEFs.
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import absolute_import
|
|
|
+from __future__ import division
|
|
|
+from __future__ import print_function
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+import numpy as np
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+from ops import util
|
|
|
+
|
|
|
+st = tf.contrib.bayesflow.stochastic_tensor
|
|
|
+
|
|
|
+distributions = tf.contrib.distributions
|
|
|
+
|
|
|
+
|
|
|
+class NormalNormalRDEF(object):
|
|
|
+ """Class for a recurrent DEF with normal latent variables and normal weights.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, n_timesteps, batch_size, p_w_mu_sigma, p_w_sigma_sigma,
|
|
|
+ p_z_sigma, fixed_p_z_sigma, z_dim, dtype):
|
|
|
+ """Initializes the NormalNormalRDEF class.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ n_timesteps: int. number of timesteps
|
|
|
+ batch_size: int. batch size
|
|
|
+ p_w_mu_sigma: float. prior for the weights for the mean of the latent
|
|
|
+ variables
|
|
|
+ p_w_sigma_sigma: float. prior for the weights for the variance of the
|
|
|
+ latent variables
|
|
|
+ p_z_sigma: floating point prior for the latent variables
|
|
|
+ fixed_p_z_sigma: bool. whether the prior variance is learned
|
|
|
+ z_dim: int. dimension of each latent variable
|
|
|
+ dtype: dtype
|
|
|
+ """
|
|
|
+
|
|
|
+ self.n_timesteps = n_timesteps
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.p_w_mu_sigma = p_w_mu_sigma
|
|
|
+ self.p_w_sigma_sigma = p_w_sigma_sigma
|
|
|
+ self.p_z_sigma = p_z_sigma
|
|
|
+ self.fixed_p_z_sigma = fixed_p_z_sigma
|
|
|
+ self.z_dim = z_dim
|
|
|
+ self.dtype = dtype
|
|
|
+
|
|
|
+ def log_prob(self, params, x):
|
|
|
+ """Returns the log joint. log p(x | z, w)p(z)p(w); [batch_size].
|
|
|
+
|
|
|
+ Args:
|
|
|
+ params: dict. dictionary of samples of the latent variables.
|
|
|
+ x: tensor. minibatch of examples
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The log joint of the NormalNormalRDEF probability model.
|
|
|
+ """
|
|
|
+ z_1 = params['z_1']
|
|
|
+ w_1_mu = params['w_1_mu']
|
|
|
+ w_1_sigma = params['w_1_sigma']
|
|
|
+ log_p_x_zw, p = util.build_bernoulli_log_likelihood(
|
|
|
+ params, x, self.batch_size)
|
|
|
+ self.p_x_zw_bernoulli_p = p
|
|
|
+
|
|
|
+ log_p_z, log_p_w_mu, log_p_w_sigma = self.build_recurrent_layer(
|
|
|
+ z_1, w_1_mu, w_1_sigma)
|
|
|
+ return log_p_x_zw + log_p_z + log_p_w_mu + log_p_w_sigma
|
|
|
+
|
|
|
+ def build_recurrent_layer(self, z, w_mu, w_sigma):
|
|
|
+ """Creates a gaussian layer of the recurrent DEF.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ z: sampled gaussian latent variables [batch_size, n_timesteps, z_dim]
|
|
|
+ w_mu: sampled gaussian stochastic weights [z_dim, z_dim]
|
|
|
+ w_sigma: sampled gaussian stochastic weights for stddev
|
|
|
+ [z_dim, z_dim]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ log_p_z: log prior of latent variables evaluated at the samples z.
|
|
|
+ log_p_w_mu: log density of the weights evaluated at the sampled weights w.
|
|
|
+ log_p_w_sigma: log density of weights for stddev.
|
|
|
+ """
|
|
|
+ # the prior for the weights p(w) has two parts: p(w_mu) and p(w_sigma)
|
|
|
+ # prior for the weights for the mean parameter
|
|
|
+ p_w_mu = distributions.Normal(
|
|
|
+ mu=0., sigma=self.p_w_mu_sigma, validate_args=False)
|
|
|
+ log_p_w_mu = tf.reduce_sum(p_w_mu.log_pdf(w_mu))
|
|
|
+
|
|
|
+ if self.fixed_p_z_sigma:
|
|
|
+ log_p_w_sigma = 0.0
|
|
|
+ else:
|
|
|
+ # prior for the weights for the standard deviation
|
|
|
+ p_w_sigma = distributions.Normal(mu=0., sigma=self.p_w_sigma_sigma,
|
|
|
+ validate_args=False)
|
|
|
+ log_p_w_sigma = tf.reduce_sum(p_w_sigma.log_pdf(w_sigma))
|
|
|
+
|
|
|
+ # need this for indexing npy-style
|
|
|
+ z = z.value()
|
|
|
+
|
|
|
+ # the prior for the latent variable at the first timestep is just 0, 1
|
|
|
+ z_t0 = z[:, 0, :]
|
|
|
+ p_z_t0 = distributions.Normal(
|
|
|
+ mu=0., sigma=self.p_z_sigma, validate_args=False)
|
|
|
+ log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 1)
|
|
|
+
|
|
|
+ # the prior for subsequent timesteps is off by one
|
|
|
+ mu = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
|
|
|
+ tf.pack([w_mu] * self.batch_size))
|
|
|
+ if self.fixed_p_z_sigma:
|
|
|
+ sigma = self.p_z_sigma
|
|
|
+ else:
|
|
|
+ wz = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
|
|
|
+ tf.pack([w_sigma] * self.batch_size))
|
|
|
+ sigma = tf.maximum(tf.nn.softplus(wz), 1e-5)
|
|
|
+ p_z_t1_to_end = distributions.Normal(mu=mu, sigma=sigma,
|
|
|
+ validate_args=False)
|
|
|
+ log_p_z_t1_to_end = tf.reduce_sum(
|
|
|
+ p_z_t1_to_end.log_pdf(z[:, 1:, :]), [1, 2])
|
|
|
+ log_p_z = log_p_z_t0 + log_p_z_t1_to_end
|
|
|
+ return log_p_z, log_p_w_mu, log_p_w_sigma
|
|
|
+
|
|
|
+ def recurrent_layer_sample(self, w_mu, w_sigma, n_samples_latents):
|
|
|
+ """Sample from the model, with learned latent weights.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ w_mu: latent weights for the mean parameter. [z_dim, z_dim]
|
|
|
+ w_sigma: latent weights for the standard deviation. [z_dim, z_dim]
|
|
|
+ n_samples_latents: how many samples of latent variables
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ z: samples from the generative process.
|
|
|
+ """
|
|
|
+ p_z_t0 = distributions.Normal(
|
|
|
+ mu=0., sigma=self.p_z_sigma, validate_args=False)
|
|
|
+ z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
|
|
|
+ z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
|
|
|
+ def sample_timestep(z_t_prev, w_mu, w_sigma):
|
|
|
+ mu_t = tf.matmul(z_t_prev, w_mu)
|
|
|
+ if self.fixed_p_z_sigma:
|
|
|
+ sigma_t = self.p_z_sigma
|
|
|
+ else:
|
|
|
+ wz_t = tf.matmul(z_t_prev, w_sigma)
|
|
|
+ sigma_t = tf.maximum(tf.nn.softplus(wz_t), 1e-5)
|
|
|
+ p_z_t = distributions.Normal(mu=mu_t, sigma=sigma_t, validate_args=False)
|
|
|
+ if self.z_dim == 1:
|
|
|
+ return p_z_t.sample_n(n=1)[0, :, :]
|
|
|
+ else:
|
|
|
+ return tf.squeeze(p_z_t.sample_n(n=1))
|
|
|
+ z_list = [z_t0]
|
|
|
+ for _ in range(self.n_timesteps - 1):
|
|
|
+ z_t = sample_timestep(z_list[-1], w_mu, w_sigma)
|
|
|
+ z_list.append(z_t)
|
|
|
+ z = tf.pack(z_list) # [n_timesteps, n_samples_latents, z_dim]
|
|
|
+ z = tf.transpose(z, perm=[1, 0, 2]) # [n_samples, n_timesteps, z_dim]
|
|
|
+ return z
|
|
|
+
|
|
|
+ def likelihood_sample(self, params, z_1, n_samples):
|
|
|
+ return util.bernoulli_likelihood_sample(params, z_1, n_samples)
|
|
|
+
|
|
|
+
|
|
|
+class NormalNormalRDEFVariational(object):
|
|
|
+ """Creates the variational family for the recurrent DEF model.
|
|
|
+
|
|
|
+ Variational family:
|
|
|
+ q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
|
|
|
+ [n_examples, n_timesteps, z_dim]
|
|
|
+ q_w_1_mu: gaussian approximate posterior q(w_1) for mean weights of first
|
|
|
+ (recurrent) layer [z_dim, z_dim]
|
|
|
+ q_w_1_sigma: gaussian approximate posterior q(w_1) for std weights, first
|
|
|
+ (recurrent) layer [z_dim, z_dim]
|
|
|
+ q_w_0: gaussian approximate posterior q(w_0) for weights of observation
|
|
|
+ layer [z_dim, timestep_dim]
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
|
|
|
+ timestep_dim, init_sigma_q_w_mu, init_sigma_q_z,
|
|
|
+ init_sigma_q_w_sigma, fixed_p_z_sigma, fixed_q_z_sigma,
|
|
|
+ fixed_q_w_mu_sigma, fixed_q_w_sigma_sigma,
|
|
|
+ fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, dtype):
|
|
|
+ """Initializes the variational family for the NormalNormalRDEF.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x_indexes: tensor. indices of the datapoints.
|
|
|
+ n_examples: int. number of examples in the dataset.
|
|
|
+ n_timesteps: int. number of timesteps in each datapoint.
|
|
|
+ z_dim: int. dimension of latent variables.
|
|
|
+ timestep_dim: int. dimension of each timestep.
|
|
|
+ init_sigma_q_w_mu: float. initial variance for weights for the means of
|
|
|
+ the latent variables.
|
|
|
+ init_sigma_q_z: float. initial variance for the variational distribution
|
|
|
+ for the latent variables.
|
|
|
+ init_sigma_q_w_sigma: float. initial variance for the weights for the
|
|
|
+ variance of the latent variables.
|
|
|
+ fixed_p_z_sigma: bool. whether to keep the prior over latents fixed.
|
|
|
+ fixed_q_z_sigma: bool. whether to train the variance of the variational
|
|
|
+ distributions for the latents.
|
|
|
+ fixed_q_w_mu_sigma: bool. whether to train the variance of the weights for
|
|
|
+ the latent variables.
|
|
|
+ fixed_q_w_sigma_sigma: bool. whether to train the variance of the weights
|
|
|
+ for the variance of the latent variables.
|
|
|
+ fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
|
|
|
+ the observations.
|
|
|
+ init_sigma_q_w_0_sigma: float. initial variance for the observation
|
|
|
+ weights.
|
|
|
+ dtype: dtype
|
|
|
+ """
|
|
|
+ self.x_indexes = x_indexes
|
|
|
+ self.n_examples = n_examples
|
|
|
+ self.n_timesteps = n_timesteps
|
|
|
+ self.z_dim = z_dim
|
|
|
+ self.timestep_dim = timestep_dim
|
|
|
+ self.init_sigma_q_z = init_sigma_q_z
|
|
|
+ self.init_sigma_q_w_mu = init_sigma_q_w_mu
|
|
|
+ self.init_sigma_q_w_sigma = init_sigma_q_w_sigma
|
|
|
+ self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
|
|
|
+ self.fixed_p_z_sigma = fixed_p_z_sigma
|
|
|
+ self.fixed_q_z_sigma = fixed_q_z_sigma
|
|
|
+ self.fixed_q_w_mu_sigma = fixed_q_w_mu_sigma
|
|
|
+ self.fixed_q_w_sigma_sigma = fixed_q_w_sigma_sigma
|
|
|
+ self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
|
|
|
+ self.dtype = dtype
|
|
|
+ self.build_graph()
|
|
|
+
|
|
|
+ @property
|
|
|
+ def sample(self):
|
|
|
+ """Returns a dict of samples of the latent variables."""
|
|
|
+ return self.params
|
|
|
+
|
|
|
+ def build_graph(self):
|
|
|
+ """Builds the graph for the variational family for the NormalNormalRDEF."""
|
|
|
+ with tf.variable_scope('q_z_1'):
|
|
|
+ z_1 = util.build_gaussian(
|
|
|
+ [self.n_examples, self.n_timesteps, self.z_dim],
|
|
|
+ init_mu=0., init_sigma=self.init_sigma_q_z, x_indexes=self.x_indexes,
|
|
|
+ fixed_sigma=self.fixed_q_z_sigma, place_on_cpu=True, dtype=self.dtype)
|
|
|
+
|
|
|
+ with tf.variable_scope('q_w_1_mu'):
|
|
|
+ # half of the weights are for the mean, half for the variance
|
|
|
+ w_1_mu = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
|
|
|
+ init_sigma=self.init_sigma_q_w_mu,
|
|
|
+ fixed_sigma=self.fixed_q_w_mu_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ if self.fixed_p_z_sigma:
|
|
|
+ w_1_sigma = None
|
|
|
+ else:
|
|
|
+ with tf.variable_scope('q_w_1_sigma'):
|
|
|
+ w_1_sigma = util.build_gaussian(
|
|
|
+ [self.z_dim, self.z_dim],
|
|
|
+ init_mu=0., init_sigma=self.init_sigma_q_w_sigma,
|
|
|
+ fixed_sigma=self.fixed_q_w_sigma_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ with tf.variable_scope('q_w_0'):
|
|
|
+ w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
|
|
|
+ init_sigma=self.init_sigma_q_w_0_sigma,
|
|
|
+ fixed_sigma=self.fixed_q_w_0_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ self.params = {'w_0': w_0, 'w_1_mu': w_1_mu, 'w_1_sigma': w_1_sigma,
|
|
|
+ 'z_1': z_1}
|
|
|
+
|
|
|
+ def log_prob(self, q_samples):
|
|
|
+ """Get the log joint of variational family: log(q(z, w_mu, w_sigma, w_0)).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ q_samples: dict. samples of latent variables
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ log_prob: tensor log-probability summed over dimensions of the variables
|
|
|
+ """
|
|
|
+ w_0 = q_samples['w_0']
|
|
|
+ z_1 = q_samples['z_1']
|
|
|
+ w_1_mu = q_samples['w_1_mu']
|
|
|
+ w_1_sigma = q_samples['w_1_sigma']
|
|
|
+
|
|
|
+ log_prob = 0.
|
|
|
+ # preserve the minibatch dimension [0]
|
|
|
+ log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1), [1, 2])
|
|
|
+ # w_1, w_0 are global, so reduce_sum across all dims
|
|
|
+ log_prob += tf.reduce_sum(w_1_mu.distribution.log_pdf(w_1_mu))
|
|
|
+ log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0))
|
|
|
+ if not self.fixed_p_z_sigma:
|
|
|
+ log_prob += tf.reduce_sum(w_1_sigma.distribution.log_pdf(w_1_sigma))
|
|
|
+ return log_prob
|
|
|
+
|
|
|
+
|
|
|
+class GammaNormalRDEF(object):
|
|
|
+ """Class for a recurrent DEF with normal latent variables and normal weights.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, n_timesteps, batch_size, p_w_shape_sigma, p_w_mean_sigma,
|
|
|
+ p_z_shape, p_z_mean, fixed_p_z_mean, z_dim, n_samples_latents,
|
|
|
+ use_bias_observations, dtype):
|
|
|
+ """Initializes the NormalNormalRDEF class.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ n_timesteps: int. number of timesteps
|
|
|
+ batch_size: int. batch size
|
|
|
+ p_w_shape_sigma: float. prior for the weights for the mean of the latent
|
|
|
+ variables
|
|
|
+ p_w_mean_sigma: float. prior for the weights for the shape of the
|
|
|
+ latent variables
|
|
|
+ p_z_shape: float. prior for shape.
|
|
|
+ p_z_mean: floating point prior for the latent variables
|
|
|
+ fixed_p_z_mean: bool. whether the prior mean is learned
|
|
|
+ z_dim: int. dimension of each latent variable
|
|
|
+ n_samples_latents: number of samples of latent variables
|
|
|
+ use_bias_observations: whether to use bias terms
|
|
|
+ dtype: dtype
|
|
|
+ """
|
|
|
+
|
|
|
+ self.n_timesteps = n_timesteps
|
|
|
+ self.batch_size = batch_size
|
|
|
+ self.p_w_shape_sigma = p_w_shape_sigma
|
|
|
+ self.p_w_mean_sigma = p_w_mean_sigma
|
|
|
+ self.p_z_shape = p_z_shape
|
|
|
+ self.p_z_mean = p_z_mean
|
|
|
+ self.fixed_p_z_mean = fixed_p_z_mean
|
|
|
+ self.z_dim = z_dim
|
|
|
+ self.n_samples_latents = n_samples_latents
|
|
|
+ self.use_bias_observations = use_bias_observations
|
|
|
+ self.use_bias_latents = False
|
|
|
+ self.dtype = dtype
|
|
|
+
|
|
|
+ def log_prob(self, params, x):
|
|
|
+ """Returns the log joint. log p(x | z, w)p(z)log p(w); [batch_size].
|
|
|
+
|
|
|
+ Args:
|
|
|
+ params: dict. dictionary of samples of the latent variables.
|
|
|
+ x: tensor. minibatch of examples
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ The log joint of the GammaNormalRDEF probability model.
|
|
|
+ """
|
|
|
+ z_1 = params['z_1']
|
|
|
+ w_1_mean = params['w_1_mean']
|
|
|
+ w_1_shape = params['w_1_shape']
|
|
|
+ log_p_x_zw, p = util.build_bernoulli_log_likelihood(
|
|
|
+ params, x, self.batch_size, n_samples_latents=self.n_samples_latents,
|
|
|
+ use_bias_observations=self.use_bias_observations)
|
|
|
+ self.p_x_zw_bernoulli_p = p
|
|
|
+
|
|
|
+ log_p_z, log_p_w_shape, log_p_w_mean = self.build_recurrent_layer(
|
|
|
+ z_1, w_1_shape, w_1_mean)
|
|
|
+ return log_p_x_zw + log_p_z + log_p_w_shape + log_p_w_mean
|
|
|
+
|
|
|
+ def build_recurrent_layer(self, z, w_shape, w_mean):
|
|
|
+ """Creates a gaussian layer of the recurrent DEF.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ z: sampled gamma latent variables,
|
|
|
+ shape [n_samples_latents, batch_size, n_timesteps, z_dim]
|
|
|
+ w_shape: single sample of gaussian stochastic weights for shape,
|
|
|
+ shape [z_dim, z_dim]
|
|
|
+ w_mean: single sample of gaussian stochastic weights for mean,
|
|
|
+ shape [z_dim, z_dim]
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ log_p_z: log prior of latent variables evaluated at the samples z.
|
|
|
+ log_p_w_shape: log density of the weights evaluated at the sampled weights
|
|
|
+ log_p_w_mean: log density of weights for stddev.
|
|
|
+ """
|
|
|
+ # the prior for the weights p(w) has two parts: p(w_shape) and p(w_mean)
|
|
|
+ # prior for the weights for the mean parameter
|
|
|
+ cast = lambda x: np.array(x, self.dtype)
|
|
|
+ p_w_shape = distributions.Normal(mu=cast(0.),
|
|
|
+ sigma=cast(self.p_w_shape_sigma),
|
|
|
+ validate_args=False)
|
|
|
+ log_p_w_shape = tf.reduce_sum(p_w_shape.log_pdf(w_shape))
|
|
|
+
|
|
|
+ if self.fixed_p_z_mean:
|
|
|
+ log_p_w_mean = 0.0
|
|
|
+ else:
|
|
|
+ # prior for the weights for the standard deviation
|
|
|
+ p_w_mean = distributions.Normal(mu=cast(0.),
|
|
|
+ sigma=cast(self.p_w_mean_sigma),
|
|
|
+ validate_args=False)
|
|
|
+ log_p_w_mean = tf.reduce_sum(p_w_mean.log_pdf(w_mean))
|
|
|
+
|
|
|
+ # need this for indexing npy-style
|
|
|
+ z = z.value()
|
|
|
+
|
|
|
+ # the prior for the latent variable at the first timestep is just 0, 1
|
|
|
+ z_t0 = z[:, :, 0, :]
|
|
|
+ # alpha is shape, beta is inverse scale. we set the scale to be the mean
|
|
|
+ # over the shape, so beta = shape / mean.
|
|
|
+ p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
|
|
|
+ beta=cast(self.p_z_shape / self.p_z_mean),
|
|
|
+ validate_args=False)
|
|
|
+ log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 2)
|
|
|
+
|
|
|
+ # the prior for subsequent timesteps is off by one
|
|
|
+ shape = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
|
|
|
+ tf.pack([tf.pack([w_shape] * self.batch_size)]
|
|
|
+ * self.n_samples_latents))
|
|
|
+ shape = util.clip_shape(shape)
|
|
|
+
|
|
|
+ if self.fixed_p_z_mean:
|
|
|
+ mean = self.p_z_mean
|
|
|
+ else:
|
|
|
+ wz = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
|
|
|
+ tf.pack([tf.pack([w_mean] * self.batch_size)]
|
|
|
+ * self.n_samples_latents))
|
|
|
+ mean = tf.nn.softplus(wz)
|
|
|
+ mean = util.clip_mean(mean)
|
|
|
+ p_z_t1_to_end = distributions.Gamma(alpha=shape,
|
|
|
+ beta=shape / mean,
|
|
|
+ validate_args=False)
|
|
|
+ log_p_z_t1_to_end = tf.reduce_sum(
|
|
|
+ p_z_t1_to_end.log_pdf(z[:, :, 1:, :]), [2, 3])
|
|
|
+ log_p_z = log_p_z_t0 + log_p_z_t1_to_end
|
|
|
+ return log_p_z, log_p_w_shape, log_p_w_mean
|
|
|
+
|
|
|
+ def recurrent_layer_sample(self, w_shape, w_mean, n_samples_latents,
|
|
|
+ b_shape=None, b_mean=None):
|
|
|
+ """Sample from the model, with learned latent weights.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ w_shape: latent weights for the mean parameter. [z_dim, z_dim]
|
|
|
+ w_mean: latent weights for the standard deviation. [z_dim, z_dim]
|
|
|
+ n_samples_latents: how many samples
|
|
|
+ b_shape: bias for shape parameters
|
|
|
+ b_mean: bias for mean parameters
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ z: samples from the generative process.
|
|
|
+ """
|
|
|
+ cast = lambda x: np.array(x, self.dtype)
|
|
|
+ p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
|
|
|
+ beta=cast(self.p_z_shape / self.p_z_mean),
|
|
|
+ validate_args=False)
|
|
|
+ z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
|
|
|
+
|
|
|
+ z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
|
|
|
+
|
|
|
+ def sample_timestep(z_t_prev, w_shape, w_mean, b_shape=b_shape,
|
|
|
+ b_mean=b_mean):
|
|
|
+ """Sample a single timestep.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ z_t_prev: previous timestep latent variable,
|
|
|
+ shape [n_samples_latents, z_dim]
|
|
|
+ w_shape: latent weights for shape param, shape [z_dim, z_dim]
|
|
|
+ w_mean: latent weights for mean param, shape [z_dim, z_dim]
|
|
|
+ b_shape: bias for shape parameters
|
|
|
+ b_mean: bias for mean parameters
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ z_t: A sample of a latent variable for all timesteps
|
|
|
+ """
|
|
|
+ wz_t = tf.matmul(z_t_prev, w_shape)
|
|
|
+ if self.use_bias_latents:
|
|
|
+ wz_t += b_shape
|
|
|
+ shape_t = tf.nn.softplus(wz_t)
|
|
|
+ shape_t = util.clip_shape(shape_t)
|
|
|
+ if self.fixed_p_z_mean:
|
|
|
+ mean_t = self.p_z_mean
|
|
|
+ else:
|
|
|
+ wz_t = tf.matmul(z_t_prev, w_mean)
|
|
|
+ if self.use_bias_latents:
|
|
|
+ wz_t += b_mean
|
|
|
+ mean_t = tf.nn.softplus(wz_t)
|
|
|
+ mean_t = util.clip_mean(mean_t)
|
|
|
+ p_z_t = distributions.Gamma(alpha=shape_t,
|
|
|
+ beta=shape_t / mean_t,
|
|
|
+ validate_args=False)
|
|
|
+ z_t = p_z_t.sample_n(n=1)[0, :, :]
|
|
|
+ return z_t
|
|
|
+ z_list = [z_t0]
|
|
|
+ for _ in range(self.n_timesteps - 1):
|
|
|
+ z_t = sample_timestep(z_list[-1], w_shape, w_mean)
|
|
|
+ z_list.append(z_t)
|
|
|
+ # pack into shape [n_timesteps, n_samples_latents, z_dim]
|
|
|
+ z = tf.pack(z_list)
|
|
|
+ # transpose into [n_samples_latents, n_timesteps, z_dim]
|
|
|
+ z = tf.transpose(z, perm=[1, 0, 2])
|
|
|
+ return z
|
|
|
+
|
|
|
+ def likelihood_sample(self, params, z_1, n_samples):
|
|
|
+ return util.bernoulli_likelihood_sample(
|
|
|
+ params, z_1, n_samples,
|
|
|
+ use_bias_observations=self.use_bias_observations)
|
|
|
+
|
|
|
+
|
|
|
+class GammaNormalRDEFVariational(object):
|
|
|
+ """Creates the variational family for the recurrent DEF model.
|
|
|
+
|
|
|
+ Variational family:
|
|
|
+ q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
|
|
|
+ [n_examples, n_timesteps, z_dim]
|
|
|
+ q_w_1_shape: gaussian approximate posterior q(w_1) for mean weights of
|
|
|
+ (recurrent) layer [z_dim, z_dim]
|
|
|
+ q_w_1_mean: gaussian approximate posterior q(w_1) for std weights, first
|
|
|
+ (recurrent) layer [z_dim, z_dim]
|
|
|
+ q_w_0: gaussian approximate posterior q(w_0) for weights of observation
|
|
|
+ layer [z_dim, timestep_dim]
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
|
|
|
+ timestep_dim, init_sigma_q_w_shape, init_shape_q_z,
|
|
|
+ init_mean_q_z,
|
|
|
+ init_sigma_q_w_mean, fixed_p_z_mean, fixed_q_z_mean,
|
|
|
+ fixed_q_w_shape_sigma, fixed_q_w_mean_sigma,
|
|
|
+ fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, n_samples_latents,
|
|
|
+ use_bias_observations,
|
|
|
+ dtype):
|
|
|
+ """Initializes the variational family for the NormalNormalRDEF.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ x_indexes: tensor. indices of the datapoints.
|
|
|
+ n_examples: int. number of examples in the dataset.
|
|
|
+ n_timesteps: int. number of timesteps in each datapoint.
|
|
|
+ z_dim: int. dimension of latent variables.
|
|
|
+ timestep_dim: int. dimension of each timestep.
|
|
|
+ init_sigma_q_w_shape: float. initial variance for weights for the means of
|
|
|
+ the latent variables.
|
|
|
+ init_shape_q_z: float. initial variance for the variational distribution
|
|
|
+ for the latent variables.
|
|
|
+ init_mean_q_z: float. initial mean for latent variables variational.
|
|
|
+ init_sigma_q_w_mean: float. initial variance for the weights for the
|
|
|
+ variance of the latent variables.
|
|
|
+ fixed_p_z_mean: bool. whether to keep the prior over latents fixed.
|
|
|
+ fixed_q_z_mean: bool. whether to train the variance of the variational
|
|
|
+ distributions for the latents.
|
|
|
+ fixed_q_w_shape_sigma: bool. whether to train the variance of the weights
|
|
|
+ the latent variables.
|
|
|
+ fixed_q_w_mean_sigma: bool. whether to train the variance of the weights
|
|
|
+ for the variance of the latent variables.
|
|
|
+ fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
|
|
|
+ the observations.
|
|
|
+ init_sigma_q_w_0_sigma: float. initial variance for the observation
|
|
|
+ weights.
|
|
|
+ n_samples_latents: number of samples of latent variables to draw
|
|
|
+ use_bias_observations: whether to use bias terms
|
|
|
+ dtype: dtype
|
|
|
+ """
|
|
|
+ self.x_indexes = x_indexes
|
|
|
+ self.n_examples = n_examples
|
|
|
+ self.n_timesteps = n_timesteps
|
|
|
+ self.z_dim = z_dim
|
|
|
+ self.timestep_dim = timestep_dim
|
|
|
+ self.init_mean_q_z = init_mean_q_z
|
|
|
+ self.init_shape_q_z = init_shape_q_z
|
|
|
+ self.init_sigma_q_w_shape = init_sigma_q_w_shape
|
|
|
+ self.init_sigma_q_w_mean = init_sigma_q_w_mean
|
|
|
+ self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
|
|
|
+ self.fixed_p_z_mean = fixed_p_z_mean
|
|
|
+ self.fixed_q_z_mean = fixed_q_z_mean
|
|
|
+ self.fixed_q_w_shape_sigma = fixed_q_w_shape_sigma
|
|
|
+ self.fixed_q_w_mean_sigma = fixed_q_w_mean_sigma
|
|
|
+ self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
|
|
|
+ self.n_samples_latents = n_samples_latents
|
|
|
+ self.use_bias_observations = use_bias_observations
|
|
|
+ self.dtype = dtype
|
|
|
+ with tf.variable_scope('variational'):
|
|
|
+ self.build_graph()
|
|
|
+
|
|
|
+ @property
|
|
|
+ def sample(self):
|
|
|
+ """Returns a dict of samples of the latent variables."""
|
|
|
+ return self.params
|
|
|
+
|
|
|
+ @property
|
|
|
+ def trainable_variables(self):
|
|
|
+ return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'variational')
|
|
|
+
|
|
|
+ def build_graph(self):
|
|
|
+ """Builds the graph for the variational family for the NormalNormalRDEF."""
|
|
|
+ with tf.variable_scope('q_z_1'):
|
|
|
+ z_1 = util.build_gamma(
|
|
|
+ [self.n_examples, self.n_timesteps, self.z_dim],
|
|
|
+ init_shape=self.init_shape_q_z,
|
|
|
+ init_mean=self.init_mean_q_z,
|
|
|
+ x_indexes=self.x_indexes,
|
|
|
+ fixed_mean=self.fixed_q_z_mean,
|
|
|
+ place_on_cpu=False,
|
|
|
+ n_samples=self.n_samples_latents,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ with tf.variable_scope('q_w_1_shape'):
|
|
|
+ # half of the weights are for the mean, half for the variance
|
|
|
+ w_1_shape = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
|
|
|
+ init_sigma=self.init_sigma_q_w_shape,
|
|
|
+ fixed_sigma=self.fixed_q_w_shape_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ if self.fixed_p_z_mean:
|
|
|
+ w_1_mean = None
|
|
|
+ else:
|
|
|
+ with tf.variable_scope('q_w_1_mean'):
|
|
|
+ w_1_mean = util.build_gaussian(
|
|
|
+ [self.z_dim, self.z_dim],
|
|
|
+ init_mu=0., init_sigma=self.init_sigma_q_w_mean,
|
|
|
+ fixed_sigma=self.fixed_q_w_mean_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ with tf.variable_scope('q_w_0'):
|
|
|
+ w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
|
|
|
+ init_sigma=self.init_sigma_q_w_0_sigma,
|
|
|
+ fixed_sigma=self.fixed_q_w_0_sigma,
|
|
|
+ dtype=self.dtype)
|
|
|
+
|
|
|
+ self.params = {'w_0': w_0, 'w_1_shape': w_1_shape, 'w_1_mean': w_1_mean,
|
|
|
+ 'z_1': z_1}
|
|
|
+
|
|
|
+ if self.use_bias_observations:
|
|
|
+ # b_0 = tf.get_variable(
|
|
|
+ # 'b_0', [self.timestep_dim], self.dtype, tf.zeros_initializer,
|
|
|
+ # collections=[tf.GraphKeys.VARIABLES, 'reparam_variables'])
|
|
|
+ b_0 = util.build_gaussian([self.timestep_dim], init_mu=0.,
|
|
|
+ init_sigma=0.01, fixed_sigma=False,
|
|
|
+ dtype=self.dtype)
|
|
|
+ self.params.update({'b_0': b_0})
|
|
|
+
|
|
|
+ def log_prob(self, q_samples):
|
|
|
+ """Get the log joint of variational family: log(q(z, w_shape, w_mean, w_0)).
|
|
|
+
|
|
|
+ Args:
|
|
|
+ q_samples: dict. samples of latent variables.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ log_prob: tensor log-probability summed over dimensions of the variables
|
|
|
+ """
|
|
|
+ w_0 = q_samples['w_0']
|
|
|
+ z_1 = q_samples['z_1']
|
|
|
+ w_1_shape = q_samples['w_1_shape']
|
|
|
+ w_1_mean = q_samples['w_1_mean']
|
|
|
+
|
|
|
+ log_prob = 0.
|
|
|
+ # preserve the sample and minibatch dimensions [0, 1]
|
|
|
+ log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1.value()), [2, 3])
|
|
|
+ # w_1, w_0 are global, so reduce_sum across all dims
|
|
|
+ log_prob += tf.reduce_sum(w_1_shape.distribution.log_pdf(w_1_shape.value()))
|
|
|
+ log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0.value()))
|
|
|
+ if not self.fixed_p_z_mean:
|
|
|
+ log_prob += tf.reduce_sum(w_1_mean.distribution.log_pdf(w_1_mean.value()))
|
|
|
+ return log_prob
|