123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475 |
- # Copyright 2016 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Utility functions for Real NVP.
- """
- # pylint: disable=dangerous-default-value
- import numpy
- import tensorflow as tf
- from tensorflow.python.framework import ops
- DEFAULT_BN_LAG = .0
- def stable_var(input_, mean=None, axes=[0]):
- """Numerically more stable variance computation."""
- if mean is None:
- mean = tf.reduce_mean(input_, axes)
- res = tf.square(input_ - mean)
- max_sqr = tf.reduce_max(res, axes)
- res /= max_sqr
- res = tf.reduce_mean(res, axes)
- res *= max_sqr
- return res
- def variable_on_cpu(name, shape, initializer, trainable=True):
- """Helper to create a Variable stored on CPU memory.
- Args:
- name: name of the variable
- shape: list of ints
- initializer: initializer for Variable
- trainable: boolean defining if the variable is for training
- Returns:
- Variable Tensor
- """
- var = tf.get_variable(
- name, shape, initializer=initializer, trainable=trainable)
- return var
- # layers
- def conv_layer(input_,
- filter_size,
- dim_in,
- dim_out,
- name,
- stddev=1e-2,
- strides=[1, 1, 1, 1],
- padding="SAME",
- nonlinearity=None,
- bias=False,
- weight_norm=False,
- scale=False):
- """Convolutional layer."""
- with tf.variable_scope(name) as scope:
- weights = variable_on_cpu(
- "weights",
- filter_size + [dim_in, dim_out],
- tf.random_uniform_initializer(
- minval=-stddev, maxval=stddev))
- # weight normalization
- if weight_norm:
- weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
- if scale:
- magnitude = variable_on_cpu(
- "magnitude", [dim_out],
- tf.constant_initializer(
- stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
- weights *= magnitude
- res = input_
- # handling filter size bigger than image size
- if hasattr(input_, "shape"):
- if input_.get_shape().as_list()[1] < filter_size[0]:
- pad_1 = tf.zeros([
- input_.get_shape().as_list()[0],
- filter_size[0] - input_.get_shape().as_list()[1],
- input_.get_shape().as_list()[2],
- input_.get_shape().as_list()[3]
- ])
- pad_2 = tf.zeros([
- input_.get_shape().as_list[0],
- filter_size[0],
- filter_size[1] - input_.get_shape().as_list()[2],
- input_.get_shape().as_list()[3]
- ])
- res = tf.concat(axis=1, values=[pad_1, res])
- res = tf.concat(axis=2, values=[pad_2, res])
- res = tf.nn.conv2d(
- input=res,
- filter=weights,
- strides=strides,
- padding=padding,
- name=scope.name)
- if hasattr(input_, "shape"):
- if input_.get_shape().as_list()[1] < filter_size[0]:
- res = tf.slice(res, [
- 0, filter_size[0] - input_.get_shape().as_list()[1],
- filter_size[1] - input_.get_shape().as_list()[2], 0
- ], [-1, -1, -1, -1])
- if bias:
- biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
- res = tf.nn.bias_add(res, biases)
- if nonlinearity is not None:
- res = nonlinearity(res)
- return res
- def max_pool_2x2(input_):
- """Max pooling."""
- return tf.nn.max_pool(
- input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
- def depool_2x2(input_, stride=2):
- """Depooling."""
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
- res = tf.concat(
- axis=2, values=[res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
- res = tf.concat(axis=4, values=[
- res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
- ])
- res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])
- return res
- # random flip on a batch of images
- def batch_random_flip(input_):
- """Simultaneous horizontal random flip."""
- if isinstance(input_, (float, int)):
- return input_
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- res = tf.split(axis=0, num_or_size_splits=batch_size, value=input_)
- res = [elem[0, :, :, :] for elem in res]
- res = [tf.image.random_flip_left_right(elem) for elem in res]
- res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
- res = tf.concat(axis=0, values=res)
- return res
- # build a one hot representation corresponding to the integer tensor
- # the one-hot dimension is appended to the integer tensor shape
- def as_one_hot(input_, n_indices):
- """Convert indices to one-hot."""
- shape = input_.get_shape().as_list()
- n_elem = numpy.prod(shape)
- indices = tf.range(n_elem)
- indices = tf.cast(indices, tf.int64)
- indices_input = tf.concat(axis=0, values=[indices, tf.reshape(input_, [-1])])
- indices_input = tf.reshape(indices_input, [2, -1])
- indices_input = tf.transpose(indices_input)
- res = tf.sparse_to_dense(
- indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
- res = tf.reshape(res, [elem for elem in shape] + [n_indices])
- return res
- def squeeze_2x2(input_):
- """Squeezing operation: reshape to convert space to channels."""
- return squeeze_nxn(input_, n_factor=2)
- def squeeze_nxn(input_, n_factor=2):
- """Squeezing operation: reshape to convert space to channels."""
- if isinstance(input_, (float, int)):
- return input_
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- if height % n_factor != 0:
- raise ValueError("Height not divisible by %d." % n_factor)
- if width % n_factor != 0:
- raise ValueError("Width not divisible by %d." % n_factor)
- res = tf.reshape(
- input_,
- [batch_size,
- height // n_factor,
- n_factor, width // n_factor,
- n_factor, channels])
- res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
- res = tf.reshape(
- res,
- [batch_size,
- height // n_factor,
- width // n_factor,
- channels * n_factor * n_factor])
- return res
- def unsqueeze_2x2(input_):
- """Unsqueezing operation: reshape to convert channels into space."""
- if isinstance(input_, (float, int)):
- return input_
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- if channels % 4 != 0:
- raise ValueError("Number of channels not divisible by 4.")
- res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
- res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
- res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])
- return res
- # batch norm
- def batch_norm(input_,
- dim,
- name,
- scale=True,
- train=True,
- epsilon=1e-8,
- decay=.1,
- axes=[0],
- bn_lag=DEFAULT_BN_LAG):
- """Batch normalization."""
- # create variables
- with tf.variable_scope(name):
- var = variable_on_cpu(
- "var", [dim], tf.constant_initializer(1.), trainable=False)
- mean = variable_on_cpu(
- "mean", [dim], tf.constant_initializer(0.), trainable=False)
- step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
- if scale:
- gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
- beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
- # choose the appropriate moments
- if train:
- used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
- cur_mean, cur_var = used_mean, used_var
- if bn_lag > 0.:
- used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
- used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
- used_mean /= (1. - bn_lag**(step + 1))
- used_var /= (1. - bn_lag**(step + 1))
- else:
- used_mean, used_var = mean, var
- cur_mean, cur_var = used_mean, used_var
- # normalize
- res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
- # de-normalize
- if scale:
- res *= gamma
- res += beta
- # update variables
- if train:
- with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
- with ops.colocate_with(mean):
- new_mean = tf.assign_sub(
- mean,
- tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
- with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
- with ops.colocate_with(var):
- new_var = tf.assign_sub(
- var,
- tf.check_numerics(decay * (var - cur_var),
- "NaN in moving variance."))
- with tf.name_scope(name, "IncrementTime", [step]):
- with ops.colocate_with(step):
- new_step = tf.assign_add(step, 1.)
- res += 0. * new_mean * new_var * new_step
- return res
- # batch normalization taking into account the volume transformation
- def batch_norm_log_diff(input_,
- dim,
- name,
- train=True,
- epsilon=1e-8,
- decay=.1,
- axes=[0],
- reuse=None,
- bn_lag=DEFAULT_BN_LAG):
- """Batch normalization with corresponding log determinant Jacobian."""
- if reuse is None:
- reuse = not train
- # create variables
- with tf.variable_scope(name) as scope:
- if reuse:
- scope.reuse_variables()
- var = variable_on_cpu(
- "var", [dim], tf.constant_initializer(1.), trainable=False)
- mean = variable_on_cpu(
- "mean", [dim], tf.constant_initializer(0.), trainable=False)
- step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
- # choose the appropriate moments
- if train:
- used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
- cur_mean, cur_var = used_mean, used_var
- if bn_lag > 0.:
- used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
- cur_var = used_var
- used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
- used_mean /= (1. - bn_lag**(step + 1))
- used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
- used_var /= (1. - bn_lag**(step + 1))
- else:
- used_mean, used_var = mean, var
- cur_mean, cur_var = used_mean, used_var
- # update variables
- if train:
- with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
- with ops.colocate_with(mean):
- new_mean = tf.assign_sub(
- mean,
- tf.check_numerics(
- decay * (mean - cur_mean), "NaN in moving mean."))
- with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
- with ops.colocate_with(var):
- new_var = tf.assign_sub(
- var,
- tf.check_numerics(decay * (var - cur_var),
- "NaN in moving variance."))
- with tf.name_scope(name, "IncrementTime", [step]):
- with ops.colocate_with(step):
- new_step = tf.assign_add(step, 1.)
- used_var += 0. * new_mean * new_var * new_step
- used_var += epsilon
- return used_mean, used_var
- def convnet(input_,
- dim_in,
- dim_hid,
- filter_sizes,
- dim_out,
- name,
- use_batch_norm=True,
- train=True,
- nonlinearity=tf.nn.relu):
- """Chaining of convolutional layers."""
- dims_in = [dim_in] + dim_hid[:-1]
- dims_out = dim_hid
- res = input_
- bias = (not use_batch_norm)
- with tf.variable_scope(name):
- for layer_idx in xrange(len(dim_hid)):
- res = conv_layer(
- input_=res,
- filter_size=filter_sizes[layer_idx],
- dim_in=dims_in[layer_idx],
- dim_out=dims_out[layer_idx],
- name="h_%d" % layer_idx,
- stddev=1e-2,
- nonlinearity=None,
- bias=bias)
- if use_batch_norm:
- res = batch_norm(
- input_=res,
- dim=dims_out[layer_idx],
- name="bn_%d" % layer_idx,
- scale=(nonlinearity == tf.nn.relu),
- train=train,
- epsilon=1e-8,
- axes=[0, 1, 2])
- if nonlinearity is not None:
- res = nonlinearity(res)
- res = conv_layer(
- input_=res,
- filter_size=filter_sizes[-1],
- dim_in=dims_out[-1],
- dim_out=dim_out,
- name="out",
- stddev=1e-2,
- nonlinearity=None)
- return res
- # distributions
- # log-likelihood estimation
- def standard_normal_ll(input_):
- """Log-likelihood of standard Gaussian distribution."""
- res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))
- return res
- def standard_normal_sample(shape):
- """Samples from standard Gaussian distribution."""
- return tf.random_normal(shape)
- SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
- [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])
- def squeeze_2x2_ordered(input_, reverse=False):
- """Squeezing operation with a controlled ordering."""
- shape = input_.get_shape().as_list()
- batch_size = shape[0]
- height = shape[1]
- width = shape[2]
- channels = shape[3]
- if reverse:
- if channels % 4 != 0:
- raise ValueError("Number of channels not divisible by 4.")
- channels /= 4
- else:
- if height % 2 != 0:
- raise ValueError("Height not divisible by 2.")
- if width % 2 != 0:
- raise ValueError("Width not divisible by 2.")
- weights = numpy.zeros((2, 2, channels, 4 * channels))
- for idx_ch in xrange(channels):
- slice_2 = slice(idx_ch, (idx_ch + 1))
- slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
- weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
- shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
- shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
- shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
- shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
- shuffle_channels = numpy.array(shuffle_channels)
- weights = weights[:, :, :, shuffle_channels].astype("float32")
- if reverse:
- res = tf.nn.conv2d_transpose(
- value=input_,
- filter=weights,
- output_shape=[batch_size, height * 2, width * 2, channels],
- strides=[1, 2, 2, 1],
- padding="SAME",
- name="unsqueeze_2x2")
- else:
- res = tf.nn.conv2d(
- input=input_,
- filter=weights,
- strides=[1, 2, 2, 1],
- padding="SAME",
- name="squeeze_2x2")
- return res
|