浏览代码

video prediction model code

Chelsea Finn 9 年之前
父节点
当前提交
d67ea24901

+ 99 - 0
video_prediction/README.md

@@ -0,0 +1,99 @@
+# Video Prediction with Neural Advection
+
+*A TensorFlow implementation of the models described in [Finn et al. (2016)]
+(http://arxiv.org/abs/1605.07157).*
+
+This video prediction model, which is optionally conditioned on actions,
+predictions future video by internally predicting how to transform the last
+image (which may have been predicted) into the next image. As a result, it can
+reuse apperance information from previous frames and can better generalize to
+objects not seen in the training set. Some example predictions on novel objects
+are shown below:
+
+![Animation](https://storage.googleapis.com/push_gens/novelgengifs9/16_70.gif)
+![Animation](https://storage.googleapis.com/push_gens/novelgengifs9/2_96.gif)
+![Animation](https://storage.googleapis.com/push_gens/novelgengifs9/1_38.gif)
+![Animation](https://storage.googleapis.com/push_gens/novelgengifs9/11_10.gif)
+![Animation](https://storage.googleapis.com/push_gens/novelgengifs9/3_34.gif)
+
+When the model is conditioned on actions, it changes it's predictions based on
+the passed in action. Here we show the models predictions in response to varying
+the magnitude of the passed in actions, from small to large:
+
+![Animation](https://storage.googleapis.com/push_gens/webgifs/0xact_0.gif)
+![Animation](https://storage.googleapis.com/push_gens/05xact_0.gif)
+![Animation](https://storage.googleapis.com/push_gens/webgifs/1xact_0.gif)
+![Animation](https://storage.googleapis.com/push_gens/webgifs/15xact_0.gif)
+
+![Animation](https://storage.googleapis.com/push_gens/webgifs/0xact_17.gif)
+![Animation](https://storage.googleapis.com/push_gens/webgifs/05xact_17.gif)
+![Animation](https://storage.googleapis.com/push_gens/webgifs/1xact_17.gif)
+![Animation](https://storage.googleapis.com/push_gens/webgifs/15xact_17.gif)
+
+
+Because the model is trained with an l2 objective, it represents uncertainty as
+blur.
+
+## Requirements
+* Tensorflow (see tensorflow.org for installation instructions)
+* spatial_tranformer model in tensorflow/models, for the spatial tranformer
+  predictor (STP).
+
+## Data
+The data used to train this model is located
+[here](https://sites.google.com/site/brainrobotdata/home/push-dataset).
+
+To download the robot data, run the following.
+```shell
+./download_data.sh
+```
+
+## Training the model
+
+To train the model, run the prediction_train.py file.
+```shell
+python prediction_train.py
+```
+
+There are several flags which can control the model that is trained, which are
+exeplified below:
+```shell
+python prediction_train.py \
+  --data_dir=push/push_train \ # path to the training set.
+  --model=CDNA \ # the model type to use - DNA, CDNA, or STP
+  --output_dir=./checkpoints \ # where to save model checkpoints
+  --event_log_dir=./summaries \ # where to save training statistics
+  --num_iterations=100000 \ # number of training iterations
+  --pretrained_model=model \ # path to model to initialize from, random if emtpy
+  --sequence_length=10 \ # the number of total frames in a sequence
+  --context_frames=2 \ # the number of ground truth frames to pass in at start
+  --use_state=1 \ # whether or not to condition on actions and the initial state
+  --num_masks=10 \ # the number of transformations and corresponding masks
+  --schedsamp_k=900.0 \ # the constant used for scheduled sampling or -1
+  --train_val_split=0.95 \ # the percentage of training data for validation
+  --batch_size=32 \ # the training batch size
+  --learning_rate=0.001 \ # the initial learning rate for the Adam optimizer
+```
+
+If the dynamic neural advection (DNA) model is being used, the `--num_masks`
+option should be set to one.
+
+The `--context_frames` option defines both the number of initial ground truth
+frames to pass in, as well as when to start penalizing the model's predictions.
+
+The data directory `--data_dir` should contain tfrecord files with the format
+used in the released push dataset. See
+[here](https://sites.google.com/site/brainrobotdata/home/push-dataset) for
+details. If the `--use_state` option is not set, then the data only needs to
+contain image sequences, not states and actions.
+
+
+## Contact
+
+To ask questions or report issues please open an issue on the tensorflow/models
+[issues tracker](https://github.com/tensorflow/models/issues).
+Please assign issues to @cbfinn.
+
+## Credits
+
+This code was written by Chelsea Finn.

+ 55 - 0
video_prediction/download_data.sh

@@ -0,0 +1,55 @@
+#!/bin/bash
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+# Example:
+#
+#   download_dataset.sh datafiles.txt ./tmp
+#
+# will download all of the files listed in the file, datafiles.txt, into
+# a directory, "./tmp".
+#
+# Each line of the datafiles.txt file should contain the path from the
+# bucket root to a file.
+
+ARGC="$#"
+LISTING_FILE=push_datafiles.txt
+if [ "${ARGC}" -ge 1 ]; then
+  LISTING_FILE=$1
+fi
+OUTPUT_DIR="./"
+if [ "${ARGC}" -ge 2 ]; then
+  OUTPUT_DIR=$2
+fi
+
+echo "OUTPUT_DIR=$OUTPUT_DIR"
+
+mkdir "${OUTPUT_DIR}"
+
+function download_file {
+  FILE=$1
+  BUCKET="https://storage.googleapis.com/brain-robotics-data"
+  URL="${BUCKET}/${FILE}"
+  OUTPUT_FILE="${OUTPUT_DIR}/${FILE}"
+  DIRECTORY=`dirname ${OUTPUT_FILE}`
+  echo DIRECTORY=$DIRECTORY
+  mkdir -p "${DIRECTORY}"
+  curl --output ${OUTPUT_FILE} ${URL}
+}
+
+while read filename; do
+  download_file $filename
+done <${LISTING_FILE}

+ 110 - 0
video_prediction/lstm_ops.py

@@ -0,0 +1,110 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Convolutional LSTM implementation."""
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim import add_arg_scope
+from tensorflow.contrib.slim import layers
+
+
+def init_state(inputs,
+               state_shape,
+               state_initializer=tf.zeros_initializer,
+               dtype=tf.float32):
+  """Helper function to create an initial state given inputs.
+
+  Args:
+    inputs: input Tensor, at least 2D, the first dimension being batch_size
+    state_shape: the shape of the state.
+    state_initializer: Initializer(shape, dtype) for state Tensor.
+    dtype: Optional dtype, needed when inputs is None.
+  Returns:
+     A tensors representing the initial state.
+  """
+  if inputs is not None:
+    # Handle both the dynamic shape as well as the inferred shape.
+    inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0]
+    batch_size = tf.shape(inputs)[0]
+    dtype = inputs.dtype
+  else:
+    inferred_batch_size = 0
+    batch_size = 0
+
+  initial_state = state_initializer(
+      tf.pack([batch_size] + state_shape),
+      dtype=dtype)
+  initial_state.set_shape([inferred_batch_size] + state_shape)
+
+  return initial_state
+
+
+@add_arg_scope
+def basic_conv_lstm_cell(inputs,
+                         state,
+                         num_channels,
+                         filter_size=5,
+                         forget_bias=1.0,
+                         scope=None,
+                         reuse=None):
+  """Basic LSTM recurrent network cell, with 2D convolution connctions.
+
+  We add forget_bias (default: 1) to the biases of the forget gate in order to
+  reduce the scale of forgetting in the beginning of the training.
+
+  It does not allow cell clipping, a projection layer, and does not
+  use peep-hole connections: it is the basic baseline.
+
+  Args:
+    inputs: input Tensor, 4D, batch x height x width x channels.
+    state: state Tensor, 4D, batch x height x width x channels.
+    num_channels: the number of output channels in the layer.
+    filter_size: the shape of the each convolution filter.
+    forget_bias: the initial value of the forget biases.
+    scope: Optional scope for variable_scope.
+    reuse: whether or not the layer and the variables should be reused.
+
+  Returns:
+     a tuple of tensors representing output and the new state.
+  """
+  spatial_size = inputs.get_shape()[1:3]
+  if state is None:
+    state = init_state(inputs, list(spatial_size) + [2 * num_channels])
+  with tf.variable_scope(scope,
+                         'BasicConvLstmCell',
+                         [inputs, state],
+                         reuse=reuse):
+    inputs.get_shape().assert_has_rank(4)
+    state.get_shape().assert_has_rank(4)
+    c, h = tf.split(3, 2, state)
+    inputs_h = tf.concat(3, [inputs, h])
+    # Parameters of gates are concatenated into one conv for efficiency.
+    i_j_f_o = layers.conv2d(inputs_h,
+                            4 * num_channels, [filter_size, filter_size],
+                            stride=1,
+                            activation_fn=None,
+                            scope='Gates')
+
+    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
+    i, j, f, o = tf.split(3, 4, i_j_f_o)
+
+    new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j)
+    new_h = tf.tanh(new_c) * tf.sigmoid(o)
+
+    return new_h, tf.concat(3, [new_c, new_h])
+
+
+

+ 119 - 0
video_prediction/prediction_input.py

@@ -0,0 +1,119 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Code for building the input for the prediction model."""
+
+import os
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import gfile
+
+
+FLAGS = flags.FLAGS
+
+# Original image dimensions
+ORIGINAL_WIDTH = 640
+ORIGINAL_HEIGHT = 512
+COLOR_CHAN = 3
+
+# Default image dimensions.
+IMG_WIDTH = 64
+IMG_HEIGHT = 64
+
+# Dimension of the state and action.
+STATE_DIM = 5
+
+
+def build_tfrecord_input(training=True):
+  """Create input tfrecord tensors.
+
+  Args:
+    training: training or validation data.
+  Returns:
+    list of tensors corresponding to images, actions, and states. The images
+    tensor is 5D, batch x time x height x width x channels. The state and
+    action tensors are 3D, batch x time x dimension.
+  Raises:
+    RuntimeError: if no files found.
+  """
+  filenames = gfile.Glob(os.path.join(FLAGS.data_dir, '*'))
+  if not filenames:
+    raise RuntimeError('No data files found.')
+  index = int(np.floor(FLAGS.train_val_split * len(filenames)))
+  if training:
+    filenames = filenames[:index]
+  else:
+    filenames = filenames[index:]
+  filename_queue = tf.train.string_input_producer(filenames, shuffle=True)
+  reader = tf.TFRecordReader()
+  _, serialized_example = reader.read(filename_queue)
+
+  image_seq, state_seq, action_seq = [], [], []
+
+  for i in range(FLAGS.sequence_length):
+    image_name = 'move/' + str(i) + '/image/encoded'
+    action_name = 'move/' + str(i) + '/commanded_pose/vec_pitch_yaw'
+    state_name = 'move/' + str(i) + '/endeffector/vec_pitch_yaw'
+    if FLAGS.use_state:
+      features = {image_name: tf.FixedLenFeature([1], tf.string),
+                  action_name: tf.FixedLenFeature([STATE_DIM], tf.float32),
+                  state_name: tf.FixedLenFeature([STATE_DIM], tf.float32)}
+    else:
+      features = {image_name: tf.FixedLenFeature([1], tf.string)}
+    features = tf.parse_single_example(serialized_example, features=features)
+
+    image_buffer = tf.reshape(features[image_name], shape=[])
+    image = tf.image.decode_jpeg(image_buffer, channels=COLOR_CHAN)
+    image.set_shape([ORIGINAL_HEIGHT, ORIGINAL_WIDTH, COLOR_CHAN])
+
+    if IMG_HEIGHT != IMG_WIDTH:
+      raise ValueError('Unequal height and width unsupported')
+
+    crop_size = min(ORIGINAL_HEIGHT, ORIGINAL_WIDTH)
+    image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size)
+    image = tf.reshape(image, [1, crop_size, crop_size, COLOR_CHAN])
+    image = tf.image.resize_bicubic(image, [IMG_HEIGHT, IMG_WIDTH])
+    image = tf.cast(image, tf.float32) / 255.0
+    image_seq.append(image)
+
+    if FLAGS.use_state:
+      state = tf.reshape(features[state_name], shape=[1, STATE_DIM])
+      state_seq.append(state)
+      action = tf.reshape(features[action_name], shape=[1, STATE_DIM])
+      action_seq.append(action)
+
+  image_seq = tf.concat(0, image_seq)
+
+  if FLAGS.use_state:
+    state_seq = tf.concat(0, state_seq)
+    action_seq = tf.concat(0, action_seq)
+    [image_batch, action_batch, state_batch] = tf.train.batch(
+        [image_seq, action_seq, state_seq],
+        FLAGS.batch_size,
+        num_threads=FLAGS.batch_size,
+        capacity=100 * FLAGS.batch_size)
+    return image_batch, action_batch, state_batch
+  else:
+    image_batch = tf.train.batch(
+        [image_seq],
+        FLAGS.batch_size,
+        num_threads=FLAGS.batch_size,
+        capacity=100 * FLAGS.batch_size)
+    zeros_batch = tf.zeros([FLAGS.batch_size, FLAGS.sequence_length, STATE_DIM])
+    return image_batch, zeros_batch, zeros_batch
+

+ 346 - 0
video_prediction/prediction_model.py

@@ -0,0 +1,346 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model architecture for predictive model, including CDNA, DNA, and STP."""
+
+import numpy as np
+import tensorflow as tf
+
+import tensorflow.contrib.slim as slim
+from tensorflow.contrib.layers.python import layers as tf_layers
+from lstm_ops import basic_conv_lstm_cell
+
+# Amount to use when lower bounding tensors
+RELU_SHIFT = 1e-12
+
+# kernel size for DNA and CDNA.
+DNA_KERN_SIZE = 5
+
+
+def construct_model(images,
+                    actions=None,
+                    states=None,
+                    iter_num=-1.0,
+                    k=-1,
+                    use_state=True,
+                    num_masks=10,
+                    stp=False,
+                    cdna=True,
+                    dna=False,
+                    context_frames=2):
+  """Build convolutional lstm video predictor using STP, CDNA, or DNA.
+
+  Args:
+    images: tensor of ground truth image sequences
+    actions: tensor of action sequences
+    states: tensor of ground truth state sequences
+    iter_num: tensor of the current training iteration (for sched. sampling)
+    k: constant used for scheduled sampling. -1 to feed in own prediction.
+    use_state: True to include state and action in prediction
+    num_masks: the number of different pixel motion predictions (and
+               the number of masks for each of those predictions)
+    stp: True to use Spatial Transformer Predictor (STP)
+    cdna: True to use Convoluational Dynamic Neural Advection (CDNA)
+    dna: True to use Dynamic Neural Advection (DNA)
+    context_frames: number of ground truth frames to pass in before
+                    feeding in own predictions
+  Returns:
+    gen_images: predicted future image frames
+    gen_states: predicted future states
+
+  Raises:
+    ValueError: if more than one network option specified or more than 1 mask
+    specified for DNA model.
+  """
+  if stp + cdna + dna != 1:
+    raise ValueError('More than one, or no network option specified.')
+  batch_size, img_height, img_width, color_channels = images[0].get_shape()[0:4]
+  lstm_func = basic_conv_lstm_cell
+
+  # Generated robot states and images.
+  gen_states, gen_images = [], []
+  current_state = states[0]
+
+  if k == -1:
+    feedself = True
+  else:
+    # Scheduled sampling:
+    # Calculate number of ground-truth frames to pass in.
+    num_ground_truth = tf.to_int32(
+        tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k)))))
+    feedself = False
+
+  # LSTM state sizes and states.
+  lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32]))
+  lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None
+  lstm_state5, lstm_state6, lstm_state7 = None, None, None
+
+  for image, action in zip(images[:-1], actions[:-1]):
+    # Reuse variables after the first timestep.
+    reuse = bool(gen_images)
+
+    done_warm_start = len(gen_images) > context_frames - 1
+    with slim.arg_scope(
+        [lstm_func, slim.layers.conv2d, slim.layers.fully_connected,
+         tf_layers.layer_norm, slim.layers.conv2d_transpose],
+        reuse=reuse):
+
+      if feedself and done_warm_start:
+        # Feed in generated image.
+        prev_image = gen_images[-1]
+      elif done_warm_start:
+        # Scheduled sampling
+        prev_image = scheduled_sample(image, gen_images[-1], batch_size,
+                                      num_ground_truth)
+      else:
+        # Always feed in ground_truth
+        prev_image = image
+
+      # Predicted state is always fed back in
+      state_action = tf.concat(1, [action, current_state])
+
+      enc0 = slim.layers.conv2d(
+          prev_image,
+          32, [5, 5],
+          stride=2,
+          scope='scale1_conv1',
+          normalizer_fn=tf_layers.layer_norm,
+          normalizer_params={'scope': 'layer_norm1'})
+
+      hidden1, lstm_state1 = lstm_func(
+          enc0, lstm_state1, lstm_size[0], scope='state1')
+      hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2')
+      hidden2, lstm_state2 = lstm_func(
+          hidden1, lstm_state2, lstm_size[1], scope='state2')
+      hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3')
+      enc1 = slim.layers.conv2d(
+          hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2')
+
+      hidden3, lstm_state3 = lstm_func(
+          enc1, lstm_state3, lstm_size[2], scope='state3')
+      hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4')
+      hidden4, lstm_state4 = lstm_func(
+          hidden3, lstm_state4, lstm_size[3], scope='state4')
+      hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5')
+      enc2 = slim.layers.conv2d(
+          hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3')
+
+      # Pass in state and action.
+      smear = tf.reshape(
+          state_action,
+          [int(batch_size), 1, 1, int(state_action.get_shape()[1])])
+      smear = tf.tile(
+          smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1])
+      if use_state:
+        enc2 = tf.concat(3, [enc2, smear])
+      enc3 = slim.layers.conv2d(
+          enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4')
+
+      hidden5, lstm_state5 = lstm_func(
+          enc3, lstm_state5, lstm_size[4], scope='state5')  # last 8x8
+      hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6')
+      enc4 = slim.layers.conv2d_transpose(
+          hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1')
+
+      hidden6, lstm_state6 = lstm_func(
+          enc4, lstm_state6, lstm_size[5], scope='state6')  # 16x16
+      hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7')
+      # Skip connection.
+      hidden6 = tf.concat(3, [hidden6, enc1])  # both 16x16
+
+      enc5 = slim.layers.conv2d_transpose(
+          hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2')
+      hidden7, lstm_state7 = lstm_func(
+          enc5, lstm_state7, lstm_size[6], scope='state7')  # 32x32
+      hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8')
+
+      # Skip connection.
+      hidden7 = tf.concat(3, [hidden7, enc0])  # both 32x32
+
+      enc6 = slim.layers.conv2d_transpose(
+          hidden7,
+          hidden7.get_shape()[3], 3, stride=2, scope='convt3',
+          normalizer_fn=tf_layers.layer_norm,
+          normalizer_params={'scope': 'layer_norm9'})
+
+      if dna:
+        # Using largest hidden state for predicting untied conv kernels.
+        enc7 = slim.layers.conv2d_transpose(
+            enc6, DNA_KERN_SIZE**2, 1, stride=1, scope='convt4')
+      else:
+        # Using largest hidden state for predicting a new image layer.
+        enc7 = slim.layers.conv2d_transpose(
+            enc6, color_channels, 1, stride=1, scope='convt4')
+        # This allows the network to also generate one image from scratch,
+        # which is useful when regions of the image become unoccluded.
+        transformed = [tf.nn.sigmoid(enc7)]
+
+      if stp:
+        stp_input0 = tf.reshape(hidden5, [int(batch_size), -1])
+        stp_input1 = slim.layers.fully_connected(
+            stp_input0, 100, scope='fc_stp')
+        transformed += stp_transformation(prev_image, stp_input1, num_masks)
+      elif cdna:
+        cdna_input = tf.reshape(hidden5, [int(batch_size), -1])
+        transformed += cdna_transformation(prev_image, cdna_input, num_masks,
+                                           int(color_channels))
+      elif dna:
+        # Only one mask is supported (more should be unnecessary).
+        if num_masks != 1:
+          raise ValueError('Only one mask is supported for DNA model.')
+        transformed = [dna_transformation(prev_image, enc7)]
+
+      masks = slim.layers.conv2d_transpose(
+          enc6, num_masks + 1, 1, stride=1, scope='convt7')
+      masks = tf.reshape(
+          tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])),
+          [int(batch_size), int(img_height), int(img_width), num_masks + 1])
+      mask_list = tf.split(3, num_masks + 1, masks)
+      output = mask_list[0] * prev_image
+      for layer, mask in zip(transformed, mask_list[1:]):
+        output += layer * mask
+      gen_images.append(output)
+
+      current_state = slim.layers.fully_connected(
+          state_action,
+          int(current_state.get_shape()[1]),
+          scope='state_pred',
+          activation_fn=None)
+      gen_states.append(current_state)
+
+  return gen_images, gen_states
+
+
+## Utility functions
+def stp_transformation(prev_image, stp_input, num_masks):
+  """Apply spatial transformer predictor (STP) to previous image.
+
+  Args:
+    prev_image: previous image to be transformed.
+    stp_input: hidden layer to be used for computing STN parameters.
+    num_masks: number of masks and hence the number of STP transformations.
+  Returns:
+    List of images transformed by the predicted STP parameters.
+  """
+  # Only import spatial transformer if needed.
+  from spatial_transformer import transformer
+
+  identity_params = tf.convert_to_tensor(
+      np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32))
+  transformed = []
+  for i in range(num_masks - 1):
+    params = slim.layers.fully_connected(
+        stp_input, 6, scope='stp_params' + str(i),
+        activation_fn=None) + identity_params
+    transformed.append(transformer(prev_image, params))
+
+  return transformed
+
+
+def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
+  """Apply convolutional dynamic neural advection to previous image.
+
+  Args:
+    prev_image: previous image to be transformed.
+    cdna_input: hidden lyaer to be used for computing CDNA kernels.
+    num_masks: the number of masks and hence the number of CDNA transformations.
+    color_channels: the number of color channels in the images.
+  Returns:
+    List of images transformed by the predicted CDNA kernels.
+  """
+  batch_size = int(cdna_input.get_shape()[0])
+
+  # Predict kernels using linear function of last hidden layer.
+  cdna_kerns = slim.layers.fully_connected(
+      cdna_input,
+      DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks,
+      scope='cdna_params',
+      activation_fn=None)
+
+  # Reshape and normalize.
+  cdna_kerns = tf.reshape(
+      cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks])
+  cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT
+  norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
+  cdna_kerns /= norm_factor
+
+  cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
+  cdna_kerns = tf.split(0, batch_size, cdna_kerns)
+  prev_images = tf.split(0, batch_size, prev_image)
+
+  # Transform image.
+  transformed = []
+  for kernel, preimg in zip(cdna_kerns, prev_images):
+    kernel = tf.squeeze(kernel)
+    if len(kernel.get_shape()) == 3:
+      kernel = tf.expand_dims(kernel, -1)
+    transformed.append(
+        tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
+  transformed = tf.concat(0, transformed)
+  transformed = tf.split(3, num_masks, transformed)
+  return transformed
+
+
+def dna_transformation(prev_image, dna_input):
+  """Apply dynamic neural advection to previous image.
+
+  Args:
+    prev_image: previous image to be transformed.
+    dna_input: hidden lyaer to be used for computing DNA transformation.
+  Returns:
+    List of images transformed by the predicted CDNA kernels.
+  """
+  # Construct translated images.
+  prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]])
+  image_height = int(prev_image.get_shape()[1])
+  image_width = int(prev_image.get_shape()[2])
+
+  inputs = []
+  for xkern in range(DNA_KERN_SIZE):
+    for ykern in range(DNA_KERN_SIZE):
+      inputs.append(
+          tf.expand_dims(
+              tf.slice(prev_image_pad, [0, xkern, ykern, 0],
+                       [-1, image_height, image_width, -1]), [3]))
+  inputs = tf.concat(3, inputs)
+
+  # Normalize channels to 1.
+  kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT
+  kernel = tf.expand_dims(
+      kernel / tf.reduce_sum(
+          kernel, [3], keep_dims=True), [4])
+  return tf.reduce_sum(kernel * inputs, [3], keep_dims=False)
+
+
+def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth):
+  """Sample batch with specified mix of ground truth and generated data points.
+
+  Args:
+    ground_truth_x: tensor of ground-truth data points.
+    generated_x: tensor of generated data points.
+    batch_size: batch size
+    num_ground_truth: number of ground-truth examples to include in batch.
+  Returns:
+    New batch with num_ground_truth sampled from ground_truth_x and the rest
+    from generated_x.
+  """
+  idx = tf.random_shuffle(tf.range(int(batch_size)))
+  ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth))
+  generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size)))
+
+  ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx)
+  generated_examps = tf.gather(generated_x, generated_idx)
+  return tf.dynamic_stitch([ground_truth_idx, generated_idx],
+                           [ground_truth_examps, generated_examps])

+ 249 - 0
video_prediction/prediction_train.py

@@ -0,0 +1,249 @@
+# Copyright 2016 The TensorFlow Authors All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Code for training the prediction model."""
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+
+from prediction_input import build_tfrecord_input
+from prediction_model import construct_model
+
+# How often to record tensorboard summaries.
+SUMMARY_INTERVAL = 40
+
+# How often to run a batch through the validation model.
+VAL_INTERVAL = 200
+
+# How often to save a model checkpoint
+SAVE_INTERVAL = 2000
+
+# tf record data location:
+DATA_DIR = 'push/push_train'
+
+# local output directory
+OUT_DIR = '/tmp/data'
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('data_dir', DATA_DIR, 'directory containing data.')
+flags.DEFINE_string('output_dir', OUT_DIR, 'directory for model checkpoints.')
+flags.DEFINE_string('event_log_dir', OUT_DIR, 'directory for writing summary.')
+flags.DEFINE_integer('num_iterations', 100000, 'number of training iterations.')
+flags.DEFINE_string('pretrained_model', '',
+                    'filepath of a pretrained model to initialize from.')
+
+flags.DEFINE_integer('sequence_length', 10,
+                     'sequence length, including context frames.')
+flags.DEFINE_integer('context_frames', 2, '# of frames before predictions.')
+flags.DEFINE_integer('use_state', 1,
+                     'Whether or not to give the state+action to the model')
+
+flags.DEFINE_string('model', 'CDNA',
+                    'model architecture to use - CDNA, DNA, or STP')
+
+flags.DEFINE_integer('num_masks', 10,
+                     'number of masks, usually 1 for DNA, 10 for CDNA, STN.')
+flags.DEFINE_float('schedsamp_k', 900.0,
+                   'The k hyperparameter for scheduled sampling,'
+                   '-1 for no scheduled sampling.')
+flags.DEFINE_float('train_val_split', 0.95,
+                   'The percentage of files to use for the training set,'
+                   ' vs. the validation set.')
+
+flags.DEFINE_integer('batch_size', 32, 'batch size for training')
+flags.DEFINE_float('learning_rate', 0.001,
+                   'the base learning rate of the generator')
+
+
+## Helper functions
+def peak_signal_to_noise_ratio(true, pred):
+  """Image quality metric based on maximal signal power vs. power of the noise.
+
+  Args:
+    true: the ground truth image.
+    pred: the predicted image.
+  Returns:
+    peak signal to noise ratio (PSNR)
+  """
+  return 10.0 * tf.log(1.0 / mean_squared_error(true, pred)) / tf.log(10.0)
+
+
+def mean_squared_error(true, pred):
+  """L2 distance between tensors true and pred.
+
+  Args:
+    true: the ground truth image.
+    pred: the predicted image.
+  Returns:
+    mean squared error between ground truth and predicted image.
+  """
+  return tf.reduce_sum(tf.square(true - pred)) / tf.to_float(tf.size(pred))
+
+
+class Model(object):
+
+  def __init__(self,
+               images=None,
+               actions=None,
+               states=None,
+               sequence_length=None,
+               reuse_scope=None):
+
+    if sequence_length is None:
+      sequence_length = FLAGS.sequence_length
+
+    self.prefix = prefix = tf.placeholder(tf.string, [])
+    self.iter_num = tf.placeholder(tf.float32, [])
+    summaries = []
+
+    # Split into timesteps.
+    actions = tf.split(1, actions.get_shape()[1], actions)
+    actions = [tf.squeeze(act) for act in actions]
+    states = tf.split(1, states.get_shape()[1], states)
+    states = [tf.squeeze(st) for st in states]
+    images = tf.split(1, images.get_shape()[1], images)
+    images = [tf.squeeze(img) for img in images]
+
+    if reuse_scope is None:
+      gen_images, gen_states = construct_model(
+          images,
+          actions,
+          states,
+          iter_num=self.iter_num,
+          k=FLAGS.schedsamp_k,
+          use_state=FLAGS.use_state,
+          num_masks=FLAGS.num_masks,
+          cdna=FLAGS.model == 'CDNA',
+          dna=FLAGS.model == 'DNA',
+          stp=FLAGS.model == 'STP',
+          context_frames=FLAGS.context_frames)
+    else:  # If it's a validation or test model.
+      with tf.variable_scope(reuse_scope, reuse=True):
+        gen_images, gen_states = construct_model(
+            images,
+            actions,
+            states,
+            iter_num=self.iter_num,
+            k=FLAGS.schedsamp_k,
+            use_state=FLAGS.use_state,
+            num_masks=FLAGS.num_masks,
+            cdna=FLAGS.model == 'CDNA',
+            dna=FLAGS.model == 'DNA',
+            stp=FLAGS.model == 'STP',
+            context_frames=FLAGS.context_frames)
+
+    # L2 loss, PSNR for eval.
+    loss, psnr_all = 0.0, 0.0
+    for i, x, gx in zip(
+        range(len(gen_images)), images[FLAGS.context_frames:],
+        gen_images[FLAGS.context_frames - 1:]):
+      recon_cost = mean_squared_error(x, gx)
+      psnr_i = peak_signal_to_noise_ratio(x, gx)
+      psnr_all += psnr_i
+      summaries.append(
+          tf.scalar_summary(prefix + '_recon_cost' + str(i), recon_cost))
+      summaries.append(tf.scalar_summary(prefix + '_psnr' + str(i), psnr_i))
+      loss += recon_cost
+
+    for i, state, gen_state in zip(
+        range(len(gen_states)), states[FLAGS.context_frames:],
+        gen_states[FLAGS.context_frames - 1:]):
+      state_cost = mean_squared_error(state, gen_state) * 1e-4
+      summaries.append(
+          tf.scalar_summary(prefix + '_state_cost' + str(i), state_cost))
+      loss += state_cost
+    summaries.append(tf.scalar_summary(prefix + '_psnr_all', psnr_all))
+    self.psnr_all = psnr_all
+
+    self.loss = loss = loss / np.float32(len(images) - FLAGS.context_frames)
+
+    summaries.append(tf.scalar_summary(prefix + '_loss', loss))
+
+    self.lr = tf.placeholder_with_default(FLAGS.learning_rate, ())
+
+    self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
+    self.summ_op = tf.merge_summary(summaries)
+
+
+def main(unused_argv):
+
+  print 'Constructing models and inputs.'
+  with tf.variable_scope('model', reuse=None) as training_scope:
+    images, actions, states = build_tfrecord_input(training=True)
+    model = Model(images, actions, states, FLAGS.sequence_length)
+
+  with tf.variable_scope('val_model', reuse=None):
+    val_images, val_actions, val_states = build_tfrecord_input(training=False)
+    val_model = Model(val_images, val_actions, val_states,
+                      FLAGS.sequence_length, training_scope)
+
+  print 'Constructing saver.'
+  # Make saver.
+  saver = tf.train.Saver(
+      tf.get_collection(tf.GraphKeys.VARIABLES), max_to_keep=0)
+
+  # Make training session.
+  sess = tf.InteractiveSession()
+  summary_writer = tf.train.SummaryWriter(
+      FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)
+
+  if FLAGS.pretrained_model:
+    saver.restore(sess, FLAGS.pretrained_model)
+
+  tf.train.start_queue_runners(sess)
+  sess.run(tf.initialize_all_variables())
+
+  tf.logging.info('iteration number, cost')
+
+  # Run training.
+  for itr in range(FLAGS.num_iterations):
+    # Generate new batch of data.
+    feed_dict = {model.prefix: 'train',
+                 model.iter_num: np.float32(itr),
+                 model.lr: FLAGS.learning_rate}
+    cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
+                                    feed_dict)
+
+    # Print info: iteration #, cost.
+    tf.logging.info(str(itr) + ' ' + str(cost))
+
+    if (itr) % VAL_INTERVAL == 2:
+      # Run through validation set.
+      feed_dict = {val_model.lr: 0.0,
+                   val_model.prefix: 'val',
+                   val_model.iter_num: np.float32(itr)}
+      _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
+                                     feed_dict)
+      summary_writer.add_summary(val_summary_str, itr)
+
+    if (itr) % SAVE_INTERVAL == 2:
+      tf.logging.info('Saving model.')
+      saver.save(sess, FLAGS.output_dir + '/model' + str(itr))
+
+    if (itr) % SUMMARY_INTERVAL:
+      summary_writer.add_summary(summary_str, itr)
+
+  tf.logging.info('Saving model.')
+  saver.save(sess, FLAGS.output_dir + '/model')
+  tf.logging.info('Training complete')
+  tf.logging.flush()
+
+
+if __name__ == '__main__':
+  app.run()

+ 274 - 0
video_prediction/push_datafiles.txt

@@ -0,0 +1,274 @@
+push/push_testnovel/push_testnovel.tfrecord-00000-of-00005
+push/push_testnovel/push_testnovel.tfrecord-00001-of-00005
+push/push_testnovel/push_testnovel.tfrecord-00002-of-00005
+push/push_testnovel/push_testnovel.tfrecord-00003-of-00005
+push/push_testnovel/push_testnovel.tfrecord-00004-of-00005
+push/push_testseen/push_testseen.tfrecord-00000-of-00005
+push/push_testseen/push_testseen.tfrecord-00001-of-00005
+push/push_testseen/push_testseen.tfrecord-00002-of-00005
+push/push_testseen/push_testseen.tfrecord-00003-of-00005
+push/push_testseen/push_testseen.tfrecord-00004-of-00005
+push/push_train/push_train.tfrecord-00000-of-00264
+push/push_train/push_train.tfrecord-00001-of-00264
+push/push_train/push_train.tfrecord-00002-of-00264
+push/push_train/push_train.tfrecord-00003-of-00264
+push/push_train/push_train.tfrecord-00004-of-00264
+push/push_train/push_train.tfrecord-00005-of-00264
+push/push_train/push_train.tfrecord-00006-of-00264
+push/push_train/push_train.tfrecord-00007-of-00264
+push/push_train/push_train.tfrecord-00008-of-00264
+push/push_train/push_train.tfrecord-00009-of-00264
+push/push_train/push_train.tfrecord-00010-of-00264
+push/push_train/push_train.tfrecord-00011-of-00264
+push/push_train/push_train.tfrecord-00012-of-00264
+push/push_train/push_train.tfrecord-00013-of-00264
+push/push_train/push_train.tfrecord-00014-of-00264
+push/push_train/push_train.tfrecord-00015-of-00264
+push/push_train/push_train.tfrecord-00016-of-00264
+push/push_train/push_train.tfrecord-00017-of-00264
+push/push_train/push_train.tfrecord-00018-of-00264
+push/push_train/push_train.tfrecord-00019-of-00264
+push/push_train/push_train.tfrecord-00020-of-00264
+push/push_train/push_train.tfrecord-00021-of-00264
+push/push_train/push_train.tfrecord-00022-of-00264
+push/push_train/push_train.tfrecord-00023-of-00264
+push/push_train/push_train.tfrecord-00024-of-00264
+push/push_train/push_train.tfrecord-00025-of-00264
+push/push_train/push_train.tfrecord-00026-of-00264
+push/push_train/push_train.tfrecord-00027-of-00264
+push/push_train/push_train.tfrecord-00028-of-00264
+push/push_train/push_train.tfrecord-00029-of-00264
+push/push_train/push_train.tfrecord-00030-of-00264
+push/push_train/push_train.tfrecord-00031-of-00264
+push/push_train/push_train.tfrecord-00032-of-00264
+push/push_train/push_train.tfrecord-00033-of-00264
+push/push_train/push_train.tfrecord-00034-of-00264
+push/push_train/push_train.tfrecord-00035-of-00264
+push/push_train/push_train.tfrecord-00036-of-00264
+push/push_train/push_train.tfrecord-00037-of-00264
+push/push_train/push_train.tfrecord-00038-of-00264
+push/push_train/push_train.tfrecord-00039-of-00264
+push/push_train/push_train.tfrecord-00040-of-00264
+push/push_train/push_train.tfrecord-00041-of-00264
+push/push_train/push_train.tfrecord-00042-of-00264
+push/push_train/push_train.tfrecord-00043-of-00264
+push/push_train/push_train.tfrecord-00044-of-00264
+push/push_train/push_train.tfrecord-00045-of-00264
+push/push_train/push_train.tfrecord-00046-of-00264
+push/push_train/push_train.tfrecord-00047-of-00264
+push/push_train/push_train.tfrecord-00048-of-00264
+push/push_train/push_train.tfrecord-00049-of-00264
+push/push_train/push_train.tfrecord-00050-of-00264
+push/push_train/push_train.tfrecord-00051-of-00264
+push/push_train/push_train.tfrecord-00052-of-00264
+push/push_train/push_train.tfrecord-00053-of-00264
+push/push_train/push_train.tfrecord-00054-of-00264
+push/push_train/push_train.tfrecord-00055-of-00264
+push/push_train/push_train.tfrecord-00056-of-00264
+push/push_train/push_train.tfrecord-00057-of-00264
+push/push_train/push_train.tfrecord-00058-of-00264
+push/push_train/push_train.tfrecord-00059-of-00264
+push/push_train/push_train.tfrecord-00060-of-00264
+push/push_train/push_train.tfrecord-00061-of-00264
+push/push_train/push_train.tfrecord-00062-of-00264
+push/push_train/push_train.tfrecord-00063-of-00264
+push/push_train/push_train.tfrecord-00064-of-00264
+push/push_train/push_train.tfrecord-00065-of-00264
+push/push_train/push_train.tfrecord-00066-of-00264
+push/push_train/push_train.tfrecord-00067-of-00264
+push/push_train/push_train.tfrecord-00068-of-00264
+push/push_train/push_train.tfrecord-00069-of-00264
+push/push_train/push_train.tfrecord-00070-of-00264
+push/push_train/push_train.tfrecord-00071-of-00264
+push/push_train/push_train.tfrecord-00072-of-00264
+push/push_train/push_train.tfrecord-00073-of-00264
+push/push_train/push_train.tfrecord-00074-of-00264
+push/push_train/push_train.tfrecord-00075-of-00264
+push/push_train/push_train.tfrecord-00076-of-00264
+push/push_train/push_train.tfrecord-00077-of-00264
+push/push_train/push_train.tfrecord-00078-of-00264
+push/push_train/push_train.tfrecord-00079-of-00264
+push/push_train/push_train.tfrecord-00080-of-00264
+push/push_train/push_train.tfrecord-00081-of-00264
+push/push_train/push_train.tfrecord-00082-of-00264
+push/push_train/push_train.tfrecord-00083-of-00264
+push/push_train/push_train.tfrecord-00084-of-00264
+push/push_train/push_train.tfrecord-00085-of-00264
+push/push_train/push_train.tfrecord-00086-of-00264
+push/push_train/push_train.tfrecord-00087-of-00264
+push/push_train/push_train.tfrecord-00088-of-00264
+push/push_train/push_train.tfrecord-00089-of-00264
+push/push_train/push_train.tfrecord-00090-of-00264
+push/push_train/push_train.tfrecord-00091-of-00264
+push/push_train/push_train.tfrecord-00092-of-00264
+push/push_train/push_train.tfrecord-00093-of-00264
+push/push_train/push_train.tfrecord-00094-of-00264
+push/push_train/push_train.tfrecord-00095-of-00264
+push/push_train/push_train.tfrecord-00096-of-00264
+push/push_train/push_train.tfrecord-00097-of-00264
+push/push_train/push_train.tfrecord-00098-of-00264
+push/push_train/push_train.tfrecord-00099-of-00264
+push/push_train/push_train.tfrecord-00100-of-00264
+push/push_train/push_train.tfrecord-00101-of-00264
+push/push_train/push_train.tfrecord-00102-of-00264
+push/push_train/push_train.tfrecord-00103-of-00264
+push/push_train/push_train.tfrecord-00104-of-00264
+push/push_train/push_train.tfrecord-00105-of-00264
+push/push_train/push_train.tfrecord-00106-of-00264
+push/push_train/push_train.tfrecord-00107-of-00264
+push/push_train/push_train.tfrecord-00108-of-00264
+push/push_train/push_train.tfrecord-00109-of-00264
+push/push_train/push_train.tfrecord-00110-of-00264
+push/push_train/push_train.tfrecord-00111-of-00264
+push/push_train/push_train.tfrecord-00112-of-00264
+push/push_train/push_train.tfrecord-00113-of-00264
+push/push_train/push_train.tfrecord-00114-of-00264
+push/push_train/push_train.tfrecord-00115-of-00264
+push/push_train/push_train.tfrecord-00116-of-00264
+push/push_train/push_train.tfrecord-00117-of-00264
+push/push_train/push_train.tfrecord-00118-of-00264
+push/push_train/push_train.tfrecord-00119-of-00264
+push/push_train/push_train.tfrecord-00120-of-00264
+push/push_train/push_train.tfrecord-00121-of-00264
+push/push_train/push_train.tfrecord-00122-of-00264
+push/push_train/push_train.tfrecord-00123-of-00264
+push/push_train/push_train.tfrecord-00124-of-00264
+push/push_train/push_train.tfrecord-00125-of-00264
+push/push_train/push_train.tfrecord-00126-of-00264
+push/push_train/push_train.tfrecord-00127-of-00264
+push/push_train/push_train.tfrecord-00128-of-00264
+push/push_train/push_train.tfrecord-00129-of-00264
+push/push_train/push_train.tfrecord-00130-of-00264
+push/push_train/push_train.tfrecord-00131-of-00264
+push/push_train/push_train.tfrecord-00132-of-00264
+push/push_train/push_train.tfrecord-00133-of-00264
+push/push_train/push_train.tfrecord-00134-of-00264
+push/push_train/push_train.tfrecord-00135-of-00264
+push/push_train/push_train.tfrecord-00136-of-00264
+push/push_train/push_train.tfrecord-00137-of-00264
+push/push_train/push_train.tfrecord-00138-of-00264
+push/push_train/push_train.tfrecord-00139-of-00264
+push/push_train/push_train.tfrecord-00140-of-00264
+push/push_train/push_train.tfrecord-00141-of-00264
+push/push_train/push_train.tfrecord-00142-of-00264
+push/push_train/push_train.tfrecord-00143-of-00264
+push/push_train/push_train.tfrecord-00144-of-00264
+push/push_train/push_train.tfrecord-00145-of-00264
+push/push_train/push_train.tfrecord-00146-of-00264
+push/push_train/push_train.tfrecord-00147-of-00264
+push/push_train/push_train.tfrecord-00148-of-00264
+push/push_train/push_train.tfrecord-00149-of-00264
+push/push_train/push_train.tfrecord-00150-of-00264
+push/push_train/push_train.tfrecord-00151-of-00264
+push/push_train/push_train.tfrecord-00152-of-00264
+push/push_train/push_train.tfrecord-00153-of-00264
+push/push_train/push_train.tfrecord-00154-of-00264
+push/push_train/push_train.tfrecord-00155-of-00264
+push/push_train/push_train.tfrecord-00156-of-00264
+push/push_train/push_train.tfrecord-00157-of-00264
+push/push_train/push_train.tfrecord-00158-of-00264
+push/push_train/push_train.tfrecord-00159-of-00264
+push/push_train/push_train.tfrecord-00160-of-00264
+push/push_train/push_train.tfrecord-00161-of-00264
+push/push_train/push_train.tfrecord-00162-of-00264
+push/push_train/push_train.tfrecord-00163-of-00264
+push/push_train/push_train.tfrecord-00164-of-00264
+push/push_train/push_train.tfrecord-00165-of-00264
+push/push_train/push_train.tfrecord-00166-of-00264
+push/push_train/push_train.tfrecord-00167-of-00264
+push/push_train/push_train.tfrecord-00168-of-00264
+push/push_train/push_train.tfrecord-00169-of-00264
+push/push_train/push_train.tfrecord-00170-of-00264
+push/push_train/push_train.tfrecord-00171-of-00264
+push/push_train/push_train.tfrecord-00172-of-00264
+push/push_train/push_train.tfrecord-00173-of-00264
+push/push_train/push_train.tfrecord-00174-of-00264
+push/push_train/push_train.tfrecord-00175-of-00264
+push/push_train/push_train.tfrecord-00176-of-00264
+push/push_train/push_train.tfrecord-00177-of-00264
+push/push_train/push_train.tfrecord-00178-of-00264
+push/push_train/push_train.tfrecord-00179-of-00264
+push/push_train/push_train.tfrecord-00180-of-00264
+push/push_train/push_train.tfrecord-00181-of-00264
+push/push_train/push_train.tfrecord-00182-of-00264
+push/push_train/push_train.tfrecord-00183-of-00264
+push/push_train/push_train.tfrecord-00184-of-00264
+push/push_train/push_train.tfrecord-00185-of-00264
+push/push_train/push_train.tfrecord-00186-of-00264
+push/push_train/push_train.tfrecord-00187-of-00264
+push/push_train/push_train.tfrecord-00188-of-00264
+push/push_train/push_train.tfrecord-00189-of-00264
+push/push_train/push_train.tfrecord-00190-of-00264
+push/push_train/push_train.tfrecord-00191-of-00264
+push/push_train/push_train.tfrecord-00192-of-00264
+push/push_train/push_train.tfrecord-00193-of-00264
+push/push_train/push_train.tfrecord-00194-of-00264
+push/push_train/push_train.tfrecord-00195-of-00264
+push/push_train/push_train.tfrecord-00196-of-00264
+push/push_train/push_train.tfrecord-00197-of-00264
+push/push_train/push_train.tfrecord-00198-of-00264
+push/push_train/push_train.tfrecord-00199-of-00264
+push/push_train/push_train.tfrecord-00200-of-00264
+push/push_train/push_train.tfrecord-00201-of-00264
+push/push_train/push_train.tfrecord-00202-of-00264
+push/push_train/push_train.tfrecord-00203-of-00264
+push/push_train/push_train.tfrecord-00204-of-00264
+push/push_train/push_train.tfrecord-00205-of-00264
+push/push_train/push_train.tfrecord-00206-of-00264
+push/push_train/push_train.tfrecord-00207-of-00264
+push/push_train/push_train.tfrecord-00208-of-00264
+push/push_train/push_train.tfrecord-00209-of-00264
+push/push_train/push_train.tfrecord-00210-of-00264
+push/push_train/push_train.tfrecord-00211-of-00264
+push/push_train/push_train.tfrecord-00212-of-00264
+push/push_train/push_train.tfrecord-00213-of-00264
+push/push_train/push_train.tfrecord-00214-of-00264
+push/push_train/push_train.tfrecord-00215-of-00264
+push/push_train/push_train.tfrecord-00216-of-00264
+push/push_train/push_train.tfrecord-00217-of-00264
+push/push_train/push_train.tfrecord-00218-of-00264
+push/push_train/push_train.tfrecord-00219-of-00264
+push/push_train/push_train.tfrecord-00220-of-00264
+push/push_train/push_train.tfrecord-00221-of-00264
+push/push_train/push_train.tfrecord-00222-of-00264
+push/push_train/push_train.tfrecord-00223-of-00264
+push/push_train/push_train.tfrecord-00224-of-00264
+push/push_train/push_train.tfrecord-00225-of-00264
+push/push_train/push_train.tfrecord-00226-of-00264
+push/push_train/push_train.tfrecord-00227-of-00264
+push/push_train/push_train.tfrecord-00228-of-00264
+push/push_train/push_train.tfrecord-00229-of-00264
+push/push_train/push_train.tfrecord-00230-of-00264
+push/push_train/push_train.tfrecord-00231-of-00264
+push/push_train/push_train.tfrecord-00232-of-00264
+push/push_train/push_train.tfrecord-00233-of-00264
+push/push_train/push_train.tfrecord-00234-of-00264
+push/push_train/push_train.tfrecord-00235-of-00264
+push/push_train/push_train.tfrecord-00236-of-00264
+push/push_train/push_train.tfrecord-00237-of-00264
+push/push_train/push_train.tfrecord-00238-of-00264
+push/push_train/push_train.tfrecord-00239-of-00264
+push/push_train/push_train.tfrecord-00240-of-00264
+push/push_train/push_train.tfrecord-00241-of-00264
+push/push_train/push_train.tfrecord-00242-of-00264
+push/push_train/push_train.tfrecord-00243-of-00264
+push/push_train/push_train.tfrecord-00244-of-00264
+push/push_train/push_train.tfrecord-00245-of-00264
+push/push_train/push_train.tfrecord-00246-of-00264
+push/push_train/push_train.tfrecord-00247-of-00264
+push/push_train/push_train.tfrecord-00248-of-00264
+push/push_train/push_train.tfrecord-00249-of-00264
+push/push_train/push_train.tfrecord-00250-of-00264
+push/push_train/push_train.tfrecord-00251-of-00264
+push/push_train/push_train.tfrecord-00252-of-00264
+push/push_train/push_train.tfrecord-00253-of-00264
+push/push_train/push_train.tfrecord-00254-of-00264
+push/push_train/push_train.tfrecord-00255-of-00264
+push/push_train/push_train.tfrecord-00256-of-00264
+push/push_train/push_train.tfrecord-00257-of-00264
+push/push_train/push_train.tfrecord-00258-of-00264
+push/push_train/push_train.tfrecord-00259-of-00264
+push/push_train/push_train.tfrecord-00260-of-00264
+push/push_train/push_train.tfrecord-00261-of-00264
+push/push_train/push_train.tfrecord-00262-of-00264
+push/push_train/push_train.tfrecord-00263-of-00264