model_factory.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Classes for models and variational distributions for recurrent DEFs.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import numpy as np
  21. import tensorflow as tf
  22. from ops import util
  23. st = tf.contrib.bayesflow.stochastic_tensor
  24. distributions = tf.contrib.distributions
  25. class NormalNormalRDEF(object):
  26. """Class for a recurrent DEF with normal latent variables and normal weights.
  27. """
  28. def __init__(self, n_timesteps, batch_size, p_w_mu_sigma, p_w_sigma_sigma,
  29. p_z_sigma, fixed_p_z_sigma, z_dim, dtype):
  30. """Initializes the NormalNormalRDEF class.
  31. Args:
  32. n_timesteps: int. number of timesteps
  33. batch_size: int. batch size
  34. p_w_mu_sigma: float. prior for the weights for the mean of the latent
  35. variables
  36. p_w_sigma_sigma: float. prior for the weights for the variance of the
  37. latent variables
  38. p_z_sigma: floating point prior for the latent variables
  39. fixed_p_z_sigma: bool. whether the prior variance is learned
  40. z_dim: int. dimension of each latent variable
  41. dtype: dtype
  42. """
  43. self.n_timesteps = n_timesteps
  44. self.batch_size = batch_size
  45. self.p_w_mu_sigma = p_w_mu_sigma
  46. self.p_w_sigma_sigma = p_w_sigma_sigma
  47. self.p_z_sigma = p_z_sigma
  48. self.fixed_p_z_sigma = fixed_p_z_sigma
  49. self.z_dim = z_dim
  50. self.dtype = dtype
  51. def log_prob(self, params, x):
  52. """Returns the log joint. log p(x | z, w)p(z)p(w); [batch_size].
  53. Args:
  54. params: dict. dictionary of samples of the latent variables.
  55. x: tensor. minibatch of examples
  56. Returns:
  57. The log joint of the NormalNormalRDEF probability model.
  58. """
  59. z_1 = params['z_1']
  60. w_1_mu = params['w_1_mu']
  61. w_1_sigma = params['w_1_sigma']
  62. log_p_x_zw, p = util.build_bernoulli_log_likelihood(
  63. params, x, self.batch_size)
  64. self.p_x_zw_bernoulli_p = p
  65. log_p_z, log_p_w_mu, log_p_w_sigma = self.build_recurrent_layer(
  66. z_1, w_1_mu, w_1_sigma)
  67. return log_p_x_zw + log_p_z + log_p_w_mu + log_p_w_sigma
  68. def build_recurrent_layer(self, z, w_mu, w_sigma):
  69. """Creates a gaussian layer of the recurrent DEF.
  70. Args:
  71. z: sampled gaussian latent variables [batch_size, n_timesteps, z_dim]
  72. w_mu: sampled gaussian stochastic weights [z_dim, z_dim]
  73. w_sigma: sampled gaussian stochastic weights for stddev
  74. [z_dim, z_dim]
  75. Returns:
  76. log_p_z: log prior of latent variables evaluated at the samples z.
  77. log_p_w_mu: log density of the weights evaluated at the sampled weights w.
  78. log_p_w_sigma: log density of weights for stddev.
  79. """
  80. # the prior for the weights p(w) has two parts: p(w_mu) and p(w_sigma)
  81. # prior for the weights for the mean parameter
  82. p_w_mu = distributions.Normal(
  83. mu=0., sigma=self.p_w_mu_sigma, validate_args=False)
  84. log_p_w_mu = tf.reduce_sum(p_w_mu.log_pdf(w_mu))
  85. if self.fixed_p_z_sigma:
  86. log_p_w_sigma = 0.0
  87. else:
  88. # prior for the weights for the standard deviation
  89. p_w_sigma = distributions.Normal(mu=0., sigma=self.p_w_sigma_sigma,
  90. validate_args=False)
  91. log_p_w_sigma = tf.reduce_sum(p_w_sigma.log_pdf(w_sigma))
  92. # need this for indexing npy-style
  93. z = z.value()
  94. # the prior for the latent variable at the first timestep is just 0, 1
  95. z_t0 = z[:, 0, :]
  96. p_z_t0 = distributions.Normal(
  97. mu=0., sigma=self.p_z_sigma, validate_args=False)
  98. log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 1)
  99. # the prior for subsequent timesteps is off by one
  100. mu = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
  101. tf.pack([w_mu] * self.batch_size))
  102. if self.fixed_p_z_sigma:
  103. sigma = self.p_z_sigma
  104. else:
  105. wz = tf.batch_matmul(z[:, :self.n_timesteps-1, :],
  106. tf.pack([w_sigma] * self.batch_size))
  107. sigma = tf.maximum(tf.nn.softplus(wz), 1e-5)
  108. p_z_t1_to_end = distributions.Normal(mu=mu, sigma=sigma,
  109. validate_args=False)
  110. log_p_z_t1_to_end = tf.reduce_sum(
  111. p_z_t1_to_end.log_pdf(z[:, 1:, :]), [1, 2])
  112. log_p_z = log_p_z_t0 + log_p_z_t1_to_end
  113. return log_p_z, log_p_w_mu, log_p_w_sigma
  114. def recurrent_layer_sample(self, w_mu, w_sigma, n_samples_latents):
  115. """Sample from the model, with learned latent weights.
  116. Args:
  117. w_mu: latent weights for the mean parameter. [z_dim, z_dim]
  118. w_sigma: latent weights for the standard deviation. [z_dim, z_dim]
  119. n_samples_latents: how many samples of latent variables
  120. Returns:
  121. z: samples from the generative process.
  122. """
  123. p_z_t0 = distributions.Normal(
  124. mu=0., sigma=self.p_z_sigma, validate_args=False)
  125. z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
  126. z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
  127. def sample_timestep(z_t_prev, w_mu, w_sigma):
  128. mu_t = tf.matmul(z_t_prev, w_mu)
  129. if self.fixed_p_z_sigma:
  130. sigma_t = self.p_z_sigma
  131. else:
  132. wz_t = tf.matmul(z_t_prev, w_sigma)
  133. sigma_t = tf.maximum(tf.nn.softplus(wz_t), 1e-5)
  134. p_z_t = distributions.Normal(mu=mu_t, sigma=sigma_t, validate_args=False)
  135. if self.z_dim == 1:
  136. return p_z_t.sample_n(n=1)[0, :, :]
  137. else:
  138. return tf.squeeze(p_z_t.sample_n(n=1))
  139. z_list = [z_t0]
  140. for _ in range(self.n_timesteps - 1):
  141. z_t = sample_timestep(z_list[-1], w_mu, w_sigma)
  142. z_list.append(z_t)
  143. z = tf.pack(z_list) # [n_timesteps, n_samples_latents, z_dim]
  144. z = tf.transpose(z, perm=[1, 0, 2]) # [n_samples, n_timesteps, z_dim]
  145. return z
  146. def likelihood_sample(self, params, z_1, n_samples):
  147. return util.bernoulli_likelihood_sample(params, z_1, n_samples)
  148. class NormalNormalRDEFVariational(object):
  149. """Creates the variational family for the recurrent DEF model.
  150. Variational family:
  151. q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
  152. [n_examples, n_timesteps, z_dim]
  153. q_w_1_mu: gaussian approximate posterior q(w_1) for mean weights of first
  154. (recurrent) layer [z_dim, z_dim]
  155. q_w_1_sigma: gaussian approximate posterior q(w_1) for std weights, first
  156. (recurrent) layer [z_dim, z_dim]
  157. q_w_0: gaussian approximate posterior q(w_0) for weights of observation
  158. layer [z_dim, timestep_dim]
  159. """
  160. def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
  161. timestep_dim, init_sigma_q_w_mu, init_sigma_q_z,
  162. init_sigma_q_w_sigma, fixed_p_z_sigma, fixed_q_z_sigma,
  163. fixed_q_w_mu_sigma, fixed_q_w_sigma_sigma,
  164. fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, dtype):
  165. """Initializes the variational family for the NormalNormalRDEF.
  166. Args:
  167. x_indexes: tensor. indices of the datapoints.
  168. n_examples: int. number of examples in the dataset.
  169. n_timesteps: int. number of timesteps in each datapoint.
  170. z_dim: int. dimension of latent variables.
  171. timestep_dim: int. dimension of each timestep.
  172. init_sigma_q_w_mu: float. initial variance for weights for the means of
  173. the latent variables.
  174. init_sigma_q_z: float. initial variance for the variational distribution
  175. for the latent variables.
  176. init_sigma_q_w_sigma: float. initial variance for the weights for the
  177. variance of the latent variables.
  178. fixed_p_z_sigma: bool. whether to keep the prior over latents fixed.
  179. fixed_q_z_sigma: bool. whether to train the variance of the variational
  180. distributions for the latents.
  181. fixed_q_w_mu_sigma: bool. whether to train the variance of the weights for
  182. the latent variables.
  183. fixed_q_w_sigma_sigma: bool. whether to train the variance of the weights
  184. for the variance of the latent variables.
  185. fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
  186. the observations.
  187. init_sigma_q_w_0_sigma: float. initial variance for the observation
  188. weights.
  189. dtype: dtype
  190. """
  191. self.x_indexes = x_indexes
  192. self.n_examples = n_examples
  193. self.n_timesteps = n_timesteps
  194. self.z_dim = z_dim
  195. self.timestep_dim = timestep_dim
  196. self.init_sigma_q_z = init_sigma_q_z
  197. self.init_sigma_q_w_mu = init_sigma_q_w_mu
  198. self.init_sigma_q_w_sigma = init_sigma_q_w_sigma
  199. self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
  200. self.fixed_p_z_sigma = fixed_p_z_sigma
  201. self.fixed_q_z_sigma = fixed_q_z_sigma
  202. self.fixed_q_w_mu_sigma = fixed_q_w_mu_sigma
  203. self.fixed_q_w_sigma_sigma = fixed_q_w_sigma_sigma
  204. self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
  205. self.dtype = dtype
  206. self.build_graph()
  207. @property
  208. def sample(self):
  209. """Returns a dict of samples of the latent variables."""
  210. return self.params
  211. def build_graph(self):
  212. """Builds the graph for the variational family for the NormalNormalRDEF."""
  213. with tf.variable_scope('q_z_1'):
  214. z_1 = util.build_gaussian(
  215. [self.n_examples, self.n_timesteps, self.z_dim],
  216. init_mu=0., init_sigma=self.init_sigma_q_z, x_indexes=self.x_indexes,
  217. fixed_sigma=self.fixed_q_z_sigma, place_on_cpu=True, dtype=self.dtype)
  218. with tf.variable_scope('q_w_1_mu'):
  219. # half of the weights are for the mean, half for the variance
  220. w_1_mu = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
  221. init_sigma=self.init_sigma_q_w_mu,
  222. fixed_sigma=self.fixed_q_w_mu_sigma,
  223. dtype=self.dtype)
  224. if self.fixed_p_z_sigma:
  225. w_1_sigma = None
  226. else:
  227. with tf.variable_scope('q_w_1_sigma'):
  228. w_1_sigma = util.build_gaussian(
  229. [self.z_dim, self.z_dim],
  230. init_mu=0., init_sigma=self.init_sigma_q_w_sigma,
  231. fixed_sigma=self.fixed_q_w_sigma_sigma,
  232. dtype=self.dtype)
  233. with tf.variable_scope('q_w_0'):
  234. w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
  235. init_sigma=self.init_sigma_q_w_0_sigma,
  236. fixed_sigma=self.fixed_q_w_0_sigma,
  237. dtype=self.dtype)
  238. self.params = {'w_0': w_0, 'w_1_mu': w_1_mu, 'w_1_sigma': w_1_sigma,
  239. 'z_1': z_1}
  240. def log_prob(self, q_samples):
  241. """Get the log joint of variational family: log(q(z, w_mu, w_sigma, w_0)).
  242. Args:
  243. q_samples: dict. samples of latent variables
  244. Returns:
  245. log_prob: tensor log-probability summed over dimensions of the variables
  246. """
  247. w_0 = q_samples['w_0']
  248. z_1 = q_samples['z_1']
  249. w_1_mu = q_samples['w_1_mu']
  250. w_1_sigma = q_samples['w_1_sigma']
  251. log_prob = 0.
  252. # preserve the minibatch dimension [0]
  253. log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1), [1, 2])
  254. # w_1, w_0 are global, so reduce_sum across all dims
  255. log_prob += tf.reduce_sum(w_1_mu.distribution.log_pdf(w_1_mu))
  256. log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0))
  257. if not self.fixed_p_z_sigma:
  258. log_prob += tf.reduce_sum(w_1_sigma.distribution.log_pdf(w_1_sigma))
  259. return log_prob
  260. class GammaNormalRDEF(object):
  261. """Class for a recurrent DEF with normal latent variables and normal weights.
  262. """
  263. def __init__(self, n_timesteps, batch_size, p_w_shape_sigma, p_w_mean_sigma,
  264. p_z_shape, p_z_mean, fixed_p_z_mean, z_dim, n_samples_latents,
  265. use_bias_observations, dtype):
  266. """Initializes the NormalNormalRDEF class.
  267. Args:
  268. n_timesteps: int. number of timesteps
  269. batch_size: int. batch size
  270. p_w_shape_sigma: float. prior for the weights for the mean of the latent
  271. variables
  272. p_w_mean_sigma: float. prior for the weights for the shape of the
  273. latent variables
  274. p_z_shape: float. prior for shape.
  275. p_z_mean: floating point prior for the latent variables
  276. fixed_p_z_mean: bool. whether the prior mean is learned
  277. z_dim: int. dimension of each latent variable
  278. n_samples_latents: number of samples of latent variables
  279. use_bias_observations: whether to use bias terms
  280. dtype: dtype
  281. """
  282. self.n_timesteps = n_timesteps
  283. self.batch_size = batch_size
  284. self.p_w_shape_sigma = p_w_shape_sigma
  285. self.p_w_mean_sigma = p_w_mean_sigma
  286. self.p_z_shape = p_z_shape
  287. self.p_z_mean = p_z_mean
  288. self.fixed_p_z_mean = fixed_p_z_mean
  289. self.z_dim = z_dim
  290. self.n_samples_latents = n_samples_latents
  291. self.use_bias_observations = use_bias_observations
  292. self.use_bias_latents = False
  293. self.dtype = dtype
  294. def log_prob(self, params, x):
  295. """Returns the log joint. log p(x | z, w)p(z)log p(w); [batch_size].
  296. Args:
  297. params: dict. dictionary of samples of the latent variables.
  298. x: tensor. minibatch of examples
  299. Returns:
  300. The log joint of the GammaNormalRDEF probability model.
  301. """
  302. z_1 = params['z_1']
  303. w_1_mean = params['w_1_mean']
  304. w_1_shape = params['w_1_shape']
  305. log_p_x_zw, p = util.build_bernoulli_log_likelihood(
  306. params, x, self.batch_size, n_samples_latents=self.n_samples_latents,
  307. use_bias_observations=self.use_bias_observations)
  308. self.p_x_zw_bernoulli_p = p
  309. log_p_z, log_p_w_shape, log_p_w_mean = self.build_recurrent_layer(
  310. z_1, w_1_shape, w_1_mean)
  311. return log_p_x_zw + log_p_z + log_p_w_shape + log_p_w_mean
  312. def build_recurrent_layer(self, z, w_shape, w_mean):
  313. """Creates a gaussian layer of the recurrent DEF.
  314. Args:
  315. z: sampled gamma latent variables,
  316. shape [n_samples_latents, batch_size, n_timesteps, z_dim]
  317. w_shape: single sample of gaussian stochastic weights for shape,
  318. shape [z_dim, z_dim]
  319. w_mean: single sample of gaussian stochastic weights for mean,
  320. shape [z_dim, z_dim]
  321. Returns:
  322. log_p_z: log prior of latent variables evaluated at the samples z.
  323. log_p_w_shape: log density of the weights evaluated at the sampled weights
  324. log_p_w_mean: log density of weights for stddev.
  325. """
  326. # the prior for the weights p(w) has two parts: p(w_shape) and p(w_mean)
  327. # prior for the weights for the mean parameter
  328. cast = lambda x: np.array(x, self.dtype)
  329. p_w_shape = distributions.Normal(mu=cast(0.),
  330. sigma=cast(self.p_w_shape_sigma),
  331. validate_args=False)
  332. log_p_w_shape = tf.reduce_sum(p_w_shape.log_pdf(w_shape))
  333. if self.fixed_p_z_mean:
  334. log_p_w_mean = 0.0
  335. else:
  336. # prior for the weights for the standard deviation
  337. p_w_mean = distributions.Normal(mu=cast(0.),
  338. sigma=cast(self.p_w_mean_sigma),
  339. validate_args=False)
  340. log_p_w_mean = tf.reduce_sum(p_w_mean.log_pdf(w_mean))
  341. # need this for indexing npy-style
  342. z = z.value()
  343. # the prior for the latent variable at the first timestep is just 0, 1
  344. z_t0 = z[:, :, 0, :]
  345. # alpha is shape, beta is inverse scale. we set the scale to be the mean
  346. # over the shape, so beta = shape / mean.
  347. p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
  348. beta=cast(self.p_z_shape / self.p_z_mean),
  349. validate_args=False)
  350. log_p_z_t0 = tf.reduce_sum(p_z_t0.log_pdf(z_t0), 2)
  351. # the prior for subsequent timesteps is off by one
  352. shape = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
  353. tf.pack([tf.pack([w_shape] * self.batch_size)]
  354. * self.n_samples_latents))
  355. shape = util.clip_shape(shape)
  356. if self.fixed_p_z_mean:
  357. mean = self.p_z_mean
  358. else:
  359. wz = tf.batch_matmul(z[:, :, :self.n_timesteps-1, :],
  360. tf.pack([tf.pack([w_mean] * self.batch_size)]
  361. * self.n_samples_latents))
  362. mean = tf.nn.softplus(wz)
  363. mean = util.clip_mean(mean)
  364. p_z_t1_to_end = distributions.Gamma(alpha=shape,
  365. beta=shape / mean,
  366. validate_args=False)
  367. log_p_z_t1_to_end = tf.reduce_sum(
  368. p_z_t1_to_end.log_pdf(z[:, :, 1:, :]), [2, 3])
  369. log_p_z = log_p_z_t0 + log_p_z_t1_to_end
  370. return log_p_z, log_p_w_shape, log_p_w_mean
  371. def recurrent_layer_sample(self, w_shape, w_mean, n_samples_latents,
  372. b_shape=None, b_mean=None):
  373. """Sample from the model, with learned latent weights.
  374. Args:
  375. w_shape: latent weights for the mean parameter. [z_dim, z_dim]
  376. w_mean: latent weights for the standard deviation. [z_dim, z_dim]
  377. n_samples_latents: how many samples
  378. b_shape: bias for shape parameters
  379. b_mean: bias for mean parameters
  380. Returns:
  381. z: samples from the generative process.
  382. """
  383. cast = lambda x: np.array(x, self.dtype)
  384. p_z_t0 = distributions.Gamma(alpha=cast(self.p_z_shape),
  385. beta=cast(self.p_z_shape / self.p_z_mean),
  386. validate_args=False)
  387. z_t0 = p_z_t0.sample_n(n=n_samples_latents * self.z_dim)
  388. z_t0 = tf.reshape(z_t0, [n_samples_latents, self.z_dim])
  389. def sample_timestep(z_t_prev, w_shape, w_mean, b_shape=b_shape,
  390. b_mean=b_mean):
  391. """Sample a single timestep.
  392. Args:
  393. z_t_prev: previous timestep latent variable,
  394. shape [n_samples_latents, z_dim]
  395. w_shape: latent weights for shape param, shape [z_dim, z_dim]
  396. w_mean: latent weights for mean param, shape [z_dim, z_dim]
  397. b_shape: bias for shape parameters
  398. b_mean: bias for mean parameters
  399. Returns:
  400. z_t: A sample of a latent variable for all timesteps
  401. """
  402. wz_t = tf.matmul(z_t_prev, w_shape)
  403. if self.use_bias_latents:
  404. wz_t += b_shape
  405. shape_t = tf.nn.softplus(wz_t)
  406. shape_t = util.clip_shape(shape_t)
  407. if self.fixed_p_z_mean:
  408. mean_t = self.p_z_mean
  409. else:
  410. wz_t = tf.matmul(z_t_prev, w_mean)
  411. if self.use_bias_latents:
  412. wz_t += b_mean
  413. mean_t = tf.nn.softplus(wz_t)
  414. mean_t = util.clip_mean(mean_t)
  415. p_z_t = distributions.Gamma(alpha=shape_t,
  416. beta=shape_t / mean_t,
  417. validate_args=False)
  418. z_t = p_z_t.sample_n(n=1)[0, :, :]
  419. return z_t
  420. z_list = [z_t0]
  421. for _ in range(self.n_timesteps - 1):
  422. z_t = sample_timestep(z_list[-1], w_shape, w_mean)
  423. z_list.append(z_t)
  424. # pack into shape [n_timesteps, n_samples_latents, z_dim]
  425. z = tf.pack(z_list)
  426. # transpose into [n_samples_latents, n_timesteps, z_dim]
  427. z = tf.transpose(z, perm=[1, 0, 2])
  428. return z
  429. def likelihood_sample(self, params, z_1, n_samples):
  430. return util.bernoulli_likelihood_sample(
  431. params, z_1, n_samples,
  432. use_bias_observations=self.use_bias_observations)
  433. class GammaNormalRDEFVariational(object):
  434. """Creates the variational family for the recurrent DEF model.
  435. Variational family:
  436. q_z_1: gaussian approximate posterior q(z_1) for latents of first layer.
  437. [n_examples, n_timesteps, z_dim]
  438. q_w_1_shape: gaussian approximate posterior q(w_1) for mean weights of
  439. (recurrent) layer [z_dim, z_dim]
  440. q_w_1_mean: gaussian approximate posterior q(w_1) for std weights, first
  441. (recurrent) layer [z_dim, z_dim]
  442. q_w_0: gaussian approximate posterior q(w_0) for weights of observation
  443. layer [z_dim, timestep_dim]
  444. """
  445. def __init__(self, x_indexes, n_examples, n_timesteps, z_dim,
  446. timestep_dim, init_sigma_q_w_shape, init_shape_q_z,
  447. init_mean_q_z,
  448. init_sigma_q_w_mean, fixed_p_z_mean, fixed_q_z_mean,
  449. fixed_q_w_shape_sigma, fixed_q_w_mean_sigma,
  450. fixed_q_w_0_sigma, init_sigma_q_w_0_sigma, n_samples_latents,
  451. use_bias_observations,
  452. dtype):
  453. """Initializes the variational family for the NormalNormalRDEF.
  454. Args:
  455. x_indexes: tensor. indices of the datapoints.
  456. n_examples: int. number of examples in the dataset.
  457. n_timesteps: int. number of timesteps in each datapoint.
  458. z_dim: int. dimension of latent variables.
  459. timestep_dim: int. dimension of each timestep.
  460. init_sigma_q_w_shape: float. initial variance for weights for the means of
  461. the latent variables.
  462. init_shape_q_z: float. initial variance for the variational distribution
  463. for the latent variables.
  464. init_mean_q_z: float. initial mean for latent variables variational.
  465. init_sigma_q_w_mean: float. initial variance for the weights for the
  466. variance of the latent variables.
  467. fixed_p_z_mean: bool. whether to keep the prior over latents fixed.
  468. fixed_q_z_mean: bool. whether to train the variance of the variational
  469. distributions for the latents.
  470. fixed_q_w_shape_sigma: bool. whether to train the variance of the weights
  471. the latent variables.
  472. fixed_q_w_mean_sigma: bool. whether to train the variance of the weights
  473. for the variance of the latent variables.
  474. fixed_q_w_0_sigma: bool. whether te train the variance of the weights for
  475. the observations.
  476. init_sigma_q_w_0_sigma: float. initial variance for the observation
  477. weights.
  478. n_samples_latents: number of samples of latent variables to draw
  479. use_bias_observations: whether to use bias terms
  480. dtype: dtype
  481. """
  482. self.x_indexes = x_indexes
  483. self.n_examples = n_examples
  484. self.n_timesteps = n_timesteps
  485. self.z_dim = z_dim
  486. self.timestep_dim = timestep_dim
  487. self.init_mean_q_z = init_mean_q_z
  488. self.init_shape_q_z = init_shape_q_z
  489. self.init_sigma_q_w_shape = init_sigma_q_w_shape
  490. self.init_sigma_q_w_mean = init_sigma_q_w_mean
  491. self.init_sigma_q_w_0_sigma = init_sigma_q_w_0_sigma
  492. self.fixed_p_z_mean = fixed_p_z_mean
  493. self.fixed_q_z_mean = fixed_q_z_mean
  494. self.fixed_q_w_shape_sigma = fixed_q_w_shape_sigma
  495. self.fixed_q_w_mean_sigma = fixed_q_w_mean_sigma
  496. self.fixed_q_w_0_sigma = fixed_q_w_0_sigma
  497. self.n_samples_latents = n_samples_latents
  498. self.use_bias_observations = use_bias_observations
  499. self.dtype = dtype
  500. with tf.variable_scope('variational'):
  501. self.build_graph()
  502. @property
  503. def sample(self):
  504. """Returns a dict of samples of the latent variables."""
  505. return self.params
  506. @property
  507. def trainable_variables(self):
  508. return tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'variational')
  509. def build_graph(self):
  510. """Builds the graph for the variational family for the NormalNormalRDEF."""
  511. with tf.variable_scope('q_z_1'):
  512. z_1 = util.build_gamma(
  513. [self.n_examples, self.n_timesteps, self.z_dim],
  514. init_shape=self.init_shape_q_z,
  515. init_mean=self.init_mean_q_z,
  516. x_indexes=self.x_indexes,
  517. fixed_mean=self.fixed_q_z_mean,
  518. place_on_cpu=False,
  519. n_samples=self.n_samples_latents,
  520. dtype=self.dtype)
  521. with tf.variable_scope('q_w_1_shape'):
  522. # half of the weights are for the mean, half for the variance
  523. w_1_shape = util.build_gaussian([self.z_dim, self.z_dim], init_mu=0.,
  524. init_sigma=self.init_sigma_q_w_shape,
  525. fixed_sigma=self.fixed_q_w_shape_sigma,
  526. dtype=self.dtype)
  527. if self.fixed_p_z_mean:
  528. w_1_mean = None
  529. else:
  530. with tf.variable_scope('q_w_1_mean'):
  531. w_1_mean = util.build_gaussian(
  532. [self.z_dim, self.z_dim],
  533. init_mu=0., init_sigma=self.init_sigma_q_w_mean,
  534. fixed_sigma=self.fixed_q_w_mean_sigma,
  535. dtype=self.dtype)
  536. with tf.variable_scope('q_w_0'):
  537. w_0 = util.build_gaussian([self.z_dim, self.timestep_dim], init_mu=0.,
  538. init_sigma=self.init_sigma_q_w_0_sigma,
  539. fixed_sigma=self.fixed_q_w_0_sigma,
  540. dtype=self.dtype)
  541. self.params = {'w_0': w_0, 'w_1_shape': w_1_shape, 'w_1_mean': w_1_mean,
  542. 'z_1': z_1}
  543. if self.use_bias_observations:
  544. # b_0 = tf.get_variable(
  545. # 'b_0', [self.timestep_dim], self.dtype, tf.zeros_initializer,
  546. # collections=[tf.GraphKeys.VARIABLES, 'reparam_variables'])
  547. b_0 = util.build_gaussian([self.timestep_dim], init_mu=0.,
  548. init_sigma=0.01, fixed_sigma=False,
  549. dtype=self.dtype)
  550. self.params.update({'b_0': b_0})
  551. def log_prob(self, q_samples):
  552. """Get the log joint of variational family: log(q(z, w_shape, w_mean, w_0)).
  553. Args:
  554. q_samples: dict. samples of latent variables.
  555. Returns:
  556. log_prob: tensor log-probability summed over dimensions of the variables
  557. """
  558. w_0 = q_samples['w_0']
  559. z_1 = q_samples['z_1']
  560. w_1_shape = q_samples['w_1_shape']
  561. w_1_mean = q_samples['w_1_mean']
  562. log_prob = 0.
  563. # preserve the sample and minibatch dimensions [0, 1]
  564. log_prob += tf.reduce_sum(z_1.distribution.log_pdf(z_1.value()), [2, 3])
  565. # w_1, w_0 are global, so reduce_sum across all dims
  566. log_prob += tf.reduce_sum(w_1_shape.distribution.log_pdf(w_1_shape.value()))
  567. log_prob += tf.reduce_sum(w_0.distribution.log_pdf(w_0.value()))
  568. if not self.fixed_p_z_mean:
  569. log_prob += tf.reduce_sum(w_1_mean.distribution.log_pdf(w_1_mean.value()))
  570. return log_prob