Explorar el Código

Open source the image-to-text model based on the "Show and Tell" paper.

Chris Shallue hace 9 años
padre
commit
4f9d102483

+ 1 - 0
README.md

@@ -19,3 +19,4 @@ To propose a model for inclusion please submit a pull request.
 - [syntaxnet](syntaxnet) -- neural models of natural language syntax
 - [textsum](textsum) -- sequence-to-sequence with attention model for text summarization.
 - [transformer](transformer) -- spatial transformer network, which allows the spatial manipulation of data within the network
+- [im2txt](im2txt) -- image-to-text neural network for image captioning.

+ 7 - 0
im2txt/.gitignore

@@ -0,0 +1,7 @@
+/bazel-bin
+/bazel-ci_build-cache
+/bazel-genfiles
+/bazel-out
+/bazel-im2txt
+/bazel-testlogs
+/bazel-tf

+ 331 - 0
im2txt/README.md

@@ -0,0 +1,331 @@
+# Show and Tell: A Neural Image Caption Generator
+
+A TensorFlow implementation of the image-to-text model described in
+
+*Show and Tell: A Neural Image Caption Generator*
+
+Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan.
+
+http://arxiv.org/abs/1411.4555
+
+## Contact
+***Author:*** Chris Shallue (shallue@google.com).
+
+***Pull requests and issues:*** @cshallue.
+
+## Contents
+* [Model Overview](#model-overview)
+    * [Introduction](#introduction)
+    * [Architecture](#architecture)
+* [Getting Started](#getting-started)
+    * [A Note on Hardware and Training Time](#a-note-on-hardware-and-training-time)
+    * [Install Required Packages](#install-required-packages)
+    * [Prepare the Training Data](#prepare-the-training-data)
+    * [Download the Inception v3 Checkpoint](#download-the-inception-v3-checkpoint)
+* [Training a Model](#training-a-model)
+    * [Initial Training](#initial-training)
+    * [Fine Tune the Inception v3 Model](#fine-tune-the-inception-v3-model)
+* [Generating Captions](#generating-captions)
+
+## Model Overview
+
+### Introduction
+
+The *Show and Tell* model is a deep neural network that learns how to describe
+the content of images. For example:
+
+<center>
+![Example captions](g3doc/example_captions.jpg)
+</center>
+
+### Architecture
+
+The *Show and Tell* model is an example of an *encoder-decoder* neural network.
+It works by first "encoding" an image into a fixed-length vector representation,
+and then "decoding" the representation into a natural language description.
+
+The image encoder is a deep convolutional neural network. This type of
+network is widely used for image tasks and is currently state-of-the-art for
+object recognition and detection. Our particular choice of network is the
+[*Inception v3*](http://arxiv.org/abs/1512.00567) image recognition model
+pretrained on the
+[ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/) image
+classification dataset.
+
+The decoder is a long short-term memory (LSTM) network. This type of network is
+commonly used for sequence modeling tasks such as language modeling and machine
+translation. In the *Show and Tell* model, the LSTM network is trained as a
+language model conditioned on the image encoding.
+
+Words in the captions are represented with an embedding model. Each word in the
+vocabulary is associated with a fixed-length vector representation that is
+learned during training.
+
+The following diagram illustrates the model architecture.
+
+<center>
+![Show and Tell Architecture](g3doc/show_and_tell_architecture.png)
+</center>
+
+In this diagram, $$\{ s_0, s_1, ..., s_{N-1} \}$$ are the words of the caption
+and $$\{ w_e s_0, w_e s_1, ..., w_e s_{N-1} \}$$ are their corresponding word
+embedding vectors. The outputs $$\{ p_1, p_2, ..., p_N \}$$ of the LSTM are
+probability distributions generated by the model for the next word in the
+sentence. The terms $$\{ \log p_1(s_1), \log p_2(s_2), ..., \log p_N(s_N) \}$$
+are the log-likelihoods of the correct word at each step; the negated sum of
+these terms is the minimization objective of the model.
+
+During the first phase of training the parameters of the *Inception v3* model
+are kept fixed: it is simply a static image encoder function. A single trainable
+layer is added on top of the *Inception v3* model to transform the image
+embedding into the word embedding vector space. The model is trained with
+respect to the parameters of the word embeddings, the parameters of the layer on
+top of *Inception v3* and the parameters of the LSTM. In the second phase of
+training, all parameters - including the parameters of *Inception v3* - are
+trained to jointly fine-tune the image encoder and the LSTM.
+
+Given a trained model and an image we use *beam search* to generate captions for
+that image. Captions are generated word-by-word, where at each step $$t$$ we use
+the set of sentences already generated with length $$t-1$$ to generate a new set
+of sentences with length $$t$$. We keep only the top $$k$$ candidates at each
+step, where the hyperparameter $$k$$ is called the *beam size*. We have found
+the best performance with $$k=3$$.
+
+## Getting Started
+
+### A Note on Hardware and Training Time
+
+The time required to train the *Show and Tell* model depends on your specific
+hardware and computational capacity. In this guide we assume you will be running
+training on a single machine with a GPU. In our experience on an NVIDIA Tesla
+K20m GPU the initial training phase takes 1-2 weeks. The second training phase
+may take several additional weeks to achieve peak performance (but you can stop
+this phase early and still get reasonable results).
+
+It is possible to achieve a speed-up by implementing distributed training across
+a cluster of machines with GPUs, but that is not covered in this guide.
+
+Whilst it is possible to run this code on a CPU, beware that this may be
+approximately 10 times slower.
+
+### Install Required Packages
+First ensure that you have installed the following required packages:
+
+* **Bazel** ([instructions](http://bazel.io/docs/install.html)).
+* **TensorFlow** ([instructions](https://www.tensorflow.org/versions/r0.10/get_started/os_setup.html)).
+* **NumPy** ([instructions](http://www.scipy.org/install.html)).
+* **Natural Language Toolkit (NLTK)**:
+    * First install NLTK ([instructions](http://www.nltk.org/install.html)).
+    * Then install the NLTK data ([instructions](http://www.nltk.org/data.html)).
+
+### Prepare the Training Data
+
+To train the model you will need to provide training data in native TFRecord
+format. The TFRecord format consists of a set of sharded files containing
+serialized `tf.SequenceExample` protocol buffers. Each `tf.SequenceExample`
+proto contains an image (JPEG format), a caption and metadata such as the image
+id.
+
+Each caption is a list of words. During preprocessing, a dictionary is created
+that assigns each word in the vocabulary to an integer-valued id. Each caption
+is encoded as a list of integer word ids in the `tf.SequenceExample` protos.
+
+We have provided a script to download and preprocess the [MSCOCO]
+(http://mscoco.org/) image captioning data set into this format. Downloading
+and preprocessing the data may take several hours depending on your network and
+computer speed. Please be patient.
+
+Before running the script, ensure that your hard disk has at least 150GB of
+available space for storing the downloaded and processed data.
+
+```shell
+# Location to save the MSCOCO data.
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+
+# Build the preprocessing script.
+bazel build im2txt/download_and_preprocess_mscoco
+
+# Run the preprocessing script.
+bazel-bin/im2txt/download_and_preprocess_mscoco "${MSCOCO_DIR}"
+```
+
+The final line of the output should read:
+
+```
+2016-09-01 16:47:47.296630: Finished processing all 20267 image-caption pairs in data set 'test'.
+```
+
+When the script finishes you will find 256 training, 4 validation and 8 testing
+files in `DATA_DIR`. The files will match the patterns `train-?????-of-00256`,
+`val-?????-of-00004` and `test-?????-of-00008`, respectively.
+
+### Download the Inception v3 Checkpoint
+
+The *Show and Tell* model requires a pretrained *Inception v3* checkpoint file
+to initialize the parameters of its image encoder submodel.
+
+This checkpoint file is provided by the
+[TensorFlow-Slim image classification library](https://github.com/tensorflow/models/tree/master/slim#tensorflow-slim-image-classification-library)
+which provides a suite of pre-trained image classification models. You can read
+more about the models provided by the library
+[here](https://github.com/tensorflow/models/tree/master/slim#pre-trained-models).
+
+
+Run the following commands to download the *Inception v3* checkpoint.
+
+```shell
+# Location to save the Inception v3 checkpoint.
+INCEPTION_DIR="${HOME}/im2txt/data"
+mkdir -p ${INCEPTION_DIR}
+
+wget "http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz"
+tar -xvf "inception_v3_2016_08_28.tar.gz" -C ${INCEPTION_DIR}
+rm "inception_v3_2016_08_28.tar.gz"
+```
+
+Note that the *Inception v3* checkpoint will only be used for initializing the
+parameters of the *Show and Tell* model. Once the *Show and Tell* model starts
+training it will save its own checkpoint files containing the values of all its
+parameters (including copies of the *Inception v3* parameters). If training is
+stopped and restarted, the parameter values will be restored from the latest
+*Show and Tell* checkpoint and the *Inception v3* checkpoint will be ignored. In
+other words, the *Inception v3* checkpoint is only used in the 0-th global step
+(initialization) of training the *Show and Tell* model.
+
+## Training a Model
+
+### Initial Training
+
+Run the training script.
+
+```shell
+# Directory containing preprocessed MSCOCO data.
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+
+# Inception v3 checkpoint file.
+INCEPTION_CHECKPOINT="${HOME}/im2txt/data/inception_v3.ckpt"
+
+# Directory to save the model.
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Build the model.
+bazel build -c opt --config=cuda im2txt/...
+
+# Run the training script.
+bazel-bin/im2txt/train \
+  --input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
+  --inception_checkpoint_file="${INCEPTION_CHECKPOINT}" \
+  --train_dir="${MODEL_DIR}/train" \
+  --train_inception=false \
+  --number_of_steps=1000000
+```
+
+Run the evaluation script in a separate process. This will log evaluation
+metrics to TensorBoard which allows training progress to be monitored in
+real-time.
+
+Note that you may run out of memory if you run the evaluation script on the same
+GPU as the training script. You can run the command
+`export CUDA_VISIBLE_DEVICES=""` to force the evaluation script to run on CPU.
+If evaluation runs too slowly on CPU, you can decrease the value of
+`--num_eval_examples`.
+
+```shell
+MSCOCO_DIR="${HOME}/im2txt/data/mscoco"
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run the evaluation script. This will run in a loop, periodically loading the
+# latest model checkpoint file and computing evaluation metrics.
+bazel-bin/im2txt/evaluate \
+  --input_file_pattern="${MSCOCO_DIR}/val-?????-of-00004" \
+  --checkpoint_dir="${MODEL_DIR}/train" \
+  --eval_dir="${MODEL_DIR}/eval"
+```
+
+Run a TensorBoard server in a separate process for real-time monitoring of
+training progress and evaluation metrics.
+
+```shell
+MODEL_DIR="${HOME}/im2txt/model"
+
+# Run a TensorBoard server.
+tensorboard --logdir="${MODEL_DIR}"
+```
+
+### Fine Tune the Inception v3 Model
+
+Your model will already be able to generate reasonable captions after the first
+phase of training. Try it out! (See [Generating Captions]
+(#generating-captions)).
+
+You can further improve the performance of the model by running a
+second training phase to jointly fine-tune the parameters of the *Inception v3*
+image submodel and the LSTM.
+
+```shell
+# Restart the training script with --train_inception=true.
+bazel-bin/im2txt/train \
+  --input_file_pattern="${MSCOCO_DIR}/train-?????-of-00256" \
+  --train_dir="${MODEL_DIR}/train" \
+  --train_inception=true \
+  --number_of_steps=3000000  # Additional 2M steps (assuming 1M in initial training).
+```
+
+Note that training will proceed much slower now, and the model will continue to
+improve by a small amount for a long time. We have found that it will improve
+slowly for an additional 2-2.5 million steps before it begins to overfit. This
+may take several weeks on a single GPU. If you don't care about absolutely
+optimal performance then feel free to halt training sooner by stopping the
+training script or passing a smaller value to the flag `--number_of_steps`. Your
+model will still work reasonably well.
+
+## Generating Captions
+
+Your trained *Show and Tell* model can generate captions for any JPEG image! The
+following command line will generate captions for an image from the test set.
+
+```shell
+# Directory containing model checkpoints.
+CHECKPOINT_DIR="${HOME}/im2txt/model/train"
+
+# Vocabulary file generated by the preprocessing script.
+VOCAB_FILE="${HOME}/im2txt/data/mscoco/word_counts.txt"
+
+# JPEG image file to caption.
+IMAGE_FILE="${HOME}/im2txt/data/mscoco/raw-data/val2014/COCO_val2014_000000224477.jpg"
+
+# Build the inference binary.
+bazel build -c opt im2txt/run_inference
+
+# Ignore GPU devices (only necessary if your GPU is currently memory
+# constrained, for example, by running the training script).
+export CUDA_VISIBLE_DEVICES=""
+
+# Run inference to generate captions.
+bazel-bin/im2txt/run_inference \
+  --checkpoint_path=${CHECKPOINT_DIR} \
+  --vocab_file=${VOCAB_FILE} \
+  --input_files=${IMAGE_FILE}
+```
+
+Example output:
+
+```shell
+Captions for image COCO_val2014_000000224477.jpg:
+  0) a man riding a wave on top of a surfboard . (p=0.040413)
+  1) a person riding a surf board on a wave (p=0.017452)
+  2) a man riding a wave on a surfboard in the ocean . (p=0.005743)
+```
+
+Note: you may get different results. Some variation between different models is
+expected.
+
+Here is the image:
+
+<center>
+![Surfer](g3doc/COCO_val2014_000000224477.jpg)
+</center>

+ 0 - 0
im2txt/WORKSPACE


BIN
im2txt/g3doc/COCO_val2014_000000224477.jpg


BIN
im2txt/g3doc/example_captions.jpg


BIN
im2txt/g3doc/show_and_tell_architecture.png


+ 96 - 0
im2txt/im2txt/BUILD

@@ -0,0 +1,96 @@
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = [
+        "//im2txt/...",
+    ],
+)
+
+py_binary(
+    name = "build_mscoco_data",
+    srcs = [
+        "data/build_mscoco_data.py",
+    ],
+)
+
+sh_binary(
+    name = "download_and_preprocess_mscoco",
+    srcs = ["data/download_and_preprocess_mscoco.sh"],
+    data = [
+        ":build_mscoco_data",
+    ],
+)
+
+py_library(
+    name = "configuration",
+    srcs = ["configuration.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "show_and_tell_model",
+    srcs = ["show_and_tell_model.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        "//im2txt/ops:image_embedding",
+        "//im2txt/ops:image_processing",
+        "//im2txt/ops:inputs",
+    ],
+)
+
+py_test(
+    name = "show_and_tell_model_test",
+    size = "large",
+    srcs = ["show_and_tell_model_test.py"],
+    deps = [
+        ":configuration",
+        ":show_and_tell_model",
+    ],
+)
+
+py_library(
+    name = "inference_wrapper",
+    srcs = ["inference_wrapper.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":show_and_tell_model",
+        "//im2txt/inference_utils:inference_wrapper_base",
+    ],
+)
+
+py_binary(
+    name = "train",
+    srcs = ["train.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":show_and_tell_model",
+    ],
+)
+
+py_binary(
+    name = "evaluate",
+    srcs = ["evaluate.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":show_and_tell_model",
+    ],
+)
+
+py_binary(
+    name = "run_inference",
+    srcs = ["run_inference.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":configuration",
+        ":inference_wrapper",
+        "//im2txt/inference_utils:caption_generator",
+        "//im2txt/inference_utils:vocabulary",
+    ],
+)

+ 105 - 0
im2txt/im2txt/configuration.py

@@ -0,0 +1,105 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image-to-text model and training configurations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class ModelConfig(object):
+  """Wrapper class for model hyperparameters."""
+
+  def __init__(self):
+    """Sets the default model hyperparameters."""
+    # File pattern of sharded TFRecord file containing SequenceExample protos.
+    # Must be provided in training and evaluation modes.
+    self.input_file_pattern = None
+
+    # Image format ("jpeg" or "png").
+    self.image_format = "jpeg"
+
+    # Approximate number of values per input shard. Used to ensure sufficient
+    # mixing between shards in training.
+    self.values_per_input_shard = 2300
+    # Minimum number of shards to keep in the input queue.
+    self.input_queue_capacity_factor = 2
+    # Number of threads for prefetching SequenceExample protos.
+    self.num_input_reader_threads = 1
+
+    # Name of the SequenceExample context feature containing image data.
+    self.image_feature_name = "image/data"
+    # Name of the SequenceExample feature list containing integer captions.
+    self.caption_feature_name = "image/caption_ids"
+
+    # Number of unique words in the vocab (plus 1, for <UNK>).
+    # The default value is larger than the expected actual vocab size to allow
+    # for differences between tokenizer versions used in preprocessing. There is
+    # no harm in using a value greater than the actual vocab size, but using a
+    # value less than the actual vocab size will result in an error.
+    self.vocab_size = 12000
+
+    # Number of threads for image preprocessing. Should be a multiple of 2.
+    self.num_preprocess_threads = 4
+
+    # Batch size.
+    self.batch_size = 32
+
+    # File containing an Inception v3 checkpoint to initialize the variables
+    # of the Inception model. Must be provided when starting training for the
+    # first time.
+    self.inception_checkpoint_file = None
+
+    # Dimensions of Inception v3 input images.
+    self.image_height = 299
+    self.image_width = 299
+
+    # Scale used to initialize model variables.
+    self.initializer_scale = 0.08
+
+    # LSTM input and output dimensionality, respectively.
+    self.embedding_size = 512
+    self.num_lstm_units = 512
+
+    # If < 1.0, the dropout keep probability applied to LSTM variables.
+    self.lstm_dropout_keep_prob = 0.7
+
+    # How many model checkpoints to keep.
+    self.max_checkpoints_to_keep = 5
+    self.keep_checkpoint_every_n_hours = 10000
+
+
+class TrainingConfig(object):
+  """Wrapper class for training hyperparameters."""
+
+  def __init__(self):
+    """Sets the default training hyperparameters."""
+    # Number of examples per epoch of training data.
+    self.num_examples_per_epoch = 586363
+
+    # Optimizer for training the model.
+    self.optimizer = "SGD"
+
+    # Learning rate for the initial phase of training.
+    self.initial_learning_rate = 2.0
+    self.learning_rate_decay_factor = 0.5
+    self.num_epochs_per_decay = 8.0
+
+    # Learning rate when fine tuning the Inception v3 parameters.
+    self.train_inception_learning_rate = 0.0005
+
+    # If not None, clip gradients to this value.
+    self.clip_gradients = 5.0

+ 481 - 0
im2txt/im2txt/data/build_mscoco_data.py

@@ -0,0 +1,481 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts MSCOCO data to TFRecord file format with SequenceExample protos.
+
+The MSCOCO images are expected to reside in JPEG files located in the following
+directory structure:
+
+  train_image_dir/COCO_train2014_000000000151.jpg
+  train_image_dir/COCO_train2014_000000000260.jpg
+  ...
+
+and
+
+  val_image_dir/COCO_val2014_000000000042.jpg
+  val_image_dir/COCO_val2014_000000000073.jpg
+  ...
+
+The MSCOCO annotations JSON files are expected to reside in train_captions_file
+and val_captions_file respectively.
+
+This script converts the combined MSCOCO data into sharded data files consisting
+of 256, 4 and 8 TFRecord files, respectively:
+
+  output_dir/train-00000-of-00256
+  output_dir/train-00001-of-00256
+  ...
+  output_dir/train-00255-of-00256
+
+and
+
+  output_dir/val-00000-of-00004
+  ...
+  output_dir/val-00003-of-00004
+
+and
+
+  output_dir/test-00000-of-00008
+  ...
+  output_dir/test-00007-of-00008
+
+Each TFRecord file contains ~2300 records. Each record within the TFRecord file
+is a serialized SequenceExample proto consisting of precisely one image-caption
+pair. Note that each image has multiple captions (usually 5) and therefore each
+image is replicated multiple times in the TFRecord files.
+
+The SequenceExample proto contains the following fields:
+
+  context:
+    image/image_id: integer MSCOCO image identifier
+    image/data: string containing JPEG encoded image in RGB colorspace
+
+  feature_lists:
+    image/caption: list of strings containing the (tokenized) caption words
+    image/caption_ids: list of integer ids corresponding to the caption words
+
+The captions are tokenized using the NLTK (http://www.nltk.org/) word tokenizer.
+The vocabulary of word identifiers is constructed from the sorted list (by
+descending frequency) of word tokens in the training set. Only tokens appearing
+at least 4 times are considered; all other words get the "unknown" word id.
+
+NOTE: This script will consume around 100GB of disk space because each image
+in the MSCOCO dataset is replicated ~5 times (once per caption) in the output.
+This is done for two reasons:
+  1. In order to better shuffle the training data.
+  2. It makes it easier to perform asynchronous preprocessing of each image in
+     TensorFlow.
+
+Running this script using 16 threads may take around 1 hour on a HP Z420.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from collections import Counter
+from collections import namedtuple
+from datetime import datetime
+import json
+import os.path
+import random
+import sys
+import threading
+
+
+
+import nltk.tokenize
+import numpy as np
+import tensorflow as tf
+
+tf.flags.DEFINE_string("train_image_dir", "/tmp/train2014/",
+                       "Training image directory.")
+tf.flags.DEFINE_string("val_image_dir", "/tmp/val2014",
+                       "Validation image directory.")
+
+tf.flags.DEFINE_string("train_captions_file", "/tmp/captions_train2014.json",
+                       "Training captions JSON file.")
+tf.flags.DEFINE_string("val_captions_file", "/tmp/captions_train2014.json",
+                       "Validation captions JSON file.")
+
+tf.flags.DEFINE_string("output_dir", "/tmp/", "Output data directory.")
+
+tf.flags.DEFINE_integer("train_shards", 256,
+                        "Number of shards in training TFRecord files.")
+tf.flags.DEFINE_integer("val_shards", 4,
+                        "Number of shards in validation TFRecord files.")
+tf.flags.DEFINE_integer("test_shards", 8,
+                        "Number of shards in testing TFRecord files.")
+
+tf.flags.DEFINE_string("start_word", "<S>",
+                       "Special word added to the beginning of each sentence.")
+tf.flags.DEFINE_string("end_word", "</S>",
+                       "Special word added to the end of each sentence.")
+tf.flags.DEFINE_string("unknown_word", "<UNK>",
+                       "Special word meaning 'unknown'.")
+tf.flags.DEFINE_integer("min_word_count", 4,
+                        "The minimum number of occurrences of each word in the "
+                        "training set for inclusion in the vocabulary.")
+tf.flags.DEFINE_string("word_counts_output_file", "/tmp/word_counts.txt",
+                       "Output vocabulary file of word counts.")
+
+tf.flags.DEFINE_integer("num_threads", 8,
+                        "Number of threads to preprocess the images.")
+
+FLAGS = tf.flags.FLAGS
+
+ImageMetadata = namedtuple("ImageMetadata",
+                           ["image_id", "filename", "captions"])
+
+
+class Vocabulary(object):
+  """Simple vocabulary wrapper."""
+
+  def __init__(self, vocab, unk_id):
+    """Initializes the vocabulary.
+
+    Args:
+      vocab: A dictionary of word to word_id.
+      unk_id: Id of the special 'unknown' word.
+    """
+    self._vocab = vocab
+    self._unk_id = unk_id
+
+  def word_to_id(self, word):
+    """Returns the integer id of a word string."""
+    if word in self._vocab:
+      return self._vocab[word]
+    else:
+      return self._unk_id
+
+
+class ImageDecoder(object):
+  """Helper class for decoding images in TensorFlow."""
+
+  def __init__(self):
+    # Create a single TensorFlow Session for all image decoding calls.
+    self._sess = tf.Session()
+
+    # TensorFlow ops for JPEG decoding.
+    self._encoded_jpeg = tf.placeholder(dtype=tf.string)
+    self._decode_jpeg = tf.image.decode_jpeg(self._encoded_jpeg, channels=3)
+
+  def decode_jpeg(self, encoded_jpeg):
+    image = self._sess.run(self._decode_jpeg,
+                           feed_dict={self._encoded_jpeg: encoded_jpeg})
+    assert len(image.shape) == 3
+    assert image.shape[2] == 3
+    return image
+
+
+def _int64_feature(value):
+  """Wrapper for inserting an int64 Feature into a SequenceExample proto."""
+  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+  """Wrapper for inserting a bytes Feature into a SequenceExample proto."""
+  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[str(value)]))
+
+
+def _int64_feature_list(values):
+  """Wrapper for inserting an int64 FeatureList into a SequenceExample proto."""
+  return tf.train.FeatureList(feature=[_int64_feature(v) for v in values])
+
+
+def _bytes_feature_list(values):
+  """Wrapper for inserting a bytes FeatureList into a SequenceExample proto."""
+  return tf.train.FeatureList(feature=[_bytes_feature(v) for v in values])
+
+
+def _to_sequence_example(image, decoder, vocab):
+  """Builds a SequenceExample proto for an image-caption pair.
+
+  Args:
+    image: An ImageMetadata object.
+    decoder: An ImageDecoder object.
+    vocab: A Vocabulary object.
+
+  Returns:
+    A SequenceExample proto.
+  """
+  with tf.gfile.FastGFile(image.filename, "r") as f:
+    encoded_image = f.read()
+
+  try:
+    decoder.decode_jpeg(encoded_image)
+  except (tf.errors.InvalidArgumentError, AssertionError):
+    print("Skipping file with invalid JPEG data: %s" % image.filename)
+    return
+
+  context = tf.train.Features(feature={
+      "image/image_id": _int64_feature(image.image_id),
+      "image/data": _bytes_feature(encoded_image),
+  })
+
+  assert len(image.captions) == 1
+  caption = image.captions[0]
+  caption_ids = [vocab.word_to_id(word) for word in caption]
+  feature_lists = tf.train.FeatureLists(feature_list={
+      "image/caption": _bytes_feature_list(caption),
+      "image/caption_ids": _int64_feature_list(caption_ids)
+  })
+  sequence_example = tf.train.SequenceExample(
+      context=context, feature_lists=feature_lists)
+
+  return sequence_example
+
+
+def _process_image_files(thread_index, ranges, name, images, decoder, vocab,
+                         num_shards):
+  """Processes and saves a subset of images as TFRecord files in one thread.
+
+  Args:
+    thread_index: Integer thread identifier within [0, len(ranges)].
+    ranges: A list of pairs of integers specifying the ranges of the dataset to
+      process in parallel.
+    name: Unique identifier specifying the dataset.
+    images: List of ImageMetadata.
+    decoder: An ImageDecoder object.
+    vocab: A Vocabulary object.
+    num_shards: Integer number of shards for the output files.
+  """
+  # Each thread produces N shards where N = num_shards / num_threads. For
+  # instance, if num_shards = 128, and num_threads = 2, then the first thread
+  # would produce shards [0, 64).
+  num_threads = len(ranges)
+  assert not num_shards % num_threads
+  num_shards_per_batch = int(num_shards / num_threads)
+
+  shard_ranges = np.linspace(ranges[thread_index][0], ranges[thread_index][1],
+                             num_shards_per_batch + 1).astype(int)
+  num_images_in_thread = ranges[thread_index][1] - ranges[thread_index][0]
+
+  counter = 0
+  for s in xrange(num_shards_per_batch):
+    # Generate a sharded version of the file name, e.g. 'train-00002-of-00010'
+    shard = thread_index * num_shards_per_batch + s
+    output_filename = "%s-%.5d-of-%.5d" % (name, shard, num_shards)
+    output_file = os.path.join(FLAGS.output_dir, output_filename)
+    writer = tf.python_io.TFRecordWriter(output_file)
+
+    shard_counter = 0
+    images_in_shard = np.arange(shard_ranges[s], shard_ranges[s + 1], dtype=int)
+    for i in images_in_shard:
+      image = images[i]
+
+      sequence_example = _to_sequence_example(image, decoder, vocab)
+      if sequence_example is not None:
+        writer.write(sequence_example.SerializeToString())
+        shard_counter += 1
+        counter += 1
+
+      if not counter % 1000:
+        print("%s [thread %d]: Processed %d of %d items in thread batch." %
+              (datetime.now(), thread_index, counter, num_images_in_thread))
+        sys.stdout.flush()
+
+    print("%s [thread %d]: Wrote %d image-caption pairs to %s" %
+          (datetime.now(), thread_index, shard_counter, output_file))
+    sys.stdout.flush()
+    shard_counter = 0
+  print("%s [thread %d]: Wrote %d image-caption pairs to %d shards." %
+        (datetime.now(), thread_index, counter, num_shards_per_batch))
+  sys.stdout.flush()
+
+
+def _process_dataset(name, images, vocab, num_shards):
+  """Processes a complete data set and saves it as a TFRecord.
+
+  Args:
+    name: Unique identifier specifying the dataset.
+    images: List of ImageMetadata.
+    vocab: A Vocabulary object.
+    num_shards: Integer number of shards for the output files.
+  """
+  # Break up each image into a separate entity for each caption.
+  images = [ImageMetadata(image.image_id, image.filename, [caption])
+            for image in images for caption in image.captions]
+
+  # Shuffle the ordering of images. Make the randomization repeatable.
+  random.seed(12345)
+  random.shuffle(images)
+
+  # Break the images into num_threads batches. Batch i is defined as
+  # images[ranges[i][0]:ranges[i][1]].
+  num_threads = min(num_shards, FLAGS.num_threads)
+  spacing = np.linspace(0, len(images), num_threads + 1).astype(np.int)
+  ranges = []
+  threads = []
+  for i in xrange(len(spacing) - 1):
+    ranges.append([spacing[i], spacing[i + 1]])
+
+  # Create a mechanism for monitoring when all threads are finished.
+  coord = tf.train.Coordinator()
+
+  # Create a utility for decoding JPEG images to run sanity checks.
+  decoder = ImageDecoder()
+
+  # Launch a thread for each batch.
+  print("Launching %d threads for spacings: %s" % (num_threads, ranges))
+  for thread_index in xrange(len(ranges)):
+    args = (thread_index, ranges, name, images, decoder, vocab, num_shards)
+    t = threading.Thread(target=_process_image_files, args=args)
+    t.start()
+    threads.append(t)
+
+  # Wait for all the threads to terminate.
+  coord.join(threads)
+  print("%s: Finished processing all %d image-caption pairs in data set '%s'." %
+        (datetime.now(), len(images), name))
+
+
+def _create_vocab(captions):
+  """Creates the vocabulary of word to word_id.
+
+  The vocabulary is saved to disk in a text file of word counts. The id of each
+  word in the file is its corresponding 0-based line number.
+
+  Args:
+    captions: A list of lists of strings.
+
+  Returns:
+    A Vocabulary object.
+  """
+  print("Creating vocabulary.")
+  counter = Counter()
+  for c in captions:
+    counter.update(c)
+  print("Total words:", len(counter))
+
+  # Filter uncommon words and sort by descending count.
+  word_counts = [x for x in counter.items() if x[1] >= FLAGS.min_word_count]
+  word_counts.sort(key=lambda x: x[1], reverse=True)
+  print("Words in vocabulary:", len(word_counts))
+
+  # Write out the word counts file.
+  with tf.gfile.FastGFile(FLAGS.word_counts_output_file, "w") as f:
+    f.write("\n".join(["%s %d" % (w, c) for w, c in word_counts]))
+  print("Wrote vocabulary file:", FLAGS.word_counts_output_file)
+
+  # Create the vocabulary dictionary.
+  reverse_vocab = [x[0] for x in word_counts]
+  unk_id = len(reverse_vocab)
+  vocab_dict = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
+  vocab = Vocabulary(vocab_dict, unk_id)
+
+  return vocab
+
+
+def _process_caption(caption):
+  """Processes a caption string into a list of tonenized words.
+
+  Args:
+    caption: A string caption.
+
+  Returns:
+    A list of strings; the tokenized caption.
+  """
+  tokenized_caption = [FLAGS.start_word]
+  tokenized_caption.extend(nltk.tokenize.word_tokenize(caption.lower()))
+  tokenized_caption.append(FLAGS.end_word)
+  return tokenized_caption
+
+
+def _load_and_process_metadata(captions_file, image_dir):
+  """Loads image metadata from a JSON file and processes the captions.
+
+  Args:
+    captions_file: JSON file containing caption annotations.
+    image_dir: Directory containing the image files.
+
+  Returns:
+    A list of ImageMetadata.
+  """
+  with tf.gfile.FastGFile(captions_file, "r") as f:
+    caption_data = json.load(f)
+
+  # Extract the filenames.
+  id_to_filename = [(x["id"], x["file_name"]) for x in caption_data["images"]]
+
+  # Extract the captions. Each image_id is associated with multiple captions.
+  id_to_captions = {}
+  for annotation in caption_data["annotations"]:
+    image_id = annotation["image_id"]
+    caption = annotation["caption"]
+    id_to_captions.setdefault(image_id, [])
+    id_to_captions[image_id].append(caption)
+
+  assert len(id_to_filename) == len(id_to_captions)
+  assert set([x[0] for x in id_to_filename]) == set(id_to_captions.keys())
+  print("Loaded caption metadata for %d images from %s" %
+        (len(id_to_filename), captions_file))
+
+  # Process the captions and combine the data into a list of ImageMetadata.
+  print("Proccessing captions.")
+  image_metadata = []
+  num_captions = 0
+  for image_id, base_filename in id_to_filename:
+    filename = os.path.join(image_dir, base_filename)
+    captions = [_process_caption(c) for c in id_to_captions[image_id]]
+    image_metadata.append(ImageMetadata(image_id, filename, captions))
+    num_captions += len(captions)
+  print("Finished processing %d captions for %d images in %s" %
+        (num_captions, len(id_to_filename), captions_file))
+
+  return image_metadata
+
+
+def main(unused_argv):
+  def _is_valid_num_shards(num_shards):
+    """Returns True if num_shards is compatible with FLAGS.num_threads."""
+    return num_shards < FLAGS.num_threads or not num_shards % FLAGS.num_threads
+
+  assert _is_valid_num_shards(FLAGS.train_shards), (
+      "Please make the FLAGS.num_threads commensurate with FLAGS.train_shards")
+  assert _is_valid_num_shards(FLAGS.val_shards), (
+      "Please make the FLAGS.num_threads commensurate with FLAGS.val_shards")
+  assert _is_valid_num_shards(FLAGS.test_shards), (
+      "Please make the FLAGS.num_threads commensurate with FLAGS.test_shards")
+
+  if not tf.gfile.IsDirectory(FLAGS.output_dir):
+    tf.gfile.MakeDirs(FLAGS.output_dir)
+
+  # Load image metadata from caption files.
+  mscoco_train_dataset = _load_and_process_metadata(FLAGS.train_captions_file,
+                                                    FLAGS.train_image_dir)
+  mscoco_val_dataset = _load_and_process_metadata(FLAGS.val_captions_file,
+                                                  FLAGS.val_image_dir)
+
+  # Redistribute the MSCOCO data as follows:
+  #   train_dataset = 100% of mscoco_train_dataset + 85% of mscoco_val_dataset.
+  #   val_dataset = 5% of mscoco_val_dataset (for validation during training).
+  #   test_dataset = 10% of mscoco_val_dataset (for final evaluation).
+  train_cutoff = int(0.85 * len(mscoco_val_dataset))
+  val_cutoff = int(0.90 * len(mscoco_val_dataset))
+  train_dataset = mscoco_train_dataset + mscoco_val_dataset[0:train_cutoff]
+  val_dataset = mscoco_val_dataset[train_cutoff:val_cutoff]
+  test_dataset = mscoco_val_dataset[val_cutoff:]
+
+  # Create vocabulary from the training captions.
+  train_captions = [c for image in train_dataset for c in image.captions]
+  vocab = _create_vocab(train_captions)
+
+  _process_dataset("train", train_dataset, vocab, FLAGS.train_shards)
+  _process_dataset("val", val_dataset, vocab, FLAGS.val_shards)
+  _process_dataset("test", test_dataset, vocab, FLAGS.test_shards)
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 84 - 0
im2txt/im2txt/data/download_and_preprocess_mscoco.sh

@@ -0,0 +1,84 @@
+#!/bin/bash
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# Script to download and preprocess the MSCOCO data set.
+#
+# The outputs of this script are sharded TFRecord files containing serialized
+# SequenceExample protocol buffers. See build_mscoco_data.py for details of how
+# the SequenceExample protocol buffers are constructed.
+#
+# usage:
+#  ./download_and_preprocess_mscoco.sh
+set -e
+
+if [ -z "$1" ]; then
+  echo "usage download_and_preproces_mscoco.sh [data dir]"
+  exit
+fi
+
+# Create the output directories.
+OUTPUT_DIR="${1%/}"
+SCRATCH_DIR="${OUTPUT_DIR}/raw-data"
+mkdir -p "${OUTPUT_DIR}"
+mkdir -p "${SCRATCH_DIR}"
+CURRENT_DIR=$(pwd)
+WORK_DIR="$0.runfiles/__main__/im2txt"
+
+# Helper function to download and unpack a .zip file.
+function download_and_unzip() {
+  local BASE_URL=${1}
+  local FILENAME=${2}
+
+  if [ ! -f ${FILENAME} ]; then
+    echo "Downloading ${FILENAME} to $(pwd)"
+    wget -nd -c "${BASE_URL}/${FILENAME}"
+  else
+    echo "Skipping download of ${FILENAME}"
+  fi
+  echo "Unzipping ${FILENAME}"
+  unzip -nq ${FILENAME}
+}
+
+cd ${SCRATCH_DIR}
+
+# Download the images.
+BASE_IMAGE_URL="http://msvocds.blob.core.windows.net/coco2014"
+
+TRAIN_IMAGE_FILE="train2014.zip"
+download_and_unzip ${BASE_IMAGE_URL} ${TRAIN_IMAGE_FILE}
+TRAIN_IMAGE_DIR="${SCRATCH_DIR}/train2014"
+
+VAL_IMAGE_FILE="val2014.zip"
+download_and_unzip ${BASE_IMAGE_URL} ${VAL_IMAGE_FILE}
+VAL_IMAGE_DIR="${SCRATCH_DIR}/val2014"
+
+# Download the captions.
+BASE_CAPTIONS_URL="http://msvocds.blob.core.windows.net/annotations-1-0-3"
+CAPTIONS_FILE="captions_train-val2014.zip"
+download_and_unzip ${BASE_CAPTIONS_URL} ${CAPTIONS_FILE}
+TRAIN_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_train2014.json"
+VAL_CAPTIONS_FILE="${SCRATCH_DIR}/annotations/captions_val2014.json"
+
+# Build TFRecords of the image data.
+cd "${CURRENT_DIR}"
+BUILD_SCRIPT="${WORK_DIR}/build_mscoco_data"
+"${BUILD_SCRIPT}" \
+  --train_image_dir="${TRAIN_IMAGE_DIR}" \
+  --val_image_dir="${VAL_IMAGE_DIR}" \
+  --train_captions_file="${TRAIN_CAPTIONS_FILE}" \
+  --val_captions_file="${VAL_CAPTIONS_FILE}" \
+  --output_dir="${OUTPUT_DIR}" \
+  --word_counts_output_file="${OUTPUT_DIR}/word_counts.txt" \

+ 194 - 0
im2txt/im2txt/evaluate.py

@@ -0,0 +1,194 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Evaluate the model.
+
+This script should be run concurrently with training so that summaries show up
+in TensorBoard.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os.path
+import time
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", "",
+                       "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("checkpoint_dir", "",
+                       "Directory containing model checkpoints.")
+tf.flags.DEFINE_string("eval_dir", "", "Directory to write event logs.")
+
+tf.flags.DEFINE_integer("eval_interval_secs", 600,
+                        "Interval between evaluation runs.")
+tf.flags.DEFINE_integer("num_eval_examples", 10132,
+                        "Number of examples for evaluation.")
+
+tf.flags.DEFINE_integer("min_global_step", 5000,
+                        "Minimum global step to run evaluation.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def evaluate_model(sess, model, global_step, summary_writer, summary_op):
+  """Computes perplexity-per-word over the evaluation dataset.
+
+  Summaries and perplexity-per-word are written out to the eval directory.
+
+  Args:
+    sess: Session object.
+    model: Instance of ShowAndTellModel; the model to evaluate.
+    global_step: Integer; global step of the model checkpoint.
+    summary_writer: Instance of SummaryWriter.
+    summary_op: Op for generating model summaries.
+  """
+  # Log model summaries on a single batch.
+  summary_str = sess.run(summary_op)
+  summary_writer.add_summary(summary_str, global_step)
+
+  # Compute perplexity over the entire dataset.
+  num_eval_batches = int(
+      math.ceil(FLAGS.num_eval_examples / model.config.batch_size))
+
+  start_time = time.time()
+  sum_losses = 0.
+  sum_weights = 0.
+  for i in xrange(num_eval_batches):
+    cross_entropy_losses, weights = sess.run([
+        model.target_cross_entropy_losses,
+        model.target_cross_entropy_loss_weights
+    ])
+    sum_losses += np.sum(cross_entropy_losses * weights)
+    sum_weights += np.sum(weights)
+    if not i % 100:
+      tf.logging.info("Computed losses for %d of %d batches.", i + 1,
+                      num_eval_batches)
+  eval_time = time.time() - start_time
+
+  perplexity = math.exp(sum_losses / sum_weights)
+  tf.logging.info("Perplexity = %f (%.2g sec)", perplexity, eval_time)
+
+  # Log perplexity to the SummaryWriter.
+  summary = tf.Summary()
+  value = summary.value.add()
+  value.simple_value = perplexity
+  value.tag = "Perplexity"
+  summary_writer.add_summary(summary, global_step)
+
+  # Write the Events file to the eval directory.
+  summary_writer.flush()
+  tf.logging.info("Finished processing evaluation at global step %d.",
+                  global_step)
+
+
+def run_once(model, summary_writer, summary_op):
+  """Evaluates the latest model checkpoint.
+
+  Args:
+    model: Instance of ShowAndTellModel; the model to evaluate.
+    summary_writer: Instance of SummaryWriter.
+    summary_op: Op for generating model summaries.
+  """
+  model_path = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
+  if not model_path:
+    tf.logging.info("Skipping evaluation. No checkpoint found in: %s",
+                    FLAGS.checkpoint_dir)
+    return
+
+  with tf.Session() as sess:
+    # Load model from checkpoint.
+    tf.logging.info("Loading model from checkpoint: %s", model_path)
+    model.saver.restore(sess, model_path)
+    global_step = tf.train.global_step(sess, model.global_step.name)
+    tf.logging.info("Successfully loaded %s at global step = %d.",
+                    os.path.basename(model_path), global_step)
+    if global_step < FLAGS.min_global_step:
+      tf.logging.info("Skipping evaluation. Global step = %d < %d", global_step,
+                      FLAGS.min_global_step)
+      return
+
+    # Start the queue runners.
+    coord = tf.train.Coordinator()
+    threads = tf.train.start_queue_runners(coord=coord)
+
+    # Run evaluation on the latest checkpoint.
+    try:
+      evaluate_model(
+          sess=sess,
+          model=model,
+          global_step=global_step,
+          summary_writer=summary_writer,
+          summary_op=summary_op)
+    except Exception, e:  # pylint: disable=broad-except
+      tf.logging.error("Evaluation failed.")
+      coord.request_stop(e)
+
+    coord.request_stop()
+    coord.join(threads, stop_grace_period_secs=10)
+
+
+def run():
+  """Runs evaluation in a loop, and logs summaries to TensorBoard."""
+  # Create the evaluation directory if it doesn't exist.
+  eval_dir = FLAGS.eval_dir
+  if not tf.gfile.IsDirectory(eval_dir):
+    tf.logging.info("Creating eval directory: %s", eval_dir)
+    tf.gfile.MakeDirs(eval_dir)
+
+  g = tf.Graph()
+  with g.as_default():
+    # Build the model for evaluation.
+    model_config = configuration.ModelConfig()
+    model_config.input_file_pattern = FLAGS.input_file_pattern
+    model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
+    model.build()
+
+    # Create the summary operation and the summary writer.
+    summary_op = tf.merge_all_summaries()
+    summary_writer = tf.train.SummaryWriter(eval_dir)
+
+    g.finalize()
+
+    # Run a new evaluation run every eval_interval_secs.
+    while True:
+      start = time.time()
+      tf.logging.info("Starting evaluation at " + time.strftime(
+          "%Y-%m-%d-%H:%M:%S", time.localtime()))
+      run_once(model, summary_writer, summary_op)
+      time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
+      if time_to_next_eval > 0:
+        time.sleep(time_to_next_eval)
+
+
+def main(unused_argv):
+  assert FLAGS.input_file_pattern, "--input_file_pattern is required"
+  assert FLAGS.checkpoint_dir, "--checkpoint_dir is required"
+  assert FLAGS.eval_dir, "--eval_dir is required"
+  run()
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 31 - 0
im2txt/im2txt/inference_utils/BUILD

@@ -0,0 +1,31 @@
+package(default_visibility = ["//im2txt:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "inference_wrapper_base",
+    srcs = ["inference_wrapper_base.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "vocabulary",
+    srcs = ["vocabulary.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "caption_generator",
+    srcs = ["caption_generator.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "caption_generator_test",
+    srcs = ["caption_generator_test.py"],
+    deps = [
+        ":caption_generator",
+    ],
+)

+ 201 - 0
im2txt/im2txt/inference_utils/caption_generator.py

@@ -0,0 +1,201 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Class for generating captions from an image-to-text model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import heapq
+import math
+
+
+import numpy as np
+
+
+class Caption(object):
+  """Represents a complete or partial caption."""
+
+  def __init__(self, sentence, state, logprob, score, metadata=None):
+    """Initializes the Caption.
+
+    Args:
+      sentence: List of word ids in the caption.
+      state: Model state after generating the previous word.
+      logprob: Log-probability of the caption.
+      score: Score of the caption.
+      metadata: Optional metadata associated with the partial sentence. If not
+        None, a list of strings with the same length as 'sentence'.
+    """
+    self.sentence = sentence
+    self.state = state
+    self.logprob = logprob
+    self.score = score
+    self.metadata = metadata
+
+  def __cmp__(self, other):
+    """Compares Captions by score."""
+    assert isinstance(other, Caption)
+    if self.score == other.score:
+      return 0
+    elif self.score < other.score:
+      return -1
+    else:
+      return 1
+
+
+class TopN(object):
+  """Maintains the top n elements of an incrementally provided set."""
+
+  def __init__(self, n):
+    self._n = n
+    self._data = []
+
+  def size(self):
+    assert self._data is not None
+    return len(self._data)
+
+  def push(self, x):
+    """Pushes a new element."""
+    assert self._data is not None
+    if len(self._data) < self._n:
+      heapq.heappush(self._data, x)
+    else:
+      heapq.heappushpop(self._data, x)
+
+  def extract(self, sort=False):
+    """Extracts all elements from the TopN. This is a destructive operation.
+
+    The only method that can be called immediately after extract() is reset().
+
+    Args:
+      sort: Whether to return the elements in descending sorted order.
+
+    Returns:
+      A list of data; the top n elements provided to the set.
+    """
+    assert self._data is not None
+    data = self._data
+    self._data = None
+    if sort:
+      data.sort(reverse=True)
+    return data
+
+  def reset(self):
+    """Returns the TopN to an empty state."""
+    self._data = []
+
+
+class CaptionGenerator(object):
+  """Class to generate captions from an image-to-text model."""
+
+  def __init__(self,
+               model,
+               vocab,
+               beam_size=3,
+               max_caption_length=20,
+               length_normalization_factor=0.0):
+    """Initializes the generator.
+
+    Args:
+      model: Object encapsulating a trained image-to-text model. Must have
+        methods feed_image() and inference_step(). For example, an instance of
+        InferenceWrapperBase.
+      vocab: A Vocabulary object.
+      beam_size: Beam size to use when generating captions.
+      max_caption_length: The maximum caption length before stopping the search.
+      length_normalization_factor: If != 0, a number x such that captions are
+        scored by logprob/length^x, rather than logprob. This changes the
+        relative scores of captions depending on their lengths. For example, if
+        x > 0 then longer captions will be favored.
+    """
+    self.vocab = vocab
+    self.model = model
+
+    self.beam_size = beam_size
+    self.max_caption_length = max_caption_length
+    self.length_normalization_factor = length_normalization_factor
+
+  def beam_search(self, sess, encoded_image):
+    """Runs beam search caption generation on a single image.
+
+    Args:
+      sess: TensorFlow Session object.
+      encoded_image: An encoded image string.
+
+    Returns:
+      A list of Caption sorted by descending score.
+    """
+    # Feed in the image to get the initial state.
+    initial_state = self.model.feed_image(sess, encoded_image)
+
+    initial_beam = Caption(
+        sentence=[self.vocab.start_id],
+        state=initial_state[0],
+        logprob=0.0,
+        score=0.0,
+        metadata=[""])
+    partial_captions = TopN(self.beam_size)
+    partial_captions.push(initial_beam)
+    complete_captions = TopN(self.beam_size)
+
+    # Run beam search.
+    for _ in range(self.max_caption_length - 1):
+      partial_captions_list = partial_captions.extract()
+      partial_captions.reset()
+      input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
+      state_feed = np.array([c.state for c in partial_captions_list])
+
+      softmax, new_states, metadata = self.model.inference_step(sess,
+                                                                input_feed,
+                                                                state_feed)
+
+      for i, partial_caption in enumerate(partial_captions_list):
+        word_probabilities = softmax[i]
+        state = new_states[i]
+        # For this partial caption, get the beam_size most probable next words.
+        words_and_probs = list(enumerate(word_probabilities))
+        words_and_probs.sort(key=lambda x: -x[1])
+        words_and_probs = words_and_probs[0:self.beam_size]
+        # Each next word gives a new partial caption.
+        for w, p in words_and_probs:
+          if p < 1e-12:
+            continue  # Avoid log(0).
+          sentence = partial_caption.sentence + [w]
+          logprob = partial_caption.logprob + math.log(p)
+          score = logprob
+          if metadata:
+            metadata_list = partial_caption.metadata + [metadata[i]]
+          else:
+            metadata_list = None
+          if w == self.vocab.end_id:
+            if self.length_normalization_factor > 0:
+              score /= len(sentence)**self.length_normalization_factor
+            beam = Caption(sentence, state, logprob, score, metadata_list)
+            complete_captions.push(beam)
+          else:
+            beam = Caption(sentence, state, logprob, score, metadata_list)
+            partial_captions.push(beam)
+      if partial_captions.size() == 0:
+        # We have run out of partial candidates; happens when beam_size = 1.
+        break
+
+    # If we have no complete captions then fall back to the partial captions.
+    # But never output a mixture of complete and partial captions because a
+    # partial caption could have a higher score than all the complete captions.
+    if not complete_captions.size():
+      complete_captions = partial_captions
+
+    return complete_captions.extract(sort=True)

+ 178 - 0
im2txt/im2txt/inference_utils/caption_generator_test.py

@@ -0,0 +1,178 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unit tests for CaptionGenerator."""
+
+import math
+
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt.inference_utils import caption_generator
+
+
+class FakeVocab(object):
+  """Fake Vocabulary for testing purposes."""
+
+  def __init__(self):
+    self.start_id = 0  # Word id denoting sentence start.
+    self.end_id = 1  # Word id denoting sentence end.
+
+
+class FakeModel(object):
+  """Fake model for testing purposes."""
+
+  def __init__(self):
+    # Number of words in the vocab.
+    self._vocab_size = 12
+
+    # Dimensionality of the nominal model state.
+    self._state_size = 1
+
+    # Map of previous word to the probability distribution of the next word.
+    self._probabilities = {
+        0: {1: 0.1,
+            2: 0.2,
+            3: 0.3,
+            4: 0.4},
+        2: {5: 0.1,
+            6: 0.9},
+        3: {1: 0.1,
+            7: 0.4,
+            8: 0.5},
+        4: {1: 0.3,
+            9: 0.3,
+            10: 0.4},
+        5: {1: 1.0},
+        6: {1: 1.0},
+        7: {1: 1.0},
+        8: {1: 1.0},
+        9: {1: 0.5,
+            11: 0.5},
+        10: {1: 1.0},
+        11: {1: 1.0},
+    }
+
+  # pylint: disable=unused-argument
+
+  def feed_image(self, sess, encoded_image):
+    # Return a nominal model state.
+    return np.zeros([1, self._state_size])
+
+  def inference_step(self, sess, input_feed, state_feed):
+    # Compute the matrix of softmax distributions for the next batch of words.
+    batch_size = input_feed.shape[0]
+    softmax_output = np.zeros([batch_size, self._vocab_size])
+    for batch_index, word_id in enumerate(input_feed):
+      for next_word, probability in self._probabilities[word_id].items():
+        softmax_output[batch_index, next_word] = probability
+
+    # Nominal state and metadata.
+    new_state = np.zeros([batch_size, self._state_size])
+    metadata = None
+
+    return softmax_output, new_state, metadata
+
+  # pylint: enable=unused-argument
+
+
+class CaptionGeneratorTest(tf.test.TestCase):
+
+  def _assertExpectedCaptions(self,
+                              expected_captions,
+                              beam_size=3,
+                              max_caption_length=20,
+                              length_normalization_factor=0):
+    """Tests that beam search generates the expected captions.
+
+    Args:
+      expected_captions: A sequence of pairs (sentence, probability), where
+        sentence is a list of integer ids and probability is a float in [0, 1].
+      beam_size: Parameter passed to beam_search().
+      max_caption_length: Parameter passed to beam_search().
+      length_normalization_factor: Parameter passed to beam_search().
+    """
+    expected_sentences = [c[0] for c in expected_captions]
+    expected_probabilities = [c[1] for c in expected_captions]
+
+    # Generate captions.
+    generator = caption_generator.CaptionGenerator(
+        model=FakeModel(),
+        vocab=FakeVocab(),
+        beam_size=beam_size,
+        max_caption_length=max_caption_length,
+        length_normalization_factor=length_normalization_factor)
+    actual_captions = generator.beam_search(sess=None, encoded_image=None)
+
+    actual_sentences = [c.sentence for c in actual_captions]
+    actual_probabilities = [math.exp(c.logprob) for c in actual_captions]
+
+    self.assertEqual(expected_sentences, actual_sentences)
+    self.assertAllClose(expected_probabilities, actual_probabilities)
+
+  def testBeamSize(self):
+    # Beam size = 1.
+    expected = [([0, 4, 10, 1], 0.16)]
+    self._assertExpectedCaptions(expected, beam_size=1)
+
+    # Beam size = 2.
+    expected = [([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)]
+    self._assertExpectedCaptions(expected, beam_size=2)
+
+    # Beam size = 3.
+    expected = [
+        ([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
+    ]
+    self._assertExpectedCaptions(expected, beam_size=3)
+
+  def testMaxLength(self):
+    # Max length = 1.
+    expected = [([0], 1.0)]
+    self._assertExpectedCaptions(expected, max_caption_length=1)
+
+    # Max length = 2.
+    # There are no complete sentences, so partial sentences are returned.
+    expected = [([0, 4], 0.4), ([0, 3], 0.3), ([0, 2], 0.2)]
+    self._assertExpectedCaptions(expected, max_caption_length=2)
+
+    # Max length = 3.
+    # There is at least one complete sentence, so only complete sentences are
+    # returned.
+    expected = [([0, 4, 1], 0.12), ([0, 3, 1], 0.03)]
+    self._assertExpectedCaptions(expected, max_caption_length=3)
+
+    # Max length = 4.
+    expected = [
+        ([0, 2, 6, 1], 0.18), ([0, 4, 10, 1], 0.16), ([0, 3, 8, 1], 0.15)
+    ]
+    self._assertExpectedCaptions(expected, max_caption_length=4)
+
+  def testLengthNormalization(self):
+    # Length normalization factor = 3.
+    # The longest caption is returned first, despite having low probability,
+    # because it has the highest log(probability)/length**3.
+    expected = [
+        ([0, 4, 9, 11, 1], 0.06),
+        ([0, 2, 6, 1], 0.18),
+        ([0, 4, 10, 1], 0.16),
+        ([0, 3, 8, 1], 0.15),
+    ]
+    self._assertExpectedCaptions(
+        expected, beam_size=4, length_normalization_factor=3)
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 183 - 0
im2txt/im2txt/inference_utils/inference_wrapper_base.py

@@ -0,0 +1,183 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base wrapper class for performing inference with an image-to-text model.
+
+Subclasses must implement the following methods:
+
+  build_model():
+    Builds the model for inference and returns the model object.
+
+  feed_image():
+    Takes an encoded image and returns the initial model state, where "state"
+    is a numpy array whose specifics are defined by the subclass, e.g.
+    concatenated LSTM state. It's assumed that feed_image() will be called
+    precisely once at the start of inference for each image. Subclasses may
+    compute and/or save per-image internal context in this method.
+
+  inference_step():
+    Takes a batch of inputs and states at a single time-step. Returns the
+    softmax output corresponding to the inputs, and the new states of the batch.
+    Optionally also returns metadata about the current inference step, e.g. a
+    serialized numpy array containing activations from a particular model layer.
+
+Client usage:
+  1. Build the model inference graph via build_graph_from_config() or
+     build_graph_from_proto().
+  2. Call the resulting restore_fn to load the model checkpoint.
+  3. For each image in a batch of images:
+     a) Call feed_image() once to get the initial state.
+     b) For each step of caption generation, call inference_step().
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+
+import tensorflow as tf
+
+# pylint: disable=unused-argument
+
+
+class InferenceWrapperBase(object):
+  """Base wrapper class for performing inference with an image-to-text model."""
+
+  def __init__(self):
+    pass
+
+  def build_model(self, model_config):
+    """Builds the model for inference.
+
+    Args:
+      model_config: Object containing configuration for building the model.
+
+    Returns:
+      model: The model object.
+    """
+    tf.logging.fatal("Please implement build_model in subclass")
+
+  def _create_restore_fn(self, checkpoint_path, saver):
+    """Creates a function that restores a model from checkpoint.
+
+    Args:
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+      saver: Saver for restoring variables from the checkpoint file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+
+    Raises:
+      ValueError: If checkpoint_path does not refer to a checkpoint file or a
+        directory containing a checkpoint file.
+    """
+    if tf.gfile.IsDirectory(checkpoint_path):
+      checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
+      if not checkpoint_path:
+        raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
+
+    def _restore_fn(sess):
+      tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
+      saver.restore(sess, checkpoint_path)
+      tf.logging.info("Successfully loaded checkpoint: %s",
+                      os.path.basename(checkpoint_path))
+
+    return _restore_fn
+
+  def build_graph_from_config(self, model_config, checkpoint_path):
+    """Builds the inference graph from a configuration object.
+
+    Args:
+      model_config: Object containing configuration for building the model.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    tf.logging.info("Building model.")
+    model = self.build_model(model_config)
+    saver = model.saver
+    if not saver:
+      saver = tf.Saver()
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def build_graph_from_proto(self, graph_def_file, saver_def_file,
+                             checkpoint_path):
+    """Builds the inference graph from serialized GraphDef and SaverDef protos.
+
+    Args:
+      graph_def_file: File containing a serialized GraphDef proto.
+      saver_def_file: File containing a serialized SaverDef proto.
+      checkpoint_path: Checkpoint file or a directory containing a checkpoint
+        file.
+
+    Returns:
+      restore_fn: A function such that restore_fn(sess) loads model variables
+        from the checkpoint file.
+    """
+    # Load the Graph.
+    tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
+    graph_def = tf.GraphDef()
+    with tf.gfile.FastGFile(graph_def_file, "rb") as f:
+      graph_def.ParseFromString(f.read())
+    tf.import_graph_def(graph_def, name="")
+
+    # Load the Saver.
+    tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
+    saver_def = tf.train.SaverDef()
+    with tf.gfile.FastGFile(saver_def_file, "rb") as f:
+      saver_def.ParseFromString(f.read())
+    saver = tf.train.Saver(saver_def=saver_def)
+
+    return self._create_restore_fn(checkpoint_path, saver)
+
+  def feed_image(self, sess, encoded_image):
+    """Feeds an image and returns the initial model state.
+
+    See comments at the top of file.
+
+    Args:
+      sess: TensorFlow Session object.
+      encoded_image: An encoded image string.
+
+    Returns:
+      state: A numpy array of shape [1, state_size].
+    """
+    tf.logging.fatal("Please implement feed_image in subclass")
+
+  def inference_step(self, sess, input_feed, state_feed):
+    """Runs one step of inference.
+
+    Args:
+      sess: TensorFlow Session object.
+      input_feed: A numpy array of shape [batch_size].
+      state_feed: A numpy array of shape [batch_size, state_size].
+
+    Returns:
+      softmax_output: A numpy array of shape [batch_size, vocab_size].
+      new_state: A numpy array of shape [batch_size, state_size].
+      metadata: Optional. If not None, a string containing metadata about the
+        current inference step (e.g. serialized numpy array containing
+        activations from a particular model layer.).
+    """
+    tf.logging.fatal("Please implement inference_step in subclass")
+
+# pylint: enable=unused-argument

+ 78 - 0
im2txt/im2txt/inference_utils/vocabulary.py

@@ -0,0 +1,78 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Vocabulary class for an image-to-text model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+class Vocabulary(object):
+  """Vocabulary class for an image-to-text model."""
+
+  def __init__(self,
+               vocab_file,
+               start_word="<S>",
+               end_word="</S>",
+               unk_word="<UNK>"):
+    """Initializes the vocabulary.
+
+    Args:
+      vocab_file: File containing the vocabulary, where the words are the first
+        whitespace-separated token on each line (other tokens are ignored) and
+        the word ids are the corresponding line numbers.
+      start_word: Special word denoting sentence start.
+      end_word: Special word denoting sentence end.
+      unk_word: Special word denoting unknown words.
+    """
+    if not tf.gfile.Exists(vocab_file):
+      tf.logging.fatal("Vocab file %s not found.", vocab_file)
+    tf.logging.info("Initializing vocabulary from file: %s", vocab_file)
+
+    with tf.gfile.GFile(vocab_file, mode="r") as f:
+      reverse_vocab = list(f.readlines())
+    reverse_vocab = [line.split()[0] for line in reverse_vocab]
+    assert start_word in reverse_vocab
+    assert end_word in reverse_vocab
+    if unk_word not in reverse_vocab:
+      reverse_vocab.append(unk_word)
+    vocab = dict([(x, y) for (y, x) in enumerate(reverse_vocab)])
+
+    tf.logging.info("Created vocabulary with %d words" % len(vocab))
+
+    self.vocab = vocab  # vocab[word] = id
+    self.reverse_vocab = reverse_vocab  # reverse_vocab[id] = word
+
+    # Save special word ids.
+    self.start_id = vocab[start_word]
+    self.end_id = vocab[end_word]
+    self.unk_id = vocab[unk_word]
+
+  def word_to_id(self, word):
+    """Returns the integer word id of a word string."""
+    if word in self.vocab:
+      return self.vocab[word]
+    else:
+      return self.unk_id
+
+  def id_to_word(self, word_id):
+    """Returns the word string of an integer word id."""
+    if word_id >= len(self.reverse_vocab):
+      return self.reverse_vocab[self.unk_id]
+    else:
+      return self.reverse_vocab[word_id]

+ 51 - 0
im2txt/im2txt/inference_wrapper.py

@@ -0,0 +1,51 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Model wrapper class for performing inference with a ShowAndTellModel."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+
+from im2txt import show_and_tell_model
+from im2txt.inference_utils import inference_wrapper_base
+
+
+class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
+  """Model wrapper class for performing inference with a ShowAndTellModel."""
+
+  def __init__(self):
+    super(InferenceWrapper, self).__init__()
+
+  def build_model(self, model_config):
+    model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
+    model.build()
+    return model
+
+  def feed_image(self, sess, encoded_image):
+    initial_state = sess.run(fetches="lstm/initial_state:0",
+                             feed_dict={"image_feed:0": encoded_image})
+    return initial_state
+
+  def inference_step(self, sess, input_feed, state_feed):
+    softmax_output, state_output = sess.run(
+        fetches=["softmax:0", "lstm/state:0"],
+        feed_dict={
+            "input_feed:0": input_feed,
+            "lstm/state_feed:0": state_feed,
+        })
+    return softmax_output, state_output, None

+ 32 - 0
im2txt/im2txt/ops/BUILD

@@ -0,0 +1,32 @@
+package(default_visibility = ["//im2txt:internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+    name = "image_processing",
+    srcs = ["image_processing.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "image_embedding",
+    srcs = ["image_embedding.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "image_embedding_test",
+    size = "small",
+    srcs = ["image_embedding_test.py"],
+    deps = [
+        ":image_embedding",
+    ],
+)
+
+py_library(
+    name = "inputs",
+    srcs = ["inputs.py"],
+    srcs_version = "PY2AND3",
+)

+ 114 - 0
im2txt/im2txt/ops/image_embedding.py

@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image embedding ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_base
+
+slim = tf.contrib.slim
+
+
+def inception_v3(images,
+                 trainable=True,
+                 is_training=True,
+                 weight_decay=0.00004,
+                 stddev=0.1,
+                 dropout_keep_prob=0.8,
+                 use_batch_norm=True,
+                 batch_norm_params=None,
+                 add_summaries=True,
+                 scope="InceptionV3"):
+  """Builds an Inception V3 subgraph for image embeddings.
+
+  Args:
+    images: A float32 Tensor of shape [batch, height, width, channels].
+    trainable: Whether the inception submodel should be trainable or not.
+    is_training: Boolean indicating training mode or not.
+    weight_decay: Coefficient for weight regularization.
+    stddev: The standard deviation of the trunctated normal weight initializer.
+    dropout_keep_prob: Dropout keep probability.
+    use_batch_norm: Whether to use batch normalization.
+    batch_norm_params: Parameters for batch normalization. See
+      tf.contrib.layers.batch_norm for details.
+    add_summaries: Whether to add activation summaries.
+    scope: Optional Variable scope.
+
+  Returns:
+    end_points: A dictionary of activations from inception_v3 layers.
+  """
+  # Only consider the inception model to be in training mode if it's trainable.
+  is_inception_model_training = trainable and is_training
+
+  if use_batch_norm:
+    # Default parameters for batch normalization.
+    if not batch_norm_params:
+      batch_norm_params = {
+          "is_training": is_inception_model_training,
+          "trainable": trainable,
+          # Decay for the moving averages.
+          "decay": 0.9997,
+          # Epsilon to prevent 0s in variance.
+          "epsilon": 0.001,
+          # Collection containing the moving mean and moving variance.
+          "variables_collections": {
+              "beta": None,
+              "gamma": None,
+              "moving_mean": ["moving_vars"],
+              "moving_variance": ["moving_vars"],
+          }
+      }
+  else:
+    batch_norm_params = None
+
+  if trainable:
+    weights_regularizer = tf.contrib.layers.l2_regularizer(weight_decay)
+  else:
+    weights_regularizer = None
+
+  with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
+    with slim.arg_scope(
+        [slim.conv2d, slim.fully_connected],
+        weights_regularizer=weights_regularizer,
+        trainable=trainable):
+      with slim.arg_scope(
+          [slim.conv2d],
+          weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
+          activation_fn=tf.nn.relu,
+          normalizer_fn=slim.batch_norm,
+          normalizer_params=batch_norm_params):
+        net, end_points = inception_v3_base(images, scope=scope)
+        with tf.variable_scope("logits"):
+          shape = net.get_shape()
+          net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
+          net = slim.dropout(
+              net,
+              keep_prob=dropout_keep_prob,
+              is_training=is_inception_model_training,
+              scope="dropout")
+          net = slim.flatten(net, scope="flatten")
+
+  # Add summaries.
+  if add_summaries:
+    for v in end_points.values():
+      tf.contrib.layers.summaries.summarize_activation(v)
+
+  return net

+ 136 - 0
im2txt/im2txt/ops/image_embedding_test.py

@@ -0,0 +1,136 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow_models.im2txt.ops.image_embedding."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt.ops import image_embedding
+
+
+class InceptionV3Test(tf.test.TestCase):
+
+  def setUp(self):
+    super(InceptionV3Test, self).setUp()
+
+    batch_size = 4
+    height = 299
+    width = 299
+    num_channels = 3
+    self._images = tf.placeholder(tf.float32,
+                                  [batch_size, height, width, num_channels])
+    self._batch_size = batch_size
+
+  def _countInceptionParameters(self):
+    """Counts the number of parameters in the inception model at top scope."""
+    counter = {}
+    for v in tf.all_variables():
+      name_tokens = v.op.name.split("/")
+      if name_tokens[0] == "InceptionV3":
+        name = "InceptionV3/" + name_tokens[1]
+        num_params = v.get_shape().num_elements()
+        assert num_params
+        counter[name] = counter.get(name, 0) + num_params
+    return counter
+
+  def _verifyParameterCounts(self):
+    """Verifies the number of parameters in the inception model."""
+    param_counts = self._countInceptionParameters()
+    expected_param_counts = {
+        "InceptionV3/Conv2d_1a_3x3": 960,
+        "InceptionV3/Conv2d_2a_3x3": 9312,
+        "InceptionV3/Conv2d_2b_3x3": 18624,
+        "InceptionV3/Conv2d_3b_1x1": 5360,
+        "InceptionV3/Conv2d_4a_3x3": 138816,
+        "InceptionV3/Mixed_5b": 256368,
+        "InceptionV3/Mixed_5c": 277968,
+        "InceptionV3/Mixed_5d": 285648,
+        "InceptionV3/Mixed_6a": 1153920,
+        "InceptionV3/Mixed_6b": 1298944,
+        "InceptionV3/Mixed_6c": 1692736,
+        "InceptionV3/Mixed_6d": 1692736,
+        "InceptionV3/Mixed_6e": 2143872,
+        "InceptionV3/Mixed_7a": 1699584,
+        "InceptionV3/Mixed_7b": 5047872,
+        "InceptionV3/Mixed_7c": 6080064,
+    }
+    self.assertDictEqual(expected_param_counts, param_counts)
+
+  def _assertCollectionSize(self, expected_size, collection):
+    actual_size = len(tf.get_collection(collection))
+    if expected_size != actual_size:
+      self.fail("Found %d items in collection %s (expected %d)." %
+                (actual_size, collection, expected_size))
+
+  def testTrainableTrueIsTrainingTrue(self):
+    embeddings = image_embedding.inception_v3(
+        self._images, trainable=True, is_training=True)
+    self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+    self._verifyParameterCounts()
+    self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
+    self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
+    self._assertCollectionSize(188, tf.GraphKeys.UPDATE_OPS)
+    self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
+    self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+    self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+  def testTrainableTrueIsTrainingFalse(self):
+    embeddings = image_embedding.inception_v3(
+        self._images, trainable=True, is_training=False)
+    self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+    self._verifyParameterCounts()
+    self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
+    self._assertCollectionSize(188, tf.GraphKeys.TRAINABLE_VARIABLES)
+    self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+    self._assertCollectionSize(94, tf.GraphKeys.REGULARIZATION_LOSSES)
+    self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+    self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+  def testTrainableFalseIsTrainingTrue(self):
+    embeddings = image_embedding.inception_v3(
+        self._images, trainable=False, is_training=True)
+    self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+    self._verifyParameterCounts()
+    self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
+    self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
+    self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+    self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
+    self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+    self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+  def testTrainableFalseIsTrainingFalse(self):
+    embeddings = image_embedding.inception_v3(
+        self._images, trainable=False, is_training=False)
+    self.assertEqual([self._batch_size, 2048], embeddings.get_shape().as_list())
+
+    self._verifyParameterCounts()
+    self._assertCollectionSize(376, tf.GraphKeys.VARIABLES)
+    self._assertCollectionSize(0, tf.GraphKeys.TRAINABLE_VARIABLES)
+    self._assertCollectionSize(0, tf.GraphKeys.UPDATE_OPS)
+    self._assertCollectionSize(0, tf.GraphKeys.REGULARIZATION_LOSSES)
+    self._assertCollectionSize(0, tf.GraphKeys.LOSSES)
+    self._assertCollectionSize(23, tf.GraphKeys.SUMMARIES)
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 134 - 0
im2txt/im2txt/ops/image_processing.py

@@ -0,0 +1,134 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Helper functions for image preprocessing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+def distort_image(image, thread_id):
+  """Perform random distortions on an image.
+
+  Args:
+    image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
+    thread_id: Preprocessing thread id used to select the ordering of color
+      distortions. There should be a multiple of 2 preprocessing threads.
+
+  Returns:
+    distorted_image: A float32 Tensor of shape [height, width, 3] with values in
+      [0, 1].
+  """
+  # Randomly flip horizontally.
+  with tf.name_scope("flip_horizontal", values=[image]):
+    image = tf.image.random_flip_left_right(image)
+
+  # Randomly distort the colors based on thread id.
+  color_ordering = thread_id % 2
+  with tf.name_scope("distort_color", values=[image]):
+    if color_ordering == 0:
+      image = tf.image.random_brightness(image, max_delta=32. / 255.)
+      image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+      image = tf.image.random_hue(image, max_delta=0.032)
+      image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+    elif color_ordering == 1:
+      image = tf.image.random_brightness(image, max_delta=32. / 255.)
+      image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+      image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+      image = tf.image.random_hue(image, max_delta=0.032)
+
+    # The random_* ops do not necessarily clamp.
+    image = tf.clip_by_value(image, 0.0, 1.0)
+
+  return image
+
+
+def process_image(encoded_image,
+                  is_training,
+                  height,
+                  width,
+                  resize_height=346,
+                  resize_width=346,
+                  thread_id=0,
+                  image_format="jpeg"):
+  """Decode an image, resize and apply random distortions.
+
+  In training, images are distorted slightly differently depending on thread_id.
+
+  Args:
+    encoded_image: String Tensor containing the image.
+    is_training: Boolean; whether preprocessing for training or eval.
+    height: Height of the output image.
+    width: Width of the output image.
+    resize_height: If > 0, resize height before crop to final dimensions.
+    resize_width: If > 0, resize width before crop to final dimensions.
+    thread_id: Preprocessing thread id used to select the ordering of color
+      distortions. There should be a multiple of 2 preprocessing threads.
+    image_format: "jpeg" or "png".
+
+  Returns:
+    A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
+
+  Raises:
+    ValueError: If image_format is invalid.
+  """
+  # Helper function to log an image summary to the visualizer. Summaries are
+  # only logged in thread 0.
+  def image_summary(name, image):
+    if not thread_id:
+      tf.image_summary(name, tf.expand_dims(image, 0))
+
+  # Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
+  with tf.name_scope("decode", values=[encoded_image]):
+    if image_format == "jpeg":
+      image = tf.image.decode_jpeg(encoded_image, channels=3)
+    elif image_format == "png":
+      image = tf.image.decode_png(encoded_image, channels=3)
+    else:
+      raise ValueError("Invalid image format: %s" % image_format)
+  image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+  image_summary("original_image", image)
+
+  # Resize image.
+  assert (resize_height > 0) == (resize_width > 0)
+  if resize_height:
+    image = tf.image.resize_images(image,
+                                   new_height=resize_height,
+                                   new_width=resize_width,
+                                   method=tf.image.ResizeMethod.BILINEAR)
+
+  # Crop to final dimensions.
+  if is_training:
+    image = tf.random_crop(image, [height, width, 3])
+  else:
+    # Central crop, assuming resize_height > height, resize_width > width.
+    image = tf.image.resize_image_with_crop_or_pad(image, height, width)
+
+  image_summary("resized_image", image)
+
+  # Randomly distort the image.
+  if is_training:
+    image = distort_image(image, thread_id)
+
+  image_summary("final_image", image)
+
+  # Rescale to [-1,1] instead of [0, 1]
+  image = tf.sub(image, 0.5)
+  image = tf.mul(image, 2.0)
+  return image

+ 204 - 0
im2txt/im2txt/ops/inputs.py

@@ -0,0 +1,204 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Input ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+
+def parse_sequence_example(serialized, image_feature, caption_feature):
+  """Parses a tensorflow.SequenceExample into an image and caption.
+
+  Args:
+    serialized: A scalar string Tensor; a single serialized SequenceExample.
+    image_feature: Name of SequenceExample context feature containing image
+      data.
+    caption_feature: Name of SequenceExample feature list containing integer
+      captions.
+
+  Returns:
+    encoded_image: A scalar string Tensor containing a JPEG encoded image.
+    caption: A 1-D uint64 Tensor with dynamically specified length.
+  """
+  context, sequence = tf.parse_single_sequence_example(
+      serialized,
+      context_features={
+          image_feature: tf.FixedLenFeature([], dtype=tf.string)
+      },
+      sequence_features={
+          caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64),
+      })
+
+  encoded_image = context[image_feature]
+  caption = sequence[caption_feature]
+  return encoded_image, caption
+
+
+def prefetch_input_data(reader,
+                        file_pattern,
+                        is_training,
+                        batch_size,
+                        values_per_shard,
+                        input_queue_capacity_factor=16,
+                        num_reader_threads=1,
+                        shard_queue_name="filename_queue",
+                        value_queue_name="input_queue"):
+  """Prefetches string values from disk into an input queue.
+
+  In training the capacity of the queue is important because a larger queue
+  means better mixing of training examples between shards. The minimum number of
+  values kept in the queue is values_per_shard * input_queue_capacity_factor,
+  where input_queue_memory factor should be chosen to trade-off better mixing
+  with memory usage.
+
+  Args:
+    reader: Instance of tf.ReaderBase.
+    file_pattern: Comma-separated list of file patterns (e.g.
+        /tmp/train_data-?????-of-00100).
+    is_training: Boolean; whether prefetching for training or eval.
+    batch_size: Model batch size used to determine queue capacity.
+    values_per_shard: Approximate number of values per shard.
+    input_queue_capacity_factor: Minimum number of values to keep in the queue
+      in multiples of values_per_shard. See comments above.
+    num_reader_threads: Number of reader threads to fill the queue.
+    shard_queue_name: Name for the shards filename queue.
+    value_queue_name: Name for the values input queue.
+
+  Returns:
+    A Queue containing prefetched string values.
+  """
+  data_files = []
+  for pattern in file_pattern.split(","):
+    data_files.extend(tf.gfile.Glob(pattern))
+  if not data_files:
+    tf.logging.fatal("Found no input files matching %s", file_pattern)
+  else:
+    tf.logging.info("Prefetching values from %d files matching %s",
+                    len(data_files), file_pattern)
+
+  if is_training:
+    filename_queue = tf.train.string_input_producer(
+        data_files, shuffle=True, capacity=16, name=shard_queue_name)
+    min_queue_examples = values_per_shard * input_queue_capacity_factor
+    capacity = min_queue_examples + 100 * batch_size
+    values_queue = tf.RandomShuffleQueue(
+        capacity=capacity,
+        min_after_dequeue=min_queue_examples,
+        dtypes=[tf.string],
+        name="random_" + value_queue_name)
+  else:
+    filename_queue = tf.train.string_input_producer(
+        data_files, shuffle=False, capacity=1, name=shard_queue_name)
+    capacity = values_per_shard + 3 * batch_size
+    values_queue = tf.FIFOQueue(
+        capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
+
+  enqueue_ops = []
+  for _ in range(num_reader_threads):
+    _, value = reader.read(filename_queue)
+    enqueue_ops.append(values_queue.enqueue([value]))
+  tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
+      values_queue, enqueue_ops))
+  tf.scalar_summary(
+      "queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
+      tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
+
+  return values_queue
+
+
+def batch_with_dynamic_pad(images_and_captions,
+                           batch_size,
+                           queue_capacity,
+                           add_summaries=True):
+  """Batches input images and captions.
+
+  This function splits the caption into an input sequence and a target sequence,
+  where the target sequence is the input sequence right-shifted by 1. Input and
+  target sequences are batched and padded up to the maximum length of sequences
+  in the batch. A mask is created to distinguish real words from padding words.
+
+  Example:
+    Actual captions in the batch ('-' denotes padded character):
+      [
+        [ 1 2 5 4 5 ],
+        [ 1 2 3 4 - ],
+        [ 1 2 3 - - ],
+      ]
+
+    input_seqs:
+      [
+        [ 1 2 3 4 ],
+        [ 1 2 3 - ],
+        [ 1 2 - - ],
+      ]
+
+    target_seqs:
+      [
+        [ 2 3 4 5 ],
+        [ 2 3 4 - ],
+        [ 2 3 - - ],
+      ]
+
+    mask:
+      [
+        [ 1 1 1 1 ],
+        [ 1 1 1 0 ],
+        [ 1 1 0 0 ],
+      ]
+
+  Args:
+    images_and_captions: A list of pairs [image, caption], where image is a
+      Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
+      any length. Each pair will be processed and added to the queue in a
+      separate thread.
+    batch_size: Batch size.
+    queue_capacity: Queue capacity.
+    add_summaries: If true, add caption length summaries.
+
+  Returns:
+    images: A Tensor of shape [batch_size, height, width, channels].
+    input_seqs: An int32 Tensor of shape [batch_size, padded_length].
+    target_seqs: An int32 Tensor of shape [batch_size, padded_length].
+    mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
+  """
+  enqueue_list = []
+  for image, caption in images_and_captions:
+    caption_length = tf.shape(caption)[0]
+    input_length = tf.expand_dims(tf.sub(caption_length, 1), 0)
+
+    input_seq = tf.slice(caption, [0], input_length)
+    target_seq = tf.slice(caption, [1], input_length)
+    indicator = tf.ones(input_length, dtype=tf.int32)
+    enqueue_list.append([image, input_seq, target_seq, indicator])
+
+  images, input_seqs, target_seqs, mask = tf.train.batch_join(
+      enqueue_list,
+      batch_size=batch_size,
+      capacity=queue_capacity,
+      dynamic_pad=True,
+      name="batch_and_pad")
+
+  if add_summaries:
+    lengths = tf.add(tf.reduce_sum(mask, 1), 1)
+    tf.scalar_summary("caption_length/batch_min", tf.reduce_min(lengths))
+    tf.scalar_summary("caption_length/batch_max", tf.reduce_max(lengths))
+    tf.scalar_summary("caption_length/batch_mean", tf.reduce_mean(lengths))
+
+  return images, input_seqs, target_seqs, mask

+ 83 - 0
im2txt/im2txt/run_inference.py

@@ -0,0 +1,83 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Generate captions for images using default beam search parameters."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+
+
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import inference_wrapper
+from im2txt.inference_utils import caption_generator
+from im2txt.inference_utils import vocabulary
+
+FLAGS = tf.flags.FLAGS
+
+tf.flags.DEFINE_string("checkpoint_path", "",
+                       "Model checkpoint file or directory containing a "
+                       "model checkpoint file.")
+tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.")
+tf.flags.DEFINE_string("input_files", "",
+                       "File pattern or comma-separated list of file patterns "
+                       "of image files.")
+
+
+def main(_):
+  # Build the inference graph.
+  g = tf.Graph()
+  with g.as_default():
+    model = inference_wrapper.InferenceWrapper()
+    restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
+                                               FLAGS.checkpoint_path)
+  g.finalize()
+
+  # Create the vocabulary.
+  vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
+
+  filenames = []
+  for file_pattern in FLAGS.input_files.split(","):
+    filenames.extend(tf.gfile.Glob(file_pattern))
+  tf.logging.info("Running caption generation on %d files matching %s",
+                  len(filenames), FLAGS.input_files)
+
+  with tf.Session(graph=g) as sess:
+    # Load the model from checkpoint.
+    restore_fn(sess)
+
+    # Prepare the caption generator. Here we are implicitly using the default
+    # beam search parameters. See caption_generator.py for a description of the
+    # available beam search parameters.
+    generator = caption_generator.CaptionGenerator(model, vocab)
+
+    for filename in filenames:
+      with tf.gfile.GFile(filename, "r") as f:
+        image = f.read()
+      captions = generator.beam_search(sess, image)
+      print("Captions for image %s:" % os.path.basename(filename))
+      for i, caption in enumerate(captions):
+        # Ignore begin and end words.
+        sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
+        sentence = " ".join(sentence)
+        print("  %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
+
+
+if __name__ == "__main__":
+  tf.app.run()

+ 364 - 0
im2txt/im2txt/show_and_tell_model.py

@@ -0,0 +1,364 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
+
+"Show and Tell: A Neural Image Caption Generator"
+Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt.ops import image_embedding
+from im2txt.ops import image_processing
+from im2txt.ops import inputs as input_ops
+
+
+class ShowAndTellModel(object):
+  """Image-to-text implementation based on http://arxiv.org/abs/1411.4555.
+
+  "Show and Tell: A Neural Image Caption Generator"
+  Oriol Vinyals, Alexander Toshev, Samy Bengio, Dumitru Erhan
+  """
+
+  def __init__(self, config, mode, train_inception=False):
+    """Basic setup.
+
+    Args:
+      config: Object containing configuration parameters.
+      mode: "train", "eval" or "inference".
+      train_inception: Whether the inception submodel variables are trainable.
+    """
+    assert mode in ["train", "eval", "inference"]
+    self.config = config
+    self.mode = mode
+    self.train_inception = train_inception
+
+    # Reader for the input data.
+    self.reader = tf.TFRecordReader()
+
+    # To match the "Show and Tell" paper we initialize all variables with a
+    # random uniform initializer.
+    self.initializer = tf.random_uniform_initializer(
+        minval=-self.config.initializer_scale,
+        maxval=self.config.initializer_scale)
+
+    # A float32 Tensor with shape [batch_size, height, width, channels].
+    self.images = None
+
+    # An int32 Tensor with shape [batch_size, padded_length].
+    self.input_seqs = None
+
+    # An int32 Tensor with shape [batch_size, padded_length].
+    self.target_seqs = None
+
+    # An int32 0/1 Tensor with shape [batch_size, padded_length].
+    self.input_mask = None
+
+    # A float32 Tensor with shape [batch_size, embedding_size].
+    self.image_embeddings = None
+
+    # A float32 Tensor with shape [batch_size, padded_length, embedding_size].
+    self.seq_embeddings = None
+
+    # A float32 scalar Tensor; the total loss for the trainer to optimize.
+    self.total_loss = None
+
+    # A float32 Tensor with shape [batch_size * padded_length].
+    self.target_cross_entropy_losses = None
+
+    # A float32 Tensor with shape [batch_size * padded_length].
+    self.target_cross_entropy_loss_weights = None
+
+    # Collection of variables from the inception submodel.
+    self.inception_variables = []
+
+    # Function to restore the inception submodel from checkpoint.
+    self.init_fn = None
+
+    # Global step Tensor.
+    self.global_step = None
+
+  def is_training(self):
+    """Returns true if the model is built for training mode."""
+    return self.mode == "train"
+
+  def process_image(self, encoded_image, thread_id=0):
+    """Decodes and processes an image string.
+
+    Args:
+      encoded_image: A scalar string Tensor; the encoded image.
+      thread_id: Preprocessing thread id used to select the ordering of color
+        distortions.
+
+    Returns:
+      A float32 Tensor of shape [height, width, 3]; the processed image.
+    """
+    return image_processing.process_image(encoded_image,
+                                          is_training=self.is_training(),
+                                          height=self.config.image_height,
+                                          width=self.config.image_width,
+                                          thread_id=thread_id,
+                                          image_format=self.config.image_format)
+
+  def build_inputs(self):
+    """Input prefetching, preprocessing and batching.
+
+    Outputs:
+      self.images
+      self.input_seqs
+      self.target_seqs (training and eval only)
+      self.input_mask (training and eval only)
+    """
+    if self.mode == "inference":
+      # In inference mode, images and inputs are fed via placeholders.
+      image_feed = tf.placeholder(dtype=tf.string, shape=[], name="image_feed")
+      input_feed = tf.placeholder(dtype=tf.int64,
+                                  shape=[None],  # batch_size
+                                  name="input_feed")
+
+      # Process image and insert batch dimensions.
+      images = tf.expand_dims(self.process_image(image_feed), 0)
+      input_seqs = tf.expand_dims(input_feed, 1)
+
+      # No target sequences or input mask in inference mode.
+      target_seqs = None
+      input_mask = None
+    else:
+      # Prefetch serialized SequenceExample protos.
+      input_queue = input_ops.prefetch_input_data(
+          self.reader,
+          self.config.input_file_pattern,
+          is_training=self.is_training(),
+          batch_size=self.config.batch_size,
+          values_per_shard=self.config.values_per_input_shard,
+          input_queue_capacity_factor=self.config.input_queue_capacity_factor,
+          num_reader_threads=self.config.num_input_reader_threads)
+
+      # Image processing and random distortion. Split across multiple threads
+      # with each thread applying a slightly different distortion.
+      assert self.config.num_preprocess_threads % 2 == 0
+      images_and_captions = []
+      for thread_id in range(self.config.num_preprocess_threads):
+        serialized_sequence_example = input_queue.dequeue()
+        encoded_image, caption = input_ops.parse_sequence_example(
+            serialized_sequence_example,
+            image_feature=self.config.image_feature_name,
+            caption_feature=self.config.caption_feature_name)
+        image = self.process_image(encoded_image, thread_id=thread_id)
+        images_and_captions.append([image, caption])
+
+      # Batch inputs.
+      queue_capacity = (2 * self.config.num_preprocess_threads *
+                        self.config.batch_size)
+      images, input_seqs, target_seqs, input_mask = (
+          input_ops.batch_with_dynamic_pad(images_and_captions,
+                                           batch_size=self.config.batch_size,
+                                           queue_capacity=queue_capacity))
+
+    self.images = images
+    self.input_seqs = input_seqs
+    self.target_seqs = target_seqs
+    self.input_mask = input_mask
+
+  def build_image_embeddings(self):
+    """Builds the image model subgraph and generates image embeddings.
+
+    Inputs:
+      self.images
+
+    Outputs:
+      self.image_embeddings
+    """
+    inception_output = image_embedding.inception_v3(
+        self.images,
+        trainable=self.train_inception,
+        is_training=self.is_training())
+    self.inception_variables = tf.get_collection(
+        tf.GraphKeys.VARIABLES, scope="InceptionV3")
+
+    # Map inception output into embedding space.
+    with tf.variable_scope("image_embedding") as scope:
+      image_embeddings = tf.contrib.layers.fully_connected(
+          inputs=inception_output,
+          num_outputs=self.config.embedding_size,
+          activation_fn=None,
+          weights_initializer=self.initializer,
+          biases_initializer=None,
+          scope=scope)
+
+    # Save the embedding size in the graph.
+    tf.constant(self.config.embedding_size, name="embedding_size")
+
+    self.image_embeddings = image_embeddings
+
+  def build_seq_embeddings(self):
+    """Builds the input sequence embeddings.
+
+    Inputs:
+      self.input_seqs
+
+    Outputs:
+      self.seq_embeddings
+    """
+    with tf.variable_scope("seq_embedding"), tf.device("/cpu:0"):
+      embedding_map = tf.get_variable(
+          name="map",
+          shape=[self.config.vocab_size, self.config.embedding_size],
+          initializer=self.initializer)
+      seq_embeddings = tf.nn.embedding_lookup(embedding_map, self.input_seqs)
+
+    self.seq_embeddings = seq_embeddings
+
+  def build_model(self):
+    """Builds the model.
+
+    Inputs:
+      self.image_embeddings
+      self.seq_embeddings
+      self.target_seqs (training and eval only)
+      self.input_mask (training and eval only)
+
+    Outputs:
+      self.total_loss (training and eval only)
+      self.target_cross_entropy_losses (training and eval only)
+      self.target_cross_entropy_loss_weights (training and eval only)
+    """
+    # This LSTM cell has biases and outputs tanh(new_c) * sigmoid(o), but the
+    # modified LSTM in the "Show and Tell" paper has no biases and outputs
+    # new_c * sigmoid(o).
+    lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(
+        num_units=self.config.num_lstm_units, state_is_tuple=True)
+    if self.mode == "train":
+      lstm_cell = tf.nn.rnn_cell.DropoutWrapper(
+          lstm_cell,
+          input_keep_prob=self.config.lstm_dropout_keep_prob,
+          output_keep_prob=self.config.lstm_dropout_keep_prob)
+
+    with tf.variable_scope("lstm", initializer=self.initializer) as lstm_scope:
+      # Feed the image embeddings to set the initial LSTM state.
+      zero_state = lstm_cell.zero_state(
+          batch_size=self.image_embeddings.get_shape()[0], dtype=tf.float32)
+      _, initial_state = lstm_cell(self.image_embeddings, zero_state)
+
+      # Allow the LSTM variables to be reused.
+      lstm_scope.reuse_variables()
+
+      if self.mode == "inference":
+        # In inference mode, use concatenated states for convenient feeding and
+        # fetching.
+        tf.concat(1, initial_state, name="initial_state")
+
+        # Placeholder for feeding a batch of concatenated states.
+        state_feed = tf.placeholder(dtype=tf.float32,
+                                    shape=[None, sum(lstm_cell.state_size)],
+                                    name="state_feed")
+        state_tuple = tf.split(1, 2, state_feed)
+
+        # Run a single LSTM step.
+        lstm_outputs, state_tuple = lstm_cell(
+            inputs=tf.squeeze(self.seq_embeddings, squeeze_dims=[1]),
+            state=state_tuple)
+
+        # Concatentate the resulting state.
+        tf.concat(1, state_tuple, name="state")
+      else:
+        # Run the batch of sequence embeddings through the LSTM.
+        sequence_length = tf.reduce_sum(self.input_mask, 1)
+        lstm_outputs, _ = tf.nn.dynamic_rnn(cell=lstm_cell,
+                                            inputs=self.seq_embeddings,
+                                            sequence_length=sequence_length,
+                                            initial_state=initial_state,
+                                            dtype=tf.float32,
+                                            scope=lstm_scope)
+
+    # Stack batches vertically.
+    lstm_outputs = tf.reshape(lstm_outputs, [-1, lstm_cell.output_size])
+
+    with tf.variable_scope("logits") as logits_scope:
+      logits = tf.contrib.layers.fully_connected(
+          inputs=lstm_outputs,
+          num_outputs=self.config.vocab_size,
+          activation_fn=None,
+          weights_initializer=self.initializer,
+          scope=logits_scope)
+
+    if self.mode == "inference":
+      tf.nn.softmax(logits, name="softmax")
+    else:
+      targets = tf.reshape(self.target_seqs, [-1])
+      weights = tf.to_float(tf.reshape(self.input_mask, [-1]))
+
+      # Compute losses.
+      losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, targets)
+      batch_loss = tf.div(tf.reduce_sum(tf.mul(losses, weights)),
+                          tf.reduce_sum(weights),
+                          name="batch_loss")
+      tf.contrib.losses.add_loss(batch_loss)
+      total_loss = tf.contrib.losses.get_total_loss()
+
+      # Add summaries.
+      tf.scalar_summary("batch_loss", batch_loss)
+      tf.scalar_summary("total_loss", total_loss)
+      for var in tf.trainable_variables():
+        tf.histogram_summary(var.op.name, var)
+
+      self.total_loss = total_loss
+      self.target_cross_entropy_losses = losses  # Used in evaluation.
+      self.target_cross_entropy_loss_weights = weights  # Used in evaluation.
+
+  def setup_inception_initializer(self):
+    """Sets up the function to restore inception variables from checkpoint."""
+    if self.mode != "inference":
+      # Restore inception variables only.
+      saver = tf.train.Saver(self.inception_variables)
+
+      def restore_fn(sess):
+        tf.logging.info("Restoring Inception variables from checkpoint file %s",
+                        self.config.inception_checkpoint_file)
+        saver.restore(sess, self.config.inception_checkpoint_file)
+
+      self.init_fn = restore_fn
+
+  def setup_global_step(self):
+    """Sets up the global step Tensor."""
+    global_step = tf.Variable(
+        initial_value=0,
+        name="global_step",
+        trainable=False,
+        collections=[tf.GraphKeys.GLOBAL_STEP, tf.GraphKeys.VARIABLES])
+
+    self.global_step = global_step
+
+  def setup_saver(self):
+    """Sets up the Saver for loading and saving model checkpoints."""
+    self.saver = tf.train.Saver(
+        max_to_keep=self.config.max_checkpoints_to_keep,
+        keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours)
+
+  def build(self):
+    """Creates all ops for training and evaluation."""
+    self.build_inputs()
+    self.build_image_embeddings()
+    self.build_seq_embeddings()
+    self.build_model()
+    self.setup_inception_initializer()
+    self.setup_global_step()
+    self.setup_saver()

+ 200 - 0
im2txt/im2txt/show_and_tell_model_test.py

@@ -0,0 +1,200 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for tensorflow_models.im2txt.show_and_tell_model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import numpy as np
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+
+class ShowAndTellModel(show_and_tell_model.ShowAndTellModel):
+  """Subclass of ShowAndTellModel without the disk I/O."""
+
+  def build_inputs(self):
+    if self.mode == "inference":
+      # Inference mode doesn't read from disk, so defer to parent.
+      return super(ShowAndTellModel, self).build_inputs()
+    else:
+      # Replace disk I/O with random Tensors.
+      self.images = tf.random_uniform(
+          shape=[self.config.batch_size, self.config.image_height,
+                 self.config.image_width, 3],
+          minval=-1,
+          maxval=1)
+      self.input_seqs = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.target_seqs = tf.random_uniform(
+          [self.config.batch_size, 15],
+          minval=0,
+          maxval=self.config.vocab_size,
+          dtype=tf.int64)
+      self.input_mask = tf.ones_like(self.input_seqs)
+
+
+class ShowAndTellModelTest(tf.test.TestCase):
+
+  def setUp(self):
+    super(ShowAndTellModelTest, self).setUp()
+    self._model_config = configuration.ModelConfig()
+
+  def _countModelParameters(self):
+    """Counts the number of parameters in the model at top level scope."""
+    counter = {}
+    for v in tf.all_variables():
+      name = v.op.name.split("/")[0]
+      num_params = v.get_shape().num_elements()
+      assert num_params
+      counter[name] = counter.get(name, 0) + num_params
+    return counter
+
+  def _checkModelParameters(self):
+    """Verifies the number of parameters in the model."""
+    param_counts = self._countModelParameters()
+    expected_param_counts = {
+        "InceptionV3": 21802784,
+        # inception_output_size * embedding_size
+        "image_embedding": 1048576,
+        # vocab_size * embedding_size
+        "seq_embedding": 6144000,
+        # (embedding_size + num_lstm_units + 1) * 4 * num_lstm_units
+        "lstm": 2099200,
+        # (num_lstm_units + 1) * vocab_size
+        "logits": 6156000,
+        "global_step": 1,
+    }
+    self.assertDictEqual(expected_param_counts, param_counts)
+
+  def _checkOutputs(self, expected_shapes, feed_dict=None):
+    """Verifies that the model produces expected outputs.
+
+    Args:
+      expected_shapes: A dict mapping Tensor or Tensor name to expected output
+        shape.
+      feed_dict: Values of Tensors to feed into Session.run().
+    """
+    fetches = expected_shapes.keys()
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      outputs = sess.run(fetches, feed_dict)
+
+    for index, output in enumerate(outputs):
+      tensor = fetches[index]
+      expected = expected_shapes[tensor]
+      actual = output.shape
+      if expected != actual:
+        self.fail("Tensor %s has shape %s (expected %s)." %
+                  (tensor, actual, expected))
+
+  def testBuildForTraining(self):
+    model = ShowAndTellModel(self._model_config, mode="train")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, image_height, image_width, 3]
+        model.images: (32, 299, 299, 3),
+        # [batch_size, sequence_length]
+        model.input_seqs: (32, 15),
+        # [batch_size, sequence_length]
+        model.target_seqs: (32, 15),
+        # [batch_size, sequence_length]
+        model.input_mask: (32, 15),
+        # [batch_size, embedding_size]
+        model.image_embeddings: (32, 512),
+        # [batch_size, sequence_length, embedding_size]
+        model.seq_embeddings: (32, 15, 512),
+        # Scalar
+        model.total_loss: (),
+        # [batch_size * sequence_length]
+        model.target_cross_entropy_losses: (480,),
+        # [batch_size * sequence_length]
+        model.target_cross_entropy_loss_weights: (480,),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForEval(self):
+    model = ShowAndTellModel(self._model_config, mode="eval")
+    model.build()
+
+    self._checkModelParameters()
+
+    expected_shapes = {
+        # [batch_size, image_height, image_width, 3]
+        model.images: (32, 299, 299, 3),
+        # [batch_size, sequence_length]
+        model.input_seqs: (32, 15),
+        # [batch_size, sequence_length]
+        model.target_seqs: (32, 15),
+        # [batch_size, sequence_length]
+        model.input_mask: (32, 15),
+        # [batch_size, embedding_size]
+        model.image_embeddings: (32, 512),
+        # [batch_size, sequence_length, embedding_size]
+        model.seq_embeddings: (32, 15, 512),
+        # Scalar
+        model.total_loss: (),
+        # [batch_size * sequence_length]
+        model.target_cross_entropy_losses: (480,),
+        # [batch_size * sequence_length]
+        model.target_cross_entropy_loss_weights: (480,),
+    }
+    self._checkOutputs(expected_shapes)
+
+  def testBuildForInference(self):
+    model = ShowAndTellModel(self._model_config, mode="inference")
+    model.build()
+
+    self._checkModelParameters()
+
+    # Test feeding an image to get the initial LSTM state.
+    images_feed = np.random.rand(1, 299, 299, 3)
+    feed_dict = {model.images: images_feed}
+    expected_shapes = {
+        # [batch_size, embedding_size]
+        model.image_embeddings: (1, 512),
+        # [batch_size, 2 * num_lstm_units]
+        "lstm/initial_state:0": (1, 1024),
+    }
+    self._checkOutputs(expected_shapes, feed_dict)
+
+    # Test feeding a batch of inputs and LSTM states to get softmax output and
+    # LSTM states.
+    input_feed = np.random.randint(0, 10, size=3)
+    state_feed = np.random.rand(3, 1024)
+    feed_dict = {"input_feed:0": input_feed, "lstm/state_feed:0": state_feed}
+    expected_shapes = {
+        # [batch_size, 2 * num_lstm_units]
+        "lstm/state:0": (3, 1024),
+        # [batch_size, vocab_size]
+        "softmax:0": (3, 12000),
+    }
+    self._checkOutputs(expected_shapes, feed_dict)
+
+
+if __name__ == "__main__":
+  tf.test.main()

+ 111 - 0
im2txt/im2txt/train.py

@@ -0,0 +1,111 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Train the model."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from im2txt import configuration
+from im2txt import show_and_tell_model
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.flags.DEFINE_string("input_file_pattern", "",
+                       "File pattern of sharded TFRecord input files.")
+tf.flags.DEFINE_string("inception_checkpoint_file", "",
+                       "Path to a pretrained inception_v3 model.")
+tf.flags.DEFINE_string("train_dir", "",
+                       "Directory for saving and loading model checkpoints.")
+tf.flags.DEFINE_boolean("train_inception", False,
+                        "Whether to train inception submodel variables.")
+tf.flags.DEFINE_integer("number_of_steps", 1000000, "Number of training steps.")
+tf.flags.DEFINE_integer("log_every_n_steps", 1,
+                        "Frequency at which loss and global step are logged.")
+
+tf.logging.set_verbosity(tf.logging.INFO)
+
+
+def main(unused_argv):
+  assert FLAGS.input_file_pattern, "--input_file_pattern is required"
+  assert FLAGS.train_dir, "--train_dir is required"
+
+  model_config = configuration.ModelConfig()
+  model_config.input_file_pattern = FLAGS.input_file_pattern
+  model_config.inception_checkpoint_file = FLAGS.inception_checkpoint_file
+  training_config = configuration.TrainingConfig()
+
+  # Create training directory.
+  train_dir = FLAGS.train_dir
+  if not tf.gfile.IsDirectory(train_dir):
+    tf.logging.info("Creating training directory: %s", train_dir)
+    tf.gfile.MakeDirs(train_dir)
+
+  # Build the TensorFlow graph.
+  g = tf.Graph()
+  with g.as_default():
+    # Build the model.
+    model = show_and_tell_model.ShowAndTellModel(
+        model_config, mode="train", train_inception=FLAGS.train_inception)
+    model.build()
+
+    # Set up the learning rate.
+    learning_rate_decay_fn = None
+    if FLAGS.train_inception:
+      learning_rate = tf.constant(training_config.train_inception_learning_rate)
+    else:
+      learning_rate = tf.constant(training_config.initial_learning_rate)
+      if training_config.learning_rate_decay_factor > 0:
+        num_batches_per_epoch = (training_config.num_examples_per_epoch /
+                                 model_config.batch_size)
+        decay_steps = int(num_batches_per_epoch *
+                          training_config.num_epochs_per_decay)
+
+        def _learning_rate_decay_fn(learning_rate, global_step):
+          return tf.train.exponential_decay(
+              learning_rate,
+              global_step,
+              decay_steps=decay_steps,
+              decay_rate=training_config.learning_rate_decay_factor,
+              staircase=True)
+
+        learning_rate_decay_fn = _learning_rate_decay_fn
+
+    # Set up the training ops.
+    train_op = tf.contrib.layers.optimize_loss(
+        loss=model.total_loss,
+        global_step=model.global_step,
+        learning_rate=learning_rate,
+        optimizer=training_config.optimizer,
+        clip_gradients=training_config.clip_gradients,
+        learning_rate_decay_fn=learning_rate_decay_fn)
+
+  # Run training.
+  tf.contrib.slim.learning.train(
+      train_op,
+      train_dir,
+      log_every_n_steps=FLAGS.log_every_n_steps,
+      graph=g,
+      global_step=model.global_step,
+      number_of_steps=FLAGS.number_of_steps,
+      init_fn=model.init_fn,
+      saver=model.saver)
+
+
+if __name__ == "__main__":
+  tf.app.run()