|
@@ -0,0 +1,474 @@
|
|
|
+# Copyright 2016 Google Inc. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+# ==============================================================================
|
|
|
+
|
|
|
+r"""Utility functions for Real NVP.
|
|
|
+"""
|
|
|
+
|
|
|
+# pylint: disable=dangerous-default-value
|
|
|
+
|
|
|
+import numpy
|
|
|
+import tensorflow as tf
|
|
|
+from tensorflow.python.framework import ops
|
|
|
+
|
|
|
+DEFAULT_BN_LAG = .0
|
|
|
+
|
|
|
+
|
|
|
+def stable_var(input_, mean=None, axes=[0]):
|
|
|
+ """Numerically more stable variance computation."""
|
|
|
+ if mean is None:
|
|
|
+ mean = tf.reduce_mean(input_, axes)
|
|
|
+ res = tf.square(input_ - mean)
|
|
|
+ max_sqr = tf.reduce_max(res, axes)
|
|
|
+ res /= max_sqr
|
|
|
+ res = tf.reduce_mean(res, axes)
|
|
|
+ res *= max_sqr
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def variable_on_cpu(name, shape, initializer, trainable=True):
|
|
|
+ """Helper to create a Variable stored on CPU memory.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ name: name of the variable
|
|
|
+ shape: list of ints
|
|
|
+ initializer: initializer for Variable
|
|
|
+ trainable: boolean defining if the variable is for training
|
|
|
+ Returns:
|
|
|
+ Variable Tensor
|
|
|
+ """
|
|
|
+ var = tf.get_variable(
|
|
|
+ name, shape, initializer=initializer, trainable=trainable)
|
|
|
+ return var
|
|
|
+
|
|
|
+
|
|
|
+# layers
|
|
|
+def conv_layer(input_,
|
|
|
+ filter_size,
|
|
|
+ dim_in,
|
|
|
+ dim_out,
|
|
|
+ name,
|
|
|
+ stddev=1e-2,
|
|
|
+ strides=[1, 1, 1, 1],
|
|
|
+ padding="SAME",
|
|
|
+ nonlinearity=None,
|
|
|
+ bias=False,
|
|
|
+ weight_norm=False,
|
|
|
+ scale=False):
|
|
|
+ """Convolutional layer."""
|
|
|
+ with tf.variable_scope(name) as scope:
|
|
|
+ weights = variable_on_cpu(
|
|
|
+ "weights",
|
|
|
+ filter_size + [dim_in, dim_out],
|
|
|
+ tf.random_uniform_initializer(
|
|
|
+ minval=-stddev, maxval=stddev))
|
|
|
+ # weight normalization
|
|
|
+ if weight_norm:
|
|
|
+ weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
|
|
|
+ if scale:
|
|
|
+ magnitude = variable_on_cpu(
|
|
|
+ "magnitude", [dim_out],
|
|
|
+ tf.constant_initializer(
|
|
|
+ stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
|
|
|
+ weights *= magnitude
|
|
|
+ res = input_
|
|
|
+ # handling filter size bigger than image size
|
|
|
+ if hasattr(input_, "shape"):
|
|
|
+ if input_.get_shape().as_list()[1] < filter_size[0]:
|
|
|
+ pad_1 = tf.zeros([
|
|
|
+ input_.get_shape().as_list()[0],
|
|
|
+ filter_size[0] - input_.get_shape().as_list()[1],
|
|
|
+ input_.get_shape().as_list()[2],
|
|
|
+ input_.get_shape().as_list()[3]
|
|
|
+ ])
|
|
|
+ pad_2 = tf.zeros([
|
|
|
+ input_.get_shape().as_list[0],
|
|
|
+ filter_size[0],
|
|
|
+ filter_size[1] - input_.get_shape().as_list()[2],
|
|
|
+ input_.get_shape().as_list()[3]
|
|
|
+ ])
|
|
|
+ res = tf.concat(1, [pad_1, res])
|
|
|
+ res = tf.concat(2, [pad_2, res])
|
|
|
+ res = tf.nn.conv2d(
|
|
|
+ input=res,
|
|
|
+ filter=weights,
|
|
|
+ strides=strides,
|
|
|
+ padding=padding,
|
|
|
+ name=scope.name)
|
|
|
+
|
|
|
+ if hasattr(input_, "shape"):
|
|
|
+ if input_.get_shape().as_list()[1] < filter_size[0]:
|
|
|
+ res = tf.slice(res, [
|
|
|
+ 0, filter_size[0] - input_.get_shape().as_list()[1],
|
|
|
+ filter_size[1] - input_.get_shape().as_list()[2], 0
|
|
|
+ ], [-1, -1, -1, -1])
|
|
|
+
|
|
|
+ if bias:
|
|
|
+ biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
|
|
|
+ res = tf.nn.bias_add(res, biases)
|
|
|
+ if nonlinearity is not None:
|
|
|
+ res = nonlinearity(res)
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def max_pool_2x2(input_):
|
|
|
+ """Max pooling."""
|
|
|
+ return tf.nn.max_pool(
|
|
|
+ input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
|
|
|
+
|
|
|
+
|
|
|
+def depool_2x2(input_, stride=2):
|
|
|
+ """Depooling."""
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ batch_size = shape[0]
|
|
|
+ height = shape[1]
|
|
|
+ width = shape[2]
|
|
|
+ channels = shape[3]
|
|
|
+ res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
|
|
|
+ res = tf.concat(
|
|
|
+ 2, [res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
|
|
|
+ res = tf.concat(4, [
|
|
|
+ res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
|
|
|
+ ])
|
|
|
+ res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+# random flip on a batch of images
|
|
|
+def batch_random_flip(input_):
|
|
|
+ """Simultaneous horizontal random flip."""
|
|
|
+ if isinstance(input_, (float, int)):
|
|
|
+ return input_
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ batch_size = shape[0]
|
|
|
+ height = shape[1]
|
|
|
+ width = shape[2]
|
|
|
+ channels = shape[3]
|
|
|
+ res = tf.split(0, batch_size, input_)
|
|
|
+ res = [elem[0, :, :, :] for elem in res]
|
|
|
+ res = [tf.image.random_flip_left_right(elem) for elem in res]
|
|
|
+ res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
|
|
|
+ res = tf.concat(0, res)
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+# build a one hot representation corresponding to the integer tensor
|
|
|
+# the one-hot dimension is appended to the integer tensor shape
|
|
|
+def as_one_hot(input_, n_indices):
|
|
|
+ """Convert indices to one-hot."""
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ n_elem = numpy.prod(shape)
|
|
|
+ indices = tf.range(n_elem)
|
|
|
+ indices = tf.cast(indices, tf.int64)
|
|
|
+ indices_input = tf.concat(0, [indices, tf.reshape(input_, [-1])])
|
|
|
+ indices_input = tf.reshape(indices_input, [2, -1])
|
|
|
+ indices_input = tf.transpose(indices_input)
|
|
|
+ res = tf.sparse_to_dense(
|
|
|
+ indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
|
|
|
+ res = tf.reshape(res, [elem for elem in shape] + [n_indices])
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def squeeze_2x2(input_):
|
|
|
+ """Squeezing operation: reshape to convert space to channels."""
|
|
|
+ return squeeze_nxn(input_, n_factor=2)
|
|
|
+
|
|
|
+
|
|
|
+def squeeze_nxn(input_, n_factor=2):
|
|
|
+ """Squeezing operation: reshape to convert space to channels."""
|
|
|
+ if isinstance(input_, (float, int)):
|
|
|
+ return input_
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ batch_size = shape[0]
|
|
|
+ height = shape[1]
|
|
|
+ width = shape[2]
|
|
|
+ channels = shape[3]
|
|
|
+ if height % n_factor != 0:
|
|
|
+ raise ValueError("Height not divisible by %d." % n_factor)
|
|
|
+ if width % n_factor != 0:
|
|
|
+ raise ValueError("Width not divisible by %d." % n_factor)
|
|
|
+ res = tf.reshape(
|
|
|
+ input_,
|
|
|
+ [batch_size,
|
|
|
+ height // n_factor,
|
|
|
+ n_factor, width // n_factor,
|
|
|
+ n_factor, channels])
|
|
|
+ res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
|
|
|
+ res = tf.reshape(
|
|
|
+ res,
|
|
|
+ [batch_size,
|
|
|
+ height // n_factor,
|
|
|
+ width // n_factor,
|
|
|
+ channels * n_factor * n_factor])
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def unsqueeze_2x2(input_):
|
|
|
+ """Unsqueezing operation: reshape to convert channels into space."""
|
|
|
+ if isinstance(input_, (float, int)):
|
|
|
+ return input_
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ batch_size = shape[0]
|
|
|
+ height = shape[1]
|
|
|
+ width = shape[2]
|
|
|
+ channels = shape[3]
|
|
|
+ if channels % 4 != 0:
|
|
|
+ raise ValueError("Number of channels not divisible by 4.")
|
|
|
+ res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
|
|
|
+ res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
|
|
|
+ res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+# batch norm
|
|
|
+def batch_norm(input_,
|
|
|
+ dim,
|
|
|
+ name,
|
|
|
+ scale=True,
|
|
|
+ train=True,
|
|
|
+ epsilon=1e-8,
|
|
|
+ decay=.1,
|
|
|
+ axes=[0],
|
|
|
+ bn_lag=DEFAULT_BN_LAG):
|
|
|
+ """Batch normalization."""
|
|
|
+ # create variables
|
|
|
+ with tf.variable_scope(name):
|
|
|
+ var = variable_on_cpu(
|
|
|
+ "var", [dim], tf.constant_initializer(1.), trainable=False)
|
|
|
+ mean = variable_on_cpu(
|
|
|
+ "mean", [dim], tf.constant_initializer(0.), trainable=False)
|
|
|
+ step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
|
|
|
+ if scale:
|
|
|
+ gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
|
|
|
+ beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
|
|
|
+ # choose the appropriate moments
|
|
|
+ if train:
|
|
|
+ used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
|
|
|
+ cur_mean, cur_var = used_mean, used_var
|
|
|
+ if bn_lag > 0.:
|
|
|
+ used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
|
|
|
+ used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
|
|
|
+ used_mean /= (1. - bn_lag**(step + 1))
|
|
|
+ used_var /= (1. - bn_lag**(step + 1))
|
|
|
+ else:
|
|
|
+ used_mean, used_var = mean, var
|
|
|
+ cur_mean, cur_var = used_mean, used_var
|
|
|
+
|
|
|
+ # normalize
|
|
|
+ res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
|
|
|
+ # de-normalize
|
|
|
+ if scale:
|
|
|
+ res *= gamma
|
|
|
+ res += beta
|
|
|
+
|
|
|
+ # update variables
|
|
|
+ if train:
|
|
|
+ with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
|
|
|
+ with ops.colocate_with(mean):
|
|
|
+ new_mean = tf.assign_sub(
|
|
|
+ mean,
|
|
|
+ tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
|
|
|
+ with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
|
|
|
+ with ops.colocate_with(var):
|
|
|
+ new_var = tf.assign_sub(
|
|
|
+ var,
|
|
|
+ tf.check_numerics(decay * (var - cur_var),
|
|
|
+ "NaN in moving variance."))
|
|
|
+ with tf.name_scope(name, "IncrementTime", [step]):
|
|
|
+ with ops.colocate_with(step):
|
|
|
+ new_step = tf.assign_add(step, 1.)
|
|
|
+ res += 0. * new_mean * new_var * new_step
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+# batch normalization taking into account the volume transformation
|
|
|
+def batch_norm_log_diff(input_,
|
|
|
+ dim,
|
|
|
+ name,
|
|
|
+ train=True,
|
|
|
+ epsilon=1e-8,
|
|
|
+ decay=.1,
|
|
|
+ axes=[0],
|
|
|
+ reuse=None,
|
|
|
+ bn_lag=DEFAULT_BN_LAG):
|
|
|
+ """Batch normalization with corresponding log determinant Jacobian."""
|
|
|
+ if reuse is None:
|
|
|
+ reuse = not train
|
|
|
+ # create variables
|
|
|
+ with tf.variable_scope(name) as scope:
|
|
|
+ if reuse:
|
|
|
+ scope.reuse_variables()
|
|
|
+ var = variable_on_cpu(
|
|
|
+ "var", [dim], tf.constant_initializer(1.), trainable=False)
|
|
|
+ mean = variable_on_cpu(
|
|
|
+ "mean", [dim], tf.constant_initializer(0.), trainable=False)
|
|
|
+ step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
|
|
|
+ # choose the appropriate moments
|
|
|
+ if train:
|
|
|
+ used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
|
|
|
+ cur_mean, cur_var = used_mean, used_var
|
|
|
+ if bn_lag > 0.:
|
|
|
+ used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
|
|
|
+ cur_var = used_var
|
|
|
+ used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
|
|
|
+ used_mean /= (1. - bn_lag**(step + 1))
|
|
|
+ used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
|
|
|
+ used_var /= (1. - bn_lag**(step + 1))
|
|
|
+ else:
|
|
|
+ used_mean, used_var = mean, var
|
|
|
+ cur_mean, cur_var = used_mean, used_var
|
|
|
+
|
|
|
+ # update variables
|
|
|
+ if train:
|
|
|
+ with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
|
|
|
+ with ops.colocate_with(mean):
|
|
|
+ new_mean = tf.assign_sub(
|
|
|
+ mean,
|
|
|
+ tf.check_numerics(
|
|
|
+ decay * (mean - cur_mean), "NaN in moving mean."))
|
|
|
+ with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
|
|
|
+ with ops.colocate_with(var):
|
|
|
+ new_var = tf.assign_sub(
|
|
|
+ var,
|
|
|
+ tf.check_numerics(decay * (var - cur_var),
|
|
|
+ "NaN in moving variance."))
|
|
|
+ with tf.name_scope(name, "IncrementTime", [step]):
|
|
|
+ with ops.colocate_with(step):
|
|
|
+ new_step = tf.assign_add(step, 1.)
|
|
|
+ used_var += 0. * new_mean * new_var * new_step
|
|
|
+ used_var += epsilon
|
|
|
+
|
|
|
+ return used_mean, used_var
|
|
|
+
|
|
|
+
|
|
|
+def convnet(input_,
|
|
|
+ dim_in,
|
|
|
+ dim_hid,
|
|
|
+ filter_sizes,
|
|
|
+ dim_out,
|
|
|
+ name,
|
|
|
+ use_batch_norm=True,
|
|
|
+ train=True,
|
|
|
+ nonlinearity=tf.nn.relu):
|
|
|
+ """Chaining of convolutional layers."""
|
|
|
+ dims_in = [dim_in] + dim_hid[:-1]
|
|
|
+ dims_out = dim_hid
|
|
|
+ res = input_
|
|
|
+
|
|
|
+ bias = (not use_batch_norm)
|
|
|
+ with tf.variable_scope(name):
|
|
|
+ for layer_idx in xrange(len(dim_hid)):
|
|
|
+ res = conv_layer(
|
|
|
+ input_=res,
|
|
|
+ filter_size=filter_sizes[layer_idx],
|
|
|
+ dim_in=dims_in[layer_idx],
|
|
|
+ dim_out=dims_out[layer_idx],
|
|
|
+ name="h_%d" % layer_idx,
|
|
|
+ stddev=1e-2,
|
|
|
+ nonlinearity=None,
|
|
|
+ bias=bias)
|
|
|
+ if use_batch_norm:
|
|
|
+ res = batch_norm(
|
|
|
+ input_=res,
|
|
|
+ dim=dims_out[layer_idx],
|
|
|
+ name="bn_%d" % layer_idx,
|
|
|
+ scale=(nonlinearity == tf.nn.relu),
|
|
|
+ train=train,
|
|
|
+ epsilon=1e-8,
|
|
|
+ axes=[0, 1, 2])
|
|
|
+ if nonlinearity is not None:
|
|
|
+ res = nonlinearity(res)
|
|
|
+
|
|
|
+ res = conv_layer(
|
|
|
+ input_=res,
|
|
|
+ filter_size=filter_sizes[-1],
|
|
|
+ dim_in=dims_out[-1],
|
|
|
+ dim_out=dim_out,
|
|
|
+ name="out",
|
|
|
+ stddev=1e-2,
|
|
|
+ nonlinearity=None)
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+# distributions
|
|
|
+# log-likelihood estimation
|
|
|
+def standard_normal_ll(input_):
|
|
|
+ """Log-likelihood of standard Gaussian distribution."""
|
|
|
+ res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
+def standard_normal_sample(shape):
|
|
|
+ """Samples from standard Gaussian distribution."""
|
|
|
+ return tf.random_normal(shape)
|
|
|
+
|
|
|
+
|
|
|
+SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
|
|
|
+ [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])
|
|
|
+
|
|
|
+
|
|
|
+def squeeze_2x2_ordered(input_, reverse=False):
|
|
|
+ """Squeezing operation with a controlled ordering."""
|
|
|
+ shape = input_.get_shape().as_list()
|
|
|
+ batch_size = shape[0]
|
|
|
+ height = shape[1]
|
|
|
+ width = shape[2]
|
|
|
+ channels = shape[3]
|
|
|
+ if reverse:
|
|
|
+ if channels % 4 != 0:
|
|
|
+ raise ValueError("Number of channels not divisible by 4.")
|
|
|
+ channels /= 4
|
|
|
+ else:
|
|
|
+ if height % 2 != 0:
|
|
|
+ raise ValueError("Height not divisible by 2.")
|
|
|
+ if width % 2 != 0:
|
|
|
+ raise ValueError("Width not divisible by 2.")
|
|
|
+ weights = numpy.zeros((2, 2, channels, 4 * channels))
|
|
|
+ for idx_ch in xrange(channels):
|
|
|
+ slice_2 = slice(idx_ch, (idx_ch + 1))
|
|
|
+ slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
|
|
|
+ weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
|
|
|
+ shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
|
|
|
+ shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
|
|
|
+ shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
|
|
|
+ shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
|
|
|
+ shuffle_channels = numpy.array(shuffle_channels)
|
|
|
+ weights = weights[:, :, :, shuffle_channels].astype("float32")
|
|
|
+ if reverse:
|
|
|
+ res = tf.nn.conv2d_transpose(
|
|
|
+ value=input_,
|
|
|
+ filter=weights,
|
|
|
+ output_shape=[batch_size, height * 2, width * 2, channels],
|
|
|
+ strides=[1, 2, 2, 1],
|
|
|
+ padding="SAME",
|
|
|
+ name="unsqueeze_2x2")
|
|
|
+ else:
|
|
|
+ res = tf.nn.conv2d(
|
|
|
+ input=input_,
|
|
|
+ filter=weights,
|
|
|
+ strides=[1, 2, 2, 1],
|
|
|
+ padding="SAME",
|
|
|
+ name="squeeze_2x2")
|
|
|
+
|
|
|
+ return res
|