Browse Source

Initial tf-slim checkin (#349)

nathansilberman 9 years ago
parent
commit
a5c4fd06d2

+ 147 - 0
slim/BUILD

@@ -0,0 +1,147 @@
+# Description:
+#   Contains files for loading, training and evaluating TF-Slim 2.0-based models.
+
+package(default_visibility = [":internal"])
+
+licenses(["notice"])  # Apache 2.0
+
+exports_files(["LICENSE"])
+
+package_group(
+    name = "internal",
+    packages = ["//slim/"],
+)
+
+py_library(
+    name = "dataset_utils",
+    srcs = ["datasets/dataset_utils.py"],
+)
+
+py_binary(
+    name = "download_and_convert_cifar10",
+    srcs = ["datasets/download_and_convert_cifar10.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "download_and_convert_flowers",
+    srcs = ["datasets/download_and_convert_flowers.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "download_and_convert_mnist",
+    srcs = ["datasets/download_and_convert_mnist.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "cifar10",
+    srcs = ["datasets/cifar10.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "flowers",
+    srcs = ["datasets/flowers.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "imagenet",
+    srcs = ["datasets/imagenet.py"],
+    deps = [":dataset_utils"],
+)
+
+py_binary(
+    name = "mnist",
+    srcs = ["datasets/mnist.py"],
+    deps = [":dataset_utils"],
+)
+
+py_library(
+    name = "dataset_factory",
+    srcs = ["datasets/dataset_factory.py"],
+    deps = [
+        ":cifar10",
+        ":flowers",
+        ":imagenet",
+        ":mnist",
+    ],
+)
+
+py_binary(
+    name = "eval",
+    srcs = ["eval.py"],
+    deps = [
+        ":dataset_factory",
+        ":model_deploy",
+        ":model_factory",
+        ":preprocessing_factory",
+    ],
+)
+
+py_library(
+    name = "model_deploy",
+    srcs = ["models/model_deploy.py"],
+)
+
+py_test(
+    name = "model_deploy_test",
+    srcs = ["models/model_deploy_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":model_deploy"],
+)
+
+py_library(
+    name = "cifar10_preprocessing",
+    srcs = ["models/cifar10_preprocessing.py"],
+)
+
+py_library(
+    name = "inception_preprocessing",
+    srcs = ["models/inception_preprocessing.py"],
+)
+
+py_library(
+    name = "lenet_preprocessing",
+    srcs = ["models/lenet_preprocessing.py"],
+)
+
+py_library(
+    name = "vgg_preprocessing",
+    srcs = ["models/vgg_preprocessing.py"],
+)
+
+py_library(
+    name = "preprocessing_factory",
+    srcs = ["models/preprocessing_factory.py"],
+    deps = [
+        ":cifar10_preprocessing",
+        ":inception_preprocessing",
+        ":lenet_preprocessing",
+        ":vgg_preprocessing",
+    ],
+)
+
+py_library(
+    name = "lenet",
+    srcs = ["nets/lenet.py"],
+)
+
+py_library(
+    name = "model_factory",
+    srcs = ["models/model_factory.py"],
+    deps = [":lenet"],
+)
+
+py_binary(
+    name = "train",
+    srcs = ["train.py"],
+    deps = [
+        ":dataset_factory",
+        ":model_deploy",
+        ":model_factory",
+        ":preprocessing_factory",
+    ],
+)

+ 398 - 0
slim/README.md

@@ -0,0 +1,398 @@
+# Image Classification Models in TF-Slim
+
+This directory contains scripts for training and evaluating models using
+TF-Slim. In particular the code base provides core binaries for:
+
+* Training a model from scratch on a given dataset.
+* Fine-tuning a model from a particular checkpoint on a given dataset.
+* Evaluating a trained model on a given dataset.
+
+All scripts are highly configurable via command-line flags. They support
+training and evaluation using a variety of architectures and datasets.
+
+# Getting Started
+
+**NOTE** Before doing anything, we first need to build TensorFlow from the
+latest nightly build. You can find the latest nightly binaries at
+[TensorFlow Installation](https://github.com/tensorflow/tensorflow#installation)
+under the header that reads "People who are a little more adventurous can
+also try our nightly binaries". Next, copy the link address that corresponds to
+the appropriate machine architecture and python version. Finally, pip install
+(upgrade) using the appropriate file.
+
+For example:
+
+```shell
+export TF_BINARY_URL=https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl
+
+sudo pip install --upgrade $TF_BINARY_URL
+```
+
+To compile the training and evaluation scripts, we also need to install bazel.
+You can find step-by-step instructions
+[here](http://bazel.io/docs/install.html).
+
+Next, you'll need to install
+[tensorflow/models/slim](https://github.com/tensorflow/models/tree/master/slim).
+If you want to use the ImageNet dataset, you'll also need to install
+[tensorflow/models/inception](https://github.com/tensorflow/models/tree/master/inception).
+Note that this directory contains an older version of slim which has been
+deprecated and can be safely ignored.
+
+# Datasets
+
+As part of this library, we've included scripts to download several popular
+datasets and convert them to TensorFlow's native TFRecord format. Each labeled
+image is represented as a
+[TF-Example](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/core/example/example.proto)
+protocol buffer.
+
+Dataset | Download Script | Dataset Specification | Description
+:------:|:---------------:|:---------------------:|:-----------
+[Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)|[Script](https://github.com/tensorflow/models/blob/master/slim/datasets/download_and_convert_cifar10.py)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/cifar10.py)|The cifar10 dataset contains 60,000 training and 10,000 testing images of 10 different object classes.
+[Flowers](https://github.com/tensorflow/models/blob/master/inception/README.md)|[Script](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_flowers.sh)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/flowers.py)|The Flowers dataset contains 2500 images of flowers with 5 different labels.
+[MNIST](http://yann.lecun.com/exdb/mnist/)|[Script](https://github.com/tensorflow/models/blob/master/slim/datasets/download_and_convert_mnist.py)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/mnist.py)|The MNIST dataset contains 60,000 training 10,000 testing grayscale images of digits.
+[ImageNet](http://www.image-net.org/)|[Script](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_imagenet.sh)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/imagenet.py)|The ImageNet dataset contains about 1.2 million training and 50,000 validation images with 1000 different labels.
+
+Below we describe the python scripts which download these datasets and convert
+to TF Record format. Once in this format, the data can easily be read by
+TensorFlow by providing a TF-Slim
+[Dataset](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/contrib/slim/python/slim/data/dataset.py)
+specification. We have included, as a part of the release, the
+[Dataset](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/contrib/slim/python/slim/data/dataset.py)
+specifications for each of these datasets as well.
+
+## Preparing the Cifar10 Dataset
+
+In order to use the Cifar10 dataset, the data must first be downloaded and
+converted to the native TFRecord format.
+
+```shell
+# Specify the directory of the Cifar10 data:
+$ DATA_DIR=$HOME/cifar10
+
+# Build the dataset creation script.
+$ bazel build slim:download_and_convert_cifar10
+
+# Run the dataset creation.
+$ ./bazel-bin/slim/download_and_convert_cifar10 --dataset_dir="${DATA_DIR}"
+```
+
+The final line of the output script should read:
+
+```shell
+Reading file [cifar-10-batches-py/test_batch], image 10000/10000
+Finished extracting the Cifar10 dataset!
+```
+
+When the script finishes you will find two TFRecord files created,
+`$DATA_DIR/cifar10_train.tfrecord` and `$DATA_DIR/cifar10_test.tfrecord`,
+which represent the training and testing sets respectively. You will also find
+a `$DATA_DIR/labels.txt` file which contains the mapping from integer labels
+to class names.
+
+## Preparing the Flowers Dataset
+
+In order to use the Flowers dataset, the data must first be downloaded and
+converted to the native TFRecord format.
+
+```shell
+# Specify the directory of the Flowers data:
+$ DATA_DIR=$HOME/flowers
+
+# Build the dataset creation script.
+$ bazel build slim:download_and_convert_flowers
+
+# Run the dataset creation.
+$ ./bazel-bin/slim/download_and_convert_flowers --dataset_dir="${DATA_DIR}"
+```
+
+The final lines of the output script should read:
+
+```shell
+>> Converting image 3320/3320 shard 4
+>> Converting image 350/350 shard 4
+
+Finished converting the Flowers dataset!
+```
+
+When the script finishes you will find several TFRecord files created:
+
+```shell
+$ ls ${DATA_DIR}
+flowers_train-00000-of-00005.tfrecord
+flowers_train-00001-of-00005.tfrecord
+flowers_train-00002-of-00005.tfrecord
+flowers_train-00003-of-00005.tfrecord
+flowers_train-00004-of-00005.tfrecord
+flowers_validation-00000-of-00005.tfrecord
+flowers_validation-00001-of-00005.tfrecord
+flowers_validation-00002-of-00005.tfrecord
+flowers_validation-00003-of-00005.tfrecord
+flowers_validation-00004-of-00005.tfrecord
+labels.txt
+```
+
+These represent the training and validation data, sharded over 5 files each.
+You will also find the `$DATA_DIR/labels.txt` file which contains the mapping
+from integer labels to class names.
+
+## Preparing the MNIST Dataset
+
+In order to use the MNIST dataset, the data must first be downloaded and
+converted to the native TFRecord format.
+
+```shell
+# Specify the directory of the MNIST data:
+$ DATA_DIR=$HOME/mnist
+
+# Build the dataset creation script.
+$ bazel build slim:download_and_convert_mnist
+
+# Run the dataset creation.
+$ ./bazel-bin/slim/download_and_convert_mnist --dataset_dir="${DATA_DIR}"
+```
+
+The final line of the output script should read:
+
+```shell
+>> Converting image 10000/10000
+Finished extracting the MNIST dataset!
+```
+
+When the script finishes you will find two TFRecord files created,
+`$DATA_DIR/mnist_train.tfrecord` and `$DATA_DIR/mnist_test.tfrecord`,
+which represent the training and testing sets respectively.  You will also find
+a `$DATA_DIR/labels.txt` file which contains the mapping from integer labels
+to class names.
+
+## Preparing the ImageNet Dataset
+
+To use the ImageNet dataset, follow the instructions in the
+[tensorflow/models/inception](https://github.com/tensorflow/models/blob/master/inception/README.md#getting-started)
+repository. In particular see file
+[download_and_preprocess_imagenet.sh](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_imagenet.sh)
+
+## Pre-trained Models
+
+For convenience, we have provided a number of pre-trained image classification
+models which are listed below. These neural networks been trained on the
+ILSVRC-2012-CLS dataset which is comprised of ~1.2 million images and annotated
+with 1000 mutually exclusive class labels.
+
+In the table below, we present each of these models, the corresponding
+TensorFlow model file, the link to the model checkpoint and the top 1 and top 5
+accuracy.
+Note that the VGG and ResNet parameters have been converted from their original
+caffe formats
+([here](https://github.com/BVLC/caffe/wiki/Model-Zoo#models-used-by-the-vgg-team-in-ilsvrc-2014)
+and
+[here](https://github.com/KaimingHe/deep-residual-networks)), whereas the Inception parameters have been trained internally at
+Google. Also be aware that these accuracies were computed by evaluating using a
+single image crop. Some academic papers report higher accuracy by using multiple
+crops at multiple scales.
+
+Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
+:----:|:------------:|:----------:|:-------:|:--------:|
+[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v1.py)|[inception_v1.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_23.tar.gz)|69.8|89.6|
+[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v2.py)|[inception_v2.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_23.tar.gz)|73.9|91.8|
+[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py)|[inception_v3.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_23.tar.gz)|78.0|93.9|
+[ResNet 50](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_50.tar.gz](http://download.tensorflow.org/models/resnet_v1_50_2016_08_23.tar.gz)|75.2|92.2|
+[ResNet 101](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_101.tar.gz](http://download.tensorflow.org/models/resnet_v1_101_2016_08_23.tar.gz)|76.4|92.9|
+[ResNet 152](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_152.tar.gz](http://download.tensorflow.org/models/resnet_v1_152_2016_08_23.tar.gz)|76.8|93.2|
+[VGG 16](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_16.tar.gz](http://download.tensorflow.org/models/vgg_16_2016_08_23.tar.gz)|71.5|89.8|
+[VGG 19](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_19.tar.gz](http://download.tensorflow.org/models/vgg_19_2016_08_23.tar.gz)|71.1|89.8|
+
+
+# Training a model from scratch.
+
+**WARNING** Training a neural network network from scratch is a computationally
+intensive task and depending on your compute setup may take days, weeks or even
+months.
+
+The training script provided allows users to train one of several architecures
+using one of a variety of optimizers on one of several datasets. Each of these
+choices is configurable and datasets can be added by creating a
+[slim.Dataset](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/data/dataset.py)
+specification and using it in the place of one of those provided.
+
+The following example demonstrates how to train Inception-V3 using SGD with
+Momentum on the ImageNet dataset.
+
+```shell
+# Specify the directory where the dataset is stored.
+DATASET_DIR=$HOME/imagenet
+
+# Specify the directory where the training logs are stored:
+TRAIN_DIR=$HOME/train_logs
+
+# Build the training script.
+$ bazel build slim/train
+
+# run it
+$ bazel-bin/slim/train \
+    --train_dir=${TRAIN_DIR} \
+    --dataset_name=imagenet \
+    --dataset_split_name=train \
+    --dataset_dir=${DATASET_DIR} \
+    --model_name=inception_v3
+```
+
+# Fine-tuning a model from an existing checkpoint
+
+Rather than training from scratch, we'll often want to start from a pre-trained
+model and fine-tune it.
+
+To indicate a checkpoint from which to fine-tune, we'll call training with
+the `--checkpoint_path` flag and assign it an absolute path to a checkpoint
+file.
+
+When fine-tuning a model, we need to be careful about restoring checkpoint
+weights. In particular, when we fine-tune a model on a new task with a different
+number of output labels, we wont be able restore the final logits (classifier)
+layer. For this, we'll use the `--checkpoint_exclude_scopes` flag. This flag
+hinders certain variables from being loaded. When fine-tuning on a
+classification task using a different number of classes than the trained model,
+the new model will have a final 'logits' layer whose dimensions differ from the
+pre-trained model. For example, if fine-tuning an ImageNet-trained model on
+Cifar10, the pre-trained logits layer will have dimensions `[2048 x 1001]` but
+our new logits layer will have dimensions `[2048 x 10]`. Consequently, this
+flag indicates to TF-Slim to avoid loading these weights from the checkpoint.
+
+Keep in mind that warm-starting from a checkpoint affects the model's weights
+only during the initialization of the model. Once a model has started training,
+a new checkpoint will be created in `${TRAIN_DIR}`. If the fine-tuning
+training is stopped and restarted, this new checkpoint will be the one from
+which weights are restored and not the `${checkpoint_path}$`. Consequently,
+the flags `--checkpoint_path` and `--checkpoint_exclude_scopes` are only used
+during the `0-`th global step (model initialization).
+
+```shell
+# Specify the directory where the dataset is stored.
+$ DATASET_DIR=$HOME/imagenet
+
+# Specify the directory where the training logs are stored:
+$ TRAIN_DIR=$HOME/train_logs
+
+# Specify the directory where the pre-trained model checkpoint was saved to:
+$ CHECKPOINT_PATH=$HOME/my_checkpoints/inception_v3.ckpt
+
+# Build the training script.
+$ bazel build slim/train
+
+# Run training. Use --checkpoint_exclude_scopes to avoid loading the weights
+# associated with the logits and auxiliary logits fully connected layers.
+$ bazel-bin/slim/train \
+    --train_dir=${TRAIN_DIR} \
+    --dataset_dir=${DATASET_DIR} \
+    --dataset_name=cifar10 \
+    --dataset_split_name=train \
+    --model_name=inception_v3 \
+    --checkpoint_path=${CHECKPOINT_PATH} \
+    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
+```
+
+
+## Evaluating the provided Checkpoints:
+
+To evaluate the checkpoints provided with this release, one need only download
+the checkpoints and run the evaluation script.
+
+Note that the provided checkpoints contain the model's weights only. They do
+not contain variables associated with training, such as weight's moving averages
+or the global step. Consequently, when evaluating one of the pre-trained
+checkpoint files, one must specify the flag `--restore_global_step=False` to
+indicate to the evaluation routine to avoid attempting to load a global step
+from the checkpoint file that doesn't contain one.
+
+```shell
+# Specify and create the directory containing the checkpoints:
+$ CHECKPOINT_DIR=/tmp/checkpoints
+$ mkdir ${CHECKPOINT_DIR}
+
+# Download, extract and copy the checkpoint file over:
+$ wget http://download.tensorflow.org/models/inception_v1_2016_08_23.tar.gz
+$ tar -xvf inception_v1_2016_08_23.tar.gz
+$ mv inception_v1.ckpt ${CHECKPOINT_DIR}
+$ rm inception_v1_2016_08_23.tar.gz
+
+# Specify the directory where the dataset is stored.
+$ DATASET_DIR=$HOME/imagenet
+
+# Compile the evaluation script:
+$ bazel build slim/eval
+
+# Run the evaluation script. Note that since the pre-trained checkpoints
+# provided do not contain a global step, we need to instruct the evaluation
+# routine not to attempt to load the global step.
+$ ./bazel-bin/slim/eval \
+    --alsologtostderr \
+    --checkpoint_path=${CHECKPOINT_DIR}/inception_v1.ckpt \
+    --dataset_dir=${DATASET_DIR} \
+    --dataset_name=imagenet \
+    --dataset_split_name=validation \
+    --model_name=inception_v1 \
+    --restore_global_step=False
+```
+
+# Troubleshooting
+
+#### The model runs out of CPU memory.
+
+See
+[Model Runs out of CPU memory](https://github.com/tensorflow/models/tree/master/inception#the-model-runs-out-of-cpu-memory).
+
+#### The model runs out of GPU memory.
+
+See
+[Adjusting Memory Demands](https://github.com/tensorflow/models/tree/master/inception#adjusting-memory-demands).
+
+#### The model training results in NaN's.
+
+See
+[Model Resulting in NaNs](https://github.com/tensorflow/models/tree/master/inception#the-model-training-results-in-nans).
+
+#### The ResNet and VGG Models have 1000 classes but the ImageNet dataset has 1001
+
+The ImageNet dataset provied has an additional background class which was used
+to help train Inception. If you try training or fine-tuning the VGG or ResNet
+models using the ImageNet dataset, you might encounter the following error:
+
+```bash
+InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [1000]
+```
+This is due to the fact that the VGG and ResNet final layers have only 1000
+outputs rather than 1001.
+
+To fix this issue, you can set the `--labels_offsets=1` flag. This results in
+the ImageNet labels being shifted down by one:
+
+```bash
+./bazel-bin/slim/train \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_dir=${DATASET_DIR} \
+  --dataset_name=imagenet \
+  --dataset_split_name=train \
+  --model_name=resnet_v1_50 \
+  --checkpoint_path=${CHECKPOINT_PATH}
+  --labels_offset=1
+```
+
+#### I wish to train a model with a different image size.
+
+The preprocessing functions all take `height` and `width` as parameters. You
+can change the default values using the following snippet:
+
+```python
+image_preprocessing_fn = preprocessing_factory.get_preprocessing(
+    preprocessing_name,
+    height=MY_NEW_HEIGHT,
+    width=MY_NEW_WIDTH,
+    is_training=True)
+```
+
+#### What hardware specification are these hyper-parameters targeted for?
+
+See
+[Hardware Specifications](https://github.com/tensorflow/models/tree/master/inception#what-hardware-specification-are-these-hyper-parameters-targeted-for).
+

+ 98 - 0
slim/datasets/cifar10.py

@@ -0,0 +1,98 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides data for the Cifar10 dataset.
+
+The dataset scripts used to create the dataset can be found at:
+tensorflow/models/slim/data/create_cifar10_dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+slim = tf.contrib.slim
+
+_FILE_PATTERN = 'cifar10_%s.tfrecord'
+
+SPLITS_TO_SIZES = {'train': 50000, 'test': 10000}
+
+_NUM_CLASSES = 10
+
+_ITEMS_TO_DESCRIPTIONS = {
+    'image': 'A [32 x 32 x 3] color image.',
+    'label': 'A single integer between 0 and 9',
+}
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+  """Gets a dataset tuple with instructions for reading cifar10.
+
+  Args:
+    split_name: A train/test split name.
+    dataset_dir: The base directory of the dataset sources.
+    file_pattern: The file pattern to use when matching the dataset sources.
+      It is assumed that the pattern contains a '%s' string so that the split
+      name can be inserted.
+    reader: The TensorFlow reader type.
+
+  Returns:
+    A `Dataset` namedtuple.
+
+  Raises:
+    ValueError: if `split_name` is not a valid train/test split.
+  """
+  if split_name not in SPLITS_TO_SIZES:
+    raise ValueError('split name %s was not recognized.' % split_name)
+
+  if not file_pattern:
+    file_pattern = _FILE_PATTERN
+  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+  # Allowing None in the signature so that dataset_factory can use the default.
+  if not reader:
+    reader = tf.TFRecordReader
+
+  keys_to_features = {
+      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
+      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
+      'image/class/label': tf.FixedLenFeature(
+          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
+  }
+
+  items_to_handlers = {
+      'image': slim.tfexample_decoder.Image(shape=[32, 32, 3]),
+      'label': slim.tfexample_decoder.Tensor('image/class/label'),
+  }
+
+  decoder = slim.tfexample_decoder.TFExampleDecoder(
+      keys_to_features, items_to_handlers)
+
+  labels_to_names = None
+  if dataset_utils.has_labels(dataset_dir):
+    labels_to_names = dataset_utils.read_label_file(dataset_dir)
+
+  return slim.dataset.Dataset(
+      data_sources=file_pattern,
+      reader=reader,
+      decoder=decoder,
+      num_samples=SPLITS_TO_SIZES[split_name],
+      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+      num_classes=_NUM_CLASSES,
+      labels_to_names=labels_to_names)

+ 57 - 0
slim/datasets/dataset_factory.py

@@ -0,0 +1,57 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A factory-pattern class which returns classification image/label pairs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from slim.datasets import cifar10
+from slim.datasets import flowers
+from slim.datasets import imagenet
+from slim.datasets import mnist
+
+datasets_map = {
+    'cifar10': cifar10,
+    'flowers': flowers,
+    'imagenet': imagenet,
+    'mnist': mnist,
+}
+
+
+def get_dataset(name, split_name, dataset_dir, file_pattern=None, reader=None):
+  """Given a dataset name and a split_name returns a Dataset.
+
+  Args:
+    name: String, the name of the dataset.
+    split_name: A train/test split name.
+    dataset_dir: The directory where the dataset files are stored.
+    file_pattern: The file pattern to use for matching the dataset source files.
+    reader: The subclass of tf.ReaderBase. If left as `None`, then the default
+      reader defined by each dataset is used.
+
+  Returns:
+    A `Dataset` class.
+
+  Raises:
+    ValueError: If the dataset `name` is unknown.
+  """
+  if name not in datasets_map:
+    raise ValueError('Name of dataset unknown %s' % name)
+  return datasets_map[name].get_split(
+      split_name,
+      dataset_dir,
+      file_pattern,
+      reader)

+ 111 - 0
slim/datasets/dataset_utils.py

@@ -0,0 +1,111 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains utilities for downloading and converting datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+LABELS_FILENAME = 'labels.txt'
+
+
+def int64_feature(values):
+  """Returns a TF-Feature of int64s.
+
+  Args:
+    values: A scalar or list of values.
+
+  Returns:
+    a TF-Feature.
+  """
+  if not isinstance(values, (tuple, list)):
+    values = [values]
+  return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
+
+
+def bytes_feature(values):
+  """Returns a TF-Feature of bytes.
+
+  Args:
+    values: A string.
+
+  Returns:
+    a TF-Feature.
+  """
+  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
+
+
+def image_to_tfexample(image_data, image_format, height, width, class_id):
+  return tf.train.Example(features=tf.train.Features(feature={
+      'image/encoded': bytes_feature(image_data),
+      'image/format': bytes_feature(image_format),
+      'image/class/label': int64_feature(class_id),
+      'image/height': int64_feature(height),
+      'image/width': int64_feature(width),
+  }))
+
+
+def write_label_file(labels_to_class_names, dataset_dir,
+                     filename=LABELS_FILENAME):
+  """Writes a file with the list of class names.
+
+  Args:
+    labels_to_class_names: A map of (integer) labels to class names.
+    dataset_dir: The directory in which the labels file should be written.
+    filename: The filename where the class names are written.
+  """
+  labels_filename = os.path.join(dataset_dir, filename)
+  with tf.gfile.Open(labels_filename, 'w') as f:
+    for label in labels_to_class_names:
+      class_name = labels_to_class_names[label]
+      f.write('%d:%s\n' % (label, class_name))
+
+
+def has_labels(dataset_dir, filename=LABELS_FILENAME):
+  """Specifies whether or not the dataset directory contains a label map file.
+
+  Args:
+    dataset_dir: The directory in which the labels file is found.
+    filename: The filename where the class names are written.
+
+  Returns:
+    `True` if the labels file exists and `False` otherwise.
+  """
+  return tf.gfile.Exists(os.path.join(dataset_dir, filename))
+
+
+def read_label_file(dataset_dir, filename=LABELS_FILENAME):
+  """Reads the labels file and returns a mapping from ID to class name.
+
+  Args:
+    dataset_dir: The directory in which the labels file is found.
+    filename: The filename where the class names are written.
+
+  Returns:
+    A map from a label (integer) to class name.
+  """
+  labels_filename = os.path.join(dataset_dir, filename)
+  with tf.gfile.Open(labels_filename, 'r') as f:
+    lines = f.read()
+  lines = lines.split('\n')
+  lines = filter(None, lines)
+
+  labels_to_class_names = {}
+  for line in lines:
+    index = line.index(':')
+    labels_to_class_names[int(line[:index])] = line[index+1:]
+  return labels_to_class_names

+ 200 - 0
slim/datasets/download_and_convert_cifar10.py

@@ -0,0 +1,200 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos.
+
+This script downloads the cifar10 data, uncompresses it, reads the files
+that make up the cifar10 data and creates two TFRecord datasets: one for train
+and one for test. Each TFRecord dataset is comprised of a set of TF-Example
+protocol buffers, each of which contain a single image and label.
+
+The script should take several minutes to run.
+
+Usage:
+$ bazel build slim:download_and_convert_cifar10
+$ .bazel-bin/slim/download_and_convert_cifar10 --dataset_dir=[DIRECTORY]
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cPickle
+import os
+import sys
+import tarfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir',
+    None,
+    'The directory where the output TFRecords and temporary files are saved.')
+
+FLAGS = tf.app.flags.FLAGS
+
+# The URL where the CIFAR data can be downloaded.
+_DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
+
+# The number of training files.
+_NUM_TRAIN_FILES = 5
+
+# The height and width of each image.
+_IMAGE_SIZE = 32
+
+# The names of the classes.
+_CLASS_NAMES = [
+    'airplane',
+    'automobile',
+    'bird',
+    'cat',
+    'deer',
+    'dog',
+    'frog',
+    'horse',
+    'ship',
+    'truck',
+]
+
+
+def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
+  """Loads data from the cifar10 pickle files and writes files to a TFRecord.
+
+  Args:
+    filename: The filename of the cifar10 pickle file.
+    tfrecord_writer: The TFRecord writer to use for writing.
+    offset: An offset into the absolute number of images previously written.
+
+  Returns:
+    The new offset.
+  """
+  with tf.gfile.Open(filename, 'r') as f:
+    data = cPickle.load(f)
+
+  images = data['data']
+  num_images = images.shape[0]
+
+  images = images.reshape((num_images, 3, 32, 32))
+  labels = data['labels']
+
+  with tf.Graph().as_default():
+    image_placeholder = tf.placeholder(dtype=tf.uint8)
+    encoded_image = tf.image.encode_png(image_placeholder)
+
+    with tf.Session('') as sess:
+
+      for j in range(num_images):
+        sys.stdout.write('\r>> Reading file [%s] image %d/%d' % (
+            filename, offset + j + 1, offset + num_images))
+        sys.stdout.flush()
+
+        image = np.squeeze(images[j]).transpose((1, 2, 0))
+        label = labels[j]
+
+        png_string = sess.run(encoded_image,
+                              feed_dict={image_placeholder: image})
+
+        example = dataset_utils.image_to_tfexample(
+            png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label)
+        tfrecord_writer.write(example.SerializeToString())
+
+  return offset + num_images
+
+
+def _get_output_filename(split_name):
+  """Creates the output filename.
+
+  Args:
+    split_name: The name of the train/test split.
+
+  Returns:
+    An absolute file path.
+  """
+  return '%s/cifar10_%s.tfrecord' % (FLAGS.dataset_dir, split_name)
+
+
+def _download_and_uncompress_dataset(dataset_dir):
+  """Downloads cifar10 and uncompresses it locally.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  filename = _DATA_URL.split('/')[-1]
+  filepath = os.path.join(dataset_dir, filename)
+
+  if not os.path.exists(filepath):
+    def _progress(count, block_size, total_size):
+      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+          filename, float(count * block_size) / float(total_size) * 100.0))
+      sys.stdout.flush()
+    filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress)
+    print()
+    statinfo = os.stat(filepath)
+    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
+
+
+def _clean_up_temporary_files(dataset_dir):
+  """Removes temporary files used to create the dataset.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  filename = _DATA_URL.split('/')[-1]
+  filepath = os.path.join(dataset_dir, filename)
+  tf.gfile.Remove(filepath)
+
+  tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py')
+  tf.gfile.DeleteRecursively(tmp_dir)
+
+
+def main(_):
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  if not tf.gfile.Exists(FLAGS.dataset_dir):
+    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+
+  _download_and_uncompress_dataset(FLAGS.dataset_dir)
+
+  # First, process the training data:
+  output_file = _get_output_filename('train')
+  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
+    offset = 0
+    for i in range(_NUM_TRAIN_FILES):
+      filename = os.path.join(FLAGS.dataset_dir,
+                              'cifar-10-batches-py',
+                              'data_batch_%d' % (i + 1))  # 1-indexed.
+      offset = _add_to_tfrecord(filename, tfrecord_writer, offset)
+
+  # Next, process the testing data:
+  output_file = _get_output_filename('test')
+  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
+    filename = os.path.join(FLAGS.dataset_dir,
+                            'cifar-10-batches-py',
+                            'test_batch')
+    _add_to_tfrecord(filename, tfrecord_writer)
+
+  # Finally, write the labels file:
+  labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
+  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+
+  _clean_up_temporary_files(FLAGS.dataset_dir)
+  print('\nFinished converting the Cifar10 dataset!')
+
+if __name__ == '__main__':
+  tf.app.run()

+ 227 - 0
slim/datasets/download_and_convert_flowers.py

@@ -0,0 +1,227 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Downloads and converts Flowers data to TFRecords of TF-Example protos.
+
+This script downloads the Flowers data, uncompresses it, reads the files
+that make up the Flowers data and creates two TFRecord datasets: one for train
+and one for test. Each TFRecord dataset is comprised of a set of TF-Example
+protocol buffers, each of which contain a single image and label.
+
+The script should take about a minute to run.
+
+Usage:
+
+$ bazel build slim:download_and_convert_flowers
+$ .bazel-bin/slim/download_and_convert_flowers --dataset_dir=[DIRECTORY]
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import os
+import random
+import sys
+import tarfile
+
+from six.moves import urllib
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir',
+    None,
+    'The directory where the output TFRecords and temporary files are saved.')
+
+FLAGS = tf.app.flags.FLAGS
+
+# The URL where the Flowers data can be downloaded.
+_DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
+
+# The number of images in the validation set.
+_NUM_VALIDATION = 350
+
+# Seed for repeatability.
+_RANDOM_SEED = 0
+
+# The number of shards per dataset split.
+_NUM_SHARDS = 5
+
+
+class ImageReader(object):
+  """Helper class that provides TensorFlow image coding utilities."""
+
+  def __init__(self):
+    # Initializes function that decodes RGB JPEG data.
+    self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
+    self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
+
+  def read_image_dims(self, sess, image_data):
+    image = self.decode_jpeg(sess, image_data)
+    return image.shape[0], image.shape[1]
+
+  def decode_jpeg(self, sess, image_data):
+    image = sess.run(self._decode_jpeg,
+                     feed_dict={self._decode_jpeg_data: image_data})
+    assert len(image.shape) == 3
+    assert image.shape[2] == 3
+    return image
+
+
+def _download_dataset(dataset_dir):
+  """Downloads the flowers data and uncompresses it locally.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  filename = _DATA_URL.split('/')[-1]
+  filepath = os.path.join(dataset_dir, filename)
+
+  if not os.path.exists(filepath):
+    def _progress(count, block_size, total_size):
+      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+          filename, float(count * block_size) / float(total_size) * 100.0))
+      sys.stdout.flush()
+    filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress)
+    print()
+    statinfo = os.stat(filepath)
+    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
+
+
+def _get_filenames_and_classes(dataset_dir):
+  """Returns a list of filenames and inferred class names.
+
+  Args:
+    dataset_dir: A directory containing a set of subdirectories representing
+      class names. Each subdirectory should contain PNG or JPG encoded images.
+
+  Returns:
+    A list of image file paths, relative to `dataset_dir` and the list of
+    subdirectories, representing class names.
+  """
+  flower_root = os.path.join(dataset_dir, 'flower_photos')
+  directories = []
+  class_names = []
+  for filename in os.listdir(flower_root):
+    path = os.path.join(flower_root, filename)
+    if os.path.isdir(path):
+      directories.append(path)
+      class_names.append(filename)
+
+  photo_filenames = []
+  for directory in directories:
+    for filename in os.listdir(directory):
+      path = os.path.join(directory, filename)
+      photo_filenames.append(path)
+
+  return photo_filenames, sorted(class_names)
+
+
+def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
+  """Converts the given filenames to a TFRecord dataset.
+
+  Args:
+    split_name: The name of the dataset, either 'train' or 'validation'.
+    filenames: A list of absolute paths to png or jpg images.
+    class_names_to_ids: A dictionary from class names (strings) to ids
+      (integers).
+    dataset_dir: The directory where the converted datasets are stored.
+  """
+  assert split_name in ['train', 'validation']
+
+  num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
+
+  with tf.Graph().as_default():
+    image_reader = ImageReader()
+
+    with tf.Session('') as sess:
+
+      for shard_id in range(_NUM_SHARDS):
+        output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
+            split_name, shard_id, _NUM_SHARDS)
+        output_filename = os.path.join(dataset_dir, output_filename)
+
+        with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
+          start_ndx = shard_id * num_per_shard
+          end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
+          for i in range(start_ndx, end_ndx):
+            sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
+                i+1, len(filenames), shard_id))
+            sys.stdout.flush()
+
+            # Read the filename:
+            image_data = tf.gfile.FastGFile(filenames[i], 'r').read()
+            height, width = image_reader.read_image_dims(sess, image_data)
+
+            class_name = os.path.basename(os.path.dirname(filenames[i]))
+            class_id = class_names_to_ids[class_name]
+
+            example = dataset_utils.image_to_tfexample(
+                image_data, 'jpg', height, width, class_id)
+            tfrecord_writer.write(example.SerializeToString())
+
+  sys.stdout.write('\n')
+  sys.stdout.flush()
+
+
+def _clean_up_temporary_files(dataset_dir):
+  """Removes temporary files used to create the dataset.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  filename = _DATA_URL.split('/')[-1]
+  filepath = os.path.join(dataset_dir, filename)
+  tf.gfile.Remove(filepath)
+
+  tmp_dir = os.path.join(dataset_dir, 'flower_photos')
+  tf.gfile.DeleteRecursively(tmp_dir)
+
+
+def main(_):
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  if not tf.gfile.Exists(FLAGS.dataset_dir):
+    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+
+  _download_dataset(FLAGS.dataset_dir)
+  photo_filenames, class_names = _get_filenames_and_classes(FLAGS.dataset_dir)
+  class_names_to_ids = dict(zip(class_names, range(len(class_names))))
+
+  # Divide into train and test:
+  random.seed(_RANDOM_SEED)
+  random.shuffle(photo_filenames)
+  training_filenames = photo_filenames[_NUM_VALIDATION:]
+  validation_filenames = photo_filenames[:_NUM_VALIDATION]
+
+  # First, convert the training and validation sets.
+  _convert_dataset('train', training_filenames, class_names_to_ids,
+                   FLAGS.dataset_dir)
+  _convert_dataset('validation', validation_filenames, class_names_to_ids,
+                   FLAGS.dataset_dir)
+
+  # Finally, write the labels file:
+  labels_to_class_names = dict(zip(range(len(class_names)), class_names))
+  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+
+  _clean_up_temporary_files(FLAGS.dataset_dir)
+  print('\nFinished converting the Flowers dataset!')
+
+if __name__ == '__main__':
+  tf.app.run()

+ 227 - 0
slim/datasets/download_and_convert_mnist.py

@@ -0,0 +1,227 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
+
+This script downloads the MNIST data, uncompresses it, reads the files
+that make up the MNIST data and creates two TFRecord datasets: one for train
+and one for test. Each TFRecord dataset is comprised of a set of TF-Example
+protocol buffers, each of which contain a single image and label.
+
+The script should take about a minute to run.
+
+Usage:
+
+$ bazel build slim:download_and_convert_mnist
+$ .bazel-bin/slim/download_and_convert_mnist --dataset_dir=[DIRECTORY]
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import sys
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir',
+    None,
+    'The directory where the output TFRecords and temporary files are saved.')
+
+FLAGS = tf.app.flags.FLAGS
+
+# The URLs where the MNIST data can be downloaded.
+_DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
+_TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
+_TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
+_TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
+_TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
+
+_IMAGE_SIZE = 28
+_NUM_CHANNELS = 1
+
+# The names of the classes.
+_CLASS_NAMES = [
+    'zero',
+    'one',
+    'two',
+    'three',
+    'four',
+    'five',
+    'size',
+    'seven',
+    'eight',
+    'nine',
+]
+
+
+def _extract_images(filename, num_images):
+  """Extract the images into a numpy array.
+
+  Args:
+    filename: The path to an MNIST images file.
+    num_images: The number of images in the file.
+
+  Returns:
+    A numpy array of shape [number_of_images, height, width, channels].
+  """
+  print('Extracting images from: ', filename)
+  with gzip.open(filename) as bytestream:
+    bytestream.read(16)
+    buf = bytestream.read(
+        _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
+    data = np.frombuffer(buf, dtype=np.uint8)
+    data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
+  return data
+
+
+def _extract_labels(filename, num_labels):
+  """Extract the labels into a vector of int64 label IDs.
+
+  Args:
+    filename: The path to an MNIST labels file.
+    num_labels: The number of labels in the file.
+
+  Returns:
+    A numpy array of shape [number_of_labels]
+  """
+  print('Extracting labels from: ', filename)
+  with gzip.open(filename) as bytestream:
+    bytestream.read(8)
+    buf = bytestream.read(1 * num_labels)
+    labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
+  return labels
+
+
+def _add_to_tfrecord(data_filename, labels_filename, num_images,
+                     tfrecord_writer):
+  """Loads data from the binary MNIST files and writes files to a TFRecord.
+
+  Args:
+    data_filename: The filename of the MNIST images.
+    labels_filename: The filename of the MNIST labels.
+    num_images: The number of images in the dataset.
+    tfrecord_writer: The TFRecord writer to use for writing.
+  """
+  images = _extract_images(data_filename, num_images)
+  labels = _extract_labels(labels_filename, num_images)
+
+  shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
+  with tf.Graph().as_default():
+    image = tf.placeholder(dtype=tf.uint8, shape=shape)
+    encoded_png = tf.image.encode_png(image)
+
+    with tf.Session('') as sess:
+      for j in range(num_images):
+        sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
+        sys.stdout.flush()
+
+        png_string = sess.run(encoded_png, feed_dict={image: images[j]})
+
+        example = dataset_utils.image_to_tfexample(
+            png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, labels[j])
+        tfrecord_writer.write(example.SerializeToString())
+
+
+def _get_output_filename(split_name):
+  """Creates the output filename.
+
+  Args:
+    split_name: The name of the train/test split.
+
+  Returns:
+    An absolute file path.
+  """
+  return '%s/mnist_%s.tfrecord' % (FLAGS.dataset_dir, split_name)
+
+
+def _download_dataset(dataset_dir):
+  """Downloads MNIST locally.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  for filename in [_TRAIN_DATA_FILENAME,
+                   _TRAIN_LABELS_FILENAME,
+                   _TEST_DATA_FILENAME,
+                   _TEST_LABELS_FILENAME]:
+    filepath = os.path.join(dataset_dir, filename)
+
+    if not os.path.exists(filepath):
+      print('Downloading file %s...' % filename)
+      def _progress(count, block_size, total_size):
+        sys.stdout.write('\r>> Downloading %.1f%%' % (
+            float(count * block_size) / float(total_size) * 100.0))
+        sys.stdout.flush()
+      filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
+                                               filepath,
+                                               _progress)
+      print()
+      with tf.gfile.GFile(filepath) as f:
+        size = f.Size()
+      print('Successfully downloaded', filename, size, 'bytes.')
+
+
+def _clean_up_temporary_files(dataset_dir):
+  """Removes temporary files used to create the dataset.
+
+  Args:
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  for filename in [_TRAIN_DATA_FILENAME,
+                   _TRAIN_LABELS_FILENAME,
+                   _TEST_DATA_FILENAME,
+                   _TEST_LABELS_FILENAME]:
+    filepath = os.path.join(dataset_dir, filename)
+    tf.gfile.Remove(filepath)
+
+
+def main(_):
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  if not tf.gfile.Exists(FLAGS.dataset_dir):
+    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+
+  _download_dataset(FLAGS.dataset_dir)
+
+  # First, process the training data:
+  output_file = _get_output_filename('train')
+  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
+    data_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_DATA_FILENAME)
+    labels_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_LABELS_FILENAME)
+    _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
+
+  # Next, process the testing data:
+  output_file = _get_output_filename('test')
+  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
+    data_filename = os.path.join(FLAGS.dataset_dir, _TEST_DATA_FILENAME)
+    labels_filename = os.path.join(FLAGS.dataset_dir, _TEST_LABELS_FILENAME)
+    _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
+
+  # Finally, write the labels file:
+  labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
+  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+
+  _clean_up_temporary_files(FLAGS.dataset_dir)
+  print('\nFinished converting the MNIST dataset!')
+
+if __name__ == '__main__':
+  tf.app.run()

+ 98 - 0
slim/datasets/flowers.py

@@ -0,0 +1,98 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides data for the Cifar10 dataset.
+
+The dataset scripts used to create the dataset can be found at:
+tensorflow/models/slim/data/create_cifar10_dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+slim = tf.contrib.slim
+
+_FILE_PATTERN = 'flowers_%s_*.tfrecord'
+
+SPLITS_TO_SIZES = {'train': 3320, 'validation': 350}
+
+_NUM_CLASSES = 5
+
+_ITEMS_TO_DESCRIPTIONS = {
+    'image': 'A color image of varying size.',
+    'label': 'A single integer between 0 and 4',
+}
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+  """Gets a dataset tuple with instructions for reading cifar10.
+
+  Args:
+    split_name: A train/validation split name.
+    dataset_dir: The base directory of the dataset sources.
+    file_pattern: The file pattern to use when matching the dataset sources.
+      It is assumed that the pattern contains a '%s' string so that the split
+      name can be inserted.
+    reader: The TensorFlow reader type.
+
+  Returns:
+    A `Dataset` namedtuple.
+
+  Raises:
+    ValueError: if `split_name` is not a valid train/validation split.
+  """
+  if split_name not in SPLITS_TO_SIZES:
+    raise ValueError('split name %s was not recognized.' % split_name)
+
+  if not file_pattern:
+    file_pattern = _FILE_PATTERN
+  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+  # Allowing None in the signature so that dataset_factory can use the default.
+  if reader is None:
+    reader = tf.TFRecordReader
+
+  keys_to_features = {
+      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
+      'image/format': tf.FixedLenFeature((), tf.string, default_value='png'),
+      'image/class/label': tf.FixedLenFeature(
+          [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)),
+  }
+
+  items_to_handlers = {
+      'image': slim.tfexample_decoder.Image(),
+      'label': slim.tfexample_decoder.Tensor('image/class/label'),
+  }
+
+  decoder = slim.tfexample_decoder.TFExampleDecoder(
+      keys_to_features, items_to_handlers)
+
+  labels_to_names = None
+  if dataset_utils.has_labels(dataset_dir):
+    labels_to_names = dataset_utils.read_label_file(dataset_dir)
+
+  return slim.dataset.Dataset(
+      data_sources=file_pattern,
+      reader=reader,
+      decoder=decoder,
+      num_samples=SPLITS_TO_SIZES[split_name],
+      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+      num_classes=_NUM_CLASSES,
+      labels_to_names=labels_to_names)

+ 128 - 0
slim/datasets/imagenet.py

@@ -0,0 +1,128 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides data for the ImageNet ILSVRC 2012 Dataset plus some bounding boxes.
+
+Some images have one or more bounding boxes associated with the label of the
+image. See details here: http://image-net.org/download-bboxes
+
+ImageNet is based upon WordNet 3.0. To uniquely identify a synset, we use
+"WordNet ID" (wnid), which is a concatenation of POS ( i.e. part of speech )
+and SYNSET OFFSET of WordNet. For more information, please refer to the
+WordNet documentation[http://wordnet.princeton.edu/wordnet/documentation/].
+
+"There are bounding boxes for over 3000 popular synsets available.
+For each synset, there are on average 150 images with bounding boxes."
+
+WARNING: Don't use for object detection, in this case all the bounding boxes
+of the image belong to just one class.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+# TODO(nsilberman): Add tfrecord file type once the script is updated.
+_FILE_PATTERN = '%s-*'
+
+_SPLITS_TO_SIZES = {
+    'train': 1281167,
+    'validation': 50000,
+}
+
+_ITEMS_TO_DESCRIPTIONS = {
+    'image': 'A color image of varying height and width.',
+    'label': 'The label id of the image, integer between 0 and 999',
+    'label_text': 'The text of the label.',
+    'object/bbox': 'A list of bounding boxes.',
+    'object/label': 'A list of labels, one per each object.',
+}
+
+_NUM_CLASSES = 1001
+
+# TODO(nsilberman): Add _LABELS_TO_NAMES
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+  """Gets a dataset tuple with instructions for reading ImageNet.
+
+  Args:
+    split_name: A train/test split name.
+    dataset_dir: The base directory of the dataset sources.
+    file_pattern: The file pattern to use when matching the dataset sources.
+      It is assumed that the pattern contains a '%s' string so that the split
+      name can be inserted.
+    reader: The TensorFlow reader type.
+
+  Returns:
+    A `Dataset` namedtuple.
+
+  Raises:
+    ValueError: if `split_name` is not a valid train/test split.
+  """
+  if split_name not in _SPLITS_TO_SIZES:
+    raise ValueError('split name %s was not recognized.' % split_name)
+
+  if not file_pattern:
+    file_pattern = _FILE_PATTERN
+  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+  # Allowing None in the signature so that dataset_factory can use the default.
+  if reader is None:
+    reader = tf.TFRecordReader
+
+  keys_to_features = {
+      'image/encoded': tf.FixedLenFeature(
+          (), tf.string, default_value=''),
+      'image/format': tf.FixedLenFeature(
+          (), tf.string, default_value='jpeg'),
+      'image/class/label': tf.FixedLenFeature(
+          [], dtype=tf.int64, default_value=-1),
+      'image/class/text': tf.FixedLenFeature(
+          [], dtype=tf.string, default_value=''),
+      'image/object/bbox/xmin': tf.VarLenFeature(
+          dtype=tf.float32),
+      'image/object/bbox/ymin': tf.VarLenFeature(
+          dtype=tf.float32),
+      'image/object/bbox/xmax': tf.VarLenFeature(
+          dtype=tf.float32),
+      'image/object/bbox/ymax': tf.VarLenFeature(
+          dtype=tf.float32),
+      'image/object/class/label': tf.VarLenFeature(
+          dtype=tf.int64),
+  }
+
+  items_to_handlers = {
+      'image': slim.tfexample_decoder.Image('image/encoded', 'image/format'),
+      'label': slim.tfexample_decoder.Tensor('image/class/label'),
+      'label_text': slim.tfexample_decoder.Tensor('image/class/text'),
+      'object/bbox': slim.tfexample_decoder.BoundingBox(
+          ['ymin', 'xmin', 'ymax', 'xmax'], 'image/object/bbox/'),
+      'object/label': slim.tfexample_decoder.Tensor('image/object/class/label'),
+  }
+
+  decoder = slim.tfexample_decoder.TFExampleDecoder(
+      keys_to_features, items_to_handlers)
+
+  return slim.dataset.Dataset(
+      data_sources=file_pattern,
+      reader=reader,
+      decoder=decoder,
+      num_samples=_SPLITS_TO_SIZES[split_name],
+      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+      num_classes=_NUM_CLASSES)

+ 98 - 0
slim/datasets/mnist.py

@@ -0,0 +1,98 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides data for the MNIST dataset.
+
+The dataset scripts used to create the dataset can be found at:
+tensorflow/models/slim/data/create_mnist_dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow as tf
+
+from slim.datasets import dataset_utils
+
+slim = tf.contrib.slim
+
+_FILE_PATTERN = 'mnist_%s.tfrecord'
+
+_SPLITS_TO_SIZES = {'train': 60000, 'test': 10000}
+
+_NUM_CLASSES = 10
+
+_ITEMS_TO_DESCRIPTIONS = {
+    'image': 'A [28 x 28 x 1] grayscale image.',
+    'label': 'A single integer between 0 and 9',
+}
+
+
+def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
+  """Gets a dataset tuple with instructions for reading MNIST.
+
+  Args:
+    split_name: A train/test split name.
+    dataset_dir: The base directory of the dataset sources.
+    file_pattern: The file pattern to use when matching the dataset sources.
+      It is assumed that the pattern contains a '%s' string so that the split
+      name can be inserted.
+    reader: The TensorFlow reader type.
+
+  Returns:
+    A `Dataset` namedtuple.
+
+  Raises:
+    ValueError: if `split_name` is not a valid train/test split.
+  """
+  if split_name not in _SPLITS_TO_SIZES:
+    raise ValueError('split name %s was not recognized.' % split_name)
+
+  if not file_pattern:
+    file_pattern = _FILE_PATTERN
+  file_pattern = os.path.join(dataset_dir, file_pattern % split_name)
+
+  # Allowing None in the signature so that dataset_factory can use the default.
+  if reader is None:
+    reader = tf.TFRecordReader
+
+  keys_to_features = {
+      'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''),
+      'image/format': tf.FixedLenFeature((), tf.string, default_value='raw'),
+      'image/class/label': tf.FixedLenFeature(
+          [1], tf.int64, default_value=tf.zeros([1], dtype=tf.int64)),
+  }
+
+  items_to_handlers = {
+      'image': slim.tfexample_decoder.Image(shape=[28, 28, 1], channels=1),
+      'label': slim.tfexample_decoder.Tensor('image/class/label', shape=[]),
+  }
+
+  decoder = slim.tfexample_decoder.TFExampleDecoder(
+      keys_to_features, items_to_handlers)
+
+  labels_to_names = None
+  if dataset_utils.has_labels(dataset_dir):
+    labels_to_names = dataset_utils.read_label_file(dataset_dir)
+
+  return slim.dataset.Dataset(
+      data_sources=file_pattern,
+      reader=reader,
+      decoder=decoder,
+      num_samples=_SPLITS_TO_SIZES[split_name],
+      num_classes=_NUM_CLASSES,
+      items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
+      labels_to_names=labels_to_names)

+ 193 - 0
slim/eval.py

@@ -0,0 +1,193 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Generic evaluation script that trains a given model a specified dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import tensorflow as tf
+
+from slim.datasets import dataset_factory
+from slim.models import model_factory
+from slim.models import preprocessing_factory
+
+slim = tf.contrib.slim
+
+tf.app.flags.DEFINE_integer(
+    'batch_size', 100, 'The number of samples in each batch.')
+
+tf.app.flags.DEFINE_integer(
+    'max_num_batches', None,
+    'Max number of batches to evaluate by default use all.')
+
+tf.app.flags.DEFINE_string(
+    'master', '', 'The address of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string(
+    'checkpoint_path', '/tmp/tfmodel/',
+    'The directory where the model was written to or an absolute path to a '
+    'checkpoint file.')
+
+tf.app.flags.DEFINE_bool(
+    'restore_global_step', True,
+    'Whether or not to restore the global step. When evaluating a model '
+    'checkpoint containing ONLY weights, set this flag to `False`.')
+
+tf.app.flags.DEFINE_string(
+    'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
+
+tf.app.flags.DEFINE_integer(
+    'num_preprocessing_threads', 4,
+    'The number of threads used to create the batches.')
+
+tf.app.flags.DEFINE_string(
+    'dataset_name', 'imagenet', 'The name of the dataset to load.')
+
+tf.app.flags.DEFINE_string(
+    'dataset_split_name', 'train', 'The name of the train/test split.')
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir', None, 'The directory where the dataset files are stored.')
+tf.app.flags.MarkFlagAsRequired('dataset_dir')
+
+tf.app.flags.DEFINE_integer(
+    'labels_offset', 0,
+    'An offset for the labels in the dataset. This flag is primarily used to '
+    'evaluate the VGG and ResNet architectures which do not use a background '
+    'class for the ImageNet dataset.')
+
+tf.app.flags.DEFINE_string(
+    'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
+
+tf.app.flags.DEFINE_string(
+    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
+    'as `None`, then the model_name flag is used.')
+
+tf.app.flags.DEFINE_float(
+    'moving_average_decay', None,
+    'The decay to use for the moving average.'
+    'If left as None, then moving averages are not used.')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def main(_):
+  with tf.Graph().as_default():
+    tf_global_step = slim.get_or_create_global_step()
+
+    ######################
+    # Select the dataset #
+    ######################
+    dataset = dataset_factory.get_dataset(
+        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
+
+    ####################
+    # Select the model #
+    ####################
+    model_fn = model_factory.get_model(
+        FLAGS.model_name,
+        num_classes=(dataset.num_classes - FLAGS.labels_offset),
+        is_training=False)
+
+    ##############################################################
+    # Create a dataset provider that loads data from the dataset #
+    ##############################################################
+    provider = slim.dataset_data_provider.DatasetDataProvider(
+        dataset,
+        shuffle=False,
+        common_queue_capacity=2 * FLAGS.batch_size,
+        common_queue_min=FLAGS.batch_size)
+    [image, label] = provider.get(['image', 'label'])
+    label -= FLAGS.labels_offset
+
+    #####################################
+    # Select the preprocessing function #
+    #####################################
+    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
+    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
+        preprocessing_name,
+        is_training=False)
+
+    image = image_preprocessing_fn(image,
+                                   height=model_fn.default_image_size,
+                                   width=model_fn.default_image_size)
+
+    images, labels = tf.train.batch(
+        [image, label],
+        batch_size=FLAGS.batch_size,
+        num_threads=FLAGS.num_preprocessing_threads,
+        capacity=5 * FLAGS.batch_size)
+
+    ####################
+    # Define the model #
+    ####################
+    logits, _ = model_fn(images)
+
+    if FLAGS.moving_average_decay:
+      variable_averages = tf.train.ExponentialMovingAverage(
+          FLAGS.moving_average_decay, tf_global_step)
+      variables_to_restore = variable_averages.variables_to_restore(
+          slim.get_model_variables())
+
+      if FLAGS.restore_global_step:
+        variables_to_restore[tf_global_step.op.name] = tf_global_step
+    else:
+      exclude = None if FLAGS.restore_global_step else ['global_step']
+      variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
+
+    predictions = tf.argmax(logits, 1)
+    labels = tf.squeeze(labels)
+
+    # Define the metrics:
+    names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
+        'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
+        'Recall@5': slim.metrics.streaming_recall_at_k(
+            logits, labels, 5),
+    })
+
+    # Print the summaries to screen.
+    for name, value in names_to_values.iteritems():
+      summary_name = 'eval/%s' % name
+      op = tf.scalar_summary(summary_name, value, collections=[])
+      op = tf.Print(op, [value], summary_name)
+      tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
+
+    # TODO(sguada) use num_epochs=1
+    if FLAGS.max_num_batches:
+      num_batches = FLAGS.max_num_batches
+    else:
+      # This ensures that we make a single pass over all of the data.
+      num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
+
+    if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
+      checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
+    else:
+      checkpoint_path = FLAGS.checkpoint_path
+
+    tf.logging.info('Evaluating %s' % checkpoint_path)
+
+    slim.evaluation.evaluate_once(
+        FLAGS.master,
+        checkpoint_path,
+        logdir=FLAGS.eval_dir,
+        num_evals=num_batches,
+        eval_op=names_to_updates.values(),
+        variables_to_restore=variables_to_restore)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 114 - 0
slim/models/cifar10_preprocessing.py

@@ -0,0 +1,114 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to preprocess images.
+
+The preprocessing steps for VGG were introduced in the following technical
+report:
+
+  Very Deep Convolutional Networks For Large-Scale Image Recognition
+  Karen Simonyan and Andrew Zisserman
+  arXiv technical report, 2015
+  PDF: http://arxiv.org/pdf/1409.1556.pdf
+  ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
+  CC-BY-4.0
+
+More information can be obtained from the VGG website:
+www.robots.ox.ac.uk/~vgg/research/very_deep/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+_PADDING = 2
+
+slim = tf.contrib.slim
+
+
+def preprocess_for_train(image,
+                         output_height,
+                         output_width,
+                         padding=_PADDING):
+  """Preprocesses the given image for training.
+
+  Note that the actual resizing scale is sampled from
+    [`resize_size_min`, `resize_size_max`].
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    padding: The amound of padding before and after each dimension of the image.
+
+  Returns:
+    A preprocessed image.
+  """
+  padded_image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]])
+  # Randomly crop a [height, width] section of the image.
+  distorted_image = tf.random_crop(padded_image,
+                                   [output_height, output_width, 3])
+
+  # Randomly flip the image horizontally.
+  distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+  # Because these operations are not commutative, consider randomizing
+  # the order their operation.
+  distorted_image = tf.image.random_brightness(distorted_image,
+                                               max_delta=63)
+  distorted_image = tf.image.random_contrast(distorted_image,
+                                             lower=0.2, upper=1.8)
+
+  # Subtract off the mean and divide by the variance of the pixels.
+  return tf.image.per_image_whitening(distorted_image)
+
+
+def preprocess_for_eval(image, output_height, output_width):
+  """Preprocesses the given image for evaluation.
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+
+  Returns:
+    A preprocessed image.
+  """
+  resized_image = tf.image.resize_image_with_crop_or_pad(image,
+                                                         output_width,
+                                                         output_height)
+
+  # Subtract off the mean and divide by the variance of the pixels.
+  return tf.image.per_image_whitening(resized_image)
+
+
+def preprocess_image(image, output_height, output_width, is_training=False):
+  """Preprocesses the given image.
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    is_training: `True` if we're preprocessing the image for training and
+      `False` otherwise.
+
+  Returns:
+    A preprocessed image.
+  """
+  if is_training:
+    return preprocess_for_train(image, output_height, output_width)
+  else:
+    return preprocess_for_eval(image, output_height, output_width)

+ 304 - 0
slim/models/inception_preprocessing.py

@@ -0,0 +1,304 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to preprocess images for the Inception networks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+
+
+def apply_with_random_selector(x, func, num_cases):
+  """Computes func(x, sel), with sel sampled from [0...num_cases-1].
+
+  Args:
+    x: input Tensor.
+    func: Python function to apply.
+    num_cases: Python int32, number of cases to sample sel from.
+
+  Returns:
+    The result of func(x, sel), where func receives the value of the
+    selector as a python integer, but sel is sampled dynamically.
+  """
+  sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
+  # Pass the real x only to one of the func calls.
+  return control_flow_ops.merge([
+      func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
+      for case in range(num_cases)])[0]
+
+
+def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
+  """Distort the color of a Tensor image.
+
+  Each color distortion is non-commutative and thus ordering of the color ops
+  matters. Ideally we would randomly permute the ordering of the color ops.
+  Rather then adding that level of complication, we select a distinct ordering
+  of color ops for each preprocessing thread.
+
+  Args:
+    image: 3-D Tensor containing single image in [0, 1].
+    color_ordering: Python int, a type of distortion (valid values: 0-3).
+    fast_mode: Avoids slower ops (random_hue and random_contrast)
+    scope: Optional scope for name_scope.
+  Returns:
+    3-D Tensor color-distorted image on range [0, 1]
+  Raises:
+    ValueError: if color_ordering not in [0, 3]
+  """
+  with tf.name_scope(scope, 'distort_color', [image]):
+    if fast_mode:
+      if color_ordering == 0:
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+      else:
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+    else:
+      if color_ordering == 0:
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+        image = tf.image.random_hue(image, max_delta=0.2)
+        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+      elif color_ordering == 1:
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+        image = tf.image.random_hue(image, max_delta=0.2)
+      elif color_ordering == 2:
+        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+        image = tf.image.random_hue(image, max_delta=0.2)
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+      elif color_ordering == 3:
+        image = tf.image.random_hue(image, max_delta=0.2)
+        image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
+        image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
+        image = tf.image.random_brightness(image, max_delta=32. / 255.)
+      else:
+        raise ValueError('color_ordering must be in [0, 3]')
+
+    # The random_* ops do not necessarily clamp.
+    return tf.clip_by_value(image, 0.0, 1.0)
+
+
+def distorted_bounding_box_crop(image,
+                                bbox,
+                                min_object_covered=0.1,
+                                aspect_ratio_range=(0.75, 1.33),
+                                area_range=(0.05, 1.0),
+                                max_attempts=100,
+                                scope=None):
+  """Generates cropped_image using a one of the bboxes randomly distorted.
+
+  See `tf.image.sample_distorted_bounding_box` for more documentation.
+
+  Args:
+    image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
+    bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+      where each coordinate is [0, 1) and the coordinates are arranged
+      as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
+      image.
+    min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
+      area of the image must contain at least this fraction of any bounding box
+      supplied.
+    aspect_ratio_range: An optional list of `floats`. The cropped area of the
+      image must have an aspect ratio = width / height within this range.
+    area_range: An optional list of `floats`. The cropped area of the image
+      must contain a fraction of the supplied image within in this range.
+    max_attempts: An optional `int`. Number of attempts at generating a cropped
+      region of the image of the specified constraints. After `max_attempts`
+      failures, return the entire image.
+    scope: Optional scope for name_scope.
+  Returns:
+    A tuple, a 3-D Tensor cropped_image and the distorted bbox
+  """
+  with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
+    # Each bounding box has shape [1, num_boxes, box coords] and
+    # the coordinates are ordered [ymin, xmin, ymax, xmax].
+
+    # A large fraction of image datasets contain a human-annotated bounding
+    # box delineating the region of the image containing the object of interest.
+    # We choose to create a new bounding box for the object which is a randomly
+    # distorted version of the human-annotated bounding box that obeys an
+    # allowed range of aspect ratios, sizes and overlap with the human-annotated
+    # bounding box. If no box is supplied, then we assume the bounding box is
+    # the entire image.
+    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
+        tf.shape(image),
+        bounding_boxes=bbox,
+        min_object_covered=min_object_covered,
+        aspect_ratio_range=aspect_ratio_range,
+        area_range=area_range,
+        max_attempts=max_attempts,
+        use_image_if_no_bounding_boxes=True)
+    bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
+
+    # Crop the image to the specified bounding box.
+    cropped_image = tf.slice(image, bbox_begin, bbox_size)
+    return cropped_image, distort_bbox
+
+
+def preprocess_for_train(image, height, width, bbox,
+                         fast_mode=True,
+                         scope=None):
+  """Distort one image for training a network.
+
+  Distorting images provides a useful technique for augmenting the data
+  set during training in order to make the network invariant to aspects
+  of the image that do not effect the label.
+
+  Additionally it would create image_summaries to display the different
+  transformations applied to the image.
+
+  Args:
+    image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
+      [0, 1], otherwise it would converted to tf.float32 assuming that the range
+      is [0, MAX], where MAX is largest positive representable number for
+      int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
+    height: integer
+    width: integer
+    bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+      where each coordinate is [0, 1) and the coordinates are arranged
+      as [ymin, xmin, ymax, xmax].
+    fast_mode: Optional boolean, if True avoids slower transformations (i.e.
+      bi-cubic resizing, random_hue or random_contrast).
+    scope: Optional scope for name_scope.
+  Returns:
+    3-D float Tensor of distorted image used for training with range [-1, 1].
+  """
+  with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
+    if bbox is None:
+      bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
+                         dtype=tf.float32,
+                         shape=[1, 1, 4])
+    if image.dtype != tf.float32:
+      image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+    # Each bounding box has shape [1, num_boxes, box coords] and
+    # the coordinates are ordered [ymin, xmin, ymax, xmax].
+    image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+                                                  bbox)
+    tf.image_summary('image_with_bounding_boxes', image_with_box)
+
+    distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
+    # Restore the shape since the dynamic slice based upon the bbox_size loses
+    # the third dimension.
+    distorted_image.set_shape([None, None, 3])
+    image_with_distorted_box = tf.image.draw_bounding_boxes(
+        tf.expand_dims(image, 0), distorted_bbox)
+    tf.image_summary('images_with_distorted_bounding_box',
+                     image_with_distorted_box)
+
+    # This resizing operation may distort the images because the aspect
+    # ratio is not respected. We select a resize method in a round robin
+    # fashion based on the thread number.
+    # Note that ResizeMethod contains 4 enumerated resizing methods.
+
+    # We select only 1 case for fast_mode bilinear.
+    num_resize_cases = 1 if fast_mode else 4
+    distorted_image = apply_with_random_selector(
+        distorted_image,
+        lambda x, method: tf.image.resize_images(x, height, width, method),
+        num_cases=num_resize_cases)
+
+    tf.image_summary('cropped_resized_image',
+                     tf.expand_dims(distorted_image, 0))
+
+    # Randomly flip the image horizontally.
+    distorted_image = tf.image.random_flip_left_right(distorted_image)
+
+    # Randomly distort the colors. There are 4 ways to do it.
+    distorted_image = apply_with_random_selector(
+        distorted_image,
+        lambda x, ordering: distort_color(x, ordering, fast_mode),
+        num_cases=4)
+
+    tf.image_summary('final_distorted_image',
+                     tf.expand_dims(distorted_image, 0))
+    distorted_image = tf.sub(distorted_image, 0.5)
+    distorted_image = tf.mul(distorted_image, 2.0)
+    return distorted_image
+
+
+def preprocess_for_eval(image, height, width,
+                        central_fraction=0.875, scope=None):
+  """Prepare one image for evaluation.
+
+  If height and width are specified it would output an image with that size by
+  applying resize_bilinear.
+
+  If central_fraction is specified it would cropt the central fraction of the
+  input image.
+
+  Args:
+    image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
+      [0, 1], otherwise it would converted to tf.float32 assuming that the range
+      is [0, MAX], where MAX is largest positive representable number for
+      int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
+    height: integer
+    width: integer
+    central_fraction: Optional Float, fraction of the image to crop.
+    scope: Optional scope for name_scope.
+  Returns:
+    3-D float Tensor of prepared image.
+  """
+  with tf.name_scope(scope, 'eval_image', [image, height, width]):
+    if image.dtype != tf.float32:
+      image = tf.image.convert_image_dtype(image, dtype=tf.float32)
+    # Crop the central region of the image with an area containing 87.5% of
+    # the original image.
+    if central_fraction:
+      image = tf.image.central_crop(image, central_fraction=central_fraction)
+
+    if height and width:
+      # Resize the image to the specified height and width.
+      image = tf.expand_dims(image, 0)
+      image = tf.image.resize_bilinear(image, [height, width],
+                                       align_corners=False)
+      image = tf.squeeze(image, [0])
+    image = tf.sub(image, 0.5)
+    image = tf.mul(image, 2.0)
+    return image
+
+
+def preprocess_image(image, height, width,
+                     is_training=False,
+                     bbox=None,
+                     fast_mode=True):
+  """Pre-process one image for training or evaluation.
+
+  Args:
+    image: 3-D Tensor [height, width, channels] with the image.
+    height: integer, image expected height.
+    width: integer, image expected width.
+    is_training: Boolean. If true it would transform an image for train,
+      otherwise it would transform it for evaluation.
+    bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
+      where each coordinate is [0, 1) and the coordinates are arranged as
+      [ymin, xmin, ymax, xmax].
+    fast_mode: Optional boolean, if True avoids slower transformations.
+
+  Returns:
+    3-D float Tensor containing an appropriately scaled image
+
+  Raises:
+    ValueError: if user does not provide bounding box
+  """
+  if is_training:
+    return preprocess_for_train(image, height, width, bbox, fast_mode)
+  else:
+    return preprocess_for_eval(image, height, width)

+ 44 - 0
slim/models/lenet_preprocessing.py

@@ -0,0 +1,44 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities for preprocessing."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def preprocess_image(image, output_height, output_width, is_training):
+  """Preprocesses the given image.
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    is_training: `True` if we're preprocessing the image for training and
+      `False` otherwise.
+
+  Returns:
+    A preprocessed image.
+  """
+  image = tf.to_float(image)
+  image = tf.image.resize_image_with_crop_or_pad(
+      image, output_width, output_height)
+  image = tf.sub(image, 128.0)
+  image = tf.div(image, 128.0)
+  return image

+ 681 - 0
slim/models/model_deploy.py

@@ -0,0 +1,681 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Deploy Slim models across multiple clones and replicas.
+
+# TODO(sguada) docstring paragraph by (a) motivating the need for the file and
+# (b) defining clones.
+
+# TODO(sguada) describe the high-level components of model deployment.
+# E.g. "each model deployment is composed of several parts: a DeploymentConfig,
+# which captures A, B and C, an input_fn which loads data.. etc
+
+To easily train a model on multiple GPUs or across multiple machines this
+module provides a set of helper functions: `create_clones`,
+`optimize_clones` and `deploy`.
+
+Usage:
+
+  g = tf.Graph()
+
+  # Set up DeploymentConfig
+  config = slim.DeploymentConfig(num_clones=2, clone_on_cpu=True)
+
+  # Create the global step on the device storing the variables.
+  with tf.device(config.variables_device()):
+    global_step = slim.create_global_step()
+
+  # Define the inputs
+  with tf.device(config.inputs_device()):
+    images, labels = LoadData(...)
+    inputs_queue = slim.data.prefetch_queue((images, labels))
+
+  # Define the optimizer.
+  with tf.device(config.optimizer_device()):
+    optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum)
+
+  # Define the model including the loss.
+  def model_fn(inputs_queue):
+    images, labels = inputs_queue.dequeue()
+    predictions = CreateNetwork(images)
+    slim.losses.log_loss(predictions, labels)
+
+  model_dp = slim.deploy(config, model_fn, [inputs_queue], optimizer=optimizer)
+
+  # Run training.
+  slim.learning.train(model_dp.train_op, my_log_dir,
+                      summary_op=model_dp.summary_op)
+
+The Clone namedtuple holds together the values associated with each call to
+model_fn:
+  * outputs: The return values of the calls to `model_fn()`.
+  * scope: The scope used to create the clone.
+  * device: The device used to create the clone.
+
+DeployedModel namedtuple, holds together the values needed to train multiple
+clones:
+  * train_op: An operation that run the optimizer training op and include
+    all the update ops created by `model_fn`. Present only if an optimizer
+    was specified.
+  * summary_op: An operation that run the summaries created by `model_fn`
+    and process_gradients.
+  * total_loss: A `Tensor` that contains the sum of all losses created by
+    `model_fn` plus the regularization losses.
+  * clones: List of `Clone` tuples returned by `create_clones()`.
+
+DeploymentConfig parameters:
+  * num_clones: Number of model clones to deploy in each replica.
+  * clone_on_cpu: True if clones should be placed on CPU.
+  * replica_id: Integer.  Index of the replica for which the model is
+      deployed.  Usually 0 for the chief replica.
+  * num_replicas: Number of replicas to use.
+  * num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
+  * worker_job_name: A name for the worker job.
+  * ps_job_name: A name for the parameter server job.
+
+TODO(sguada):
+  - describe side effect to the graph.
+  - what happens to summaries and update_ops.
+  - which graph collections are altered.
+  - write a tutorial on how to use this.
+  - analyze the possibility of calling deploy more than once.
+
+
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+
+slim = tf.contrib.slim
+
+
+__all__ = ['create_clones',
+           'deploy',
+           'optimize_clones',
+           'DeployedModel',
+           'DeploymentConfig',
+           'Clone',
+          ]
+
+
+# Namedtuple used to represent a clone during deployment.
+Clone = collections.namedtuple('Clone',
+                               ['outputs',  # Whatever model_fn() returned.
+                                'scope',  # The scope used to create it.
+                                'device',  # The device used to create.
+                               ])
+
+# Namedtuple used to represent a DeployedModel, returned by deploy().
+DeployedModel = collections.namedtuple('DeployedModel',
+                                       ['train_op',  # The `train_op`
+                                        'summary_op',  # The `summary_op`
+                                        'total_loss',  # The loss `Tensor`
+                                        'clones',  # A list of `Clones` tuples.
+                                       ])
+
+# Default parameters for DeploymentConfig
+_deployment_params = {'num_clones': 1,
+                      'clone_on_cpu': False,
+                      'replica_id': 0,
+                      'num_replicas': 1,
+                      'num_ps_tasks': 0,
+                      'worker_job_name': 'worker',
+                      'ps_job_name': 'ps'}
+
+
+def create_clones(config, model_fn, args=None, kwargs=None):
+  """Creates multiple clones according to config using a `model_fn`.
+
+  The returned values of `model_fn(*args, **kwargs)` are collected along with
+  the scope and device used to created it in a namedtuple
+  `Clone(outputs, scope, device)`
+
+  Note: it is assumed that any loss created by `model_fn` is collected at
+  the tf.GraphKeys.LOSSES collection.
+
+  To recover the losses, summaries or update_ops created by the clone use:
+  ```python
+    losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
+    summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, clone.scope)
+    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
+  ```
+
+  The deployment options are specified by the config object and support
+  deploying one or several clones on different GPUs and one or several replicas
+  of such clones.
+
+  The argument `model_fn` is called `config.num_clones` times to create the
+  model clones as `model_fn(*args, **kwargs)`.
+
+  If `config` specifies deployment on multiple replicas then the default
+  tensorflow device is set appropriatly for each call to `model_fn` and for the
+  slim variable creation functions: model and global variables will be created
+  on the `ps` device, the clone operations will be on the `worker` device.
+
+  Args:
+    config: A DeploymentConfig object.
+    model_fn: A callable. Called as `model_fn(*args, **kwargs)`
+    args: Optional list of arguments to pass to `model_fn`.
+    kwargs: Optional list of keyword arguments to pass to `model_fn`.
+
+  Returns:
+    A list of namedtuples `Clone`.
+  """
+  clones = []
+  args = args or []
+  kwargs = kwargs or {}
+  with slim.arg_scope([slim.model_variable, slim.variable],
+                      device=config.variables_device()):
+    # Create clones.
+    for i in range(0, config.num_clones):
+      with tf.name_scope(config.clone_scope(i)) as clone_scope:
+        clone_device = config.clone_device(i)
+        with tf.device(clone_device):
+          with tf.variable_scope(tf.get_variable_scope(),
+                                 reuse=True if i > 0 else None):
+            outputs = model_fn(*args, **kwargs)
+          clones.append(Clone(outputs, clone_scope, clone_device))
+  return clones
+
+
+def _gather_clone_loss(clone, num_clones, regularization_losses):
+  """Gather the loss for a single clone.
+
+  Args:
+    clone: A Clone namedtuple.
+    num_clones: The number of clones being deployed.
+    regularization_losses: Possibly empty list of regularization_losses
+      to add to the clone losses.
+
+  Returns:
+    A tensor for the total loss for the clone.  Can be None.
+  """
+  # The return value.
+  sum_loss = None
+  # Individual components of the loss that will need summaries.
+  clone_loss = None
+  regularization_loss = None
+  # Compute and aggregate losses on the clone device.
+  with tf.device(clone.device):
+    all_losses = []
+    clone_losses = tf.get_collection(tf.GraphKeys.LOSSES, clone.scope)
+    if clone_losses:
+      clone_loss = tf.add_n(clone_losses, name='clone_loss')
+      if num_clones > 1:
+        clone_loss = tf.div(clone_loss, 1.0 * num_clones,
+                            name='scaled_clone_loss')
+      all_losses.append(clone_loss)
+    if regularization_losses:
+      regularization_loss = tf.add_n(regularization_losses,
+                                     name='regularization_loss')
+      all_losses.append(regularization_loss)
+    if all_losses:
+      sum_loss = tf.add_n(all_losses)
+  # Add the summaries out of the clone device block.
+  if clone_loss is not None:
+    tf.scalar_summary(clone.scope + '/clone_loss', clone_loss,
+                      name='clone_loss')
+  if regularization_loss is not None:
+    tf.scalar_summary('regularization_loss', regularization_loss,
+                      name='regularization_loss')
+  return sum_loss
+
+
+def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
+                    kwargs=None):
+  """Compute losses and gradients for a single clone.
+
+  Args:
+    optimizer: A tf.Optimizer  object.
+    clone: A Clone namedtuple.
+    num_clones: The number of clones being deployed.
+    regularization_losses: Possibly empty list of regularization_losses
+      to add to the clone losses.
+    kwargs: Dict of kwarg to pass to compute_gradients().
+
+  Returns:
+    A tuple (clone_loss, clone_grads_and_vars).
+      - clone_loss: A tensor for the total loss for the clone.  Can be None.
+      - clone_grads_and_vars: List of (gradient, variable) for the clone.
+        Can be empty.
+  """
+  sum_loss = _gather_clone_loss(clone, num_clones, regularization_losses)
+  clone_grad = None
+  if sum_loss is not None:
+    with tf.device(clone.device):
+      clone_grad = optimizer.compute_gradients(sum_loss, **kwargs)
+  return sum_loss, clone_grad
+
+
+def optimize_clones(clones, optimizer,
+                    regularization_losses=None,
+                    kwargs=None):
+  """Compute clone losses and gradients for the given list of `Clones`.
+
+  Note: The regularization_losses are added to the first clone losses.
+
+  Args:
+   clones: List of `Clones` created by `create_clones()`.
+   optimizer: An `Optimizer` object.
+   regularization_losses: Optional list of regularization losses. If None it
+     will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
+     exclude them.
+   kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
+
+  Returns:
+   A tuple (total_loss, grads_and_vars).
+     - total_loss: A Tensor containing the average of the clone losses including
+       the regularization loss.
+     - grads_and_vars: A List of tuples (gradient, variable) containing the sum
+       of the gradients for each variable.
+
+  """
+  grads_and_vars = []
+  clones_losses = []
+  kwargs = kwargs or {}
+  num_clones = len(clones)
+  if regularization_losses is None:
+    regularization_losses = tf.get_collection(
+        tf.GraphKeys.REGULARIZATION_LOSSES)
+  for clone in clones:
+    with tf.name_scope(clone.scope):
+      clone_loss, clone_grad = _optimize_clone(
+          optimizer, clone, num_clones, regularization_losses, kwargs)
+      if clone_loss is not None:
+        clones_losses.append(clone_loss)
+        grads_and_vars.append(clone_grad)
+      # Only use regularization_losses for the first clone
+      regularization_losses = None
+  # Compute the total_loss summing all the clones_losses.
+  total_loss = tf.add_n(clones_losses, name='total_loss')
+  # Sum the gradients accross clones.
+  grads_and_vars = _sum_clones_gradients(grads_and_vars)
+  return total_loss, grads_and_vars
+
+
+def deploy(config,
+           model_fn,
+           args=None,
+           kwargs=None,
+           optimizer=None,
+           summarize_gradients=False):
+  """Deploys a Slim-constructed model across multiple clones.
+
+  The deployment options are specified by the config object and support
+  deploying one or several clones on different GPUs and one or several replicas
+  of such clones.
+
+  The argument `model_fn` is called `config.num_clones` times to create the
+  model clones as `model_fn(*args, **kwargs)`.
+
+  The optional argument `optimizer` is an `Optimizer` object.  If not `None`,
+  the deployed model is configured for training with that optimizer.
+
+  If `config` specifies deployment on multiple replicas then the default
+  tensorflow device is set appropriatly for each call to `model_fn` and for the
+  slim variable creation functions: model and global variables will be created
+  on the `ps` device, the clone operations will be on the `worker` device.
+
+  Args:
+    config: A `DeploymentConfig` object.
+    model_fn: A callable. Called as `model_fn(*args, **kwargs)`
+    args: Optional list of arguments to pass to `model_fn`.
+    kwargs: Optional list of keyword arguments to pass to `model_fn`.
+    optimizer: Optional `Optimizer` object.  If passed the model is deployed
+      for training with that optimizer.
+    summarize_gradients: Whether or not add summaries to the gradients.
+
+  Returns:
+    A `DeployedModel` namedtuple.
+
+  """
+  # Gather initial summaries.
+  summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+  # Create Clones.
+  clones = create_clones(config, model_fn, args, kwargs)
+  first_clone = clones[0]
+
+  # Gather update_ops from the first clone. These contain, for example,
+  # the updates for the batch_norm variables created by model_fn.
+  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone.scope)
+
+  train_op = None
+  total_loss = None
+  with tf.device(config.optimizer_device()):
+    if optimizer:
+      # Place the global step on the device storing the variables.
+      with tf.device(config.variables_device()):
+        global_step = slim.get_or_create_global_step()
+
+      # Compute the gradients for the clones.
+      total_loss, clones_gradients = optimize_clones(clones, optimizer)
+
+      if clones_gradients:
+        if summarize_gradients:
+          # Add summaries to the gradients.
+          summaries |= set(_add_gradients_summaries(clones_gradients))
+
+        # Create gradient updates.
+        grad_updates = optimizer.apply_gradients(clones_gradients,
+                                                 global_step=global_step)
+        update_ops.append(grad_updates)
+
+        update_op = tf.group(*update_ops)
+        train_op = control_flow_ops.with_dependencies([update_op], total_loss,
+                                                      name='train_op')
+    else:
+      clones_losses = []
+      regularization_losses = tf.get_collection(
+          tf.GraphKeys.REGULARIZATION_LOSSES)
+      for clone in clones:
+        with tf.name_scope(clone.scope):
+          clone_loss = _gather_clone_loss(clone, len(clones),
+                                          regularization_losses)
+          if clone_loss is not None:
+            clones_losses.append(clone_loss)
+          # Only use regularization_losses for the first clone
+          regularization_losses = None
+      if clones_losses:
+        total_loss = tf.add_n(clones_losses, name='total_loss')
+
+    # Add the summaries from the first clone. These contain the summaries
+    # created by model_fn and either optimize_clones() or _gather_clone_loss().
+    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
+                                       first_clone.scope))
+
+    if total_loss is not None:
+      # Add total_loss to summary.
+      summaries.add(tf.scalar_summary('total_loss', total_loss,
+                                      name='total_loss'))
+
+    if summaries:
+      # Merge all summaries together.
+      summary_op = tf.merge_summary(list(summaries), name='summary_op')
+    else:
+      summary_op = None
+
+  return DeployedModel(train_op, summary_op, total_loss, clones)
+
+
+def _sum_clones_gradients(clone_grads):
+  """Calculate the sum gradient for each shared variable across all clones.
+
+  This function assumes that the clone_grads has been scaled appropriately by
+  1 / num_clones.
+
+  Args:
+    clone_grads: A List of List of tuples (gradient, variable), one list per
+    `Clone`.
+
+  Returns:
+     List of tuples of (gradient, variable) where the gradient has been summed
+     across all clones.
+  """
+  sum_grads = []
+  for grad_and_vars in zip(*clone_grads):
+    # Note that each grad_and_vars looks like the following:
+    #   ((grad_var0_clone0, var0), ... (grad_varN_cloneN, varN))
+    grads = []
+    var = grad_and_vars[0][1]
+    for g, v in grad_and_vars:
+      assert v == var
+      if g is not None:
+        grads.append(g)
+    if grads:
+      if len(grads) > 1:
+        sum_grad = tf.add_n(grads, name=var.op.name + '/sum_grads')
+      else:
+        sum_grad = grads[0]
+      sum_grads.append((sum_grad, var))
+  return sum_grads
+
+
+def _add_gradients_summaries(grads_and_vars):
+  """Add histogram summaries to gradients.
+
+  Note: The summaries are also added to the SUMMARIES collection.
+
+  Args:
+    grads_and_vars: A list of gradient to variable pairs (tuples).
+
+  Returns:
+    The _list_ of the added summaries for grads_and_vars.
+  """
+  summaries = []
+  for grad, var in grads_and_vars:
+    if grad is not None:
+      if isinstance(grad, tf.IndexedSlices):
+        grad_values = grad.values
+      else:
+        grad_values = grad
+      summaries.append(tf.histogram_summary(var.op.name + ':gradient',
+                                            grad_values))
+      summaries.append(tf.histogram_summary(var.op.name + ':gradient_norm',
+                                            tf.global_norm([grad_values])))
+    else:
+      tf.logging.info('Var %s has no gradient', var.op.name)
+  return summaries
+
+
+class DeploymentConfig(object):
+  """Configuration for deploying a model with `deploy()`.
+
+  You can pass an instance of this class to `deploy()` to specify exactly
+  how to deploy the model to build.  If you do not pass one, an instance built
+  from the default deployment_hparams will be used.
+  """
+
+  def __init__(self,
+               num_clones=1,
+               clone_on_cpu=False,
+               replica_id=0,
+               num_replicas=1,
+               num_ps_tasks=0,
+               worker_job_name='worker',
+               ps_job_name='ps'):
+    """Create a DeploymentConfig.
+
+    The config describes how to deploy a model across multiple clones and
+    replicas.  The model will be replicated `num_clones` times in each replica.
+    If `clone_on_cpu` is True, each clone will placed on CPU.
+
+    If `num_replicas` is 1, the model is deployed via a single process.  In that
+    case `worker_device`, `num_ps_tasks`, and `ps_device` are ignored.
+
+    If `num_replicas` is greater than 1, then `worker_device` and `ps_device`
+    must specify TensorFlow devices for the `worker` and `ps` jobs and
+    `num_ps_tasks` must be positive.
+
+    Args:
+      num_clones: Number of model clones to deploy in each replica.
+      clone_on_cpu: If True clones would be placed on CPU.
+      replica_id: Integer.  Index of the replica for which the model is
+        deployed.  Usually 0 for the chief replica.
+      num_replicas: Number of replicas to use.
+      num_ps_tasks: Number of tasks for the `ps` job. 0 to not use replicas.
+      worker_job_name: A name for the worker job.
+      ps_job_name: A name for the parameter server job.
+
+    Raises:
+      ValueError: If the arguments are invalid.
+    """
+    if num_replicas > 1:
+      if num_ps_tasks < 1:
+        raise ValueError('When using replicas num_ps_tasks must be positive')
+    if num_replicas > 1 or num_ps_tasks > 0:
+      if not worker_job_name:
+        raise ValueError('Must specify worker_job_name when using replicas')
+      if not ps_job_name:
+        raise ValueError('Must specify ps_job_name when using parameter server')
+    if replica_id >= num_replicas:
+      raise ValueError('replica_id must be less than num_replicas')
+    self._num_clones = num_clones
+    self._clone_on_cpu = clone_on_cpu
+    self._replica_id = replica_id
+    self._num_replicas = num_replicas
+    self._num_ps_tasks = num_ps_tasks
+    self._ps_device = '/job:' + ps_job_name if num_ps_tasks > 0 else ''
+    self._worker_device = '/job:' + worker_job_name if num_ps_tasks > 0 else ''
+
+  @property
+  def num_clones(self):
+    return self._num_clones
+
+  @property
+  def clone_on_cpu(self):
+    return self._clone_on_cpu
+
+  @property
+  def replica_id(self):
+    return self._replica_id
+
+  @property
+  def num_replicas(self):
+    return self._num_replicas
+
+  @property
+  def num_ps_tasks(self):
+    return self._num_ps_tasks
+
+  @property
+  def ps_device(self):
+    return self._ps_device
+
+  @property
+  def worker_device(self):
+    return self._worker_device
+
+  def caching_device(self):
+    """Returns the device to use for caching variables.
+
+    Variables are cached on the worker CPU when using replicas.
+
+    Returns:
+      A device string or None if the variables do not need to be cached.
+    """
+    if self._num_ps_tasks > 0:
+      return lambda op: op.device
+    else:
+      return None
+
+  def clone_device(self, clone_index):
+    """Device used to create the clone and all the ops inside the clone.
+
+    Args:
+      clone_index: Int, representing the clone_index.
+
+    Returns:
+      A value suitable for `tf.device()`.
+
+    Raises:
+      ValueError: if `clone_index` is greater or equal to the number of clones".
+    """
+    if clone_index >= self._num_clones:
+      raise ValueError('clone_index must be less than num_clones')
+    device = ''
+    if self._num_ps_tasks > 0:
+      device += self._worker_device
+    if self._clone_on_cpu:
+      device += '/device:CPU:0'
+    else:
+      if self._num_clones > 1:
+        device += '/device:GPU:%d' % clone_index
+    return device
+
+  def clone_scope(self, clone_index):
+    """Name scope to create the clone.
+
+    Args:
+      clone_index: Int, representing the clone_index.
+
+    Returns:
+      A name_scope suitable for `tf.name_scope()`.
+
+    Raises:
+      ValueError: if `clone_index` is greater or equal to the number of clones".
+    """
+    if clone_index >= self._num_clones:
+      raise ValueError('clone_index must be less than num_clones')
+    scope = ''
+    if self._num_clones > 1:
+      scope = 'clone_%d' % clone_index
+    return scope
+
+  def optimizer_device(self):
+    """Device to use with the optimizer.
+
+    Returns:
+      A value suitable for `tf.device()`.
+    """
+    if self._num_ps_tasks > 0 or self._num_clones > 0:
+      return self._worker_device + '/device:CPU:0'
+    else:
+      return ''
+
+  def inputs_device(self):
+    """Device to use to build the inputs.
+
+    Returns:
+      A value suitable for `tf.device()`.
+    """
+    device = ''
+    if self._num_ps_tasks > 0:
+      device += self._worker_device
+    device += '/device:CPU:0'
+    return device
+
+  def variables_device(self):
+    """Returns the device to use for variables created inside the clone.
+
+    Returns:
+      A value suitable for `tf.device()`.
+    """
+    device = ''
+    if self._num_ps_tasks > 0:
+      device += self._ps_device
+    device += '/device:CPU:0'
+
+    class _PSDeviceChooser(object):
+      """Slim device chooser for variables when using PS."""
+
+      def __init__(self, device, tasks):
+        self._device = device
+        self._tasks = tasks
+        self._task = 0
+
+      def choose(self, op):
+        if op.device:
+          return op.device
+        node_def = op if isinstance(op, tf.NodeDef) else op.node_def
+        if node_def.op == 'Variable':
+          t = self._task
+          self._task = (self._task + 1) % self._tasks
+          d = '%s/task:%d' % (self._device, t)
+          return d
+        else:
+          return op.device
+
+    if not self._num_ps_tasks:
+      return device
+    else:
+      chooser = _PSDeviceChooser(device, self._num_ps_tasks)
+      return chooser.choose

+ 565 - 0
slim/models/model_deploy_test.py

@@ -0,0 +1,565 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for model_deploy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from slim.models import model_deploy
+
+slim = tf.contrib.slim
+
+
+class DeploymentConfigTest(tf.test.TestCase):
+
+  def testDefaults(self):
+    deploy_config = model_deploy.DeploymentConfig()
+
+    self.assertEqual(slim.get_variables(), [])
+    self.assertEqual(deploy_config.caching_device(), None)
+    self.assertDeviceEqual(deploy_config.clone_device(0), '')
+    self.assertEqual(deploy_config.clone_scope(0), '')
+    self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
+
+  def testCPUonly(self):
+    deploy_config = model_deploy.DeploymentConfig(clone_on_cpu=True)
+
+    self.assertEqual(deploy_config.caching_device(), None)
+    self.assertDeviceEqual(deploy_config.clone_device(0), 'CPU:0')
+    self.assertEqual(deploy_config.clone_scope(0), '')
+    self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
+
+  def testMultiGPU(self):
+    deploy_config = model_deploy.DeploymentConfig(num_clones=2)
+
+    self.assertEqual(deploy_config.caching_device(), None)
+    self.assertDeviceEqual(deploy_config.clone_device(0), 'GPU:0')
+    self.assertDeviceEqual(deploy_config.clone_device(1), 'GPU:1')
+    self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
+    self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
+    self.assertDeviceEqual(deploy_config.optimizer_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(), 'CPU:0')
+    self.assertDeviceEqual(deploy_config.variables_device(), 'CPU:0')
+
+  def testPS(self):
+    deploy_config = model_deploy.DeploymentConfig(num_clones=1, num_ps_tasks=1)
+
+    self.assertDeviceEqual(deploy_config.clone_device(0),
+                           '/job:worker')
+    self.assertEqual(deploy_config.clone_scope(0), '')
+    self.assertDeviceEqual(deploy_config.optimizer_device(),
+                           '/job:worker/device:CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(),
+                           '/job:worker/device:CPU:0')
+    with tf.device(deploy_config.variables_device()):
+      a = tf.Variable(0)
+      b = tf.Variable(0)
+      c = tf.no_op()
+      d = slim.variable('a', [],
+                        caching_device=deploy_config.caching_device())
+    self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
+    self.assertDeviceEqual(a.device, a.value().device)
+    self.assertDeviceEqual(b.device, '/job:ps/task:0/device:CPU:0')
+    self.assertDeviceEqual(b.device, b.value().device)
+    self.assertDeviceEqual(c.device, '')
+    self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
+    self.assertDeviceEqual(d.value().device, '')
+
+  def testMultiGPUPS(self):
+    deploy_config = model_deploy.DeploymentConfig(num_clones=2, num_ps_tasks=1)
+
+    self.assertEqual(deploy_config.caching_device()(tf.no_op()), '')
+    self.assertDeviceEqual(deploy_config.clone_device(0),
+                           '/job:worker/device:GPU:0')
+    self.assertDeviceEqual(deploy_config.clone_device(1),
+                           '/job:worker/device:GPU:1')
+    self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
+    self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
+    self.assertDeviceEqual(deploy_config.optimizer_device(),
+                           '/job:worker/device:CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(),
+                           '/job:worker/device:CPU:0')
+
+  def testReplicasPS(self):
+    deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
+                                                  num_ps_tasks=2)
+
+    self.assertDeviceEqual(deploy_config.clone_device(0),
+                           '/job:worker')
+    self.assertEqual(deploy_config.clone_scope(0), '')
+    self.assertDeviceEqual(deploy_config.optimizer_device(),
+                           '/job:worker/device:CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(),
+                           '/job:worker/device:CPU:0')
+
+  def testReplicasMultiGPUPS(self):
+    deploy_config = model_deploy.DeploymentConfig(num_replicas=2,
+                                                  num_clones=2,
+                                                  num_ps_tasks=2)
+    self.assertDeviceEqual(deploy_config.clone_device(0),
+                           '/job:worker/device:GPU:0')
+    self.assertDeviceEqual(deploy_config.clone_device(1),
+                           '/job:worker/device:GPU:1')
+    self.assertEqual(deploy_config.clone_scope(0), 'clone_0')
+    self.assertEqual(deploy_config.clone_scope(1), 'clone_1')
+    self.assertDeviceEqual(deploy_config.optimizer_device(),
+                           '/job:worker/device:CPU:0')
+    self.assertDeviceEqual(deploy_config.inputs_device(),
+                           '/job:worker/device:CPU:0')
+
+  def testVariablesPS(self):
+    deploy_config = model_deploy.DeploymentConfig(num_ps_tasks=2)
+
+    with tf.device(deploy_config.variables_device()):
+      a = tf.Variable(0)
+      b = tf.Variable(0)
+      c = tf.no_op()
+      d = slim.variable('a', [],
+                        caching_device=deploy_config.caching_device())
+
+    self.assertDeviceEqual(a.device, '/job:ps/task:0/device:CPU:0')
+    self.assertDeviceEqual(a.device, a.value().device)
+    self.assertDeviceEqual(b.device, '/job:ps/task:1/device:CPU:0')
+    self.assertDeviceEqual(b.device, b.value().device)
+    self.assertDeviceEqual(c.device, '')
+    self.assertDeviceEqual(d.device, '/job:ps/task:0/device:CPU:0')
+    self.assertDeviceEqual(d.value().device, '')
+
+
+def LogisticClassifier(inputs, labels, scope=None, reuse=None):
+  with tf.variable_scope(scope, 'LogisticClassifier', [inputs, labels],
+                         reuse=reuse):
+    predictions = slim.fully_connected(inputs, 1, activation_fn=tf.sigmoid,
+                                       scope='fully_connected')
+    slim.losses.log_loss(predictions, labels)
+    return predictions
+
+
+def BatchNormClassifier(inputs, labels, scope=None, reuse=None):
+  with tf.variable_scope(scope, 'BatchNormClassifier', [inputs, labels],
+                         reuse=reuse):
+    inputs = slim.batch_norm(inputs, decay=0.1)
+    predictions = slim.fully_connected(inputs, 1,
+                                       activation_fn=tf.sigmoid,
+                                       scope='fully_connected')
+    slim.losses.log_loss(predictions, labels)
+    return predictions
+
+
+class CreatecloneTest(tf.test.TestCase):
+
+  def setUp(self):
+    # Create an easy training set:
+    np.random.seed(0)
+
+    self._inputs = np.zeros((16, 4))
+    self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
+    self._logdir = self.get_temp_dir()
+
+    for i in range(16):
+      j = int(2 * self._labels[i] + np.random.randint(0, 2))
+      self._inputs[i, j] = 1
+
+  def testCreateLogisticClassifier(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = LogisticClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      clone = clones[0]
+      self.assertEqual(len(slim.get_variables()), 2)
+      for v in slim.get_variables():
+        self.assertDeviceEqual(v.device, 'CPU:0')
+        self.assertDeviceEqual(v.value().device, 'CPU:0')
+      self.assertEqual(clone.outputs.op.name,
+                       'LogisticClassifier/fully_connected/Sigmoid')
+      self.assertEqual(clone.scope, '')
+      self.assertDeviceEqual(clone.device, '')
+      self.assertEqual(len(slim.losses.get_losses()), 1)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(update_ops, [])
+
+  def testCreateSingleclone(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      clone = clones[0]
+      self.assertEqual(len(slim.get_variables()), 5)
+      for v in slim.get_variables():
+        self.assertDeviceEqual(v.device, 'CPU:0')
+        self.assertDeviceEqual(v.value().device, 'CPU:0')
+      self.assertEqual(clone.outputs.op.name,
+                       'BatchNormClassifier/fully_connected/Sigmoid')
+      self.assertEqual(clone.scope, '')
+      self.assertDeviceEqual(clone.device, '')
+      self.assertEqual(len(slim.losses.get_losses()), 1)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), 2)
+
+  def testCreateMulticlone(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      num_clones = 4
+      deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      for v in slim.get_variables():
+        self.assertDeviceEqual(v.device, 'CPU:0')
+        self.assertDeviceEqual(v.value().device, 'CPU:0')
+      self.assertEqual(len(clones), num_clones)
+      for i, clone in enumerate(clones):
+        self.assertEqual(
+            clone.outputs.op.name,
+            'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
+        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, clone.scope)
+        self.assertEqual(len(update_ops), 2)
+        self.assertEqual(clone.scope, 'clone_%d/' % i)
+        self.assertDeviceEqual(clone.device, 'GPU:%d' % i)
+
+  def testCreateOnecloneWithPS(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1,
+                                                    num_ps_tasks=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(clones), 1)
+      clone = clones[0]
+      self.assertEqual(clone.outputs.op.name,
+                       'BatchNormClassifier/fully_connected/Sigmoid')
+      self.assertDeviceEqual(clone.device, '/job:worker')
+      self.assertEqual(clone.scope, '')
+      self.assertEqual(len(slim.get_variables()), 5)
+      for v in slim.get_variables():
+        self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
+        self.assertDeviceEqual(v.device, v.value().device)
+
+  def testCreateMulticloneWithPS(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=2,
+                                                    num_ps_tasks=2)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      for i, v in enumerate(slim.get_variables()):
+        t = i % 2
+        self.assertDeviceEqual(v.device, '/job:ps/task:%d/device:CPU:0' % t)
+        self.assertDeviceEqual(v.device, v.value().device)
+      self.assertEqual(len(clones), 2)
+      for i, clone in enumerate(clones):
+        self.assertEqual(
+            clone.outputs.op.name,
+            'clone_%d/BatchNormClassifier/fully_connected/Sigmoid' % i)
+        self.assertEqual(clone.scope, 'clone_%d/' % i)
+        self.assertDeviceEqual(clone.device, '/job:worker/device:GPU:%d' % i)
+
+
+class OptimizeclonesTest(tf.test.TestCase):
+
+  def setUp(self):
+    # Create an easy training set:
+    np.random.seed(0)
+
+    self._inputs = np.zeros((16, 4))
+    self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
+    self._logdir = self.get_temp_dir()
+
+    for i in range(16):
+      j = int(2 * self._labels[i] + np.random.randint(0, 2))
+      self._inputs[i, j] = 1
+
+  def testCreateLogisticClassifier(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = LogisticClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(slim.get_variables()), 2)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(update_ops, [])
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
+      self.assertEqual(total_loss.op.name, 'total_loss')
+      for g, v in grads_and_vars:
+        self.assertDeviceEqual(g.device, '')
+        self.assertDeviceEqual(v.device, 'CPU:0')
+
+  def testCreateSingleclone(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), 2)
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
+      self.assertEqual(total_loss.op.name, 'total_loss')
+      for g, v in grads_and_vars:
+        self.assertDeviceEqual(g.device, '')
+        self.assertDeviceEqual(v.device, 'CPU:0')
+
+  def testCreateMulticlone(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      clone_args = (tf_inputs, tf_labels)
+      num_clones = 4
+      deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, clone_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), num_clones * 2)
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
+      self.assertEqual(total_loss.op.name, 'total_loss')
+      for g, v in grads_and_vars:
+        self.assertDeviceEqual(g.device, '')
+        self.assertDeviceEqual(v.device, 'CPU:0')
+
+  def testCreateMulticloneCPU(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      model_args = (tf_inputs, tf_labels)
+      num_clones = 4
+      deploy_config = model_deploy.DeploymentConfig(num_clones=num_clones,
+                                                    clone_on_cpu=True)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), num_clones * 2)
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
+      self.assertEqual(total_loss.op.name, 'total_loss')
+      for g, v in grads_and_vars:
+        self.assertDeviceEqual(g.device, '')
+        self.assertDeviceEqual(v.device, 'CPU:0')
+
+  def testCreateOnecloneWithPS(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      model_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=1,
+                                                    num_ps_tasks=1)
+
+      self.assertEqual(slim.get_variables(), [])
+      clones = model_deploy.create_clones(deploy_config, model_fn, model_args)
+      self.assertEqual(len(slim.get_variables()), 5)
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), 2)
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+      total_loss, grads_and_vars = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+      self.assertEqual(len(grads_and_vars), len(tf.trainable_variables()))
+      self.assertEqual(total_loss.op.name, 'total_loss')
+      for g, v in grads_and_vars:
+        self.assertDeviceEqual(g.device, '/job:worker')
+        self.assertDeviceEqual(v.device, '/job:ps/task:0/CPU:0')
+
+
+class DeployTest(tf.test.TestCase):
+
+  def setUp(self):
+    # Create an easy training set:
+    np.random.seed(0)
+
+    self._inputs = np.zeros((16, 4))
+    self._labels = np.random.randint(0, 2, size=(16, 1)).astype(np.float32)
+    self._logdir = self.get_temp_dir()
+
+    for i in range(16):
+      j = int(2 * self._labels[i] + np.random.randint(0, 2))
+      self._inputs[i, j] = 1
+
+  def testLocalTrainOp(self):
+    g = tf.Graph()
+    with g.as_default():
+      tf.set_random_seed(0)
+      tf_inputs = tf.constant(self._inputs, dtype=tf.float32)
+      tf_labels = tf.constant(self._labels, dtype=tf.float32)
+
+      model_fn = BatchNormClassifier
+      model_args = (tf_inputs, tf_labels)
+      deploy_config = model_deploy.DeploymentConfig(num_clones=2,
+                                                    clone_on_cpu=True)
+
+      optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0)
+
+      self.assertEqual(slim.get_variables(), [])
+      model = model_deploy.deploy(deploy_config, model_fn, model_args,
+                                  optimizer=optimizer)
+
+      update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
+      self.assertEqual(len(update_ops), 4)
+      self.assertEqual(len(model.clones), 2)
+      self.assertEqual(model.total_loss.op.name, 'total_loss')
+      self.assertEqual(model.summary_op.op.name, 'summary_op/summary_op')
+      self.assertEqual(model.train_op.op.name, 'train_op')
+
+      with tf.Session() as sess:
+        sess.run(tf.initialize_all_variables())
+        moving_mean = tf.contrib.framework.get_variables_by_name(
+            'moving_mean')[0]
+        moving_variance = tf.contrib.framework.get_variables_by_name(
+            'moving_variance')[0]
+        initial_loss = sess.run(model.total_loss)
+        initial_mean, initial_variance = sess.run([moving_mean,
+                                                   moving_variance])
+        self.assertAllClose(initial_mean, [0.0, 0.0, 0.0, 0.0])
+        self.assertAllClose(initial_variance, [1.0, 1.0, 1.0, 1.0])
+        for _ in range(10):
+          sess.run(model.train_op)
+        final_loss = sess.run(model.total_loss)
+        self.assertLess(final_loss, initial_loss / 10.0)
+
+        final_mean, final_variance = sess.run([moving_mean,
+                                               moving_variance])
+        self.assertAllClose(final_mean, [0.125, 0.25, 0.375, 0.25])
+        self.assertAllClose(final_variance, [0.109375, 0.1875,
+                                             0.234375, 0.1875])
+
+  def testNoSummariesOnGPU(self):
+    with tf.Graph().as_default():
+      deploy_config = model_deploy.DeploymentConfig(num_clones=2)
+
+      # clone function creates a fully_connected layer with a regularizer loss.
+      def ModelFn():
+        inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
+        reg = tf.contrib.layers.l2_regularizer(0.001)
+        tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
+
+      model = model_deploy.deploy(
+          deploy_config, ModelFn,
+          optimizer=tf.train.GradientDescentOptimizer(1.0))
+      # The model summary op should have a few summary inputs and all of them
+      # should be on the CPU.
+      self.assertTrue(model.summary_op.op.inputs)
+      for inp in  model.summary_op.op.inputs:
+        self.assertEqual('/device:CPU:0', inp.device)
+
+  def testNoSummariesOnGPUForEvals(self):
+    with tf.Graph().as_default():
+      deploy_config = model_deploy.DeploymentConfig(num_clones=2)
+
+      # clone function creates a fully_connected layer with a regularizer loss.
+      def ModelFn():
+        inputs = tf.constant(1.0, shape=(10, 20), dtype=tf.float32)
+        reg = tf.contrib.layers.l2_regularizer(0.001)
+        tf.contrib.layers.fully_connected(inputs, 30, weights_regularizer=reg)
+
+      # No optimizer here, it's an eval.
+      model = model_deploy.deploy(deploy_config, ModelFn)
+      # The model summary op should have a few summary inputs and all of them
+      # should be on the CPU.
+      self.assertTrue(model.summary_op.op.inputs)
+      for inp in  model.summary_op.op.inputs:
+        self.assertEqual('/device:CPU:0', inp.device)
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 140 - 0
slim/models/model_factory.py

@@ -0,0 +1,140 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a factory for building various models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim import nets
+from slim.nets import lenet
+
+slim = tf.contrib.slim
+
+
+def get_model(name, num_classes, weight_decay=0.0, is_training=False):
+  """Returns a model_fn such as `logits, end_points = model_fn(images)`.
+
+  Args:
+    name: The name of the model.
+    num_classes: The number of classes to use for classification.
+    weight_decay: The l2 coefficient for the model weights.
+    is_training: `True` if the model is being used for training and `False`
+      otherwise.
+
+  Returns:
+    model_fn: A function that applies the model to a batch of images. It has
+      the following signature:
+        logits, end_points = model_fn(images)
+  Raises:
+    ValueError: If model `name` is not recognized.
+  """
+  if name == 'inception_v1':
+    default_image_size = nets.inception.inception_v1.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.inception.inception_v1_arg_scope(
+          weight_decay=weight_decay)):
+        return nets.inception.inception_v1(images,
+                                           num_classes,
+                                           is_training=is_training)
+    model_fn = func
+  elif name == 'inception_v2':
+    default_image_size = nets.inception.inception_v2.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.inception.inception_v2_arg_scope(
+          weight_decay=weight_decay)):
+        return nets.inception.inception_v2(images,
+                                           num_classes=num_classes,
+                                           is_training=is_training)
+    model_fn = func
+  elif name == 'inception_v3':
+    default_image_size = nets.inception.inception_v3.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.inception.inception_v3_arg_scope(
+          weight_decay=weight_decay)):
+        return nets.inception.inception_v3(images,
+                                           num_classes=num_classes,
+                                           is_training=is_training)
+    model_fn = func
+  elif name == 'lenet':
+    default_image_size = lenet.lenet.default_image_size
+    def func(images):
+      with slim.arg_scope(lenet.lenet_arg_scope(weight_decay=weight_decay)):
+        return lenet.lenet(images,
+                           num_classes=num_classes,
+                           is_training=is_training)
+    model_fn = func
+  elif name == 'resnet_v1_50':
+    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
+          is_training, weight_decay=weight_decay)):
+        net, end_points = nets.resnet_v1.resnet_v1_50(
+            images, num_classes=num_classes)
+        net = tf.squeeze(net, squeeze_dims=[1, 2])
+        return net, end_points
+    model_fn = func
+  elif name == 'resnet_v1_101':
+    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
+          is_training, weight_decay=weight_decay)):
+        net, end_points = nets.resnet_v1.resnet_v1_101(
+            images, num_classes=num_classes)
+        net = tf.squeeze(net, squeeze_dims=[1, 2])
+        return net, end_points
+    model_fn = func
+  elif name == 'resnet_v1_152':
+    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
+          is_training, weight_decay=weight_decay)):
+        net, end_points = nets.resnet_v1.resnet_v1_152(
+            images, num_classes=num_classes)
+        net = tf.squeeze(net, squeeze_dims=[1, 2])
+        return net, end_points
+    model_fn = func
+  elif name == 'vgg_a':
+    default_image_size = nets.vgg.vgg_a.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
+        return nets.vgg.vgg_a(images,
+                              num_classes=num_classes,
+                              is_training=is_training)
+    model_fn = func
+  elif name == 'vgg_16':
+    default_image_size = nets.vgg.vgg_16.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
+        return nets.vgg.vgg_16(images,
+                               num_classes=num_classes,
+                               is_training=is_training)
+    model_fn = func
+  elif name == 'vgg_19':
+    default_image_size = nets.vgg.vgg_19.default_image_size
+    def func(images):
+      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
+        return nets.vgg.vgg_19(images,
+                               num_classes=num_classes,
+                               is_training=is_training)
+    model_fn = func
+  else:
+    raise ValueError('Model name [%s] was not recognized' % name)
+
+  model_fn.default_image_size = default_image_size
+
+  return model_fn

+ 70 - 0
slim/models/preprocessing_factory.py

@@ -0,0 +1,70 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a factory for building various models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from slim.models import cifar10_preprocessing
+from slim.models import inception_preprocessing
+from slim.models import lenet_preprocessing
+from slim.models import vgg_preprocessing
+
+slim = tf.contrib.slim
+
+
+def get_preprocessing(name, is_training=False):
+  """Returns preprocessing_fn(image, height, width, **kwargs).
+
+  Args:
+    name: The name of the preprocessing function.
+    is_training: `True` if the model is being used for training and `False`
+      otherwise.
+
+  Returns:
+    preprocessing_fn: A function that preprocessing a single image (pre-batch).
+      It has the following signature:
+        image = preprocessing_fn(image, output_height, output_width, ...).
+
+  Raises:
+    ValueError: If Preprocessing `name` is not recognized.
+  """
+  preprocessing_fn_map = {
+      'cifar10': cifar10_preprocessing,
+      'inception': inception_preprocessing,
+      'inception_v1': inception_preprocessing,
+      'inception_v2': inception_preprocessing,
+      'inception_v3': inception_preprocessing,
+      'lenet': lenet_preprocessing,
+      'resnet_v1_50': vgg_preprocessing,
+      'resnet_v1_101': vgg_preprocessing,
+      'resnet_v1_152': vgg_preprocessing,
+      'vgg': vgg_preprocessing,
+      'vgg_a': vgg_preprocessing,
+      'vgg_16': vgg_preprocessing,
+      'vgg_19': vgg_preprocessing,
+  }
+
+  if name not in preprocessing_fn_map:
+    raise ValueError('Preprocessing name [%s] was not recognized' % name)
+
+  def preprocessing_fn(image, output_height, output_width, **kwargs):
+    return preprocessing_fn_map[name].preprocess_image(
+        image, output_height, output_width, is_training=is_training, **kwargs)
+
+  return preprocessing_fn

+ 200 - 0
slim/models/resnet_preprocessing.py

@@ -0,0 +1,200 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to preprocess images for the ResNet networks."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.slim import nets
+from tensorflow.python.ops import control_flow_ops
+
+slim = tf.contrib.slim
+
+_R_MEAN = 123.68
+_G_MEAN = 116.78
+_B_MEAN = 103.94
+
+_CROP_HEIGHT = nets.resnet_v1.resnet_v1.default_image_size
+_CROP_WIDTH = nets.resnet_v1.resnet_v1.default_image_size
+_RESIZE_SIDE = 256
+
+
+def _mean_image_subtraction(image, means):
+  """Subtracts the given means from each image channel.
+
+  For example:
+    means = [123.68, 116.779, 103.939]
+    image = _mean_image_subtraction(image, means)
+
+  Note that the rank of `image` must be known.
+
+  Args:
+    image: a tensor of size [height, width, C].
+    means: a C-vector of values to subtract from each channel.
+
+  Returns:
+    the centered image.
+
+  Raises:
+    ValueError: If the rank of `image` is unknown, if `image` has a rank other
+      than three or if the number of channels in `image` doesn't match the
+      number of values in `means`.
+  """
+  if image.get_shape().ndims != 3:
+    raise ValueError('Input must be of size [height, width, C>0]')
+  num_channels = image.get_shape().as_list()[-1]
+  if len(means) != num_channels:
+    raise ValueError('len(means) must match the number of channels')
+
+  channels = tf.split(2, num_channels, image)
+  for i in range(num_channels):
+    channels[i] -= means[i]
+  return tf.concat(2, channels)
+
+
+def _smallest_size_at_least(height, width, smallest_side):
+  """Computes new shape with the smallest side equal to `smallest_side`.
+
+  Computes new shape with the smallest side equal to `smallest_side` while
+  preserving the original aspect ratio.
+
+  Args:
+    height: an int32 scalar tensor indicating the current height.
+    width: an int32 scalar tensor indicating the current width.
+    smallest_side: an python integer indicating the smallest side of the new
+      shape.
+
+  Returns:
+    new_height: an int32 scalar tensor indicating the new height.
+    new_width: and int32 scalar tensor indicating the new width.
+  """
+  height = tf.to_float(height)
+  width = tf.to_float(width)
+  smallest_side = float(smallest_side)
+  scale = tf.cond(tf.greater(height, width),
+                  lambda: smallest_side / width,
+                  lambda: smallest_side / height)
+  new_height = tf.to_int32(height * scale)
+  new_width = tf.to_int32(width * scale)
+  return new_height, new_width
+
+
+def _aspect_preserving_resize(image, smallest_side):
+  """Resize images preserving the original aspect ratio.
+
+  Args:
+    image: a 3-D image tensor.
+    smallest_side: a python integer indicating the size of the smallest side
+      after resize.
+
+  Returns:
+    resized_image: a 3-D tensor containing the resized image.
+  """
+  shape = tf.shape(image)
+  height = shape[0]
+  width = shape[1]
+  new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
+  image = tf.expand_dims(image, 0)
+  resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
+                                           align_corners=False)
+  resized_image = tf.squeeze(resized_image)
+  resized_image.set_shape([None, None, 3])
+  return resized_image
+
+
+def _crop(image, offset_height, offset_width, crop_height, crop_width):
+  """Crops the given image using the provided offsets and sizes.
+
+  Note that the method doesn't assume we know the input image size but it does
+  assume we know the input image rank.
+
+  Args:
+    image: an image of shape [height, width, channels].
+    offset_height: a scalar tensor indicating the height offset.
+    offset_width: a scalar tensor indicating the width offset.
+    crop_height: the height of the cropped image.
+    crop_width: the width of the cropped image.
+
+  Returns:
+    the cropped (and resized) image.
+
+  Raises:
+    InvalidArgumentError: if the rank is not 3 or if the image dimensions are
+      less than the crop size.
+  """
+  original_shape = tf.shape(image)
+
+  rank_assertion = tf.Assert(
+      tf.equal(tf.rank(image), 3),
+      ['Rank of image must be equal to 3.'])
+  cropped_shape = control_flow_ops.with_dependencies(
+      [rank_assertion],
+      tf.pack([crop_height, crop_width, original_shape[2]]))
+
+  size_assertion = tf.Assert(
+      tf.logical_and(
+          tf.greater_equal(original_shape[0], crop_height),
+          tf.greater_equal(original_shape[1], crop_width)),
+      ['Crop size greater than the image size.'])
+
+  offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
+
+  # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
+  # define the crop size.
+  image = control_flow_ops.with_dependencies(
+      [size_assertion],
+      tf.slice(image, offsets, cropped_shape))
+  return tf.reshape(image, cropped_shape)
+
+
+def _central_crop(image_list, crop_height, crop_width):
+  """Performs central crops of the given image list.
+
+  Args:
+    image_list: a list of image tensors of the same dimension but possibly
+      varying channel.
+    crop_height: the height of the image following the crop.
+    crop_width: the width of the image following the crop.
+
+  Returns:
+    the list of cropped images.
+  """
+  outputs = []
+  for image in image_list:
+    image_height = tf.shape(image)[0]
+    image_width = tf.shape(image)[1]
+
+    offset_height = (image_height - crop_height) / 2
+    offset_width = (image_width - crop_width) / 2
+
+    outputs.append(_crop(image, offset_height, offset_width,
+                         crop_height, crop_width))
+  return outputs
+
+
+def preprocess_image(image,
+                     height=_CROP_HEIGHT,
+                     width=_CROP_WIDTH,
+                     is_training=False,  # pylint: disable=unused-argument
+                     resize_side=_RESIZE_SIDE):
+  image = _aspect_preserving_resize(image, resize_side)
+  image = _central_crop([image], height, width)[0]
+  image.set_shape([height, width, 3])
+  image = tf.to_float(image)
+  image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
+  return image

+ 370 - 0
slim/models/vgg_preprocessing.py

@@ -0,0 +1,370 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Provides utilities to preprocess images.
+
+The preprocessing steps for VGG were introduced in the following technical
+report:
+
+  Very Deep Convolutional Networks For Large-Scale Image Recognition
+  Karen Simonyan and Andrew Zisserman
+  arXiv technical report, 2015
+  PDF: http://arxiv.org/pdf/1409.1556.pdf
+  ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
+  CC-BY-4.0
+
+More information can be obtained from the VGG website:
+www.robots.ox.ac.uk/~vgg/research/very_deep/
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+
+slim = tf.contrib.slim
+
+_R_MEAN = 123.68
+_G_MEAN = 116.78
+_B_MEAN = 103.94
+
+_RESIZE_SIDE_MIN = 256
+_RESIZE_SIDE_MAX = 512
+
+
+def _crop(image, offset_height, offset_width, crop_height, crop_width):
+  """Crops the given image using the provided offsets and sizes.
+
+  Note that the method doesn't assume we know the input image size but it does
+  assume we know the input image rank.
+
+  Args:
+    image: an image of shape [height, width, channels].
+    offset_height: a scalar tensor indicating the height offset.
+    offset_width: a scalar tensor indicating the width offset.
+    crop_height: the height of the cropped image.
+    crop_width: the width of the cropped image.
+
+  Returns:
+    the cropped (and resized) image.
+
+  Raises:
+    InvalidArgumentError: if the rank is not 3 or if the image dimensions are
+      less than the crop size.
+  """
+  original_shape = tf.shape(image)
+
+  rank_assertion = tf.Assert(
+      tf.equal(tf.rank(image), 3),
+      ['Rank of image must be equal to 3.'])
+  cropped_shape = control_flow_ops.with_dependencies(
+      [rank_assertion],
+      tf.pack([crop_height, crop_width, original_shape[2]]))
+
+  size_assertion = tf.Assert(
+      tf.logical_and(
+          tf.greater_equal(original_shape[0], crop_height),
+          tf.greater_equal(original_shape[1], crop_width)),
+      ['Crop size greater than the image size.'])
+
+  offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
+
+  # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
+  # define the crop size.
+  image = control_flow_ops.with_dependencies(
+      [size_assertion],
+      tf.slice(image, offsets, cropped_shape))
+  return tf.reshape(image, cropped_shape)
+
+
+def _random_crop(image_list, crop_height, crop_width):
+  """Crops the given list of images.
+
+  The function applies the same crop to each image in the list. This can be
+  effectively applied when there are multiple image inputs of the same
+  dimension such as:
+
+    image, depths, normals = _random_crop([image, depths, normals], 120, 150)
+
+  Args:
+    image_list: a list of image tensors of the same dimension but possibly
+      varying channel.
+    crop_height: the new height.
+    crop_width: the new width.
+
+  Returns:
+    the image_list with cropped images.
+
+  Raises:
+    ValueError: if there are multiple image inputs provided with different size
+      or the images are smaller than the crop dimensions.
+  """
+  if not image_list:
+    raise ValueError('Empty image_list.')
+
+  # Compute the rank assertions.
+  rank_assertions = []
+  for i in range(len(image_list)):
+    image_rank = tf.rank(image_list[i])
+    rank_assert = tf.Assert(
+        tf.equal(image_rank, 3),
+        ['Wrong rank for tensor  %s [expected] [actual]',
+         image_list[i].name, 3, image_rank])
+    rank_assertions.append(rank_assert)
+
+  image_shape = control_flow_ops.with_dependencies(
+      [rank_assertions[0]],
+      tf.shape(image_list[0]))
+  image_height = image_shape[0]
+  image_width = image_shape[1]
+  crop_size_assert = tf.Assert(
+      tf.logical_and(
+          tf.greater_equal(image_height, crop_height),
+          tf.greater_equal(image_width, crop_width)),
+      ['Crop size greater than the image size.'])
+
+  asserts = [rank_assertions[0], crop_size_assert]
+
+  for i in range(1, len(image_list)):
+    image = image_list[i]
+    asserts.append(rank_assertions[i])
+    shape = control_flow_ops.with_dependencies([rank_assertions[i]],
+                                               tf.shape(image))
+    height = shape[0]
+    width = shape[1]
+
+    height_assert = tf.Assert(
+        tf.equal(height, image_height),
+        ['Wrong height for tensor %s [expected][actual]',
+         image.name, height, image_height])
+    width_assert = tf.Assert(
+        tf.equal(width, image_width),
+        ['Wrong width for tensor %s [expected][actual]',
+         image.name, width, image_width])
+    asserts.extend([height_assert, width_assert])
+
+  # Create a random bounding box.
+  #
+  # Use tf.random_uniform and not numpy.random.rand as doing the former would
+  # generate random numbers at graph eval time, unlike the latter which
+  # generates random numbers at graph definition time.
+  max_offset_height = control_flow_ops.with_dependencies(
+      asserts, tf.reshape(image_height - crop_height + 1, []))
+  max_offset_width = control_flow_ops.with_dependencies(
+      asserts, tf.reshape(image_width - crop_width + 1, []))
+  offset_height = tf.random_uniform(
+      [], maxval=max_offset_height, dtype=tf.int32)
+  offset_width = tf.random_uniform(
+      [], maxval=max_offset_width, dtype=tf.int32)
+
+  return [_crop(image, offset_height, offset_width,
+                crop_height, crop_width) for image in image_list]
+
+
+def _central_crop(image_list, crop_height, crop_width):
+  """Performs central crops of the given image list.
+
+  Args:
+    image_list: a list of image tensors of the same dimension but possibly
+      varying channel.
+    crop_height: the height of the image following the crop.
+    crop_width: the width of the image following the crop.
+
+  Returns:
+    the list of cropped images.
+  """
+  outputs = []
+  for image in image_list:
+    image_height = tf.shape(image)[0]
+    image_width = tf.shape(image)[1]
+
+    offset_height = (image_height - crop_height) / 2
+    offset_width = (image_width - crop_width) / 2
+
+    outputs.append(_crop(image, offset_height, offset_width,
+                         crop_height, crop_width))
+  return outputs
+
+
+def _mean_image_subtraction(image, means):
+  """Subtracts the given means from each image channel.
+
+  For example:
+    means = [123.68, 116.779, 103.939]
+    image = _mean_image_subtraction(image, means)
+
+  Note that the rank of `image` must be known.
+
+  Args:
+    image: a tensor of size [height, width, C].
+    means: a C-vector of values to subtract from each channel.
+
+  Returns:
+    the centered image.
+
+  Raises:
+    ValueError: If the rank of `image` is unknown, if `image` has a rank other
+      than three or if the number of channels in `image` doesn't match the
+      number of values in `means`.
+  """
+  if image.get_shape().ndims != 3:
+    raise ValueError('Input must be of size [height, width, C>0]')
+  num_channels = image.get_shape().as_list()[-1]
+  if len(means) != num_channels:
+    raise ValueError('len(means) must match the number of channels')
+
+  channels = tf.split(2, num_channels, image)
+  for i in range(num_channels):
+    channels[i] -= means[i]
+  return tf.concat(2, channels)
+
+
+def _smallest_size_at_least(height, width, smallest_side):
+  """Computes new shape with the smallest side equal to `smallest_side`.
+
+  Computes new shape with the smallest side equal to `smallest_side` while
+  preserving the original aspect ratio.
+
+  Args:
+    height: an int32 scalar tensor indicating the current height.
+    width: an int32 scalar tensor indicating the current width.
+    smallest_side: A python integer or scalar `Tensor` indicating the size of
+      the smallest side after resize.
+
+  Returns:
+    new_height: an int32 scalar tensor indicating the new height.
+    new_width: and int32 scalar tensor indicating the new width.
+  """
+  smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32)
+
+  height = tf.to_float(height)
+  width = tf.to_float(width)
+  smallest_side = tf.to_float(smallest_side)
+
+  scale = tf.cond(tf.greater(height, width),
+                  lambda: smallest_side / width,
+                  lambda: smallest_side / height)
+  new_height = tf.to_int32(height * scale)
+  new_width = tf.to_int32(width * scale)
+  return new_height, new_width
+
+
+def _aspect_preserving_resize(image, smallest_side):
+  """Resize images preserving the original aspect ratio.
+
+  Args:
+    image: A 3-D image `Tensor`.
+    smallest_side: A python integer or scalar `Tensor` indicating the size of
+      the smallest side after resize.
+
+  Returns:
+    resized_image: A 3-D tensor containing the resized image.
+  """
+  smallest_side = tf.convert_to_tensor(smallest_side, dtype=tf.int32)
+
+  shape = tf.shape(image)
+  height = shape[0]
+  width = shape[1]
+  new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
+  image = tf.expand_dims(image, 0)
+  resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
+                                           align_corners=False)
+  resized_image = tf.squeeze(resized_image)
+  resized_image.set_shape([None, None, 3])
+  return resized_image
+
+
+def preprocess_for_train(image,
+                         output_height,
+                         output_width,
+                         resize_side_min=_RESIZE_SIDE_MIN,
+                         resize_side_max=_RESIZE_SIDE_MAX):
+  """Preprocesses the given image for training.
+
+  Note that the actual resizing scale is sampled from
+    [`resize_size_min`, `resize_size_max`].
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    resize_side_min: The lower bound for the smallest side of the image for
+      aspect-preserving resizing.
+    resize_side_max: The upper bound for the smallest side of the image for
+      aspect-preserving resizing.
+
+  Returns:
+    A preprocessed image.
+  """
+  resize_side = tf.random_uniform(
+      [], minval=resize_side_min, maxval=resize_side_max+1, dtype=tf.int32)
+
+  image = _aspect_preserving_resize(image, resize_side)
+  image = _random_crop([image], output_height, output_width)[0]
+  image.set_shape([output_height, output_width, 3])
+  image = tf.to_float(image)
+  image = tf.image.random_flip_left_right(image)
+  return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
+
+
+def preprocess_for_eval(image, output_height, output_width, resize_side):
+  """Preprocesses the given image for evaluation.
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    resize_side: The smallest side of the image for aspect-preserving resizing.
+
+  Returns:
+    A preprocessed image.
+  """
+  image = _aspect_preserving_resize(image, resize_side)
+  image = _central_crop([image], output_height, output_width)[0]
+  image.set_shape([output_height, output_width, 3])
+  image = tf.to_float(image)
+  return _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
+
+
+def preprocess_image(image, output_height, output_width, is_training=False,
+                     resize_side_min=_RESIZE_SIDE_MIN,
+                     resize_side_max=_RESIZE_SIDE_MAX):
+  """Preprocesses the given image.
+
+  Args:
+    image: A `Tensor` representing an image of arbitrary size.
+    output_height: The height of the image after preprocessing.
+    output_width: The width of the image after preprocessing.
+    is_training: `True` if we're preprocessing the image for training and
+      `False` otherwise.
+    resize_side_min: The lower bound for the smallest side of the image for
+      aspect-preserving resizing. If `is_training` is `False`, then this value
+      is used for rescaling.
+    resize_side_max: The upper bound for the smallest side of the image for
+      aspect-preserving resizing. If `is_training` is `False`, this value is
+      ignored. Otherwise, the resize side is sampled from
+        [resize_size_min, resize_size_max].
+
+  Returns:
+    A preprocessed image.
+  """
+  if is_training:
+    return preprocess_for_train(image, output_height, output_width,
+                                resize_side_min, resize_side_max)
+  else:
+    return preprocess_for_eval(image, output_height, output_width,
+                               resize_side_min)

+ 92 - 0
slim/nets/lenet.py

@@ -0,0 +1,92 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a variant of the LeNet model definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def lenet(images, num_classes=10, is_training=False,
+          dropout_keep_prob=0.5,
+          prediction_fn=slim.softmax,
+          scope='LeNet'):
+  """Creates a variant of the LeNet model.
+
+  Note that since the output is a set of 'logits', the values fall in the
+  interval of (-infinity, infinity). Consequently, to convert the outputs to a
+  probability distribution over the characters, one will need to convert them
+  using the softmax function:
+        logits = mnist.Mnist(images, is_training=False)
+        probabilities = tf.nn.softmax(logits)
+        predictions = tf.argmax(logits, 1)
+
+  Args:
+    images: A batch of `Tensors` of size [batch_size, height, width, channels].
+    num_classes: the number of classes in the dataset.
+    is_training: specifies whether or not we're currently training the model.
+      This variable will determine the behaviour of the dropout layer.
+    dropout_keep_prob: the percentage of activation values that are retained.
+    prediction_fn: a function to get predictions out of logits.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the pre-softmax activations, a tensor of size
+      [batch_size, `num_classes`]
+    end_points: a dictionary from components of the network to the corresponding
+      activation.
+  """
+  end_points = {}
+
+  with tf.variable_scope(scope, 'LeNet', [images, num_classes]):
+    net = slim.conv2d(images, 32, [5, 5], scope='conv1')
+    net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
+    net = slim.conv2d(net, 64, [5, 5], scope='conv2')
+    net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
+    net = slim.flatten(net)
+    end_points['Flatten'] = net
+
+    net = slim.fully_connected(net, 1024, scope='fc3')
+    net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                       scope='dropout3')
+    logits = slim.fully_connected(net, num_classes, activation_fn=None,
+                                  scope='fc4')
+
+  end_points['Logits'] = logits
+  end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
+
+  return logits, end_points
+lenet.default_image_size = 28
+
+
+def lenet_arg_scope(weight_decay=0.0):
+  """Defines the default lenet argument scope.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+
+  Returns:
+    An `arg_scope` to use for the inception v3 model.
+  """
+  with slim.arg_scope(
+      [slim.conv2d, slim.fully_connected],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
+      activation_fn=tf.nn.relu) as sc:
+    return sc

+ 43 - 0
slim/scripts/train_lenet_on_mnist.sh

@@ -0,0 +1,43 @@
+#!/bin/bash
+#
+# Before running this script, make sure you've followed the instructions for
+# downloading and converting the MNIST dataset.
+# See slim/datasets/download_and_convert_mnist.py.
+#
+# Usage:
+# ./slim/scripts/train_lenet_on_mnist.sh
+
+# Compile the training and evaluation binaries
+bazel build slim:train
+bazel build slim:eval
+
+# Where the checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/lenet-model
+
+# Where the dataset was saved to.
+DATASET_DIR=/tmp/mnist
+
+# Run training.
+./bazel-bin/slim/train \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=mnist \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=lenet \
+  --preprocessing_name=lenet \
+  --max_number_of_steps=20000 \
+  --learning_rate=0.01 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --optimizer=sgd \
+  --learning_rate_decay_factor=1.0
+  --weight_decay=0
+
+# Run evaluation.
+./blaze-bin/slim/eval \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=mnist \
+  --dataset_split_name=test \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=lenet

+ 540 - 0
slim/train.py

@@ -0,0 +1,540 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Generic training script that trains a given model a specified dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops
+from slim.datasets import dataset_factory
+from slim.models import model_deploy
+from slim.models import model_factory
+from slim.models import preprocessing_factory
+
+slim = tf.contrib.slim
+
+tf.app.flags.DEFINE_string(
+    'master', '', 'The address of the TensorFlow master to use.')
+
+tf.app.flags.DEFINE_string(
+    'train_dir', '/tmp/tfmodel/',
+    'Directory where checkpoints and event logs are written to.')
+
+tf.app.flags.DEFINE_integer('num_clones', 1,
+                            'Number of model clones to deploy.')
+
+tf.app.flags.DEFINE_boolean('clone_on_cpu', False,
+                            'Use CPUs to deploy clones.')
+
+tf.app.flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas.')
+
+tf.app.flags.DEFINE_integer(
+    'num_ps_tasks', 0,
+    'The number of parameter servers. If the value is 0, then the parameters '
+    'are handled locally by the worker.')
+
+tf.app.flags.DEFINE_integer(
+    'num_readers', 4,
+    'The number of parallel readers that read data from the dataset.')
+
+tf.app.flags.DEFINE_integer(
+    'num_preprocessing_threads', 4,
+    'The number of threads used to create the batches.')
+
+tf.app.flags.DEFINE_integer(
+    'log_every_n_steps', 5,
+    'The frequency with which logs are print.')
+
+tf.app.flags.DEFINE_integer(
+    'save_summaries_secs', 600,
+    'The frequency with which summaries are saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+    'save_interval_secs', 600,
+    'The frequency with which the model is saved, in seconds.')
+
+tf.app.flags.DEFINE_integer(
+    'task', 0, 'Task id of the replica running the training.')
+
+######################
+# Optimization Flags #
+######################
+
+tf.app.flags.DEFINE_float(
+    'weight_decay', 0.00004, 'The weight decay on the model weights.')
+
+tf.app.flags.DEFINE_string(
+    'optimizer', 'rmsprop',
+    'The name of the optimizer, one of "adadelta", "adagrad", "adam",'
+    '"ftrl", "momentum", "sgd" or "rmsprop".')
+
+tf.app.flags.DEFINE_float(
+    'adadelta_rho', 0.95,
+    'The decay rate for adadelta.')
+
+tf.app.flags.DEFINE_float(
+    'adagrad_initial_accumulator_value', 0.1,
+    'Starting value for the AdaGrad accumulators.')
+
+tf.app.flags.DEFINE_float(
+    'adam_beta1', 0.9,
+    'The exponential decay rate for the 1st moment estimates.')
+
+tf.app.flags.DEFINE_float(
+    'adam_beta2', 0.999,
+    'The exponential decay rate for the 2nd moment estimates.')
+
+tf.app.flags.DEFINE_float('opt_epsilon', 1.0, 'Epsilon term for the optimizer.')
+
+tf.app.flags.DEFINE_float('ftrl_learning_rate_power', -0.5,
+                          'The learning rate power.')
+
+tf.app.flags.DEFINE_float(
+    'ftrl_initial_accumulator_value', 0.1,
+    'Starting value for the FTRL accumulators.')
+
+tf.app.flags.DEFINE_float(
+    'ftrl_l1', 0.0, 'The FTRL l1 regularization strength.')
+
+tf.app.flags.DEFINE_float(
+    'ftrl_l2', 0.0, 'The FTRL l2 regularization strength.')
+
+tf.app.flags.DEFINE_float(
+    'momentum', 0.9,
+    'The momentum for the MomentumOptimizer and RMSPropOptimizer.')
+
+tf.app.flags.DEFINE_float('rmsprop_momentum', 0.9, 'Momentum.')
+
+tf.app.flags.DEFINE_float('rmsprop_decay', 0.9, 'Decay term for RMSProp.')
+
+#######################
+# Learning Rate Flags #
+#######################
+
+tf.app.flags.DEFINE_string(
+    'learning_rate_decay_type',
+    'exponential',
+    'Specifies how the learning rate is decayed. One of "fixed", "exponential",'
+    ' or "polynomial"')
+
+tf.app.flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
+
+tf.app.flags.DEFINE_float(
+    'end_learning_rate', 0.0001,
+    'The minimal end learning rate used by a polynomial decay learning rate.')
+
+tf.app.flags.DEFINE_float(
+    'label_smoothing', 0.0, 'The amount of label smoothing.')
+
+tf.app.flags.DEFINE_float(
+    'learning_rate_decay_factor', 0.94, 'Learning rate decay factor.')
+
+tf.app.flags.DEFINE_float(
+    'num_epochs_per_decay', 2.0,
+    'Number of epochs after which learning rate decays.')
+
+tf.app.flags.DEFINE_bool(
+    'sync_replicas', False,
+    'Whether or not to synchronize the replicas during training.')
+
+tf.app.flags.DEFINE_integer(
+    'replicas_to_aggregate', 1,
+    'The Number of gradients to collect before updating params.')
+
+tf.app.flags.DEFINE_float(
+    'moving_average_decay', None,
+    'The decay to use for the moving average.'
+    'If left as None, then moving averages are not used.')
+
+
+
+#######################
+# Dataset Flags #
+#######################
+
+tf.app.flags.DEFINE_string(
+    'dataset_name', 'imagenet', 'The name of the dataset to load.')
+
+tf.app.flags.DEFINE_string(
+    'dataset_split_name', 'train', 'The name of the train/test split.')
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir', None, 'The directory where the dataset files are stored.')
+
+tf.app.flags.DEFINE_integer(
+    'labels_offset', 0,
+    'An offset for the labels in the dataset. This flag is primarily used to '
+    'evaluate the VGG and ResNet architectures which do not use a background '
+    'class for the ImageNet dataset.')
+
+tf.app.flags.DEFINE_string(
+    'model_name', 'inception_v3', 'The name of the architecture to train.')
+
+tf.app.flags.DEFINE_string(
+    'preprocessing_name', None, 'The name of the preprocessing to use. If left '
+    'as `None`, then the model_name flag is used.')
+
+tf.app.flags.DEFINE_integer(
+    'batch_size', 32, 'The number of samples in each batch.')
+
+tf.app.flags.DEFINE_integer(
+    'train_image_size', None, 'Train image size')
+
+tf.app.flags.DEFINE_integer('max_number_of_steps', None,
+                            'The maximum number of training steps.')
+
+#####################
+# Fine-Tuning Flags #
+#####################
+
+tf.app.flags.DEFINE_string(
+    'checkpoint_path', None,
+    'The path to a checkpoint from which to fine-tune.')
+
+tf.app.flags.DEFINE_string(
+    'checkpoint_exclude_scopes', None,
+    'Comma-separated list of scopes to include when fine-tuning '
+    'from a checkpoint.')
+
+FLAGS = tf.app.flags.FLAGS
+
+
+def _configure_learning_rate(num_samples_per_epoch, global_step):
+  """Configures the learning rate.
+
+  Args:
+    num_samples_per_epoch: The number of samples in each epoch of training.
+    global_step: The global_step tensor.
+
+  Returns:
+    A `Tensor` representing the learning rate.
+
+  Raises:
+    ValueError: if
+  """
+  decay_steps = int(num_samples_per_epoch / FLAGS.batch_size *
+                    FLAGS.num_epochs_per_decay)
+  if FLAGS.sync_replicas:
+    decay_steps /= FLAGS.replicas_to_aggregate
+
+  if FLAGS.learning_rate_decay_type == 'exponential':
+    return tf.train.exponential_decay(FLAGS.learning_rate,
+                                      global_step,
+                                      decay_steps,
+                                      FLAGS.learning_rate_decay_factor,
+                                      staircase=True,
+                                      name='exponential_decay_learning_rate')
+  elif FLAGS.learning_rate_decay_type == 'fixed':
+    return tf.constant(FLAGS.learning_rate, name='fixed_learning_rate')
+  elif FLAGS.learning_rate_decay_type == 'polynomial':
+    return tf.train.polynomial_decay(FLAGS.learning_rate,
+                                     global_step,
+                                     decay_steps,
+                                     FLAGS.end_learning_rate,
+                                     power=1.0,
+                                     cycle=False,
+                                     name='polynomial_decay_learning_rate')
+  else:
+    raise ValueError('learning_rate_decay_type [%s] was not recognized',
+                     FLAGS.learning_rate_decay_type)
+
+
+def _configure_optimizer(learning_rate):
+  """Configures the optimizer used for training.
+
+  Args:
+    learning_rate: A scalar or `Tensor` learning rate.
+
+  Returns:
+    An instance of an optimizer.
+
+  Raises:
+    ValueError: if FLAGS.optimizer is not recognized.
+  """
+  if FLAGS.optimizer == 'adadelta':
+    optimizer = tf.train.AdadeltaOptimizer(
+        learning_rate,
+        rho=FLAGS.adadelta_rho,
+        epsilon=FLAGS.opt_epsilon)
+  elif FLAGS.optimizer == 'adagrad':
+    optimizer = tf.train.AdagradOptimizer(
+        learning_rate,
+        initial_accumulator_value=FLAGS.adagrad_initial_accumulator_value)
+  elif FLAGS.optimizer == 'adam':
+    optimizer = tf.train.AdamOptimizer(
+        learning_rate,
+        beta1=FLAGS.adam_beta1,
+        beta2=FLAGS.adam_beta2,
+        epsilon=FLAGS.opt_epsilon)
+  elif FLAGS.optimizer == 'ftrl':
+    optimizer = tf.train.FtrlOptimizer(
+        learning_rate,
+        learning_rate_power=FLAGS.ftrl_learning_rate_power,
+        initial_accumulator_value=FLAGS.ftrl_initial_accumulator_value,
+        l1_regularization_strength=FLAGS.ftrl_l1,
+        l2_regularization_strength=FLAGS.ftrl_l2)
+  elif FLAGS.optimizer == 'momentum':
+    optimizer = tf.train.MomentumOptimizer(
+        learning_rate,
+        momentum=FLAGS.momentum,
+        name='Momentum')
+  elif FLAGS.optimizer == 'rmsprop':
+    optimizer = tf.train.RMSPropOptimizer(
+        learning_rate,
+        decay=FLAGS.rmsprop_decay,
+        momentum=FLAGS.rmsprop_momentum,
+        epsilon=FLAGS.opt_epsilon)
+  elif FLAGS.optimizer == 'sgd':
+    optimizer = tf.train.GradientDescentOptimizer(learning_rate)
+  else:
+    raise ValueError('Optimizer [%s] was not recognized', FLAGS.optimizer)
+  return optimizer
+
+
+def _add_variables_summaries(learning_rate):
+  summaries = []
+  for variable in slim.get_model_variables():
+    summaries.append(tf.histogram_summary(variable.op.name, variable))
+  summaries.append(tf.scalar_summary('training/Learning Rate', learning_rate))
+  return summaries
+
+
+def _get_init_fn():
+  """Returns a function run by the chief worker to warm-start the training.
+
+  Note that the init_fn is only run when initializing the model during the very
+  first global step.
+
+  Returns:
+    An init function run by the supervisor.
+  """
+  if FLAGS.checkpoint_path is None:
+    return None
+
+  # Warn the user if a checkpoint exists in the train_dir. Then we'll be
+  # ignoring the checkpoint anyway.
+  if tf.train.latest_checkpoint(FLAGS.train_dir):
+    tf.logging.info(
+        'Ignoring --checkpoint_path because a checkpoint already exists in %s'
+        % FLAGS.train_dir)
+    return None
+
+  exclusions = []
+  if FLAGS.checkpoint_exclude_scopes:
+    exclusions = [scope.strip()
+                  for scope in FLAGS.checkpoint_exclude_scopes.split(',')]
+
+  # TODO(sguada) variables.filter_variables()
+  variables_to_restore = []
+  for var in slim.get_model_variables():
+    excluded = False
+    for exclusion in exclusions:
+      if var.op.name.startswith(exclusion):
+        excluded = True
+        break
+    if not excluded:
+      variables_to_restore.append(var)
+
+  return slim.assign_from_checkpoint_fn(
+      FLAGS.checkpoint_path,
+      variables_to_restore)
+
+
+def main(_):
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  with tf.Graph().as_default():
+    ######################
+    # Config model_deploy#
+    ######################
+    deploy_config = model_deploy.DeploymentConfig(
+        num_clones=FLAGS.num_clones,
+        clone_on_cpu=FLAGS.clone_on_cpu,
+        replica_id=FLAGS.task,
+        num_replicas=FLAGS.worker_replicas,
+        num_ps_tasks=FLAGS.num_ps_tasks)
+
+    # Create global_step
+    with tf.device(deploy_config.variables_device()):
+      global_step = slim.create_global_step()
+
+    ######################
+    # Select the dataset #
+    ######################
+    dataset = dataset_factory.get_dataset(
+        FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
+
+    ####################
+    # Select the model #
+    ####################
+    model_fn = model_factory.get_model(
+        FLAGS.model_name,
+        num_classes=(dataset.num_classes - FLAGS.labels_offset),
+        weight_decay=FLAGS.weight_decay,
+        is_training=True)
+
+    #####################################
+    # Select the preprocessing function #
+    #####################################
+    preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
+    image_preprocessing_fn = preprocessing_factory.get_preprocessing(
+        preprocessing_name,
+        is_training=True)
+
+    ##############################################################
+    # Create a dataset provider that loads data from the dataset #
+    ##############################################################
+    with tf.device(deploy_config.inputs_device()):
+      provider = slim.dataset_data_provider.DatasetDataProvider(
+          dataset,
+          num_readers=FLAGS.num_readers,
+          common_queue_capacity=20 * FLAGS.batch_size,
+          common_queue_min=10 * FLAGS.batch_size)
+      [image, label] = provider.get(['image', 'label'])
+      label -= FLAGS.labels_offset
+
+      if FLAGS.train_image_size is None:
+        train_image_size = model_fn.default_image_size
+      else:
+        train_image_size = FLAGS.train_image_size
+
+      image = image_preprocessing_fn(image, train_image_size, train_image_size)
+
+      images, labels = tf.train.batch(
+          [image, label],
+          batch_size=FLAGS.batch_size,
+          num_threads=FLAGS.num_preprocessing_threads,
+          capacity=5 * FLAGS.batch_size)
+      labels = slim.one_hot_encoding(
+          labels, dataset.num_classes - FLAGS.labels_offset)
+      batch_queue = slim.prefetch_queue.prefetch_queue(
+          [images, labels], capacity=2 * deploy_config.num_clones)
+
+    ####################
+    # Define the model #
+    ####################
+    def clone_fn(batch_queue):
+      """Allows data parallelism by creating multiple clones of the model_fn."""
+      images, labels = batch_queue.dequeue()
+      logits, end_points = model_fn(images)
+
+      #############################
+      # Specify the loss function #
+      #############################
+      if 'AuxLogits' in end_points:
+        slim.losses.softmax_cross_entropy(
+            end_points['AuxLogits'], labels,
+            label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
+      slim.losses.softmax_cross_entropy(
+          logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
+
+    # Gather initial summaries.
+    summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
+
+    clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
+    first_clone_scope = deploy_config.clone_scope(0)
+    # Gather update_ops from the first clone. These contain, for example,
+    # the updates for the batch_norm variables created by model_fn.
+    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
+
+    # Add summaries for losses.
+    for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
+      tf.scalar_summary('losses/%s' % loss.op.name, loss)
+
+    # Add summaries for variables.
+    for variable in slim.get_model_variables():
+      summaries.add(tf.histogram_summary(variable.op.name, variable))
+
+    #################################
+    # Configure the moving averages #
+    #################################
+    if FLAGS.moving_average_decay:
+      moving_average_variables = slim.get_model_variables()
+      variable_averages = tf.train.ExponentialMovingAverage(
+          FLAGS.moving_average_decay, global_step)
+    else:
+      moving_average_variables, variable_averages = None, None
+
+    #########################################
+    # Configure the optimization procedure. #
+    #########################################
+    with tf.device(deploy_config.optimizer_device()):
+      learning_rate = _configure_learning_rate(dataset.num_samples, global_step)
+      optimizer = _configure_optimizer(learning_rate)
+      summaries.add(tf.scalar_summary('learning_rate', learning_rate,
+                                      name='learning_rate'))
+
+    if FLAGS.sync_replicas:
+      # If sync_replicas is enabled, the averaging will be done in the chief
+      # queue runner.
+      optimizer = tf.train.SyncReplicasOptimizer(
+          opt=optimizer,
+          replicas_to_aggregate=FLAGS.replicas_to_aggregate,
+          variable_averages=variable_averages,
+          variables_to_average=moving_average_variables,
+          replica_id=tf.constant(FLAGS.task, tf.int32, shape=()),
+          total_num_replicas=FLAGS.worker_replicas)
+    elif FLAGS.moving_average_decay:
+      # Update ops executed locally by trainer.
+      update_ops.append(variable_averages.apply(moving_average_variables))
+
+    # TODO(sguada) Refactor into function that takes the clones and optimizer
+    #  and returns a train_tensor and summary_op
+    total_loss, clones_gradients = model_deploy.optimize_clones(clones,
+                                                                optimizer)
+    # Add total_loss to summary.
+    summaries.add(tf.scalar_summary('total_loss', total_loss,
+                                    name='total_loss'))
+
+    # Create gradient updates.
+    grad_updates = optimizer.apply_gradients(clones_gradients,
+                                             global_step=global_step)
+    update_ops.append(grad_updates)
+
+    update_op = tf.group(*update_ops)
+    train_tensor = control_flow_ops.with_dependencies([update_op], total_loss,
+                                                      name='train_op')
+
+    # Add the summaries from the first clone. These contain the summaries
+    # created by model_fn and either optimize_clones() or _gather_clone_loss().
+    summaries |= set(tf.get_collection(tf.GraphKeys.SUMMARIES,
+                                       first_clone_scope))
+
+    # Merge all summaries together.
+    summary_op = tf.merge_summary(list(summaries), name='summary_op')
+
+    ###########################
+    # Kicks off the training. #
+    ###########################
+    slim.learning.train(
+        train_tensor,
+        logdir=FLAGS.train_dir,
+        master=FLAGS.master,
+        is_chief=(FLAGS.task == 0),
+        init_fn=_get_init_fn(),
+        summary_op=summary_op,
+        number_of_steps=FLAGS.max_number_of_steps,
+        log_every_n_steps=FLAGS.log_every_n_steps,
+        save_summaries_secs=FLAGS.save_summaries_secs,
+        save_interval_secs=FLAGS.save_interval_secs,
+        sync_optimizer=optimizer if FLAGS.sync_replicas else None)
+
+
+if __name__ == '__main__':
+  tf.app.run()