Bladeren bron

Real NVP code

Laurent Dinh 8 jaren geleden
bovenliggende
commit
e871d29598

+ 278 - 0
real_nvp/README.md

@@ -0,0 +1,278 @@
+# Real NVP in TensorFlow
+
+*A Tensorflow implementation of the training procedure of*
+[*Density estimation using Real NVP*](https://arxiv.org/abs/1605.08803)*, by
+Laurent Dinh, Jascha Sohl-Dickstein and Samy Bengio, for Imagenet
+(32x32 and 64x64), CelebA and LSUN Including the scripts to
+put the datasets in `.tfrecords` format.*
+
+We are happy to open source the code for *Real NVP*, a novel approach to
+density estimation using deep neural networks that enables tractable density
+estimation and efficient one-pass inference and sampling. This model
+successfully decomposes images into hierarchical features ranging from
+high-level concepts to low-resolution details. Visualizations are available
+[here](http://goo.gl/yco14s).
+
+## Installation
+*   python 2.7:
+    * python 3 support is not available yet
+*   pip (python package manager)
+    * `apt-get install python-pip` on Ubuntu
+    * `brew` installs pip along with python on OSX
+*   Install the dependencies for [LSUN](https://github.com/fyu/lsun.git)
+    * Install [OpenCV](http://opencv.org/)
+    * `pip install numpy lmdb`
+*   Install the python dependencies
+    * `pip install scipy scikit-image Pillow`
+*   Install the
+[latest Tensorflow Pip package](https://www.tensorflow.org/get_started/os_setup.html#using-pip)
+for Python 2.7
+
+## Getting Started
+Once you have successfully installed the dependencies, you can start by
+downloading the repository:
+```shell
+git clone --recursive https://github.com/tensorflow/models.git
+```
+Afterward, you can use the utilities in this folder prepare the datasets.
+
+## Preparing datasets
+### CelebA
+For [*CelebA*](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html), download
+`img_align_celeba.zip` from the Dropbox link on this
+[page](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) under the
+link *Align&Cropped Images* in the *Img* directory and `list_eval_partition.txt`
+under the link *Train/Val/Test Partitions* in the *Eval* directory. Then do:
+
+```shell
+mkdir celeba
+cd celeba
+unzip img_align_celeba.zip
+```
+
+We'll format the training subset:
+```shell
+python2.7 ../models/real_nvp/celeba_formatting.py \
+    --partition_fn list_eval_partition.txt \
+    --file_out celeba_train \
+    --fn_root img_align_celeba \
+    --set 0
+```
+
+Then the validation subset:
+```shell
+python2.7 ../models/real_nvp/celeba_formatting.py \
+    --partition_fn list_eval_partition.txt \
+    --file_out celeba_valid \
+    --fn_root img_align_celeba \
+    --set 1
+```
+
+And finally the test subset:
+```shell
+python2.7 ../models/real_nvp/celeba_formatting.py \
+    --partition_fn list_eval_partition.txt \
+    --file_out celeba_test \
+    --fn_root img_align_celeba \
+    --set 2
+```
+
+Afterward:
+```shell
+cd ..
+```
+
+### Small Imagenet
+Downloading the [*small Imagenet*](http://image-net.org/small/download.php)
+dataset is more straightforward and can be done
+entirely in Shell:
+```shell
+mkdir small_imnet
+cd small_imnet
+for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar
+do
+    curl -O http://image-net.org/small/$FILENAME
+    tar -xvf $FILENAME
+done
+```
+
+Then, you can format the datasets as follow:
+```shell
+for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64
+do
+    python2.7 ../models/real_nvp/imnet_formatting.py \
+        --file_out $DIRNAME \
+        --fn_root $DIRNAME
+done
+cd ..
+```
+
+### LSUN
+To prepare the [*LSUN*](http://lsun.cs.princeton.edu/2016/) dataset, we will
+need to use the code associated:
+```shell
+git clone https://github.com/fyu/lsun.git
+cd lsun
+```
+Then we'll download the db files:
+```shell
+for CATEGORY in bedroom church_outdoor tower
+do
+    python2.7 download.py -c $CATEGORY
+    unzip "$CATEGORY"_train_lmdb.zip
+    unzip "$CATEGORY"_val_lmdb.zip
+    python2.7 data.py export "$CATEGORY"_train_lmdb \
+        --out_dir "$CATEGORY"_train --flat
+    python2.7 data.py export "$CATEGORY"_val_lmdb \
+        --out_dir "$CATEGORY"_val --flat
+done
+```
+
+Finally, we then format the dataset into `.tfrecords`:
+```shell
+for CATEGORY in bedroom church_outdoor tower
+do
+    python2.7 ../models/real_nvp/lsun_formatting.py \
+        --file_out "$CATEGORY"_train \
+        --fn_root "$CATEGORY"_train
+    python2.7 ../models/real_nvp/lsun_formatting.py \
+        --file_out "$CATEGORY"_val \
+        --fn_root "$CATEGORY"_val
+done
+cd ..
+```
+
+
+## Training
+We'll give an example on how to train a model on the small Imagenet
+dataset (32x32):
+```shell
+cd models/real_nvp/
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 32 \
+--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset imnet \
+--traindir /tmp/real_nvp_imnet32/train \
+--logdir /tmp/real_nvp_imnet32/train \
+--data_path ../../small_imnet/train_32x32_?????.tfrecords
+```
+In parallel, you can run the script to generate visualization from the model:
+```shell
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 32 \
+--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset imnet \
+--traindir /tmp/real_nvp_imnet32/train \
+--logdir /tmp/real_nvp_imnet32/sample \
+--data_path ../../small_imnet/valid_32x32_?????.tfrecords \
+--mode sample
+```
+Additionally, you can also run in the script to evaluate the model on the
+validation set:
+```shell
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 32 \
+--hpconfig=n_scale=4,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset imnet \
+--traindir /tmp/real_nvp_imnet32/train \
+--logdir /tmp/real_nvp_imnet32/eval \
+--data_path ../../small_imnet/valid_32x32_?????.tfrecords \
+--eval_set_size 50000
+--mode eval
+```
+The visualizations and validation set evaluation can be seen through
+[Tensorboard](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tensorboard/README.md).
+
+Another example would be how to run the model on LSUN (bedroom category):
+```shell
+# train the model
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset lsun \
+--traindir /tmp/real_nvp_church_outdoor/train \
+--logdir /tmp/real_nvp_church_outdoor/train \
+--data_path ../../lsun/church_outdoor_train_?????.tfrecords
+```
+
+```shell
+# sample from the model
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset lsun \
+--traindir /tmp/real_nvp_church_outdoor/train \
+--logdir /tmp/real_nvp_church_outdoor/sample \
+--data_path ../../lsun/church_outdoor_val_?????.tfrecords \
+--mode sample
+```
+
+```shell
+# evaluate the model
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset lsun \
+--traindir /tmp/real_nvp_church_outdoor/train \
+--logdir /tmp/real_nvp_church_outdoor/eval \
+--data_path ../../lsun/church_outdoor_val_?????.tfrecords \
+--eval_set_size 300
+--mode eval
+```
+
+Finally, we'll give the commands to run the model on the CelebA dataset:
+```shell
+# train the model
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset lsun \
+--traindir /tmp/real_nvp_celeba/train \
+--logdir /tmp/real_nvp_celeba/train \
+--data_path ../../celeba/celeba_train.tfrecords
+```
+
+```shell
+# sample from the model
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset celeba \
+--traindir /tmp/real_nvp_celeba/train \
+--logdir /tmp/real_nvp_celeba/sample \
+--data_path ../../celeba/celeba_valid.tfrecords \
+--mode sample
+```
+
+```shell
+# evaluate the model on validation set
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset celeba \
+--traindir /tmp/real_nvp_celeba/train \
+--logdir /tmp/real_nvp_celeba/eval_valid \
+--data_path ../../celeba/celeba_valid.tfrecords \
+--eval_set_size 19867
+--mode eval
+
+# evaluate the model on test set
+python2.7 real_nvp_multiscale_dataset.py \
+--image_size 64 \
+--hpconfig=n_scale=5,base_dim=32,clip_gradient=100,residual_blocks=4 \
+--dataset celeba \
+--traindir /tmp/real_nvp_celeba/train \
+--logdir /tmp/real_nvp_celeba/eval_test \
+--data_path ../../celeba/celeba_test.tfrecords \
+--eval_set_size 19962
+--mode eval
+```
+
+## Credits
+This code was written by Laurent Dinh
+([@laurent-dinh](https://github.com/laurent-dinh)) with
+the help of
+Jascha Sohl-Dickstein ([@Sohl-Dickstein](https://github.com/Sohl-Dickstein)
+and [jaschasd@google.com](mailto:jaschasd@google.com)),
+Samy Bengio, Jon Shlens, Sherry Moore and
+David Andersen.

+ 0 - 0
real_nvp/__init__.py


+ 94 - 0
real_nvp/celeba_formatting.py

@@ -0,0 +1,94 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""CelebA dataset formating.
+
+Download img_align_celeba.zip from
+http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html under the
+link "Align&Cropped Images" in the "Img" directory and list_eval_partition.txt
+under the link "Train/Val/Test Partitions" in the "Eval" directory. Then do:
+unzip img_align_celeba.zip
+
+Use the script as follow:
+python celeba_formatting.py \
+    --partition_fn [PARTITION_FILE_PATH] \
+    --file_out [OUTPUT_FILE_PATH_PREFIX] \
+    --fn_root [CELEBA_FOLDER] \
+    --set [SUBSET_INDEX]
+
+"""
+
+import os
+import os.path
+
+import scipy.io
+import scipy.io.wavfile
+import scipy.ndimage
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string("file_out", "",
+                       "Filename of the output .tfrecords file.")
+tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
+tf.flags.DEFINE_string("partition_fn", "", "Partition file path.")
+tf.flags.DEFINE_string("set", "", "Name of subset.")
+
+FLAGS = tf.flags.FLAGS
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def main():
+    """Main converter function."""
+    # Celeb A
+    with open(FLAGS.partition_fn, "r") as infile:
+        img_fn_list = infile.readlines()
+    img_fn_list = [elem.strip().split() for elem in img_fn_list]
+    img_fn_list = [elem[0] for elem in img_fn_list if elem[1] == FLAGS.set]
+    fn_root = FLAGS.fn_root
+    num_examples = len(img_fn_list)
+
+    file_out = "%s.tfrecords" % FLAGS.file_out
+    writer = tf.python_io.TFRecordWriter(file_out)
+    for example_idx, img_fn in enumerate(img_fn_list):
+        if example_idx % 1000 == 0:
+            print example_idx, "/", num_examples
+        image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
+        rows = image_raw.shape[0]
+        cols = image_raw.shape[1]
+        depth = image_raw.shape[2]
+        image_raw = image_raw.tostring()
+        example = tf.train.Example(
+            features=tf.train.Features(
+                feature={
+                    "height": _int64_feature(rows),
+                    "width": _int64_feature(cols),
+                    "depth": _int64_feature(depth),
+                    "image_raw": _bytes_feature(image_raw)
+                }
+            )
+        )
+        writer.write(example.SerializeToString())
+    writer.close()
+
+
+if __name__ == "__main__":
+    main()

+ 103 - 0
real_nvp/imnet_formatting.py

@@ -0,0 +1,103 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""LSUN dataset formatting.
+
+Download and format the Imagenet dataset as follow:
+mkdir [IMAGENET_PATH]
+cd [IMAGENET_PATH]
+for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar
+do
+    curl -O http://image-net.org/small/$FILENAME
+    tar -xvf $FILENAME
+done
+
+Then use the script as follow:
+for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64
+do
+    python imnet_formatting.py \
+        --file_out $DIRNAME \
+        --fn_root $DIRNAME
+done
+
+"""
+
+import os
+import os.path
+
+import scipy.io
+import scipy.io.wavfile
+import scipy.ndimage
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string("file_out", "",
+                       "Filename of the output .tfrecords file.")
+tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
+
+FLAGS = tf.flags.FLAGS
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def main():
+    """Main converter function."""
+    # LSUN
+    fn_root = FLAGS.fn_root
+    img_fn_list = os.listdir(fn_root)
+    img_fn_list = [img_fn for img_fn in img_fn_list
+                   if img_fn.endswith('.png')]
+    num_examples = len(img_fn_list)
+
+    n_examples_per_file = 10000
+    for example_idx, img_fn in enumerate(img_fn_list):
+        if example_idx % n_examples_per_file == 0:
+            file_out = "%s_%05d.tfrecords"
+            file_out = file_out % (FLAGS.file_out,
+                                   example_idx // n_examples_per_file)
+            print "Writing on:", file_out
+            writer = tf.python_io.TFRecordWriter(file_out)
+        if example_idx % 1000 == 0:
+            print example_idx, "/", num_examples
+        image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
+        rows = image_raw.shape[0]
+        cols = image_raw.shape[1]
+        depth = image_raw.shape[2]
+        image_raw = image_raw.astype("uint8")
+        image_raw = image_raw.tostring()
+        example = tf.train.Example(
+            features=tf.train.Features(
+                feature={
+                    "height": _int64_feature(rows),
+                    "width": _int64_feature(cols),
+                    "depth": _int64_feature(depth),
+                    "image_raw": _bytes_feature(image_raw)
+                }
+            )
+        )
+        writer.write(example.SerializeToString())
+        if example_idx % n_examples_per_file == (n_examples_per_file - 1):
+            writer.close()
+    writer.close()
+
+
+if __name__ == "__main__":
+    main()

+ 104 - 0
real_nvp/lsun_formatting.py

@@ -0,0 +1,104 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""LSUN dataset formatting.
+
+Download and format the LSUN dataset as follow:
+git clone https://github.com/fyu/lsun.git
+cd lsun
+python2.7 download.py -c [CATEGORY]
+
+Then unzip the downloaded .zip files before executing:
+python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat
+
+Then use the script as follow:
+python lsun_formatting.py \
+    --file_out [OUTPUT_FILE_PATH_PREFIX] \
+    --fn_root [LSUN_FOLDER]
+
+"""
+
+import os
+import os.path
+
+import numpy
+import skimage.transform
+from PIL import Image
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string("file_out", "",
+                       "Filename of the output .tfrecords file.")
+tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
+
+FLAGS = tf.flags.FLAGS
+
+
+def _int64_feature(value):
+    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def main():
+    """Main converter function."""
+    fn_root = FLAGS.fn_root
+    img_fn_list = os.listdir(fn_root)
+    img_fn_list = [img_fn for img_fn in img_fn_list
+                   if img_fn.endswith('.webp')]
+    num_examples = len(img_fn_list)
+
+    n_examples_per_file = 10000
+    for example_idx, img_fn in enumerate(img_fn_list):
+        if example_idx % n_examples_per_file == 0:
+            file_out = "%s_%05d.tfrecords"
+            file_out = file_out % (FLAGS.file_out,
+                                   example_idx // n_examples_per_file)
+            print "Writing on:", file_out
+            writer = tf.python_io.TFRecordWriter(file_out)
+        if example_idx % 1000 == 0:
+            print example_idx, "/", num_examples
+        image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn)))
+        rows = image_raw.shape[0]
+        cols = image_raw.shape[1]
+        depth = image_raw.shape[2]
+        downscale = min(rows / 96., cols / 96.)
+        image_raw = skimage.transform.pyramid_reduce(image_raw, downscale)
+        image_raw *= 255.
+        image_raw = image_raw.astype("uint8")
+        rows = image_raw.shape[0]
+        cols = image_raw.shape[1]
+        depth = image_raw.shape[2]
+        image_raw = image_raw.tostring()
+        example = tf.train.Example(
+            features=tf.train.Features(
+                feature={
+                    "height": _int64_feature(rows),
+                    "width": _int64_feature(cols),
+                    "depth": _int64_feature(depth),
+                    "image_raw": _bytes_feature(image_raw)
+                }
+            )
+        )
+        writer.write(example.SerializeToString())
+        if example_idx % n_examples_per_file == (n_examples_per_file - 1):
+            writer.close()
+    writer.close()
+
+
+if __name__ == "__main__":
+    main()

File diff suppressed because it is too large
+ 1636 - 0
real_nvp/real_nvp_multiscale_dataset.py


+ 474 - 0
real_nvp/real_nvp_utils.py

@@ -0,0 +1,474 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""Utility functions for Real NVP.
+"""
+
+# pylint: disable=dangerous-default-value
+
+import numpy
+import tensorflow as tf
+from tensorflow.python.framework import ops
+
+DEFAULT_BN_LAG = .0
+
+
+def stable_var(input_, mean=None, axes=[0]):
+    """Numerically more stable variance computation."""
+    if mean is None:
+        mean = tf.reduce_mean(input_, axes)
+    res = tf.square(input_ - mean)
+    max_sqr = tf.reduce_max(res, axes)
+    res /= max_sqr
+    res = tf.reduce_mean(res, axes)
+    res *= max_sqr
+
+    return res
+
+
+def variable_on_cpu(name, shape, initializer, trainable=True):
+    """Helper to create a Variable stored on CPU memory.
+
+    Args:
+            name: name of the variable
+            shape: list of ints
+            initializer: initializer for Variable
+            trainable: boolean defining if the variable is for training
+    Returns:
+            Variable Tensor
+    """
+    var = tf.get_variable(
+        name, shape, initializer=initializer, trainable=trainable)
+    return var
+
+
+# layers
+def conv_layer(input_,
+               filter_size,
+               dim_in,
+               dim_out,
+               name,
+               stddev=1e-2,
+               strides=[1, 1, 1, 1],
+               padding="SAME",
+               nonlinearity=None,
+               bias=False,
+               weight_norm=False,
+               scale=False):
+    """Convolutional layer."""
+    with tf.variable_scope(name) as scope:
+        weights = variable_on_cpu(
+            "weights",
+            filter_size + [dim_in, dim_out],
+            tf.random_uniform_initializer(
+                minval=-stddev, maxval=stddev))
+        # weight normalization
+        if weight_norm:
+            weights /= tf.sqrt(tf.reduce_sum(tf.square(weights), [0, 1, 2]))
+            if scale:
+                magnitude = variable_on_cpu(
+                    "magnitude", [dim_out],
+                    tf.constant_initializer(
+                        stddev * numpy.sqrt(dim_in * numpy.prod(filter_size) / 12.)))
+                weights *= magnitude
+        res = input_
+        # handling filter size bigger than image size
+        if hasattr(input_, "shape"):
+            if input_.get_shape().as_list()[1] < filter_size[0]:
+                pad_1 = tf.zeros([
+                    input_.get_shape().as_list()[0],
+                    filter_size[0] - input_.get_shape().as_list()[1],
+                    input_.get_shape().as_list()[2],
+                    input_.get_shape().as_list()[3]
+                ])
+                pad_2 = tf.zeros([
+                    input_.get_shape().as_list[0],
+                    filter_size[0],
+                    filter_size[1] - input_.get_shape().as_list()[2],
+                    input_.get_shape().as_list()[3]
+                ])
+                res = tf.concat(1, [pad_1, res])
+                res = tf.concat(2, [pad_2, res])
+        res = tf.nn.conv2d(
+            input=res,
+            filter=weights,
+            strides=strides,
+            padding=padding,
+            name=scope.name)
+
+        if hasattr(input_, "shape"):
+            if input_.get_shape().as_list()[1] < filter_size[0]:
+                res = tf.slice(res, [
+                    0, filter_size[0] - input_.get_shape().as_list()[1],
+                    filter_size[1] - input_.get_shape().as_list()[2], 0
+                ], [-1, -1, -1, -1])
+
+        if bias:
+            biases = variable_on_cpu("biases", [dim_out], tf.constant_initializer(0.))
+            res = tf.nn.bias_add(res, biases)
+        if nonlinearity is not None:
+            res = nonlinearity(res)
+
+    return res
+
+
+def max_pool_2x2(input_):
+    """Max pooling."""
+    return tf.nn.max_pool(
+        input_, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding="SAME")
+
+
+def depool_2x2(input_, stride=2):
+    """Depooling."""
+    shape = input_.get_shape().as_list()
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+    channels = shape[3]
+    res = tf.reshape(input_, [batch_size, height, 1, width, 1, channels])
+    res = tf.concat(
+        2, [res, tf.zeros([batch_size, height, stride - 1, width, 1, channels])])
+    res = tf.concat(4, [
+        res, tf.zeros([batch_size, height, stride, width, stride - 1, channels])
+    ])
+    res = tf.reshape(res, [batch_size, stride * height, stride * width, channels])
+
+    return res
+
+
+# random flip on a batch of images
+def batch_random_flip(input_):
+    """Simultaneous horizontal random flip."""
+    if isinstance(input_, (float, int)):
+        return input_
+    shape = input_.get_shape().as_list()
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+    channels = shape[3]
+    res = tf.split(0, batch_size, input_)
+    res = [elem[0, :, :, :] for elem in res]
+    res = [tf.image.random_flip_left_right(elem) for elem in res]
+    res = [tf.reshape(elem, [1, height, width, channels]) for elem in res]
+    res = tf.concat(0, res)
+
+    return res
+
+
+# build a one hot representation corresponding to the integer tensor
+# the one-hot dimension is appended to the integer tensor shape
+def as_one_hot(input_, n_indices):
+    """Convert indices to one-hot."""
+    shape = input_.get_shape().as_list()
+    n_elem = numpy.prod(shape)
+    indices = tf.range(n_elem)
+    indices = tf.cast(indices, tf.int64)
+    indices_input = tf.concat(0, [indices, tf.reshape(input_, [-1])])
+    indices_input = tf.reshape(indices_input, [2, -1])
+    indices_input = tf.transpose(indices_input)
+    res = tf.sparse_to_dense(
+        indices_input, [n_elem, n_indices], 1., 0., name="flat_one_hot")
+    res = tf.reshape(res, [elem for elem in shape] + [n_indices])
+
+    return res
+
+
+def squeeze_2x2(input_):
+    """Squeezing operation: reshape to convert space to channels."""
+    return squeeze_nxn(input_, n_factor=2)
+
+
+def squeeze_nxn(input_, n_factor=2):
+    """Squeezing operation: reshape to convert space to channels."""
+    if isinstance(input_, (float, int)):
+        return input_
+    shape = input_.get_shape().as_list()
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+    channels = shape[3]
+    if height % n_factor != 0:
+        raise ValueError("Height not divisible by %d." % n_factor)
+    if width % n_factor != 0:
+        raise ValueError("Width not divisible by %d." % n_factor)
+    res = tf.reshape(
+        input_,
+        [batch_size,
+         height // n_factor,
+         n_factor, width // n_factor,
+         n_factor, channels])
+    res = tf.transpose(res, [0, 1, 3, 5, 2, 4])
+    res = tf.reshape(
+        res,
+        [batch_size,
+         height // n_factor,
+         width // n_factor,
+         channels * n_factor * n_factor])
+
+    return res
+
+
+def unsqueeze_2x2(input_):
+    """Unsqueezing operation: reshape to convert channels into space."""
+    if isinstance(input_, (float, int)):
+        return input_
+    shape = input_.get_shape().as_list()
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+    channels = shape[3]
+    if channels % 4 != 0:
+        raise ValueError("Number of channels not divisible by 4.")
+    res = tf.reshape(input_, [batch_size, height, width, channels // 4, 2, 2])
+    res = tf.transpose(res, [0, 1, 4, 2, 5, 3])
+    res = tf.reshape(res, [batch_size, 2 * height, 2 * width, channels // 4])
+
+    return res
+
+
+# batch norm
+def batch_norm(input_,
+               dim,
+               name,
+               scale=True,
+               train=True,
+               epsilon=1e-8,
+               decay=.1,
+               axes=[0],
+               bn_lag=DEFAULT_BN_LAG):
+    """Batch normalization."""
+    # create variables
+    with tf.variable_scope(name):
+        var = variable_on_cpu(
+            "var", [dim], tf.constant_initializer(1.), trainable=False)
+        mean = variable_on_cpu(
+            "mean", [dim], tf.constant_initializer(0.), trainable=False)
+        step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
+        if scale:
+            gamma = variable_on_cpu("gamma", [dim], tf.constant_initializer(1.))
+        beta = variable_on_cpu("beta", [dim], tf.constant_initializer(0.))
+    # choose the appropriate moments
+    if train:
+        used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
+        cur_mean, cur_var = used_mean, used_var
+        if bn_lag > 0.:
+            used_mean -= (1. - bn_lag) * (used_mean - tf.stop_gradient(mean))
+            used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
+            used_mean /= (1. - bn_lag**(step + 1))
+            used_var /= (1. - bn_lag**(step + 1))
+    else:
+        used_mean, used_var = mean, var
+        cur_mean, cur_var = used_mean, used_var
+
+    # normalize
+    res = (input_ - used_mean) / tf.sqrt(used_var + epsilon)
+    # de-normalize
+    if scale:
+        res *= gamma
+    res += beta
+
+    # update variables
+    if train:
+        with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
+            with ops.colocate_with(mean):
+                new_mean = tf.assign_sub(
+                    mean,
+                    tf.check_numerics(decay * (mean - cur_mean), "NaN in moving mean."))
+        with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
+            with ops.colocate_with(var):
+                new_var = tf.assign_sub(
+                    var,
+                    tf.check_numerics(decay * (var - cur_var),
+                                      "NaN in moving variance."))
+        with tf.name_scope(name, "IncrementTime", [step]):
+            with ops.colocate_with(step):
+                new_step = tf.assign_add(step, 1.)
+        res += 0. * new_mean * new_var * new_step
+
+    return res
+
+
+# batch normalization taking into account the volume transformation
+def batch_norm_log_diff(input_,
+                        dim,
+                        name,
+                        train=True,
+                        epsilon=1e-8,
+                        decay=.1,
+                        axes=[0],
+                        reuse=None,
+                        bn_lag=DEFAULT_BN_LAG):
+    """Batch normalization with corresponding log determinant Jacobian."""
+    if reuse is None:
+        reuse = not train
+    # create variables
+    with tf.variable_scope(name) as scope:
+        if reuse:
+            scope.reuse_variables()
+        var = variable_on_cpu(
+            "var", [dim], tf.constant_initializer(1.), trainable=False)
+        mean = variable_on_cpu(
+            "mean", [dim], tf.constant_initializer(0.), trainable=False)
+        step = variable_on_cpu("step", [], tf.constant_initializer(0.), trainable=False)
+    # choose the appropriate moments
+    if train:
+        used_mean, used_var = tf.nn.moments(input_, axes, name="batch_norm")
+        cur_mean, cur_var = used_mean, used_var
+        if bn_lag > 0.:
+            used_var = stable_var(input_=input_, mean=used_mean, axes=axes)
+            cur_var = used_var
+            used_mean -= (1 - bn_lag) * (used_mean - tf.stop_gradient(mean))
+            used_mean /= (1. - bn_lag**(step + 1))
+            used_var -= (1 - bn_lag) * (used_var - tf.stop_gradient(var))
+            used_var /= (1. - bn_lag**(step + 1))
+    else:
+        used_mean, used_var = mean, var
+        cur_mean, cur_var = used_mean, used_var
+
+    # update variables
+    if train:
+        with tf.name_scope(name, "AssignMovingAvg", [mean, cur_mean, decay]):
+            with ops.colocate_with(mean):
+                new_mean = tf.assign_sub(
+                    mean,
+                    tf.check_numerics(
+                        decay * (mean - cur_mean), "NaN in moving mean."))
+        with tf.name_scope(name, "AssignMovingAvg", [var, cur_var, decay]):
+            with ops.colocate_with(var):
+                new_var = tf.assign_sub(
+                    var,
+                    tf.check_numerics(decay * (var - cur_var),
+                                      "NaN in moving variance."))
+        with tf.name_scope(name, "IncrementTime", [step]):
+            with ops.colocate_with(step):
+                new_step = tf.assign_add(step, 1.)
+        used_var += 0. * new_mean * new_var * new_step
+    used_var += epsilon
+
+    return used_mean, used_var
+
+
+def convnet(input_,
+            dim_in,
+            dim_hid,
+            filter_sizes,
+            dim_out,
+            name,
+            use_batch_norm=True,
+            train=True,
+            nonlinearity=tf.nn.relu):
+    """Chaining of convolutional layers."""
+    dims_in = [dim_in] + dim_hid[:-1]
+    dims_out = dim_hid
+    res = input_
+
+    bias = (not use_batch_norm)
+    with tf.variable_scope(name):
+        for layer_idx in xrange(len(dim_hid)):
+            res = conv_layer(
+                input_=res,
+                filter_size=filter_sizes[layer_idx],
+                dim_in=dims_in[layer_idx],
+                dim_out=dims_out[layer_idx],
+                name="h_%d" % layer_idx,
+                stddev=1e-2,
+                nonlinearity=None,
+                bias=bias)
+            if use_batch_norm:
+                res = batch_norm(
+                    input_=res,
+                    dim=dims_out[layer_idx],
+                    name="bn_%d" % layer_idx,
+                    scale=(nonlinearity == tf.nn.relu),
+                    train=train,
+                    epsilon=1e-8,
+                    axes=[0, 1, 2])
+            if nonlinearity is not None:
+                res = nonlinearity(res)
+
+        res = conv_layer(
+            input_=res,
+            filter_size=filter_sizes[-1],
+            dim_in=dims_out[-1],
+            dim_out=dim_out,
+            name="out",
+            stddev=1e-2,
+            nonlinearity=None)
+
+    return res
+
+
+# distributions
+# log-likelihood estimation
+def standard_normal_ll(input_):
+    """Log-likelihood of standard Gaussian distribution."""
+    res = -.5 * (tf.square(input_) + numpy.log(2. * numpy.pi))
+
+    return res
+
+
+def standard_normal_sample(shape):
+    """Samples from standard Gaussian distribution."""
+    return tf.random_normal(shape)
+
+
+SQUEEZE_MATRIX = numpy.array([[[[1., 0., 0., 0.]], [[0., 0., 1., 0.]]],
+                              [[[0., 0., 0., 1.]], [[0., 1., 0., 0.]]]])
+
+
+def squeeze_2x2_ordered(input_, reverse=False):
+    """Squeezing operation with a controlled ordering."""
+    shape = input_.get_shape().as_list()
+    batch_size = shape[0]
+    height = shape[1]
+    width = shape[2]
+    channels = shape[3]
+    if reverse:
+        if channels % 4 != 0:
+            raise ValueError("Number of channels not divisible by 4.")
+        channels /= 4
+    else:
+        if height % 2 != 0:
+            raise ValueError("Height not divisible by 2.")
+        if width % 2 != 0:
+            raise ValueError("Width not divisible by 2.")
+    weights = numpy.zeros((2, 2, channels, 4 * channels))
+    for idx_ch in xrange(channels):
+        slice_2 = slice(idx_ch, (idx_ch + 1))
+        slice_3 = slice((idx_ch * 4), ((idx_ch + 1) * 4))
+        weights[:, :, slice_2, slice_3] = SQUEEZE_MATRIX
+    shuffle_channels = [idx_ch * 4 for idx_ch in xrange(channels)]
+    shuffle_channels += [idx_ch * 4 + 1 for idx_ch in xrange(channels)]
+    shuffle_channels += [idx_ch * 4 + 2 for idx_ch in xrange(channels)]
+    shuffle_channels += [idx_ch * 4 + 3 for idx_ch in xrange(channels)]
+    shuffle_channels = numpy.array(shuffle_channels)
+    weights = weights[:, :, :, shuffle_channels].astype("float32")
+    if reverse:
+        res = tf.nn.conv2d_transpose(
+            value=input_,
+            filter=weights,
+            output_shape=[batch_size, height * 2, width * 2, channels],
+            strides=[1, 2, 2, 1],
+            padding="SAME",
+            name="unsqueeze_2x2")
+    else:
+        res = tf.nn.conv2d(
+            input=input_,
+            filter=weights,
+            strides=[1, 2, 2, 1],
+            padding="SAME",
+            name="squeeze_2x2")
+
+    return res