123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Utility functions.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- import h5py
- import numpy as np
- import tensorflow as tf
- st = tf.contrib.bayesflow.stochastic_tensor
- distributions = tf.contrib.distributions
- def provide_tfrecords_data(path, split_name, batch_size, n_timesteps,
- timestep_dim):
- """Provides batches of MNIST digits.
- Args:
- path: String specifying location of tf.records files.
- split_name: string. name of the split.
- batch_size: int. batch size.
- n_timesteps: int. number of timesteps.
- timestep_dim: int. dimension of each timestep.
- Returns:
- labels: minibatch tensor of the indices of each datapoint.
- images: minibatch tensor of images.
- """
- # Load the data:
- image, label = read_and_decode_single_example(
- os.path.join(path, 'binarized_mnist_{}.tfrecords'.format(split_name)))
- # Preprocess the images.
- image = tf.reshape(image, [28, 28])
- if n_timesteps < 28:
- image = image[0:n_timesteps, :]
- if timestep_dim < 28:
- image = image[:, 0:timestep_dim]
- image = tf.expand_dims(image, 2)
- # Creates a QueueRunner for the pre-fetching operation.
- images, labels = tf.train.batch(
- [image, label],
- batch_size=batch_size,
- num_threads=15,
- capacity=batch_size * 5000)
- return labels, images
- def read_and_decode_single_example(filename):
- """Read and decode a single example.
- Args:
- filename: str. path to a tf.records file.
- Returns:
- image: tensor. a single image.
- label: tensor. the index for the image.
- """
- # first construct a queue containing a list of filenames.
- # this lets a user split up there dataset in multiple files to keep
- # size down
- filename_queue = tf.train.string_input_producer([filename],
- num_epochs=None)
- # Unlike the TFRecordWriter, the TFRecordReader is symbolic
- reader = tf.TFRecordReader()
- # One can read a single serialized example from a filename
- # serialized_example is a Tensor of type string.
- _, serialized_example = reader.read(filename_queue)
- # The serialized example is converted back to actual values.
- # One needs to describe the format of the objects to be returned
- features = tf.parse_single_example(
- serialized_example,
- features={
- # We know the length of both fields. If not the
- # tf.VarLenFeature could be used
- 'image': tf.FixedLenFeature([784], tf.float32),
- 'label': tf.FixedLenFeature([], tf.int64)
- })
- # now return the converted data
- image = features['image']
- label = features['label']
- return image, label
- def provide_hdf5_data(path, split_name, n_examples, batch_size, n_timesteps,
- timestep_dim, dataset):
- """Provides batches of MNIST digits.
- Args:
- path: str. path to the dataset.
- split_name: string. name of the split.
- n_examples: int. number of examples to serve from the dataset.
- batch_size: int. batch size.
- n_timesteps: int. number of timesteps.
- timestep_dim: int. dimension of each timestep.
- dataset: String specifying dataset.
- Returns:
- data_iterator: a generator of minibatches.
- """
- if dataset == 'alternating':
- data_list = []
- start_zeros = np.vstack([np.zeros(timestep_dim) if t % 2 == 0 else
- np.ones(timestep_dim) for t in range(n_timesteps)])
- start_ones = np.roll(start_zeros, 1, axis=0)
- start_zeros = start_zeros.flatten()
- start_ones = start_ones.flatten()
- data_list = [start_zeros if n % 2 == 0 else
- start_ones for n in range(n_examples)]
- data = np.vstack(data_list)
- elif dataset == 'MNIST':
- f = h5py.File(path, 'r')
- if split_name == 'train_and_valid':
- train = f['train'][:]
- valid = f['valid'][:]
- data = np.vstack([train, valid])
- else:
- data = f[split_name][:]
- data = data[0:n_examples]
- # create indexes for the data points.
- indexed_data = zip(range(len(data)), np.split(data, len(data)))
- def data_iterator():
- """Generate minibatches of examples from the dataset."""
- batch_idx = 0
- while True:
- # shuffle data
- idxs = np.arange(0, len(data))
- np.random.shuffle(idxs)
- shuf_data = [indexed_data[idx] for idx in idxs]
- for batch_idx in range(0, len(data), batch_size):
- indexed_images_batch = shuf_data[batch_idx:batch_idx+batch_size]
- indexes, images_batch = zip(*indexed_images_batch)
- images_batch = np.vstack(images_batch)
- if timestep_dim == 784:
- images_batch = images_batch.reshape(
- (batch_size, 1, 784, 1))
- else:
- if dataset == 'alternating':
- images_batch = images_batch.reshape(
- (batch_size, n_timesteps, timestep_dim, 1))
- else:
- images_batch = images_batch.reshape(
- (batch_size, 28, 28, 1))[:, :n_timesteps, :timestep_dim]
- yield indexes, images_batch
- return data_iterator()
- def inv_softplus(x):
- """Inverse softplus."""
- return np.log(np.exp(x) - 1.)
- def softplus(x):
- """Softplus."""
- return np.log(np.exp(x) + 1.)
- def build_gamma(shape, init_shape=1., init_mean=1., x_indexes=None,
- fixed_mean=False, place_on_cpu=False, n_samples=1,
- dtype='float64'):
- """Builds a Gaussian DistributionTensor.
- Truncation: we truncate shape and mean parameters because gamma sampling is
- numerically unstable. Reference: http://ajbc.io/resources/bbvi_for_gammas.pdf
- Args:
- shape: list. shape of the distribution.
- init_shape: float. initial shape
- init_mean: float. initial standard deviation
- x_indexes: tensor. integer placeholder for mean-field parameters
- fixed_mean: bool. whether to learn mean
- place_on_cpu: bool. whether to place the op on cpu.
- n_samples: number of samples
- dtype: dtype
- Returns:
- A Gaussian DistributionTensor of the specified shape, with variables for
- mean and standard deviation safely parametrized to avoid over/underflow.
- """
- if place_on_cpu:
- with tf.device('/cpu:0'):
- shape_softplus_inv = tf.get_variable(
- 'shape_softplus_inv', shape, dtype, tf.constant_initializer(
- inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
- 'non_reparam_variables'])
- else:
- shape_softplus_inv = tf.get_variable(
- 'shape_softplus_inv', shape, dtype, tf.constant_initializer(
- inv_softplus(init_shape)), collections=[tf.GraphKeys.VARIABLES,
- 'non_reparam_variables'])
- if fixed_mean:
- mean_softplus_inv = None
- else:
- mean_softplus_arg = tf.constant_initializer(inv_softplus(init_mean))
- if place_on_cpu:
- with tf.device('/cpu:0'):
- mean_softplus_inv = tf.get_variable(
- 'mean_softplus_inv', shape, dtype, mean_softplus_arg)
- else:
- mean_softplus_inv = tf.get_variable('mean_softplus_inv', shape,
- dtype, mean_softplus_arg,
- collections=[tf.GraphKeys.VARIABLES,
- 'non_reparam_variables'])
- if x_indexes is not None:
- shape_softplus_inv_batch = tf.nn.embedding_lookup(
- shape_softplus_inv, x_indexes)
- if not fixed_mean:
- mean_softplus_inv_batch = tf.nn.embedding_lookup(
- mean_softplus_inv, x_indexes)
- else:
- shape_softplus_inv_batch, mean_softplus_inv_batch = (shape_softplus_inv,
- mean_softplus_inv)
- shape_batch = tf.nn.softplus(shape_softplus_inv_batch)
- if fixed_mean:
- mean_batch = tf.constant(init_mean)
- else:
- mean_batch = tf.nn.softplus(mean_softplus_inv_batch)
- with st.value_type(st.SampleValue(n=n_samples)):
- dist = st.StochasticTensor(distributions.Gamma,
- alpha=shape_batch,
- beta=shape_batch / mean_batch,
- validate_args=False)
- return dist
- def truncate(max_or_min, var, val):
- """Truncate variable to a max or min value."""
- if max_or_min == 'max':
- tf_fn = tf.minimum
- elif max_or_min == 'min':
- tf_fn = tf.maximum
- if isinstance(var, tf.IndexedSlices):
- assign_op = tf.assign(var.values, tf_fn(var.values, inv_softplus(val)))
- else:
- assign_op = tf.assign(var, tf_fn(var, inv_softplus(val)))
- return assign_op
- def build_gaussian(shape, init_mu=0., init_sigma=1.0, x_indexes=None,
- fixed_sigma=False, place_on_cpu=False, dtype='float64'):
- """Builds a Gaussian DistributionTensor.
- Args:
- shape: list. shape of the distribution.
- init_mu: float. initial mean
- init_sigma: float. initial standard deviation
- x_indexes: tensor. integer placeholder for mean-field parameters
- fixed_sigma: bool. whether to learn sigma
- place_on_cpu: bool. whether to place the op on cpu.
- dtype: dtpe
- Returns:
- A Gaussian DistributionTensor of the specified shape, with variables for
- mean and standard deviation safely parametrized to avoid over/underflow.
- """
- if place_on_cpu:
- with tf.device('/cpu:0'):
- mu = tf.get_variable(
- 'mu', shape, dtype, tf.random_normal_initializer(
- mean=init_mu, stddev=0.1))
- else:
- mu = tf.get_variable('mu', shape, dtype,
- tf.random_normal_initializer(mean=init_mu, stddev=0.1),
- collections=[tf.GraphKeys.VARIABLES,
- 'reparam_variables'])
- if fixed_sigma:
- sigma_softplus_inv = None
- else:
- sigma_softplus_arg = tf.truncated_normal_initializer(
- mean=inv_softplus(init_sigma), stddev=0.1)
- if place_on_cpu:
- with tf.device('/cpu:0'):
- sigma_softplus_inv = tf.get_variable(
- 'sigma_softplus_inv', shape, dtype, sigma_softplus_arg)
- else:
- sigma_softplus_inv = tf.get_variable('sigma_softplus_inv', shape,
- dtype, sigma_softplus_arg,
- collections=[tf.GraphKeys.VARIABLES,
- 'reparam_variables'])
- if x_indexes is not None:
- mu_batch = tf.nn.embedding_lookup(mu, x_indexes)
- if not fixed_sigma:
- sigma_softplus_inv_batch = tf.nn.embedding_lookup(
- sigma_softplus_inv, x_indexes)
- else:
- mu_batch, sigma_softplus_inv_batch = mu, sigma_softplus_inv
- if fixed_sigma:
- sigma_batch = np.array(init_sigma, dtype)
- else:
- sigma_batch = tf.maximum(tf.nn.softplus(sigma_softplus_inv_batch), 1e-5)
- dist = st.StochasticTensor(distributions.Normal, mu=mu_batch,
- sigma=sigma_batch, validate_args=False)
- return dist
- def get_np_dtype(tensor):
- """Returns the numpy dtype."""
- return np.float32 if 'float32' in str(tensor.dtype) else np.float64
- def build_bernoulli_log_likelihood(params, x, batch_size,
- n_samples_latents=1,
- use_bias_observations=False):
- """Builds the likelihood given stochastic latents and weights.
- Args:
- params: dict that contains:
- z_1 tensor. sampled latent variables
- [n_samples_latents] + [batch_size, n_timesteps, z_dim]
- w_0 tensor. sampled stochastic weights [z_dim, timestep_dim]
- b_0 optional tensor. biases [timestep_dim]
- x: tensor. minibatch of examples
- batch_size: integer number of minibatch examples.
- n_samples_latents: number of samples of latent variables
- use_bias_observations: use bias
- Returns:
- likelihood: the bernoulli likelihood distribution of the data.
- [n_samples, batch_size, n_timesteps, timestep_dim]
- """
- z_1 = params['z_1']
- w_0 = params['w_0']
- if use_bias_observations:
- b_0 = params['b_0']
- if n_samples_latents > 1:
- wz = tf.batch_matmul(z_1, tf.pack([tf.pack([w_0] * batch_size)]
- * n_samples_latents))
- if use_bias_observations:
- wz += b_0
- logits = tf.expand_dims(wz, 4)
- dims_to_reduce = [2, 3, 4]
- else:
- wz = tf.batch_matmul(z_1, tf.pack([w_0] * batch_size))
- if use_bias_observations:
- wz += b_0
- logits = tf.expand_dims(wz, 3)
- dims_to_reduce = [1, 2, 3]
- p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
- log_p_x_zw = tf.reduce_sum(p_x_zw.log_pmf(x), dims_to_reduce)
- print('log_p_x_zw', log_p_x_zw.get_shape())
- print('logits', logits.get_shape())
- print('z_1', z_1.value().get_shape())
- return log_p_x_zw, p_x_zw.p
- def clip_mean(mean):
- """Clip mean parameter of gamma."""
- return tf.clip_by_value(mean, clip_value_max=sys.float_info.max,
- clip_value_min=1e-5)
- def clip_shape(shape):
- """Clip shape parameter of gamma."""
- return tf.clip_by_value(shape, clip_value_max=sys.float_info.max,
- clip_value_min=5e-3)
- def bernoulli_likelihood_sample(params, z_1, n_samples,
- use_bias_observations=False):
- """Sample from the model likelihood.
- Args:
- params: dict that contains
- w_0 tensor. sample of latent weights
- b_0 optional tensor. bias
- z_1: tensor. sample of latent variables
- n_samples: int. number of samples to draw
- use_bias_observations: use bias
- Returns:
- A tensor sample from the model likelihood.
- """
- w_0 = params['w_0']
- if isinstance(z_1, st.StochasticTensor):
- z_1 = z_1.value()
- if z_1.get_shape().ndims == 4:
- z_1 = z_1[0, :, :, :]
- wz = tf.batch_matmul(z_1, tf.pack([w_0] * n_samples))
- if use_bias_observations:
- wz += params['b_0']
- logits = tf.expand_dims(wz, 3)
- p_x_zw = distributions.Bernoulli(logits=logits, validate_args=False)
- return tf.cast(p_x_zw.sample_n(n=1)[0, :, :, :, :], logits.dtype)
|