Browse Source

Add cross conv model for next frame prediction.

Xin Pan 8 years ago
parent
commit
ba986cfcb0

+ 91 - 0
next_frame_prediction/README.md

@@ -0,0 +1,91 @@
+<font size=4><b>Visual Dynamics: Probabilistic Future Frame Synthesis via Cross Convolutional Networks.</b></font>
+
+<b>Introduction</b>
+
+https://arxiv.org/pdf/1607.02586v1.pdf
+
+This is an implementation based on my understanding, with small
+variations. It doesn't necessarily represents the paper published
+by the original authors.
+
+Authors: Xin Pan (Github: panyx0718), Anelia Angelova
+
+<b>Results:</b>
+
+<left>
+![Sample1](g3doc/cross_conv.png)
+</left>
+<left>
+![Sample2](g3doc/cross_conv2.png)
+</left>
+
+<left>
+![Loss](g3doc/cross_conv3.png)
+</left>
+
+
+<b>Prerequisite:</b>
+
+1. Install TensorFlow (r0.12), Bazel.
+
+2. Download the Sprites dataset or generate moving object dataset.
+
+Sprites data is located here:
+
+http://www.scottreed.info/files/nips2015-analogy-data.tar.gz
+
+Convert .mat files into images and use sprites_gen.py to convert them
+to tf.SequenceExample.
+
+<b>How to run:</b>
+
+```shell
+ls -R
+.:
+data  next_frame_prediction  WORKSPACE
+
+./data:
+tfrecords  tfrecords_test
+
+./next_frame_prediction:
+cross_conv  g3doc  README.md
+
+./next_frame_prediction/cross_conv:
+BUILD  eval.py  objects_gen.py  model.py  reader.py  sprites_gen.py  train.py
+
+./next_frame_prediction/g3doc:
+cross_conv2.png  cross_conv3.png  cross_conv.png
+
+
+# Build everything.
+bazel build -c opt next_frame_prediction/...
+
+# The following example runs the generated 2d objects.
+# For Sprites dataset, image_size should be 60, norm_scale should be 255.0.
+# Batch size is normally 16~64, depending on your memory size.
+#
+# Run training.
+bazel-bin/next_frame_prediction/cross_conv/train \
+  --batch_size=1 \
+  --data_filepattern=data/tfrecords \
+  --image_size=64 \
+  --log_root=/tmp/predict
+
+step: 1, loss: 24.428671
+step: 2, loss: 19.211605
+step: 3, loss: 5.543143
+step: 4, loss: 3.035339
+step: 5, loss: 1.771392
+step: 6, loss: 2.099824
+step: 7, loss: 1.747665
+step: 8, loss: 1.572436
+step: 9, loss: 1.586816
+step: 10, loss: 1.434191
+#
+# Run eval.
+bazel-bin/next_frame_prediction/cross_conv/eval \
+  --batch_size=1 \
+  --data_filepattern=data/tfrecords_test \
+  --image_size=64 \
+  --log_root=/tmp/predict
+```

+ 48 - 0
next_frame_prediction/cross_conv/BUILD

@@ -0,0 +1,48 @@
+licenses(["notice"])  # Apache 2.0
+
+package_group(
+    name = "internal",
+    packages = [
+        "//next_frame_prediction/...",
+    ],
+)
+
+package(default_visibility = [":internal"])
+
+py_library(
+    name = "model",
+    srcs = ["model.py"],
+)
+
+py_library(
+    name = "reader",
+    srcs = ["reader.py"],
+)
+
+py_binary(
+    name = "train",
+    srcs = ["train.py"],
+    deps = [
+        ":model",
+        ":reader",
+    ],
+)
+
+py_binary(
+    name = "eval",
+    srcs = ["eval.py"],
+    deps = [
+        ":model",
+        ":reader",
+    ],
+)
+
+py_binary(
+    name = "example_gen",
+    srcs = ["example_gen.py"],
+)
+
+py_binary(
+    name = "sprites_gen",
+    srcs = ["sprites_gen.py"],
+)

+ 118 - 0
next_frame_prediction/cross_conv/eval.py

@@ -0,0 +1,118 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Eval Cross Convolutional Model."""
+import io
+import os
+import sys
+import time
+
+import numpy as np
+import tensorflow as tf
+
+import model as cross_conv_model
+import reader
+
+FLAGS = tf.flags.FLAGS
+tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
+tf.flags.DEFINE_string('data_filepattern',
+                       'est',
+                       'training data file pattern.')
+tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
+tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
+tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
+tf.flags.DEFINE_float('scale', 10.0,
+                      'Scale the image after norm_scale and move the diff '
+                      'to the positive realm.')
+tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
+tf.flags.DEFINE_integer('eval_batch_count', 100,
+                        'Average the result this number of examples.')
+tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
+tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
+tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
+
+slim = tf.contrib.slim
+
+
+def _Eval():
+  params = dict()
+  params['batch_size'] = FLAGS.batch_size
+  params['seq_len'] = FLAGS.sequence_length
+  params['image_size'] = FLAGS.image_size
+  params['is_training'] = False
+  params['norm_scale'] = FLAGS.norm_scale
+  params['scale'] = FLAGS.scale
+  params['l2_loss'] = FLAGS.l2_loss
+  params['reconstr_loss'] = FLAGS.reconstr_loss
+  params['kl_loss'] = FLAGS.kl_loss
+
+  eval_dir = os.path.join(FLAGS.log_root, 'eval')
+
+  images = reader.ReadInput(
+      FLAGS.data_filepattern, shuffle=False, params=params)
+  images *= params['scale']
+  # Increase the value makes training much faster.
+  image_diff_list = reader.SequenceToImageAndDiff(images)
+  model = cross_conv_model.CrossConvModel(image_diff_list, params)
+  model.Build()
+
+  summary_writer = tf.summary.FileWriter(eval_dir)
+  saver = tf.train.Saver()
+  sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True))
+  tf.train.start_queue_runners(sess)
+
+  while True:
+    time.sleep(60)
+    try:
+      ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
+    except tf.errors.OutOfRangeError as e:
+      sys.stderr.write('Cannot restore checkpoint: %s\n' % e)
+      continue
+    if not (ckpt_state and ckpt_state.model_checkpoint_path):
+      sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root)
+      continue
+    sys.stderr.write('Loading checkpoint %s\n' %
+                     ckpt_state.model_checkpoint_path)
+    saver.restore(sess, ckpt_state.model_checkpoint_path)
+    # Use the empirical distribution of z from training set.
+    if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')):
+      sys.stderr.write('No z at %s\n' % FLAGS.log_root)
+      continue
+
+    with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f:
+      sample_z_mean = np.load(io.BytesIO(f.read()))
+    with tf.gfile.Open(
+        os.path.join(FLAGS.log_root, 'z_stddev_log.npy')) as f:
+      sample_z_stddev_log = np.load(io.BytesIO(f.read()))
+
+    total_loss = 0.0
+    for _ in xrange(FLAGS.eval_batch_count):
+      loss_val, total_steps, summaries = sess.run(
+          [model.loss, model.global_step, model.summary_op],
+          feed_dict={model.z_mean: sample_z_mean,
+                     model.z_stddev_log: sample_z_stddev_log})
+      total_loss += loss_val
+
+    summary_writer.add_summary(summaries, total_steps)
+    sys.stderr.write('steps: %d, loss: %f\n' %
+                     (total_steps, total_loss / FLAGS.eval_batch_count))
+
+
+def main(_):
+  _Eval()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 92 - 0
next_frame_prediction/cross_conv/example_gen.py

@@ -0,0 +1,92 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate examples of two objects moving in different directions."""
+import random
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string('out_file', '',
+                       'Output file for the tfrecords.')
+
+
+def _add_object(obj_type, image, image2, xpos, ypos):
+  """Add a moving obj to two consecutive images."""
+  obj_size = random.randint(8, 10)
+  channel = random.randint(0, 2)
+  move = random.randint(6, 10)
+
+  obj = np.zeros([obj_size, obj_size, 3])
+  if obj_type == 'rectangle':
+    xpos2 = xpos + move
+    ypos2 = ypos
+    for i in xrange(obj_size):
+      obj[i, 0:i+1, channel] = [1.0 for _ in xrange(i+1)]
+  elif obj_type == 'square':
+    xpos2 = xpos
+    ypos2 = ypos + move
+    obj[:, :, channel] = 1.0
+
+  for x in xrange(obj_size):
+    for y in xrange(obj_size):
+      if obj[x, y, channel] == 1.0:
+        image[xpos+x, ypos+y, channel] = 1.0
+        image2[xpos2+x, ypos2+y, channel] = 1.0
+
+
+def _images_to_example(image, image2):
+  """Convert two consecutive images to SequenceExample."""
+  example = tf.SequenceExample()
+  feature_list = example.feature_lists.feature_list['moving_objs']
+  feature = feature_list.feature.add()
+  feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
+  feature = feature_list.feature.add()
+  feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
+  return example
+
+
+def generate_input():
+  """Generate tfrecords."""
+  writer = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file)
+  writer2 = tf.python_io.TFRecordWriter(tf.flags.FLAGS.out_file + '_test')
+
+  examples = []
+  for xpos in xrange(0, 40, 3):
+    for ypos in xrange(0, 40, 3):
+      for xpos2 in xrange(0, 40, 3):
+        for ypos2 in xrange(0, 40, 3):
+          image = np.zeros([64, 64, 3])
+          image2 = np.zeros([64, 64, 3])
+          _add_object('rectangle', image, image2, xpos, ypos)
+          _add_object('square', image, image2, xpos2, ypos2)
+          examples.append(_images_to_example(image, image2))
+
+  sys.stderr.write('Finish generating examples.\n')
+  random.shuffle(examples)
+  for count, ex in enumerate(examples):
+    if count % 10 == 0:
+      writer2.write(ex.SerializeToString())
+    else:
+      writer.write(ex.SerializeToString())
+
+def main(_):
+  generate_input()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 232 - 0
next_frame_prediction/cross_conv/model.py

@@ -0,0 +1,232 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Cross Convolutional Model.
+
+https://arxiv.org/pdf/1607.02586v1.pdf
+"""
+import math
+import sys
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+class CrossConvModel(object):
+
+  def __init__(self, image_diff_list, params):
+    """Constructor.
+
+    Args:
+      image_diff_list: A list of (image, diff) tuples, with shape
+          [batch_size, image_size, image_size, 3] and image_sizes as
+          [32, 64, 128, 256].
+      params: Dict of parameters.
+    """
+    self.images = [i for (i, _) in image_diff_list]
+    # Move the diff to the positive realm.
+    self.diffs = [(d + params['scale']) / 2 for (i, d) in image_diff_list]
+    self.params = params
+
+  def Build(self):
+    with tf.device('/gpu:0'):
+      with slim.arg_scope([slim.conv2d],
+                          activation_fn=tf.nn.relu,
+                          normalizer_fn=slim.batch_norm,
+                          normalizer_params={'is_training':
+                                             self.params['is_training']}):
+        self._BuildMotionKernel()
+        encoded_images = self._BuildImageEncoder()
+        cross_conved_images = self._CrossConv(encoded_images)
+        self._BuildImageDecoder(cross_conved_images)
+        self._BuildLoss()
+
+      image = self.images[1]
+      diff = self.diffs[1]
+
+      self.global_step = tf.Variable(0, name='global_step', trainable=False)
+
+      if self.params['is_training']:
+        self._BuildTrainOp()
+
+      diff = diff * 2.0 - self.params['scale']
+      diff_output = self.diff_output * 2.0 - self.params['scale']
+      concat_image = tf.concat(
+          1, [image, image + diff_output, image + diff, diff_output])
+      tf.summary.image('origin_predict_expect_predictdiff', concat_image)
+      self.summary_op = tf.summary.merge_all()
+      return self.loss
+
+  def _BuildTrainOp(self):
+    lrn_rate = tf.maximum(
+        0.01,  # min_lr_rate.
+        tf.train.exponential_decay(
+            self.params['learning_rate'], self.global_step, 10000, 0.5))
+    tf.summary.scalar('learning rate', lrn_rate)
+    optimizer = tf.train.GradientDescentOptimizer(lrn_rate)
+    self.train_op = slim.learning.create_train_op(
+        self.loss, optimizer, global_step=self.global_step)
+
+  def _BuildLoss(self):
+    # 1. reconstr_loss seems doesn't do better than l2 loss.
+    # 2. Only works when using reduce_mean. reduce_sum doesn't work.
+    # 3. It seems kl loss doesn't play an important role.
+    self.loss = 0
+    with tf.variable_scope('loss'):
+      if self.params['l2_loss']:
+        l2_loss = tf.reduce_mean(tf.square(self.diff_output - self.diffs[1]))
+        tf.summary.scalar('l2_loss', l2_loss)
+        self.loss += l2_loss
+      if self.params['reconstr_loss']:
+        reconstr_loss = (-tf.reduce_mean(
+            self.diffs[1] * (1e-10 + self.diff_output) +
+            (1-self.diffs[1]) * tf.log(1e-10 + 1 - self.diff_output)))
+        reconstr_loss = tf.check_numerics(reconstr_loss, 'reconstr_loss')
+        tf.summary.scalar('reconstr_loss', reconstr_loss)
+        self.loss += reconstr_loss
+      if self.params['kl_loss']:
+        kl_loss = (0.5 * tf.reduce_mean(
+            tf.square(self.z_mean) + tf.square(self.z_stddev) -
+            2 * self.z_stddev_log - 1))
+        tf.summary.scalar('kl_loss', kl_loss)
+        self.loss += kl_loss
+
+      tf.summary.scalar('loss', self.loss)
+
+  def _BuildMotionKernel(self):
+    image = self.images[-2]
+    diff = self.diffs[-2]
+    shape = image.get_shape().as_list()
+    assert shape[1] == shape[2] and shape[1] == 128
+    batch_size = shape[0]
+
+    net = tf.concat(3, [image, diff])
+    with tf.variable_scope('motion_encoder'):
+      with slim.arg_scope([slim.conv2d], padding='VALID'):
+        net = slim.conv2d(net, 96, [5, 5], stride=1)
+        net = slim.max_pool2d(net, [2, 2])
+        net = slim.conv2d(net, 96, [5, 5], stride=1)
+        net = slim.max_pool2d(net, [2, 2])
+        net = slim.conv2d(net, 128, [5, 5], stride=1)
+        net = slim.conv2d(net, 128, [5, 5], stride=1)
+        net = slim.max_pool2d(net, [2, 2])
+        net = slim.conv2d(net, 256, [4, 4], stride=1)
+        net = slim.conv2d(net, 256, [3, 3], stride=1)
+
+        z = tf.reshape(net, shape=[batch_size, -1])
+        self.z_mean, self.z_stddev_log = tf.split(
+            split_dim=1, num_split=2, value=z)
+        self.z_stddev = tf.exp(self.z_stddev_log)
+
+        epsilon = tf.random_normal(
+            self.z_mean.get_shape().as_list(), 0, 1, dtype=tf.float32)
+        kernel = self.z_mean + tf.multiply(self.z_stddev, epsilon)
+
+        width = int(math.sqrt(kernel.get_shape().as_list()[1] // 128))
+        kernel = tf.reshape(kernel, [batch_size, width, width, 128])
+    with tf.variable_scope('kernel_decoder'):
+      with slim.arg_scope([slim.conv2d], padding='SAME'):
+        kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
+        self.kernel = slim.conv2d(kernel, 128, [5, 5], stride=1)
+
+    sys.stderr.write('kernel shape: %s\n' % kernel.get_shape())
+
+  def _BuildImageEncoder(self):
+    feature_maps = []
+    for (i, image) in enumerate(self.images):
+      with tf.variable_scope('image_encoder_%d' % i):
+        with slim.arg_scope([slim.conv2d, slim.max_pool2d], padding='SAME'):
+          net = slim.conv2d(image, 64, [5, 5], stride=1)
+          net = slim.conv2d(net, 64, [5, 5], stride=1)
+          net = slim.max_pool2d(net, [5, 5])
+          net = slim.conv2d(net, 64, [5, 5], stride=1)
+          net = slim.conv2d(net, 32, [5, 5], stride=1)
+          net = slim.max_pool2d(net, [2, 2])
+      sys.stderr.write('image_conv shape: %s\n' % net.get_shape())
+      feature_maps.append(net)
+    return feature_maps
+
+  def _CrossConvHelper(self, encoded_image, kernel):
+    """Cross Convolution.
+
+      The encoded image and kernel are of the same shape. Namely
+      [batch_size, image_size, image_size, channels]. They are split
+      into [image_size, image_size] image squares [kernel_size, kernel_size]
+      kernel squares. kernel squares are used to convolute image squares.
+    """
+    images = tf.expand_dims(encoded_image, 0)
+    kernels = tf.expand_dims(kernel, 3)
+    return tf.nn.depthwise_conv2d(images, kernels, [1, 1, 1, 1], 'SAME')
+
+  def _CrossConv(self, encoded_images):
+    """Apply the motion kernel on the encoded_images."""
+    cross_conved_images = []
+    kernels = tf.split(split_dim=3, num_split=4, value=self.kernel)
+    for (i, encoded_image) in enumerate(encoded_images):
+      with tf.variable_scope('cross_conv_%d' % i):
+        kernel = kernels[i]
+
+        encoded_image = tf.unstack(encoded_image, axis=0)
+        kernel = tf.unstack(kernel, axis=0)
+        assert len(encoded_image) == len(kernel)
+        assert len(encoded_image) == self.params['batch_size']
+        conved_image = []
+        for j in xrange(len(encoded_image)):
+          conved_image.append(self._CrossConvHelper(
+              encoded_image[j], kernel[j]))
+        cross_conved_images.append(tf.concat(0, conved_image))
+        sys.stderr.write('cross_conved shape: %s\n' %
+                         cross_conved_images[-1].get_shape())
+    return cross_conved_images
+
+  def _Deconv(self, net, out_filters, kernel_size, stride):
+    shape = net.get_shape().as_list()
+    in_filters = shape[3]
+    kernel_shape = [kernel_size, kernel_size, out_filters, in_filters]
+
+    weights = tf.get_variable(
+        name='weights',
+        shape=kernel_shape,
+        dtype=tf.float32,
+        initializer=tf.truncated_normal_initializer(stddev=0.01))
+
+
+    out_height = shape[1] * stride
+    out_width = shape[2] * stride
+    batch_size = shape[0]
+
+    output_shape = [batch_size, out_height, out_width, out_filters]
+    net = tf.nn.conv2d_transpose(net, weights, output_shape,
+                                 [1, stride, stride, 1], padding='SAME')
+    slim.batch_norm(net)
+    return net
+
+  def _BuildImageDecoder(self, cross_conved_images):
+    """Decode the cross_conved feature maps into the predicted images."""
+    nets = []
+    for i, cross_conved_image in enumerate(cross_conved_images):
+      with tf.variable_scope('image_decoder_%d' % i):
+        stride = 64 / cross_conved_image.get_shape().as_list()[1]
+        # TODO(xpan): Alternative solution for upsampling?
+        nets.append(self._Deconv(
+            cross_conved_image, 64, kernel_size=3, stride=stride))
+
+    net = tf.concat(3, nets)
+    net = slim.conv2d(net, 128, [9, 9], padding='SAME', stride=1)
+    net = slim.conv2d(net, 128, [1, 1], padding='SAME', stride=1)
+    net = slim.conv2d(net, 3, [1, 1], padding='SAME', stride=1)
+    self.diff_output = net
+    sys.stderr.write('diff_output shape: %s\n' % self.diff_output.get_shape())

+ 85 - 0
next_frame_prediction/cross_conv/reader.py

@@ -0,0 +1,85 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Read image sequence."""
+
+import tensorflow as tf
+
+
+def SequenceToImageAndDiff(images):
+  """Convert image sequence batch into image and diff batch.
+
+    Each image pair is converted to the first image and their diff.
+    Batch size will increase if sequence length is larger than 2.
+
+  Args:
+    images: Image sequence with shape
+        [batch_size, seq_len, image_size, image_size, channel]
+
+  Returns:
+    the list of (image, diff) tuples with shape
+        [batch_size2, image_size, image_size, channel]. image_sizes are
+        [32, 64, 128, 256].
+  """
+  image_diff_list = []
+  image_seq = tf.unstack(images, axis=1)
+  for size in [32, 64, 128, 256]:
+    resized_images = [
+        tf.image.resize_images(i, [size, size]) for i in image_seq]
+    diffs = []
+    for i in xrange(0, len(resized_images)-1):
+      diffs.append(resized_images[i+1] - resized_images[i])
+    image_diff_list.append(
+        (tf.concat(0, resized_images[:-1]), tf.concat(0, diffs)))
+  return image_diff_list
+
+
+def ReadInput(data_filepattern, shuffle, params):
+  """Read the tf.SequenceExample tfrecord files.
+
+  Args:
+    data_filepattern: tf.SequenceExample tfrecord filepattern.
+    shuffle: Whether to shuffle the examples.
+    params: parameter dict.
+
+  Returns:
+    image sequence batch [batch_size, seq_len, image_size, image_size, channel].
+  """
+  image_size = params['image_size']
+  filenames = tf.gfile.Glob(data_filepattern)
+  filename_queue = tf.train.string_input_producer(filenames, shuffle=shuffle)
+  reader = tf.TFRecordReader()
+  _, example = reader.read(filename_queue)
+  feature_sepc = {
+      'moving_objs': tf.FixedLenSequenceFeature(
+          shape=[image_size * image_size * 3], dtype=tf.float32)}
+  _, features = tf.parse_single_sequence_example(
+      example, sequence_features=feature_sepc)
+  moving_objs = tf.reshape(
+      features['moving_objs'], [params['seq_len'], image_size, image_size, 3])
+  if shuffle:
+    examples = tf.train.shuffle_batch(
+        [moving_objs],
+        batch_size=params['batch_size'],
+        num_threads=64,
+        capacity=params['batch_size'] * 100,
+        min_after_dequeue=params['batch_size'] * 4)
+  else:
+    examples = tf.train.batch([moving_objs],
+                              batch_size=params['batch_size'],
+                              num_threads=16,
+                              capacity=params['batch_size'])
+  examples /= params['norm_scale']
+  return examples

+ 97 - 0
next_frame_prediction/cross_conv/sprites_gen.py

@@ -0,0 +1,97 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Generate the sprites tfrecords from raw_images."""
+import os
+import random
+import re
+import sys
+
+import numpy as np
+import scipy.misc
+import tensorflow as tf
+
+
+tf.flags.DEFINE_string('data_filepattern', '', 'The raw images.')
+tf.flags.DEFINE_string('out_file', '',
+                       'File name for the tfrecord output.')
+
+
+def _read_images():
+  """Read images from image files into data structure."""
+  sprites = dict()
+  files = tf.gfile.Glob(tf.flags.FLAGS.data_filepattern)
+  for f in files:
+    image = scipy.misc.imread(f)
+    m = re.search('image_([0-9]+)_([0-9]+)_([0-9]+).jpg', os.path.basename(f))
+    if m.group(1) not in sprites:
+      sprites[m.group(1)] = dict()
+    character = sprites[m.group(1)]
+    if m.group(2) not in character:
+      character[m.group(2)] = dict()
+    pose = character[m.group(2)]
+    pose[int(m.group(3))] = image
+  return sprites
+
+
+def _images_to_example(image, image2):
+  """Convert 2 consecutive image to a SequenceExample."""
+  example = tf.SequenceExample()
+  feature_list = example.feature_lists.feature_list['moving_objs']
+  feature = feature_list.feature.add()
+  feature.float_list.value.extend(np.reshape(image, [-1]).tolist())
+  feature = feature_list.feature.add()
+  feature.float_list.value.extend(np.reshape(image2, [-1]).tolist())
+  return example
+
+
+def generate_input():
+  """Generate tfrecords."""
+  sprites = _read_images()
+  sys.stderr.write('Finish reading images.\n')
+  train_writer = tf.python_io.TFRecordWriter(
+      tf.flags.FLAGS.out_file.replace('sprites', 'sprites_train'))
+  test_writer = tf.python_io.TFRecordWriter(
+      tf.flags.FLAGS.out_file.replace('sprites', 'sprites_test'))
+
+  train_examples = []
+  test_examples = []
+  for i in sprites:
+    if int(i) < 24:
+      examples = test_examples
+    else:
+      examples = train_examples
+
+    character = sprites[i]
+    for j in character.keys():
+      pose = character[j]
+      for k in xrange(1, len(pose), 1):
+        image = pose[k]
+        image2 = pose[k+1]
+        examples.append(_images_to_example(image, image2))
+
+  sys.stderr.write('Finish generating examples: %d, %d.\n' %
+                   (len(train_examples), len(test_examples)))
+  random.shuffle(train_examples)
+  _ = [train_writer.write(ex.SerializeToString()) for ex in train_examples]
+  _ = [test_writer.write(ex.SerializeToString()) for ex in test_examples]
+
+
+def main(_):
+  generate_input()
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 122 - 0
next_frame_prediction/cross_conv/train.py

@@ -0,0 +1,122 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train the cross convolutional model."""
+import os
+import sys
+
+import numpy as np
+import tensorflow as tf
+
+import model as cross_conv_model
+import reader
+
+FLAGS = tf.flags.FLAGS
+tf.flags.DEFINE_string('master', '', 'Session address.')
+tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
+tf.flags.DEFINE_string('data_filepattern', '',
+                       'training data file pattern.')
+tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
+tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
+tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
+tf.flags.DEFINE_float('scale', 10.0,
+                      'Scale the image after norm_scale and move the diff '
+                      'to the positive realm.')
+tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
+tf.flags.DEFINE_float('learning_rate', 0.8, 'Learning rate.')
+tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
+tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
+tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
+
+slim = tf.contrib.slim
+
+
+def _Train():
+  params = dict()
+  params['batch_size'] = FLAGS.batch_size
+  params['seq_len'] = FLAGS.sequence_length
+  params['image_size'] = FLAGS.image_size
+  params['is_training'] = True
+  params['norm_scale'] = FLAGS.norm_scale
+  params['scale'] = FLAGS.scale
+  params['learning_rate'] = FLAGS.learning_rate
+  params['l2_loss'] = FLAGS.l2_loss
+  params['reconstr_loss'] = FLAGS.reconstr_loss
+  params['kl_loss'] = FLAGS.kl_loss
+
+  train_dir = os.path.join(FLAGS.log_root, 'train')
+
+  images = reader.ReadInput(FLAGS.data_filepattern, shuffle=True, params=params)
+  images *= params['scale']
+  # Increase the value makes training much faster.
+  image_diff_list = reader.SequenceToImageAndDiff(images)
+  model = cross_conv_model.CrossConvModel(image_diff_list, params)
+  model.Build()
+  tf.contrib.tfprof.model_analyzer.print_model_analysis(tf.get_default_graph())
+
+  summary_writer = tf.summary.FileWriter(train_dir)
+  sv = tf.train.Supervisor(logdir=FLAGS.log_root,
+                           summary_op=None,
+                           is_chief=True,
+                           save_model_secs=60,
+                           global_step=model.global_step)
+  sess = sv.prepare_or_wait_for_session(
+      FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True))
+
+  total_loss = 0.0
+  step = 0
+  sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
+  sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list())
+  sample_step = 0
+
+  while True:
+    _, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run(
+        [model.train_op, model.loss, model.global_step,
+         model.summary_op,
+         model.z_mean, model.z_stddev_log])
+
+    sample_z_mean += z_mean
+    sample_z_stddev_log += z_stddev_log
+    total_loss += loss_val
+    step += 1
+    sample_step += 1
+
+    if step % 100 == 0:
+      summary_writer.add_summary(summaries, total_steps)
+      sys.stderr.write('step: %d, loss: %f\n' %
+                       (total_steps, total_loss / step))
+      total_loss = 0.0
+      step = 0
+
+    # Sampled z is used for eval.
+    # It seems 10k is better than 1k. Maybe try 100k next?
+    if sample_step % 10000 == 0:
+      with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'), 'w') as f:
+        np.save(f, sample_z_mean / sample_step)
+      with tf.gfile.Open(
+          os.path.join(FLAGS.log_root, 'z_stddev_log.npy'), 'w') as f:
+        np.save(f, sample_z_stddev_log / sample_step)
+      sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
+      sample_z_stddev_log = np.zeros(
+          model.z_stddev_log.get_shape().as_list())
+      sample_step = 0
+
+
+def main(_):
+  _Train()
+
+
+if __name__ == '__main__':
+  tf.app.run()

BIN
next_frame_prediction/g3doc/cross_conv.png


BIN
next_frame_prediction/g3doc/cross_conv2.png


BIN
next_frame_prediction/g3doc/cross_conv3.png