util.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utility functions.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import os
  21. import sys
  22. import h5py
  23. import numpy as np
  24. import tensorflow as tf
  25. st = tf.contrib.bayesflow.stochastic_tensor
  26. distributions = tf.contrib.distributions
  27. def provide_tfrecords_data(path, split_name, batch_size, n_timesteps,
  28. timestep_dim):
  29. """Provides batches of MNIST digits.
  30. Args:
  31. path: String specifying location of tf.records files.
  32. split_name: string. name of the split.
  33. batch_size: int. batch size.
  34. n_timesteps: int. number of timesteps.
  35. timestep_dim: int. dimension of each timestep.
  36. Returns:
  37. labels: minibatch tensor of the indices of each datapoint.
  38. images: minibatch tensor of images.
  39. """
  40. # Load the data:
  41. image, label = read_and_decode_single_example(
  42. os.path.join(path, 'binarized_mnist_{}.tfrecords'.format(split_name)))
  43. # Preprocess the images.
  44. image = tf.reshape(image, [28, 28])
  45. if n_timesteps < 28:
  46. image = image[0:n_timesteps, :]
  47. if timestep_dim < 28:
  48. image = image[:, 0:timestep_dim]
  49. image = tf.expand_dims(image, 2)
  50. # Creates a QueueRunner for the pre-fetching operation.
  51. images, labels = tf.train.batch(
  52. [image, label],
  53. batch_size=batch_size,
  54. num_threads=15,
  55. capacity=batch_size * 5000)
  56. return labels, images
  57. def read_and_decode_single_example(filename):
  58. """Read and decode a single example.
  59. Args:
  60. filename: str. path to a tf.records file.
  61. Returns:
  62. image: tensor. a single image.
  63. label: tensor. the index for the image.
  64. """
  65. # first construct a queue containing a list of filenames.
  66. # this lets a user split up there dataset in multiple files to keep
  67. # size down
  68. filename_queue = tf.train.string_input_producer([filename],
  69. num_epochs=None)
  70. # Unlike the TFRecordWriter, the TFRecordReader is symbolic
  71. reader = tf.TFRecordReader()
  72. # One can read a single serialized example from a filename
  73. # serialized_example is a Tensor of type string.
  74. _, serialized_example = reader.read(filename_queue)
  75. # The serialized example is converted back to actual values.
  76. # One needs to describe the format of the objects to be returned
  77. features = tf.parse_single_example(
  78. serialized_example,
  79. features={
  80. # We know the length of both fields. If not the
  81. # tf.VarLenFeature could be used
  82. 'image': tf.FixedLenFeature([784], tf.float32),
  83. 'label': tf.FixedLenFeature([], tf.int64)
  84. })
  85. # now return the converted data
  86. image = features['image']
  87. label = features['label']
  88. return image, label
  89. def provide_hdf5_data(path, split_name, n_examples, batch_size, n_timesteps,
  90. timestep_dim, dataset):
  91. """Provides batches of MNIST digits.
  92. Args:
  93. path: str. path to the dataset.
  94. split_name: string. name of the split.
  95. n_examples: int. number of examples to serve from the dataset.
  96. batch_size: int. batch size.
  97. n_timesteps: int. number of timesteps.
  98. timestep_dim: int. dimension of each timestep.
  99. dataset: String specifying dataset.
  100. Returns:
  101. data_iterator: a generator of minibatches.
  102. """
  103. if dataset == 'alternating':
  104. data_list = []
  105. start_zeros = np.vstack([np.zeros(timestep_dim) if t % 2 == 0 else
  106. np.ones(timestep_dim) for t in range(n_timesteps)])
  107. start_ones = np.roll(start_zeros, 1, axis=0)
  108. start_zeros = start_zeros.flatten()
  109. start_ones = start_ones.flatten()
  110. data_list = [start_zeros if n % 2 == 0 else
  111. start_ones for n in range(n_examples)]
  112. data = np.vstack(data_list)
  113. elif dataset == 'MNIST':
  114. f = h5py.File(path, 'r')
  115. if split_name == 'train_and_valid':
  116. train = f['train'][:]
  117. valid = f['valid'][:]
  118. data = np.vstack([train, valid])
  119. else:
  120. data = f[split_name][:]
  121. data = data[0:n_examples]
  122. # create indexes for the data points.
  123. indexed_data = zip(range(len(data)), np.split(data, len(data)))
  124. def data_iterator():
  125. """Generate minibatches of examples from the dataset."""
  126. batch_idx = 0
  127. while True:
  128. # shuffle data
  129. idxs = np.arange(0, len(data))
  130. np.random.shuffle(idxs)
  131. shuf_data = [indexed_data[idx] for idx in idxs]
  132. for batch_idx in range(0, len(data), batch_size):
  133. indexed_images_batch = shuf_data[batch_idx:batch_idx+batch_size]
  134. indexes, images_batch = zip(*indexed_images_batch)
  135. images_batch = np.vstack(images_batch)
  136. if timestep_dim == 784:
  137. images_batch = images_batch.reshape(
  138. (batch_size, 1, 784, 1))
  139. else:
  140. if dataset == 'alternating':
  141. images_batch = images_batch.reshape(
  142. (batch_size, n_timesteps, timestep_dim, 1))
  143. else:
  144. images_batch = images_batch.reshape(
  145. (batch_size, 28, 28, 1))[:, :n_timesteps, :timestep_dim]
  146. yield indexes, images_batch
  147. return data_iterator()
  148. def inv_softplus(x):
  149. """Inverse softplus."""
  150. return np.log(np.exp(x) - 1.)
  151. def softplus(x):
  152. """Softplus."""
  153. return np.log(np.exp(x) + 1.)
  154. def build_gamma(shape, init_shape=1., init_mean=1., x_indexes=None,
  155. fixed_mean=False, place_on_cpu=False, n_samples=1,
  156. dtype='float64'):
  157. """Builds a Gaussian DistributionTensor.
  158. Truncation: we truncate shape and mean parameters because gamma sampling is
  159. numerically unstable. Reference: http://ajbc.io/resources/bbvi_for_gammas.pdf
  160. Args:
  161. shape: list. shape of the distribution.
  162. init_shape: float. initial shape
  163. init_mean: float. initial standard deviation
  164. x_indexes: tensor. integer placeholder for mean-field parameters
  165. fixed_mean: bool. whether to learn mean
  166. place_on_cpu: bool. whether to place the op on cpu.
  167. n_samples: number of samples
  168. dtype: dtype
  169. Returns:
  170. A Gaussian DistributionTensor of the specified shape, with variables for
  171. mean and standard deviation safely parametrized to avoid over/underflow.
  172. """
  173. if place_on_cpu:
  174. with tf.device('/cpu:0'):
  175. shape_softplus_inv = tf.get_variable(
  176. 'shape_softplus_inv', shape, dtype, tf.constant_initializer(
  177. inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
  178. 'non_reparam_variables'])
  179. else:
  180. shape_softplus_inv = tf.get_variable(
  181. 'shape_softplus_inv', shape, dtype, tf.constant_initializer(
  182. inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
  183. 'non_reparam_variables'])
  184. if fixed_mean:
  185. mean_softplus_inv = None
  186. else:
  187. mean_softplus_arg = tf.constant_initializer(inv_softplus(init_mean))
  188. if place_on_cpu:
  189. with tf.device('/cpu:0'):
  190. mean_softplus_inv = tf.get_variable(
  191. 'mean_softplus_inv', shape, dtype, mean_softplus_arg)
  192. else:
  193. mean_softplus_inv = tf.get_variable('mean_softplus_inv', shape,
  194. dtype, mean_softplus_arg,
  195. collections=[tf.GraphKeys.VARIABLES,
  196. 'non_reparam_variables'])
  197. if x_indexes is not None:
  198. shape_softplus_inv_batch = tf.nn.embedding_lookup(
  199. shape_softplus_inv, x_indexes)
  200. if not fixed_mean:
  201. mean_softplus_inv_batch = tf.nn.embedding_lookup(
  202. mean_softplus_inv, x_indexes)
  203. else:
  204. shape_softplus_inv_batch, mean_softplus_inv_batch = (shape_softplus_inv,
  205. mean_softplus_inv)
  206. shape_batch = tf.nn.softplus(shape_softplus_inv_batch)
  207. if fixed_mean:
  208. mean_batch = tf.constant(init_mean)
  209. else:
  210. mean_batch = tf.nn.softplus(mean_softplus_inv_batch)
  211. with st.value_type(st.SampleValue(n=n_samples)):
  212. dist = st.StochasticTensor(distributions.Gamma,
  213. alpha=shape_batch,
  214. beta=shape_batch / mean_batch,
  215. validate_args=False)
  216. return dist
  217. def truncate(max_or_min, var, val):
  218. """Truncate variable to a max or min value."""
  219. if max_or_min == 'max':
  220. tf_fn = tf.minimum
  221. elif max_or_min == 'min':
  222. tf_fn = tf.maximum
  223. if isinstance(var, tf.IndexedSlices):
  224. assign_op = tf.assign(var.values, tf_fn(var.values, inv_softplus(val)))
  225. else:
  226. assign_op = tf.assign(var, tf_fn(var, inv_softplus(val)))
  227. return assign_op
  228. def build_gaussian(shape, init_mu=0., init_sigma=1.0, x_indexes=None,
  229. fixed_sigma=False, place_on_cpu=False, dtype='float64'):
  230. """Builds a Gaussian DistributionTensor.
  231. Args:
  232. shape: list. shape of the distribution.
  233. init_mu: float. initial mean
  234. init_sigma: float. initial standard deviation
  235. x_indexes: tensor. integer placeholder for mean-field parameters
  236. fixed_sigma: bool. whether to learn sigma
  237. place_on_cpu: bool. whether to place the op on cpu.
  238. dtype: dtpe
  239. Returns:
  240. A Gaussian DistributionTensor of the specified shape, with variables for
  241. mean and standard deviation safely parametrized to avoid over/underflow.
  242. """
  243. if place_on_cpu:
  244. with tf.device('/cpu:0'):
  245. mu = tf.get_variable(
  246. 'mu', shape, dtype, tf.random_normal_initializer(
  247. mean=init_mu, stddev=0.1))
  248. else:
  249. mu = tf.get_variable('mu', shape, dtype,
  250. tf.random_normal_initializer(mean=init_mu, stddev=0.1),
  251. collections=[tf.GraphKeys.VARIABLES,
  252. 'reparam_variables'])
  253. if fixed_sigma:
  254. sigma_softplus_inv = None
  255. else:
  256. sigma_softplus_arg = tf.truncated_normal_initializer(
  257. mean=inv_softplus(init_sigma), stddev=0.1)
  258. if place_on_cpu:
  259. with tf.device('/cpu:0'):
  260. sigma_softplus_inv = tf.get_variable(
  261. 'sigma_softplus_inv', shape, dtype, sigma_softplus_arg)
  262. else:
  263. sigma_softplus_inv = tf.get_variable('sigma_softplus_inv', shape,
  264. dtype, sigma_softplus_arg,
  265. collections=[tf.GraphKeys.VARIABLES,
  266. 'reparam_variables'])
  267. if x_indexes is not None:
  268. mu_batch = tf.nn.embedding_lookup(mu, x_indexes)
  269. if not fixed_sigma:
  270. sigma_softplus_inv_batch = tf.nn.embedding_lookup(
  271. sigma_softplus_inv, x_indexes)
  272. else:
  273. mu_batch, sigma_softplus_inv_batch = mu, sigma_softplus_inv
  274. if fixed_sigma:
  275. sigma_batch = np.array(init_sigma, dtype)
  276. else:
  277. sigma_batch = tf.maximum(tf.nn.softplus(sigma_softplus_inv_batch), 1e-5)
  278. dist = st.StochasticTensor(distributions.Normal, mu=mu_batch,
  279. sigma=sigma_batch, validate_args=False)
  280. return dist
  281. def get_np_dtype(tensor):
  282. """Returns the numpy dtype."""
  283. return np.float32 if 'float32' in str(tensor.dtype) else np.float64
  284. def build_bernoulli_log_likelihood(params, x, batch_size,
  285. n_samples_latents=1,
  286. use_bias_observations=False):
  287. """Builds the likelihood given stochastic latents and weights.
  288. Args:
  289. params: dict that contains:
  290. z_1 tensor. sampled latent variables
  291. [n_samples_latents] + [batch_size, n_timesteps, z_dim]
  292. w_0 tensor. sampled stochastic weights [z_dim, timestep_dim]
  293. b_0 optional tensor. biases [timestep_dim]
  294. x: tensor. minibatch of examples
  295. batch_size: integer number of minibatch examples.
  296. n_samples_latents: number of samples of latent variables
  297. use_bias_observations: use bias
  298. Returns:
  299. likelihood: the bernoulli likelihood distribution of the data.
  300. [n_samples, batch_size, n_timesteps, timestep_dim]
  301. """
  302. z_1 = params['z_1']
  303. w_0 = params['w_0']
  304. if use_bias_observations:
  305. b_0 = params['b_0']
  306. if n_samples_latents > 1:
  307. wz = tf.batch_matmul(z_1, tf.pack([tf.pack([w_0] * batch_size)]
  308. * n_samples_latents))
  309. if use_bias_observations:
  310. wz += b_0
  311. logits = tf.expand_dims(wz, 4)
  312. dims_to_reduce = [2, 3, 4]
  313. else:
  314. wz = tf.batch_matmul(z_1, tf.pack([w_0] * batch_size))
  315. if use_bias_observations:
  316. wz += b_0
  317. logits = tf.expand_dims(wz, 3)
  318. dims_to_reduce = [1, 2, 3]
  319. p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
  320. log_p_x_zw = tf.reduce_sum(p_x_zw.log_pmf(x), dims_to_reduce)
  321. print('log_p_x_zw', log_p_x_zw.get_shape())
  322. print('logits', logits.get_shape())
  323. print('z_1', z_1.value().get_shape())
  324. return log_p_x_zw, p_x_zw.p
  325. def clip_mean(mean):
  326. """Clip mean parameter of gamma."""
  327. return tf.clip_by_value(mean, clip_value_max=sys.float_info.max,
  328. clip_value_min=1e-5)
  329. def clip_shape(shape):
  330. """Clip shape parameter of gamma."""
  331. return tf.clip_by_value(shape, clip_value_max=sys.float_info.max,
  332. clip_value_min=5e-3)
  333. def bernoulli_likelihood_sample(params, z_1, n_samples,
  334. use_bias_observations=False):
  335. """Sample from the model likelihood.
  336. Args:
  337. params: dict that contains
  338. w_0 tensor. sample of latent weights
  339. b_0 optional tensor. bias
  340. z_1: tensor. sample of latent variables
  341. n_samples: int. number of samples to draw
  342. use_bias_observations: use bias
  343. Returns:
  344. A tensor sample from the model likelihood.
  345. """
  346. w_0 = params['w_0']
  347. if isinstance(z_1, st.StochasticTensor):
  348. z_1 = z_1.value()
  349. if z_1.get_shape().ndims == 4:
  350. z_1 = z_1[0, :, :, :]
  351. wz = tf.batch_matmul(z_1, tf.pack([w_0] * n_samples))
  352. if use_bias_observations:
  353. wz += params['b_0']
  354. logits = tf.expand_dims(wz, 3)
  355. p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
  356. return tf.cast(p_x_zw.sample_n(n=1)[0, :, :, :, :], logits.dtype)