123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Classes for models and variational distributions for recurrent DEFs.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- import tensorflow as tf
- from ops import util
- st = tf.contrib.bayesflow.stochastic_tensor
- distributions = tf.contrib.distributions
- class NormalNormalRDEF(object):
- """Class for a recurrent DEF with normal latent variables and normal weights.
- """
- def __init__(self, n_timesteps, batch_size, p_w_mu_sigma, p_w_sigma_sigma,
- p_z_sigma, fixed_p_z_sigma, z_dim, dtype):
- """Initializes the NormalNormalRDEF class.
- Args:
- n_timesteps: int. number of timesteps
- batch_size: int. batch size
- p_w_mu_sigma: float. prior for the weights for the mean of the latent
- variables
- p_w_sigma_sigma: float. prior for the weights for the variance of the
- latent variables
- p_z_sigma: floating point prior for the latent variables
- fixed_p_z_sigma: bool. whether the prior variance is learned
- z_dim: int. dimension of each latent variable
- dtype: dtype
- """
- self.n_timesteps = n_timesteps
- self.batch_size = batch_size
- self.p_w_mu_sigma = p_w_mu_sigma
- self.p_w_sigma_sigma = p_w_sigma_sigma
- self.p_z_sigma = p_z_sigma
- self.fixed_p_z_sigma = fixed_p_z_sigma
- self.z_dim = z_dim
- self.dtype = dtype
- def log_prob(self, params, x):
- """Returns the log joint. log p(x | z, w)p(z)p(w); [batch_size].
- Args:
- params: dict. dictionary of samples of the latent variables.
- x: tensor. minibatch of examples
- Returns:
- The log joint of the NormalNormalRDEF probability model.
- """
- z_1 = params['z_1']
- w_1_mu = params['w_1_mu']
- w_1_sigma = params['w_1_sigma']
- log_p_x_zw, p = util.build_bernoulli_log_likelihood(
- params, x, self.batch_size)
- self.p_x_zw_bernoulli_p = p
- log_p_z, log_p_w_mu, log_p_w_sigma = self.build_recurrent_layer(
- z_1, w_1_mu, w_1_sigma)
- return log_p_x_zw + log_p_z + log_p_w_mu + log_p_w_sigma
- def build_recurrent_layer(self, z, w_mu, w_sigma):
- """Creates a gaussian layer of the recurrent DEF.
- Args:
- z: sampled gaussian latent variables [batch_size, n_timesteps, z_dim]
- w_mu: sampled gaussian stochastic weights [z_dim, z_dim]
- w_sigma: sampled gaussian stochastic weights for stddev
- [z_dim, z_dim]
- Returns:
- log_p_z: log prior of latent variables evaluated at the samples z.
- log_p_w_mu: log density of the weights evaluated at the sampled weights w.
- log_p_w_sigma: log density of weights for stddev.
- """
- # the prior for the weights p(w) has two parts: p(w_mu) and p(w_sigma)
- # prior for the weights for the mean parameter
- p_w_mu = distributions.Normal(
- mu=0., sigma=self.p_w_mu_sigma, validate_args=False)
- log_p_w_mu = tf.reduce_sum(p_w_mu.log_pdf(w_mu))
- if self.fixed_p_z_sigma:
- log_p_w_sigma = 0.0
- else:
- # prior for the weights for the standard deviation
- p_w_sigma = distributions.Normal(mu=0., sigma=self.p_w_sigma_sigma,
- validate_args=False)
- log_p_w_sigma = tf.reduce_sum(p_w_sigma.log_pdf(w_sigma))
- # need this for indexing npy-style
- z = z.value()
- # the prior for the latent variable at the first timestep is just 0, 1
- z_t0 = z[:, 0, :]
- p_z_t0 = distributions.Normal(
- mu=0., sigma=self.p_z_sigma, validate_args=False)
- log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 1)
- # the prior for subsequent timesteps is off by one
- mu = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
- tf.pack([w_mu] * self.batch_size))
- if self.fixed_p_z_sigma:
- sigma = self.p_z_sigma
- else:
- wz = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
- tf.pack([w_sigma] * self.batch_size))
- sigma = tf.maximum(tf.nn.softplus(wz), 1e-5)
- p_z_t1_to_end = distributions.Normal(mu=mu, sigma=sigma,
- validate_args=False)
- log_p_z_t1_to_end = tf.reduce_sum(
- p_z_t1_to_end.log_pdf(z[:, 1:, :]), [1, 2])
- log_p_z = log_p_z_t0 + log_p_z_t1_to_end
- return log_p_z, log_p_w_mu, log_p_w_sigma
- def recurrent_layer_sample(self, w_mu, w_sigma, n_samples_latents):
- """Sample from the model, with learned latent weights.
- Args:
- w_mu: latent weights for the mean parameter. [z_dim, z_dim]
- w_sigma: latent weights for the standard deviation. [z_dim, z_dim]
- n_samples_latents: how many samples of latent variables
- Returns:
- z: samples from the generative process.
- """
- p_z_t0 = distributions.Normal(
- mu=0., sigma=self.p_z_sigma, validate_args=False)
- z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
- z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
- def sample_timestep(z_t_prev, w_mu, w_sigma):
- mu_t = tf.matmul(z_t_prev, w_mu)
- if self.fixed_p_z_sigma:
- sigma_t = self.p_z_sigma
- else:
- wz_t = tf.matmul(z_t_prev, w_sigma)
- sigma_t = tf.maximum(tf.nn.softplus(wz_t), 1e-5)
- p_z_t = distributions.Normal(mu=mu_t, sigma=sigma_t, validate_args=False)
- if self.z_dim == 1:
- return p_z_t.sample_n(n=1)[0, :, :]
- else:
- return tf.squeeze(p_z_t.sample_n(n=1))
- z_list = [z_t0]
- for _ in range(self.n_timesteps - 1):
- z_t = sample_timestep(z_list[-1], w_mu, w_sigma)
- z_list.append(z_t)
- z = tf.pack(z_list) # [n_timesteps, n_samples_latents, z_dim]
- z = tf.transpose(z, perm=[1, 0, 2]) # [n_samples, n_timesteps, z_dim]
- return z
- def likelihood_sample(self, params, z_1, n_samples):
- return util.bernoulli_likelihood_sample(params, z_1, n_samples)
- class NormalNormalRDEFVariational(object):
- """Creates the variational family for the recurrent DEF model.
- Variational family:
- q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
- [n_examples, n_timesteps, z_dim]
- q_w_1_mu: gaussian approximate posterior q(w_1) for mean weights of first
- (recurrent) layer [z_dim, z_dim]
- q_w_1_sigma: gaussian approximate posterior q(w_1) for std weights, first
- (recurrent) layer [z_dim, z_dim]
- q_w_0: gaussian approximate posterior q(w_0) for weights of observation
- layer [z_dim, timestep_dim]
- """
- def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
- timestep_dim, init_sigma_q_w_mu, init_sigma_q_z,
- init_sigma_q_w_sigma, fixed_p_z_sigma, fixed_q_z_sigma,
- fixed_q_w_mu_sigma, fixed_q_w_sigma_sigma,
- fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, dtype):
- """Initializes the variational family for the NormalNormalRDEF.
- Args:
- x_indexes: tensor. indices of the datapoints.
- n_examples: int. number of examples in the dataset.
- n_timesteps: int. number of timesteps in each datapoint.
- z_dim: int. dimension of latent variables.
- timestep_dim: int. dimension of each timestep.
- init_sigma_q_w_mu: float. initial variance for weights for the means of
- the latent variables.
- init_sigma_q_z: float. initial variance for the variational distribution
- for the latent variables.
- init_sigma_q_w_sigma: float. initial variance for the weights for the
- variance of the latent variables.
- fixed_p_z_sigma: bool. whether to keep the prior over latents fixed.
- fixed_q_z_sigma: bool. whether to train the variance of the variational
- distributions for the latents.
- fixed_q_w_mu_sigma: bool. whether to train the variance of the weights for
- the latent variables.
- fixed_q_w_sigma_sigma: bool. whether to train the variance of the weights
- for the variance of the latent variables.
- fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
- the observations.
- init_sigma_q_w_0_sigma: float. initial variance for the observation
- weights.
- dtype: dtype
- """
- self.x_indexes = x_indexes
- self.n_examples = n_examples
- self.n_timesteps = n_timesteps
- self.z_dim = z_dim
- self.timestep_dim = timestep_dim
- self.init_sigma_q_z = init_sigma_q_z
- self.init_sigma_q_w_mu = init_sigma_q_w_mu
- self.init_sigma_q_w_sigma = init_sigma_q_w_sigma
- self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
- self.fixed_p_z_sigma = fixed_p_z_sigma
- self.fixed_q_z_sigma = fixed_q_z_sigma
- self.fixed_q_w_mu_sigma = fixed_q_w_mu_sigma
- self.fixed_q_w_sigma_sigma = fixed_q_w_sigma_sigma
- self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
- self.dtype = dtype
- self.build_graph()
- @property
- def sample(self):
- """Returns a dict of samples of the latent variables."""
- return self.params
- def build_graph(self):
- """Builds the graph for the variational family for the NormalNormalRDEF."""
- with tf.variable_scope('q_z_1'):
- z_1 = util.build_gaussian(
- [self.n_examples, self.n_timesteps, self.z_dim],
- init_mu=0., init_sigma=self.init_sigma_q_z, x_indexes=self.x_indexes,
- fixed_sigma=self.fixed_q_z_sigma, place_on_cpu=True, dtype=self.dtype)
- with tf.variable_scope('q_w_1_mu'):
- # half of the weights are for the mean, half for the variance
- w_1_mu = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
- init_sigma=self.init_sigma_q_w_mu,
- fixed_sigma=self.fixed_q_w_mu_sigma,
- dtype=self.dtype)
- if self.fixed_p_z_sigma:
- w_1_sigma = None
- else:
- with tf.variable_scope('q_w_1_sigma'):
- w_1_sigma = util.build_gaussian(
- [self.z_dim, self.z_dim],
- init_mu=0., init_sigma=self.init_sigma_q_w_sigma,
- fixed_sigma=self.fixed_q_w_sigma_sigma,
- dtype=self.dtype)
- with tf.variable_scope('q_w_0'):
- w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
- init_sigma=self.init_sigma_q_w_0_sigma,
- fixed_sigma=self.fixed_q_w_0_sigma,
- dtype=self.dtype)
- self.params = {'w_0': w_0, 'w_1_mu': w_1_mu, 'w_1_sigma': w_1_sigma,
- 'z_1': z_1}
- def log_prob(self, q_samples):
- """Get the log joint of variational family: log(q(z, w_mu, w_sigma, w_0)).
- Args:
- q_samples: dict. samples of latent variables
- Returns:
- log_prob: tensor log-probability summed over dimensions of the variables
- """
- w_0 = q_samples['w_0']
- z_1 = q_samples['z_1']
- w_1_mu = q_samples['w_1_mu']
- w_1_sigma = q_samples['w_1_sigma']
- log_prob = 0.
- # preserve the minibatch dimension [0]
- log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1), [1, 2])
- # w_1, w_0 are global, so reduce_sum across all dims
- log_prob += tf.reduce_sum(w_1_mu.distribution.log_pdf(w_1_mu))
- log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0))
- if not self.fixed_p_z_sigma:
- log_prob += tf.reduce_sum(w_1_sigma.distribution.log_pdf(w_1_sigma))
- return log_prob
- class GammaNormalRDEF(object):
- """Class for a recurrent DEF with normal latent variables and normal weights.
- """
- def __init__(self, n_timesteps, batch_size, p_w_shape_sigma, p_w_mean_sigma,
- p_z_shape, p_z_mean, fixed_p_z_mean, z_dim, n_samples_latents,
- use_bias_observations, dtype):
- """Initializes the NormalNormalRDEF class.
- Args:
- n_timesteps: int. number of timesteps
- batch_size: int. batch size
- p_w_shape_sigma: float. prior for the weights for the mean of the latent
- variables
- p_w_mean_sigma: float. prior for the weights for the shape of the
- latent variables
- p_z_shape: float. prior for shape.
- p_z_mean: floating point prior for the latent variables
- fixed_p_z_mean: bool. whether the prior mean is learned
- z_dim: int. dimension of each latent variable
- n_samples_latents: number of samples of latent variables
- use_bias_observations: whether to use bias terms
- dtype: dtype
- """
- self.n_timesteps = n_timesteps
- self.batch_size = batch_size
- self.p_w_shape_sigma = p_w_shape_sigma
- self.p_w_mean_sigma = p_w_mean_sigma
- self.p_z_shape = p_z_shape
- self.p_z_mean = p_z_mean
- self.fixed_p_z_mean = fixed_p_z_mean
- self.z_dim = z_dim
- self.n_samples_latents = n_samples_latents
- self.use_bias_observations = use_bias_observations
- self.use_bias_latents = False
- self.dtype = dtype
- def log_prob(self, params, x):
- """Returns the log joint. log p(x | z, w)p(z)log p(w); [batch_size].
- Args:
- params: dict. dictionary of samples of the latent variables.
- x: tensor. minibatch of examples
- Returns:
- The log joint of the GammaNormalRDEF probability model.
- """
- z_1 = params['z_1']
- w_1_mean = params['w_1_mean']
- w_1_shape = params['w_1_shape']
- log_p_x_zw, p = util.build_bernoulli_log_likelihood(
- params, x, self.batch_size, n_samples_latents=self.n_samples_latents,
- use_bias_observations=self.use_bias_observations)
- self.p_x_zw_bernoulli_p = p
- log_p_z, log_p_w_shape, log_p_w_mean = self.build_recurrent_layer(
- z_1, w_1_shape, w_1_mean)
- return log_p_x_zw + log_p_z + log_p_w_shape + log_p_w_mean
- def build_recurrent_layer(self, z, w_shape, w_mean):
- """Creates a gaussian layer of the recurrent DEF.
- Args:
- z: sampled gamma latent variables,
- shape [n_samples_latents, batch_size, n_timesteps, z_dim]
- w_shape: single sample of gaussian stochastic weights for shape,
- shape [z_dim, z_dim]
- w_mean: single sample of gaussian stochastic weights for mean,
- shape [z_dim, z_dim]
- Returns:
- log_p_z: log prior of latent variables evaluated at the samples z.
- log_p_w_shape: log density of the weights evaluated at the sampled weights
- log_p_w_mean: log density of weights for stddev.
- """
- # the prior for the weights p(w) has two parts: p(w_shape) and p(w_mean)
- # prior for the weights for the mean parameter
- cast = lambda x: np.array(x, self.dtype)
- p_w_shape = distributions.Normal(mu=cast(0.),
- sigma=cast(self.p_w_shape_sigma),
- validate_args=False)
- log_p_w_shape = tf.reduce_sum(p_w_shape.log_pdf(w_shape))
- if self.fixed_p_z_mean:
- log_p_w_mean = 0.0
- else:
- # prior for the weights for the standard deviation
- p_w_mean = distributions.Normal(mu=cast(0.),
- sigma=cast(self.p_w_mean_sigma),
- validate_args=False)
- log_p_w_mean = tf.reduce_sum(p_w_mean.log_pdf(w_mean))
- # need this for indexing npy-style
- z = z.value()
- # the prior for the latent variable at the first timestep is just 0, 1
- z_t0 = z[:, :, 0, :]
- # alpha is shape, beta is inverse scale. we set the scale to be the mean
- # over the shape, so beta = shape / mean.
- p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
- beta=cast(self.p_z_shape / self.p_z_mean),
- validate_args=False)
- log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 2)
- # the prior for subsequent timesteps is off by one
- shape = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
- tf.pack([tf.pack([w_shape] * self.batch_size)]
- * self.n_samples_latents))
- shape = util.clip_shape(shape)
- if self.fixed_p_z_mean:
- mean = self.p_z_mean
- else:
- wz = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
- tf.pack([tf.pack([w_mean] * self.batch_size)]
- * self.n_samples_latents))
- mean = tf.nn.softplus(wz)
- mean = util.clip_mean(mean)
- p_z_t1_to_end = distributions.Gamma(alpha=shape,
- beta=shape / mean,
- validate_args=False)
- log_p_z_t1_to_end = tf.reduce_sum(
- p_z_t1_to_end.log_pdf(z[:, :, 1:, :]), [2, 3])
- log_p_z = log_p_z_t0 + log_p_z_t1_to_end
- return log_p_z, log_p_w_shape, log_p_w_mean
- def recurrent_layer_sample(self, w_shape, w_mean, n_samples_latents,
- b_shape=None, b_mean=None):
- """Sample from the model, with learned latent weights.
- Args:
- w_shape: latent weights for the mean parameter. [z_dim, z_dim]
- w_mean: latent weights for the standard deviation. [z_dim, z_dim]
- n_samples_latents: how many samples
- b_shape: bias for shape parameters
- b_mean: bias for mean parameters
- Returns:
- z: samples from the generative process.
- """
- cast = lambda x: np.array(x, self.dtype)
- p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
- beta=cast(self.p_z_shape / self.p_z_mean),
- validate_args=False)
- z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
- z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
- def sample_timestep(z_t_prev, w_shape, w_mean, b_shape=b_shape,
- b_mean=b_mean):
- """Sample a single timestep.
- Args:
- z_t_prev: previous timestep latent variable,
- shape [n_samples_latents, z_dim]
- w_shape: latent weights for shape param, shape [z_dim, z_dim]
- w_mean: latent weights for mean param, shape [z_dim, z_dim]
- b_shape: bias for shape parameters
- b_mean: bias for mean parameters
- Returns:
- z_t: A sample of a latent variable for all timesteps
- """
- wz_t = tf.matmul(z_t_prev, w_shape)
- if self.use_bias_latents:
- wz_t += b_shape
- shape_t = tf.nn.softplus(wz_t)
- shape_t = util.clip_shape(shape_t)
- if self.fixed_p_z_mean:
- mean_t = self.p_z_mean
- else:
- wz_t = tf.matmul(z_t_prev, w_mean)
- if self.use_bias_latents:
- wz_t += b_mean
- mean_t = tf.nn.softplus(wz_t)
- mean_t = util.clip_mean(mean_t)
- p_z_t = distributions.Gamma(alpha=shape_t,
- beta=shape_t / mean_t,
- validate_args=False)
- z_t = p_z_t.sample_n(n=1)[0, :, :]
- return z_t
- z_list = [z_t0]
- for _ in range(self.n_timesteps - 1):
- z_t = sample_timestep(z_list[-1], w_shape, w_mean)
- z_list.append(z_t)
- # pack into shape [n_timesteps, n_samples_latents, z_dim]
- z = tf.pack(z_list)
- # transpose into [n_samples_latents, n_timesteps, z_dim]
- z = tf.transpose(z, perm=[1, 0, 2])
- return z
- def likelihood_sample(self, params, z_1, n_samples):
- return util.bernoulli_likelihood_sample(
- params, z_1, n_samples,
- use_bias_observations=self.use_bias_observations)
- class GammaNormalRDEFVariational(object):
- """Creates the variational family for the recurrent DEF model.
- Variational family:
- q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
- [n_examples, n_timesteps, z_dim]
- q_w_1_shape: gaussian approximate posterior q(w_1) for mean weights of
- (recurrent) layer [z_dim, z_dim]
- q_w_1_mean: gaussian approximate posterior q(w_1) for std weights, first
- (recurrent) layer [z_dim, z_dim]
- q_w_0: gaussian approximate posterior q(w_0) for weights of observation
- layer [z_dim, timestep_dim]
- """
- def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
- timestep_dim, init_sigma_q_w_shape, init_shape_q_z,
- init_mean_q_z,
- init_sigma_q_w_mean, fixed_p_z_mean, fixed_q_z_mean,
- fixed_q_w_shape_sigma, fixed_q_w_mean_sigma,
- fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, n_samples_latents,
- use_bias_observations,
- dtype):
- """Initializes the variational family for the NormalNormalRDEF.
- Args:
- x_indexes: tensor. indices of the datapoints.
- n_examples: int. number of examples in the dataset.
- n_timesteps: int. number of timesteps in each datapoint.
- z_dim: int. dimension of latent variables.
- timestep_dim: int. dimension of each timestep.
- init_sigma_q_w_shape: float. initial variance for weights for the means of
- the latent variables.
- init_shape_q_z: float. initial variance for the variational distribution
- for the latent variables.
- init_mean_q_z: float. initial mean for latent variables variational.
- init_sigma_q_w_mean: float. initial variance for the weights for the
- variance of the latent variables.
- fixed_p_z_mean: bool. whether to keep the prior over latents fixed.
- fixed_q_z_mean: bool. whether to train the variance of the variational
- distributions for the latents.
- fixed_q_w_shape_sigma: bool. whether to train the variance of the weights
- the latent variables.
- fixed_q_w_mean_sigma: bool. whether to train the variance of the weights
- for the variance of the latent variables.
- fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
- the observations.
- init_sigma_q_w_0_sigma: float. initial variance for the observation
- weights.
- n_samples_latents: number of samples of latent variables to draw
- use_bias_observations: whether to use bias terms
- dtype: dtype
- """
- self.x_indexes = x_indexes
- self.n_examples = n_examples
- self.n_timesteps = n_timesteps
- self.z_dim = z_dim
- self.timestep_dim = timestep_dim
- self.init_mean_q_z = init_mean_q_z
- self.init_shape_q_z = init_shape_q_z
- self.init_sigma_q_w_shape = init_sigma_q_w_shape
- self.init_sigma_q_w_mean = init_sigma_q_w_mean
- self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
- self.fixed_p_z_mean = fixed_p_z_mean
- self.fixed_q_z_mean = fixed_q_z_mean
- self.fixed_q_w_shape_sigma = fixed_q_w_shape_sigma
- self.fixed_q_w_mean_sigma = fixed_q_w_mean_sigma
- self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
- self.n_samples_latents = n_samples_latents
- self.use_bias_observations = use_bias_observations
- self.dtype = dtype
- with tf.variable_scope('variational'):
- self.build_graph()
- @property
- def sample(self):
- """Returns a dict of samples of the latent variables."""
- return self.params
- @property
- def trainable_variables(self):
- return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'variational')
- def build_graph(self):
- """Builds the graph for the variational family for the NormalNormalRDEF."""
- with tf.variable_scope('q_z_1'):
- z_1 = util.build_gamma(
- [self.n_examples, self.n_timesteps, self.z_dim],
- init_shape=self.init_shape_q_z,
- init_mean=self.init_mean_q_z,
- x_indexes=self.x_indexes,
- fixed_mean=self.fixed_q_z_mean,
- place_on_cpu=False,
- n_samples=self.n_samples_latents,
- dtype=self.dtype)
- with tf.variable_scope('q_w_1_shape'):
- # half of the weights are for the mean, half for the variance
- w_1_shape = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
- init_sigma=self.init_sigma_q_w_shape,
- fixed_sigma=self.fixed_q_w_shape_sigma,
- dtype=self.dtype)
- if self.fixed_p_z_mean:
- w_1_mean = None
- else:
- with tf.variable_scope('q_w_1_mean'):
- w_1_mean = util.build_gaussian(
- [self.z_dim, self.z_dim],
- init_mu=0., init_sigma=self.init_sigma_q_w_mean,
- fixed_sigma=self.fixed_q_w_mean_sigma,
- dtype=self.dtype)
- with tf.variable_scope('q_w_0'):
- w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
- init_sigma=self.init_sigma_q_w_0_sigma,
- fixed_sigma=self.fixed_q_w_0_sigma,
- dtype=self.dtype)
- self.params = {'w_0': w_0, 'w_1_shape': w_1_shape, 'w_1_mean': w_1_mean,
- 'z_1': z_1}
- if self.use_bias_observations:
- # b_0 = tf.get_variable(
- # 'b_0', [self.timestep_dim], self.dtype, tf.zeros_initializer,
- # collections=[tf.GraphKeys.VARIABLES, 'reparam_variables'])
- b_0 = util.build_gaussian([self.timestep_dim], init_mu=0.,
- init_sigma=0.01, fixed_sigma=False,
- dtype=self.dtype)
- self.params.update({'b_0': b_0})
- def log_prob(self, q_samples):
- """Get the log joint of variational family: log(q(z, w_shape, w_mean, w_0)).
- Args:
- q_samples: dict. samples of latent variables.
- Returns:
- log_prob: tensor log-probability summed over dimensions of the variables
- """
- w_0 = q_samples['w_0']
- z_1 = q_samples['z_1']
- w_1_shape = q_samples['w_1_shape']
- w_1_mean = q_samples['w_1_mean']
- log_prob = 0.
- # preserve the sample and minibatch dimensions [0, 1]
- log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1.value()), [2, 3])
- # w_1, w_0 are global, so reduce_sum across all dims
- log_prob += tf.reduce_sum(w_1_shape.distribution.log_pdf(w_1_shape.value()))
- log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0.value()))
- if not self.fixed_p_z_mean:
- log_prob += tf.reduce_sum(w_1_mean.distribution.log_pdf(w_1_mean.value()))
- return log_prob
|