Explorar o código

Full code refactor and added all networks

Nathan Silberman %!s(int64=9) %!d(string=hai) anos
pai
achega
65fad62dc6
Modificáronse 56 ficheiros con 8060 adicións e 818 borrados
  1. 224 34
      slim/BUILD
  2. 205 228
      slim/README.md
  3. 1 0
      slim/datasets/__init__.py
  4. 1 1
      slim/datasets/cifar10.py
  5. 4 4
      slim/datasets/dataset_factory.py
  6. 25 0
      slim/datasets/dataset_utils.py
  7. 26 31
      slim/datasets/download_and_convert_cifar10.py
  8. 37 52
      slim/datasets/download_and_convert_flowers.py
  9. 28 34
      slim/datasets/download_and_convert_mnist.py
  10. 1 1
      slim/datasets/flowers.py
  11. 67 2
      slim/datasets/imagenet.py
  12. 1 1
      slim/datasets/mnist.py
  13. 1 0
      slim/deployment/__init__.py
  14. 8 8
      slim/models/model_deploy.py
  15. 1 1
      slim/models/model_deploy_test.py
  16. 74 0
      slim/download_and_convert_data.py
  17. 21 23
      slim/eval.py
  18. 0 140
      slim/models/model_factory.py
  19. 0 200
      slim/models/resnet_preprocessing.py
  20. 1 0
      slim/nets/__init__.py
  21. 125 0
      slim/nets/alexnet.py
  22. 145 0
      slim/nets/alexnet_test.py
  23. 112 0
      slim/nets/cifarnet.py
  24. 33 0
      slim/nets/inception.py
  25. 280 0
      slim/nets/inception_resnet_v2.py
  26. 136 0
      slim/nets/inception_resnet_v2_test.py
  27. 340 0
      slim/nets/inception_v1.py
  28. 210 0
      slim/nets/inception_v1_test.py
  29. 545 0
      slim/nets/inception_v2.py
  30. 262 0
      slim/nets/inception_v2_test.py
  31. 587 0
      slim/nets/inception_v3.py
  32. 292 0
      slim/nets/inception_v3_test.py
  33. 2 1
      slim/nets/lenet.py
  34. 107 0
      slim/nets/nets_factory.py
  35. 46 0
      slim/nets/nets_factory_test.py
  36. 118 0
      slim/nets/overfeat.py
  37. 145 0
      slim/nets/overfeat_test.py
  38. 254 0
      slim/nets/resnet_utils.py
  39. 295 0
      slim/nets/resnet_v1.py
  40. 450 0
      slim/nets/resnet_v1_test.py
  41. 302 0
      slim/nets/resnet_v2.py
  42. 453 0
      slim/nets/resnet_v2_test.py
  43. 244 0
      slim/nets/vgg.py
  44. 455 0
      slim/nets/vgg_test.py
  45. 1 0
      slim/preprocessing/__init__.py
  46. 17 17
      slim/models/cifar10_preprocessing.py
  47. 0 0
      slim/preprocessing/inception_preprocessing.py
  48. 0 0
      slim/preprocessing/lenet_preprocessing.py
  49. 6 5
      slim/models/preprocessing_factory.py
  50. 0 0
      slim/preprocessing/vgg_preprocessing.py
  51. 89 0
      slim/scripts/finetune_inception_v1_on_flowers.sh
  52. 91 0
      slim/scripts/finetune_inception_v3_on_flowers.sh
  53. 49 0
      slim/scripts/train_cifarnet_on_cifar10.sh
  54. 16 11
      slim/scripts/train_lenet_on_mnist.sh
  55. 1058 0
      slim/slim_walkthough.ipynb
  56. 69 24
      slim/train.py

+ 224 - 34
slim/BUILD

@@ -1,5 +1,5 @@
 # Description:
-#   Contains files for loading, training and evaluating TF-Slim 2.0-based models.
+#   Contains files for loading, training and evaluating TF-Slim-based models.
 
 package(default_visibility = [":internal"])
 
@@ -7,35 +7,42 @@ licenses(["notice"])  # Apache 2.0
 
 exports_files(["LICENSE"])
 
-package_group(
-    name = "internal",
-    packages = ["//slim/"],
-)
+package_group(name = "internal")
 
 py_library(
     name = "dataset_utils",
     srcs = ["datasets/dataset_utils.py"],
 )
 
-py_binary(
+py_library(
     name = "download_and_convert_cifar10",
     srcs = ["datasets/download_and_convert_cifar10.py"],
     deps = [":dataset_utils"],
 )
 
-py_binary(
+py_library(
     name = "download_and_convert_flowers",
     srcs = ["datasets/download_and_convert_flowers.py"],
     deps = [":dataset_utils"],
 )
 
-py_binary(
+py_library(
     name = "download_and_convert_mnist",
     srcs = ["datasets/download_and_convert_mnist.py"],
     deps = [":dataset_utils"],
 )
 
 py_binary(
+    name = "download_and_convert_data",
+    srcs = ["download_and_convert_data.py"],
+    deps = [
+        ":download_and_convert_cifar10",
+        ":download_and_convert_flowers",
+        ":download_and_convert_mnist",
+    ],
+)
+
+py_binary(
     name = "cifar10",
     srcs = ["datasets/cifar10.py"],
     deps = [":dataset_utils"],
@@ -70,78 +77,261 @@ py_library(
     ],
 )
 
-py_binary(
-    name = "eval",
-    srcs = ["eval.py"],
-    deps = [
-        ":dataset_factory",
-        ":model_deploy",
-        ":model_factory",
-        ":preprocessing_factory",
-    ],
-)
-
 py_library(
     name = "model_deploy",
-    srcs = ["models/model_deploy.py"],
+    srcs = ["deployment/model_deploy.py"],
 )
 
 py_test(
     name = "model_deploy_test",
-    srcs = ["models/model_deploy_test.py"],
+    srcs = ["deployment/model_deploy_test.py"],
     srcs_version = "PY2AND3",
     deps = [":model_deploy"],
 )
 
 py_library(
-    name = "cifar10_preprocessing",
-    srcs = ["models/cifar10_preprocessing.py"],
+    name = "cifarnet_preprocessing",
+    srcs = ["preprocessing/cifarnet_preprocessing.py"],
 )
 
 py_library(
     name = "inception_preprocessing",
-    srcs = ["models/inception_preprocessing.py"],
+    srcs = ["preprocessing/inception_preprocessing.py"],
 )
 
 py_library(
     name = "lenet_preprocessing",
-    srcs = ["models/lenet_preprocessing.py"],
+    srcs = ["preprocessing/lenet_preprocessing.py"],
 )
 
 py_library(
     name = "vgg_preprocessing",
-    srcs = ["models/vgg_preprocessing.py"],
+    srcs = ["preprocessing/vgg_preprocessing.py"],
 )
 
 py_library(
     name = "preprocessing_factory",
-    srcs = ["models/preprocessing_factory.py"],
+    srcs = ["preprocessing/preprocessing_factory.py"],
     deps = [
-        ":cifar10_preprocessing",
+        ":cifarnet_preprocessing",
         ":inception_preprocessing",
         ":lenet_preprocessing",
         ":vgg_preprocessing",
     ],
 )
 
+# Typical networks definitions.
+
+py_library(
+    name = "nets",
+    deps = [
+        ":alexnet",
+        ":cifarnet",
+        ":inception",
+        ":lenet",
+        ":overfeat",
+        ":resnet_v1",
+        ":resnet_v2",
+        ":vgg",
+    ],
+)
+
+py_library(
+    name = "alexnet",
+    srcs = ["nets/alexnet.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "alexnet_test",
+    size = "medium",
+    srcs = ["nets/alexnet_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":alexnet"],
+)
+
+py_library(
+    name = "cifarnet",
+    srcs = ["nets/cifarnet.py"],
+)
+
+py_library(
+    name = "inception",
+    srcs = ["nets/inception.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":inception_resnet_v2",
+        ":inception_v1",
+        ":inception_v2",
+        ":inception_v3",
+    ],
+)
+
+py_library(
+    name = "inception_v1",
+    srcs = ["nets/inception_v1.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "inception_v2",
+    srcs = ["nets/inception_v2.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "inception_v3",
+    srcs = ["nets/inception_v3.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "inception_resnet_v2",
+    srcs = ["nets/inception_resnet_v2.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "inception_v1_test",
+    size = "large",
+    srcs = ["nets/inception_v1_test.py"],
+    shard_count = 3,
+    srcs_version = "PY2AND3",
+    deps = [":inception"],
+)
+
+py_test(
+    name = "inception_v2_test",
+    size = "large",
+    srcs = ["nets/inception_v2_test.py"],
+    shard_count = 3,
+    srcs_version = "PY2AND3",
+    deps = [":inception"],
+)
+
+py_test(
+    name = "inception_v3_test",
+    size = "large",
+    srcs = ["nets/inception_v3_test.py"],
+    shard_count = 3,
+    srcs_version = "PY2AND3",
+    deps = [":inception"],
+)
+
+py_test(
+    name = "inception_resnet_v2_test",
+    size = "large",
+    srcs = ["nets/inception_resnet_v2_test.py"],
+    shard_count = 3,
+    srcs_version = "PY2AND3",
+    deps = [":inception"],
+)
+
 py_library(
     name = "lenet",
     srcs = ["nets/lenet.py"],
 )
 
 py_library(
-    name = "model_factory",
-    srcs = ["models/model_factory.py"],
-    deps = [":lenet"],
+    name = "overfeat",
+    srcs = ["nets/overfeat.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "overfeat_test",
+    size = "medium",
+    srcs = ["nets/overfeat_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":overfeat"],
+)
+
+py_library(
+    name = "resnet_utils",
+    srcs = ["nets/resnet_utils.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
+    name = "resnet_v1",
+    srcs = ["nets/resnet_v1.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":resnet_utils",
+    ],
+)
+
+py_test(
+    name = "resnet_v1_test",
+    size = "medium",
+    srcs = ["nets/resnet_v1_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":resnet_v1"],
+)
+
+py_library(
+    name = "resnet_v2",
+    srcs = ["nets/resnet_v2.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":resnet_utils",
+    ],
+)
+
+py_test(
+    name = "resnet_v2_test",
+    size = "medium",
+    srcs = ["nets/resnet_v2_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":resnet_v2"],
+)
+
+py_library(
+    name = "vgg",
+    srcs = ["nets/vgg.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_test(
+    name = "vgg_test",
+    size = "medium",
+    srcs = ["nets/vgg_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":vgg"],
+)
+
+py_library(
+    name = "nets_factory",
+    srcs = ["nets/nets_factory.py"],
+    deps = [":nets"],
+)
+
+py_test(
+    name = "nets_factory_test",
+    size = "medium",
+    srcs = ["nets/nets_factory_test.py"],
+    srcs_version = "PY2AND3",
+    deps = [":nets_factory"],
+)
+
+py_binary(
+    name = "train_image_classifier",
+    srcs = ["train_image_classifier.py"],
+    deps = [
+        ":dataset_factory",
+        ":model_deploy",
+        ":nets_factory",
+        ":preprocessing_factory",
+    ],
 )
 
 py_binary(
-    name = "train",
-    srcs = ["train.py"],
+    name = "eval_image_classifier",
+    srcs = ["eval_image_classifier.py"],
     deps = [
         ":dataset_factory",
         ":model_deploy",
-        ":model_factory",
+        ":nets_factory",
         ":preprocessing_factory",
     ],
 )

+ 205 - 228
slim/README.md

@@ -1,119 +1,114 @@
-# Image Classification Models in TF-Slim
-
-This directory contains scripts for training and evaluating models using
-TF-Slim. In particular the code base provides core binaries for:
-
-* Training a model from scratch on a given dataset.
-* Fine-tuning a model from a particular checkpoint on a given dataset.
-* Evaluating a trained model on a given dataset.
-
-All scripts are highly configurable via command-line flags. They support
-training and evaluation using a variety of architectures and datasets.
-
-# Getting Started
-
-**NOTE** Before doing anything, we first need to build TensorFlow from the
-latest nightly build. You can find the latest nightly binaries at
+# TensorFlow-Slim image classification library
+
+[TF-slim](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/slim)
+is a new lightweight high-level API of TensorFlow (`tensorflow.contrib.slim`)
+for defining, training and evaluating complex
+models. This directory contains
+code for training and evaluating several widely used Convolutional Neural
+Network (CNN) image classification models using TF-slim.
+It contains scripts that will allow
+you to train models from scratch or fine-tune them from pre-trained network
+weights. It also contains code for downloading standard image datasets,
+converting them
+to TensorFlow's native TFRecord format and reading them in using TF-Slim's
+data reading and queueing utilities. You can easily train any model on any of
+these datasets, as we demonstrate below. We've also included a
+[jupyter notebook](https://github.com/tensorflow/models/tree/master/slim/slim_walkthrough.ipynb),
+which provides working examples of how to use TF-Slim for image classification.
+
+
+## Table of contents
+
+<a href="#Install">Installation and setup</a><br>
+<a href='#Data'>Preparing the datasets</a><br>
+<a href='#Pretained'>Using pre-trained models</a><br>
+<a href='#Training'>Training from scratch</a><br>
+<a href='#Tuning'>Fine tuning to a new task</a><br>
+<a href='#Eval'>Evaluating performance</a><br>
+
+# Installation
+<a id='Install'></a>
+
+In this section, we describe the steps required to install the appropriate
+prerequisite packages.
+
+## Installing latest version of TF-slim
+
+As of 8/28/16, the latest [stable release of TF](https://www.tensorflow.org/versions/r0.10/get_started/os_setup.html#pip-installation)
+is r0.10, which contains most of TF-Slim but not some later additions. To obtain the
+latest version, you must install the most recent nightly build of
+TensorFlow. You can find the latest nightly binaries at
 [TensorFlow Installation](https://github.com/tensorflow/tensorflow#installation)
-under the header that reads "People who are a little more adventurous can
-also try our nightly binaries". Next, copy the link address that corresponds to
-the appropriate machine architecture and python version. Finally, pip install
-(upgrade) using the appropriate file.
-
-For example:
+in the section that reads "People who are a little more adventurous can
+also try our nightly binaries". Copy the link address that corresponds to
+the appropriate machine architecture and python version, and pip install
+it. For example:
 
 ```shell
 export TF_BINARY_URL=https://ci.tensorflow.org/view/Nightly/job/nightly-matrix-cpu/TF_BUILD_CONTAINER_TYPE=CPU,TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tensorflow-0.10.0rc0-cp27-none-linux_x86_64.whl
-
 sudo pip install --upgrade $TF_BINARY_URL
 ```
 
-To compile the training and evaluation scripts, we also need to install bazel.
-You can find step-by-step instructions
-[here](http://bazel.io/docs/install.html).
-
-Next, you'll need to install
-[tensorflow/models/slim](https://github.com/tensorflow/models/tree/master/slim).
-If you want to use the ImageNet dataset, you'll also need to install
-[tensorflow/models/inception](https://github.com/tensorflow/models/tree/master/inception).
-Note that this directory contains an older version of slim which has been
-deprecated and can be safely ignored.
+To test this has worked, execute the following command; it should run
+without raising any errors.
 
-# Datasets
-
-As part of this library, we've included scripts to download several popular
-datasets and convert them to TensorFlow's native TFRecord format. Each labeled
-image is represented as a
-[TF-Example](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/core/example/example.proto)
-protocol buffer.
+```
+python -c "import tensorflow.contrib.slim as slim; eval = slim.evaluation.evaluate_once"
+```
 
-Dataset | Download Script | Dataset Specification | Description
-:------:|:---------------:|:---------------------:|:-----------
-[Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html)|[Script](https://github.com/tensorflow/models/blob/master/slim/datasets/download_and_convert_cifar10.py)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/cifar10.py)|The cifar10 dataset contains 60,000 training and 10,000 testing images of 10 different object classes.
-[Flowers](https://github.com/tensorflow/models/blob/master/inception/README.md)|[Script](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_flowers.sh)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/flowers.py)|The Flowers dataset contains 2500 images of flowers with 5 different labels.
-[MNIST](http://yann.lecun.com/exdb/mnist/)|[Script](https://github.com/tensorflow/models/blob/master/slim/datasets/download_and_convert_mnist.py)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/mnist.py)|The MNIST dataset contains 60,000 training 10,000 testing grayscale images of digits.
-[ImageNet](http://www.image-net.org/)|[Script](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_imagenet.sh)|[Code](https://github.com/tensorflow/models/blob/master/slim/datasets/imagenet.py)|The ImageNet dataset contains about 1.2 million training and 50,000 validation images with 1000 different labels.
+## Installing the TF-slim image models library
 
-Below we describe the python scripts which download these datasets and convert
-to TF Record format. Once in this format, the data can easily be read by
-TensorFlow by providing a TF-Slim
-[Dataset](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/contrib/slim/python/slim/data/dataset.py)
-specification. We have included, as a part of the release, the
-[Dataset](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/contrib/slim/python/slim/data/dataset.py)
-specifications for each of these datasets as well.
+To use TF-Slim for image classification, you also have to install
+the [TF-Slim image models library](https://github.com/tensorflow/models/tree/master/slim),
+which is not part of the core TF library.
+To do this, check out the
+[tensorflow/models](https://github.com/tensorflow/models/) repository as follows:
 
-## Preparing the Cifar10 Dataset
+```bash
+cd $HOME/workspace
+git clone https://github.com/tensorflow/models/
+```
 
-In order to use the Cifar10 dataset, the data must first be downloaded and
-converted to the native TFRecord format.
+This will put the TF-Slim image models library in `$HOME/workspace/models/slim`.
+(It will also create a directory called
+[models/inception](https://github.com/tensorflow/models/tree/master/inception),
+which contains an older version of slim; you can safely ignore this.)
 
-```shell
-# Specify the directory of the Cifar10 data:
-$ DATA_DIR=$HOME/cifar10
+To verify that this has worked, execute the following commands; it should run
+without raising any errors.
 
-# Build the dataset creation script.
-$ bazel build slim:download_and_convert_cifar10
-
-# Run the dataset creation.
-$ ./bazel-bin/slim/download_and_convert_cifar10 --dataset_dir="${DATA_DIR}"
 ```
-
-The final line of the output script should read:
-
-```shell
-Reading file [cifar-10-batches-py/test_batch], image 10000/10000
-Finished extracting the Cifar10 dataset!
+cd $HOME/workspace/slim
+python -c "from nets import cifarnet; mynet = cifarnet.cifarnet"
 ```
 
-When the script finishes you will find two TFRecord files created,
-`$DATA_DIR/cifar10_train.tfrecord` and `$DATA_DIR/cifar10_test.tfrecord`,
-which represent the training and testing sets respectively. You will also find
-a `$DATA_DIR/labels.txt` file which contains the mapping from integer labels
-to class names.
 
-## Preparing the Flowers Dataset
+# Preparing the datasets
 
-In order to use the Flowers dataset, the data must first be downloaded and
-converted to the native TFRecord format.
-
-```shell
-# Specify the directory of the Flowers data:
-$ DATA_DIR=$HOME/flowers
+As part of this library, we've included scripts to download several popular
+image datasets (listed below) and convert them to slim format.
 
-# Build the dataset creation script.
-$ bazel build slim:download_and_convert_flowers
+Dataset | Training Set Size | Testing Set Size | Number of Classes | Comments
+:------:|:---------------:|:---------------------:|:-----------:|:-----------:
+Flowers|2500 | 2500 | 5 | Various sizes (source: Flickr)
+[Cifar10](https://www.cs.toronto.edu/~kriz/cifar.html) | 60k| 10k | 10 |32x32 color
+[MNIST](http://yann.lecun.com/exdb/mnist/)| 60k | 10k | 10 | 28x28 gray
+[ImageNet](http://www.image-net.org/challenges/LSVRC/2012/)|1.2M| 50k | 1000 | Various sizes
 
-# Run the dataset creation.
-$ ./bazel-bin/slim/download_and_convert_flowers --dataset_dir="${DATA_DIR}"
-```
+## Downloading and converting to TFRecord format
 
-The final lines of the output script should read:
+For each dataset, we'll need to download the raw data and convert it to
+TensorFlow's native
+[TFRecord](https://www.tensorflow.org/versions/r0.10/api_docs/python/python_io.html#tfrecords-format-details)
+format. Each TFRecord contains a
+[TF-Example](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/core/example/example.proto)
+protocol buffer. Below we demonstrate how to do this for the Flowers dataset.
 
 ```shell
->> Converting image 3320/3320 shard 4
->> Converting image 350/350 shard 4
-
-Finished converting the Flowers dataset!
+$ DATA_DIR=/tmp/data/flowers
+$ python download_and_convert_data.py \
+    --dataset_name=flowers \
+    --dataset_dir="${DATA_DIR}"
 ```
 
 When the script finishes you will find several TFRecord files created:
@@ -121,14 +116,10 @@ When the script finishes you will find several TFRecord files created:
 ```shell
 $ ls ${DATA_DIR}
 flowers_train-00000-of-00005.tfrecord
-flowers_train-00001-of-00005.tfrecord
-flowers_train-00002-of-00005.tfrecord
-flowers_train-00003-of-00005.tfrecord
+...
 flowers_train-00004-of-00005.tfrecord
 flowers_validation-00000-of-00005.tfrecord
-flowers_validation-00001-of-00005.tfrecord
-flowers_validation-00002-of-00005.tfrecord
-flowers_validation-00003-of-00005.tfrecord
+...
 flowers_validation-00004-of-00005.tfrecord
 labels.txt
 ```
@@ -137,100 +128,108 @@ These represent the training and validation data, sharded over 5 files each.
 You will also find the `$DATA_DIR/labels.txt` file which contains the mapping
 from integer labels to class names.
 
-## Preparing the MNIST Dataset
+You can use the same script to create the mnist and cifar10 datasets.
+However, for ImageNet, you have to follow the instructions
+[here](https://github.com/tensorflow/models/blob/master/inception/README.md#getting-started).
+Note that you first have to sign up for an account at image-net.org.
+Also, the download can take several hours, and uses about 500MB.
 
-In order to use the MNIST dataset, the data must first be downloaded and
-converted to the native TFRecord format.
-
-```shell
-# Specify the directory of the MNIST data:
-$ DATA_DIR=$HOME/mnist
 
-# Build the dataset creation script.
-$ bazel build slim:download_and_convert_mnist
+## Creating a TF-Slim Dataset Descriptor.
 
-# Run the dataset creation.
-$ ./bazel-bin/slim/download_and_convert_mnist --dataset_dir="${DATA_DIR}"
-```
+Once the TFRecord files have been created, you can easily define a Slim
+[Dataset](https://github.com/tensorflow/tensorflow/blob/r0.10/tensorflow/contrib/slim/python/slim/data/dataset.py),
+which stores pointers to the data file, as well as various other pieces of
+metadata, such as the class labels, the train/test split, and how to parse the
+TFExample protos. We have included the TF-Slim Dataset descriptors
+for
+[Cifar10](https://github.com/tensorflow/models/blob/master/slim/datasets/cifar10.py),
+[ImageNet](https://github.com/tensorflow/models/blob/master/slim/datasets/imagenet.py),
+[Flowers](https://github.com/tensorflow/models/blob/master/slim/datasets/flowers.py),
+and
+[MNIST](https://github.com/tensorflow/models/blob/master/slim/datasets/mnist.py).
+An example of how to load data using a TF-Slim dataset descriptor using a
+TF-Slim
+[DatasetDataProvider](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/data/dataset_data_provider.py)
+is found below:
 
-The final line of the output script should read:
+```python
+import tensorflow as tf
+from datasets import flowers
 
-```shell
->> Converting image 10000/10000
-Finished extracting the MNIST dataset!
-```
+slim = tf.contrib.slim
 
-When the script finishes you will find two TFRecord files created,
-`$DATA_DIR/mnist_train.tfrecord` and `$DATA_DIR/mnist_test.tfrecord`,
-which represent the training and testing sets respectively.  You will also find
-a `$DATA_DIR/labels.txt` file which contains the mapping from integer labels
-to class names.
+# Selects the 'validation' dataset.
+dataset = flowers.get_split('validation', DATA_DIR)
 
-## Preparing the ImageNet Dataset
+# Creates a TF-Slim DataProvider which reads the dataset in the background
+# during both training and testing.
+provider = slim.dataset_data_provider.DatasetDataProvider(dataset)
+[image, label] = provider.get(['image', 'label'])
+```
 
-To use the ImageNet dataset, follow the instructions in the
-[tensorflow/models/inception](https://github.com/tensorflow/models/blob/master/inception/README.md#getting-started)
-repository. In particular see file
-[download_and_preprocess_imagenet.sh](https://github.com/tensorflow/models/blob/master/inception/inception/data/download_and_preprocess_imagenet.sh)
 
-## Pre-trained Models
+# Pre-trained Models
+<a id='Pretrained'></a>
 
-For convenience, we have provided a number of pre-trained image classification
-models which are listed below. These neural networks been trained on the
-ILSVRC-2012-CLS dataset which is comprised of ~1.2 million images and annotated
-with 1000 mutually exclusive class labels.
+Neural nets work best when they have many parameters, making them powerful
+function approximators.
+However, this  means they must be trained on very large datasets. Because
+training models from scratch can be a very computationally intensive process
+requiring days or even weeks, we provide various pre-trained models,
+as listed below. These CNNs have been trained on the
+[ILSVRC-2012-CLS](http://www.image-net.org/challenges/LSVRC/2012/)
+image classification dataset.
 
-In the table below, we present each of these models, the corresponding
-TensorFlow model file, the link to the model checkpoint and the top 1 and top 5
-accuracy.
+In the table below, we list each model, the corresponding
+TensorFlow model file, the link to the model checkpoint, and the top 1 and top 5
+accuracy (on the imagenet test set).
 Note that the VGG and ResNet parameters have been converted from their original
 caffe formats
 ([here](https://github.com/BVLC/caffe/wiki/Model-Zoo#models-used-by-the-vgg-team-in-ilsvrc-2014)
 and
-[here](https://github.com/KaimingHe/deep-residual-networks)), whereas the Inception parameters have been trained internally at
+[here](https://github.com/KaimingHe/deep-residual-networks)),
+whereas the Inception parameters have been trained internally at
 Google. Also be aware that these accuracies were computed by evaluating using a
 single image crop. Some academic papers report higher accuracy by using multiple
 crops at multiple scales.
 
 Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
 :----:|:------------:|:----------:|:-------:|:--------:|
-[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v1.py)|[inception_v1.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_23.tar.gz)|69.8|89.6|
-[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v2.py)|[inception_v2.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_23.tar.gz)|73.9|91.8|
-[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py)|[inception_v3.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_23.tar.gz)|78.0|93.9|
-[ResNet 50](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_50.tar.gz](http://download.tensorflow.org/models/resnet_v1_50_2016_08_23.tar.gz)|75.2|92.2|
-[ResNet 101](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_101.tar.gz](http://download.tensorflow.org/models/resnet_v1_101_2016_08_23.tar.gz)|76.4|92.9|
-[ResNet 152](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_152.tar.gz](http://download.tensorflow.org/models/resnet_v1_152_2016_08_23.tar.gz)|76.8|93.2|
-[VGG 16](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_16.tar.gz](http://download.tensorflow.org/models/vgg_16_2016_08_23.tar.gz)|71.5|89.8|
-[VGG 19](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_19.tar.gz](http://download.tensorflow.org/models/vgg_19_2016_08_23.tar.gz)|71.1|89.8|
-
+[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v1.py)|[inception_v1.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz)|69.8|89.6|
+[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v2.py)|[inception_v2.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz)|73.9|91.8|
+[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py)|[inception_v3.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)|78.0|93.9|
+[ResNet 50](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_50.tar.gz](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz)|75.2|92.2|
+[ResNet 101](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_101.tar.gz](http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz)|76.4|92.9|
+[ResNet 152](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_152.tar.gz](http://download.tensorflow.org/models/resnet_v1_152_2016_08_28.tar.gz)|76.8|93.2|
+[VGG 16](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_16.tar.gz](http://download.tensorflow.org/models/vgg_16_2016_08_28.tar.gz)|71.5|89.8|
+[VGG 19](http://arxiv.org/abs/1409.1556.pdf)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/vgg.py)|[vgg_19.tar.gz](http://download.tensorflow.org/models/vgg_19_2016_08_28.tar.gz)|71.1|89.8|
 
-# Training a model from scratch.
 
-**WARNING** Training a neural network network from scratch is a computationally
-intensive task and depending on your compute setup may take days, weeks or even
-months.
+Here is an example of how to download the Inception V3 checkpoint:
 
-The training script provided allows users to train one of several architecures
-using one of a variety of optimizers on one of several datasets. Each of these
-choices is configurable and datasets can be added by creating a
-[slim.Dataset](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/data/dataset.py)
-specification and using it in the place of one of those provided.
+```shell
+$ CHECKPOINT_DIR=/tmp/checkpoints
+$ mkdir ${CHECKPOINT_DIR}
+$ wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
+$ tar -xvf inception_v3_2016_08_28.tar.gz
+$ mv inception_v3.ckpt ${CHECKPOINT_DIR}
+$ rm inception_v3_2016_08_28.tar.gz
+```
 
-The following example demonstrates how to train Inception-V3 using SGD with
-Momentum on the ImageNet dataset.
 
-```shell
-# Specify the directory where the dataset is stored.
-DATASET_DIR=$HOME/imagenet
 
-# Specify the directory where the training logs are stored:
-TRAIN_DIR=$HOME/train_logs
+# Training a model from scratch.
+<a id='Training'></a>
 
-# Build the training script.
-$ bazel build slim/train
+We provide an easy way to train a model from scratch using any TF-Slim dataset.
+The following example demonstrates how to train Inception V3 using the default
+parameters on the ImageNet dataset.
 
-# run it
-$ bazel-bin/slim/train \
+```shell
+DATASET_DIR=/tmp/imagenet
+TRAIN_DIR=/tmp/train_logs
+python train_image_classifier.py \
     --train_dir=${TRAIN_DIR} \
     --dataset_name=imagenet \
     --dataset_split_name=train \
@@ -238,11 +237,18 @@ $ bazel-bin/slim/train \
     --model_name=inception_v3
 ```
 
+This process may take several days, depending on your hardware setup.
+For convenience, we provide a way to train a model on multiple GPUs,
+and/or multiple CPUs, either synchrononously or asynchronously.
+See [model_deploy](https://github.com/tensorflow/models/blob/master/slim/deployment/model_deploy.py)
+for details.
+
+
 # Fine-tuning a model from an existing checkpoint
+<a id='Tuning'></a>
 
 Rather than training from scratch, we'll often want to start from a pre-trained
 model and fine-tune it.
-
 To indicate a checkpoint from which to fine-tune, we'll call training with
 the `--checkpoint_path` flag and assign it an absolute path to a checkpoint
 file.
@@ -255,8 +261,8 @@ hinders certain variables from being loaded. When fine-tuning on a
 classification task using a different number of classes than the trained model,
 the new model will have a final 'logits' layer whose dimensions differ from the
 pre-trained model. For example, if fine-tuning an ImageNet-trained model on
-Cifar10, the pre-trained logits layer will have dimensions `[2048 x 1001]` but
-our new logits layer will have dimensions `[2048 x 10]`. Consequently, this
+Flowers, the pre-trained logits layer will have dimensions `[2048 x 1001]` but
+our new logits layer will have dimensions `[2048 x 5]`. Consequently, this
 flag indicates to TF-Slim to avoid loading these weights from the checkpoint.
 
 Keep in mind that warm-starting from a checkpoint affects the model's weights
@@ -265,76 +271,56 @@ a new checkpoint will be created in `${TRAIN_DIR}`. If the fine-tuning
 training is stopped and restarted, this new checkpoint will be the one from
 which weights are restored and not the `${checkpoint_path}$`. Consequently,
 the flags `--checkpoint_path` and `--checkpoint_exclude_scopes` are only used
-during the `0-`th global step (model initialization).
+during the `0-`th global step (model initialization). Typically for fine-tuning
+one only want train a sub-set of layers, so the flag `--trainable_scopes` allows
+to specify which subsets of layers should trained, the rest would remain frozen.
 
-```shell
-# Specify the directory where the dataset is stored.
-$ DATASET_DIR=$HOME/imagenet
+Below we give an example of
+[fine-tuning inception-v3 on flowers](https://github.com/tensorflow/models/blob/master/slim/scripts/finetune_inception_v3_on_flowers.sh),
+inception_v3  was trained on ImageNet with 1000 class labels, but the flowers
+dataset only have 5 classes. Since the dataset is quite small we will only train
+the new layers.
 
-# Specify the directory where the training logs are stored:
-$ TRAIN_DIR=$HOME/train_logs
 
-# Specify the directory where the pre-trained model checkpoint was saved to:
-$ CHECKPOINT_PATH=$HOME/my_checkpoints/inception_v3.ckpt
-
-# Build the training script.
-$ bazel build slim/train
-
-# Run training. Use --checkpoint_exclude_scopes to avoid loading the weights
-# associated with the logits and auxiliary logits fully connected layers.
-$ bazel-bin/slim/train \
+```shell
+$ DATASET_DIR=/tmp/flowers
+$ TRAIN_DIR=/tmp/flowers-models/inception_v3
+$ CHECKPOINT_PATH=/tmp/my_checkpoints/inception_v3.ckpt
+$ python train_image_classifier.py \
     --train_dir=${TRAIN_DIR} \
     --dataset_dir=${DATASET_DIR} \
-    --dataset_name=cifar10 \
+    --dataset_name=flowers \
     --dataset_split_name=train \
     --model_name=inception_v3 \
     --checkpoint_path=${CHECKPOINT_PATH} \
-    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits
+    --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits \
+    --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits/Logits
 ```
 
 
-## Evaluating the provided Checkpoints:
-
-To evaluate the checkpoints provided with this release, one need only download
-the checkpoints and run the evaluation script.
 
-Note that the provided checkpoints contain the model's weights only. They do
-not contain variables associated with training, such as weight's moving averages
-or the global step. Consequently, when evaluating one of the pre-trained
-checkpoint files, one must specify the flag `--restore_global_step=False` to
-indicate to the evaluation routine to avoid attempting to load a global step
-from the checkpoint file that doesn't contain one.
-
-```shell
-# Specify and create the directory containing the checkpoints:
-$ CHECKPOINT_DIR=/tmp/checkpoints
-$ mkdir ${CHECKPOINT_DIR}
+# Evaluating performance of a model
+<a id='Eval'></a>
 
-# Download, extract and copy the checkpoint file over:
-$ wget http://download.tensorflow.org/models/inception_v1_2016_08_23.tar.gz
-$ tar -xvf inception_v1_2016_08_23.tar.gz
-$ mv inception_v1.ckpt ${CHECKPOINT_DIR}
-$ rm inception_v1_2016_08_23.tar.gz
+To evaluate the performance of a model (whether pretrained or your own),
+you can use the eval_image_classifier.py script, as shown below.
 
-# Specify the directory where the dataset is stored.
-$ DATASET_DIR=$HOME/imagenet
+Below we give an example of downloading the pretrained inception model and
+evaluating it on the imagenet dataset.
 
-# Compile the evaluation script:
-$ bazel build slim/eval
-
-# Run the evaluation script. Note that since the pre-trained checkpoints
-# provided do not contain a global step, we need to instruct the evaluation
-# routine not to attempt to load the global step.
-$ ./bazel-bin/slim/eval \
+```shell
+CHECKPOINT_FILE = ${CHECKPOINT_DIR}/inception_v3.ckpt  # Example
+$ python eval_image_classifier.py \
     --alsologtostderr \
-    --checkpoint_path=${CHECKPOINT_DIR}/inception_v1.ckpt \
+    --checkpoint_path=${CHECKPOINT_FILE} \
     --dataset_dir=${DATASET_DIR} \
     --dataset_name=imagenet \
     --dataset_split_name=validation \
-    --model_name=inception_v1 \
-    --restore_global_step=False
+    --model_name=inception_v3
 ```
 
+
+
 # Troubleshooting
 
 #### The model runs out of CPU memory.
@@ -354,9 +340,10 @@ See
 
 #### The ResNet and VGG Models have 1000 classes but the ImageNet dataset has 1001
 
-The ImageNet dataset provied has an additional background class which was used
-to help train Inception. If you try training or fine-tuning the VGG or ResNet
-models using the ImageNet dataset, you might encounter the following error:
+The ImageNet dataset provied has an empty background class which was can be used
+to fine-tune the model to other tasks. If you try training or fine-tuning the
+VGG or ResNet models using the ImageNet dataset, you might encounter the
+following error:
 
 ```bash
 InvalidArgumentError: Assign requires shapes of both tensors to match. lhs shape= [1001] rhs shape= [1000]
@@ -367,16 +354,6 @@ outputs rather than 1001.
 To fix this issue, you can set the `--labels_offsets=1` flag. This results in
 the ImageNet labels being shifted down by one:
 
-```bash
-./bazel-bin/slim/train \
-  --train_dir=${TRAIN_DIR} \
-  --dataset_dir=${DATASET_DIR} \
-  --dataset_name=imagenet \
-  --dataset_split_name=train \
-  --model_name=resnet_v1_50 \
-  --checkpoint_path=${CHECKPOINT_PATH}
-  --labels_offset=1
-```
 
 #### I wish to train a model with a different image size.
 

+ 1 - 0
slim/datasets/__init__.py

@@ -0,0 +1 @@
+

+ 1 - 1
slim/datasets/cifar10.py

@@ -25,7 +25,7 @@ from __future__ import print_function
 import os
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
+from datasets import dataset_utils
 
 slim = tf.contrib.slim
 

+ 4 - 4
slim/datasets/dataset_factory.py

@@ -18,10 +18,10 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-from slim.datasets import cifar10
-from slim.datasets import flowers
-from slim.datasets import imagenet
-from slim.datasets import mnist
+from datasets import cifar10
+from datasets import flowers
+from datasets import imagenet
+from datasets import mnist
 
 datasets_map = {
     'cifar10': cifar10,

+ 25 - 0
slim/datasets/dataset_utils.py

@@ -18,6 +18,10 @@ from __future__ import division
 from __future__ import print_function
 
 import os
+import sys
+import tarfile
+
+from six.moves import urllib
 import tensorflow as tf
 
 LABELS_FILENAME = 'labels.txt'
@@ -59,6 +63,27 @@ def image_to_tfexample(image_data, image_format, height, width, class_id):
   }))
 
 
+def download_and_uncompress_tarball(tarball_url, dataset_dir):
+  """Downloads the `tarball_url` and uncompresses it locally.
+
+  Args:
+    tarball_url: The URL of a tarball file.
+    dataset_dir: The directory where the temporary files are stored.
+  """
+  filename = tarball_url.split('/')[-1]
+  filepath = os.path.join(dataset_dir, filename)
+
+  def _progress(count, block_size, total_size):
+    sys.stdout.write('\r>> Downloading %s %.1f%%' % (
+        filename, float(count * block_size) / float(total_size) * 100.0))
+    sys.stdout.flush()
+  filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
+  print()
+  statinfo = os.stat(filepath)
+  print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+  tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
+
+
 def write_label_file(labels_to_class_names, dataset_dir,
                      filename=LABELS_FILENAME):
   """Writes a file with the list of class names.

+ 26 - 31
slim/datasets/download_and_convert_cifar10.py

@@ -14,16 +14,13 @@
 # ==============================================================================
 r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos.
 
-This script downloads the cifar10 data, uncompresses it, reads the files
+This module downloads the cifar10 data, uncompresses it, reads the files
 that make up the cifar10 data and creates two TFRecord datasets: one for train
 and one for test. Each TFRecord dataset is comprised of a set of TF-Example
 protocol buffers, each of which contain a single image and label.
 
 The script should take several minutes to run.
 
-Usage:
-$ bazel build slim:download_and_convert_cifar10
-$ .bazel-bin/slim/download_and_convert_cifar10 --dataset_dir=[DIRECTORY]
 """
 from __future__ import absolute_import
 from __future__ import division
@@ -38,14 +35,7 @@ import numpy as np
 from six.moves import urllib
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
-
-tf.app.flags.DEFINE_string(
-    'dataset_dir',
-    None,
-    'The directory where the output TFRecords and temporary files are saved.')
-
-FLAGS = tf.app.flags.FLAGS
+from datasets import dataset_utils
 
 # The URL where the CIFAR data can be downloaded.
 _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
@@ -115,16 +105,17 @@ def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
   return offset + num_images
 
 
-def _get_output_filename(split_name):
+def _get_output_filename(dataset_dir, split_name):
   """Creates the output filename.
 
   Args:
+    dataset_dir: The dataset directory where the dataset is stored.
     split_name: The name of the train/test split.
 
   Returns:
     An absolute file path.
   """
-  return '%s/cifar10_%s.tfrecord' % (FLAGS.dataset_dir, split_name)
+  return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name)
 
 
 def _download_and_uncompress_dataset(dataset_dir):
@@ -162,39 +153,43 @@ def _clean_up_temporary_files(dataset_dir):
   tf.gfile.DeleteRecursively(tmp_dir)
 
 
-def main(_):
-  if not FLAGS.dataset_dir:
-    raise ValueError('You must supply the dataset directory with --dataset_dir')
+def run(dataset_dir):
+  """Runs the download and conversion operation.
 
-  if not tf.gfile.Exists(FLAGS.dataset_dir):
-    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+  Args:
+    dataset_dir: The dataset directory where the dataset is stored.
+  """
+  if not tf.gfile.Exists(dataset_dir):
+    tf.gfile.MakeDirs(dataset_dir)
 
-  _download_and_uncompress_dataset(FLAGS.dataset_dir)
+  training_filename = _get_output_filename(dataset_dir, 'train')
+  testing_filename = _get_output_filename(dataset_dir, 'test')
+
+  if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
+    print('Dataset files already exist. Exiting without re-creating them.')
+    return
+
+  dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
 
   # First, process the training data:
-  output_file = _get_output_filename('train')
-  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
+  with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
     offset = 0
     for i in range(_NUM_TRAIN_FILES):
-      filename = os.path.join(FLAGS.dataset_dir,
+      filename = os.path.join(dataset_dir,
                               'cifar-10-batches-py',
                               'data_batch_%d' % (i + 1))  # 1-indexed.
       offset = _add_to_tfrecord(filename, tfrecord_writer, offset)
 
   # Next, process the testing data:
-  output_file = _get_output_filename('test')
-  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
-    filename = os.path.join(FLAGS.dataset_dir,
+  with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
+    filename = os.path.join(dataset_dir,
                             'cifar-10-batches-py',
                             'test_batch')
     _add_to_tfrecord(filename, tfrecord_writer)
 
   # Finally, write the labels file:
   labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
-  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
 
-  _clean_up_temporary_files(FLAGS.dataset_dir)
+  _clean_up_temporary_files(dataset_dir)
   print('\nFinished converting the Cifar10 dataset!')
-
-if __name__ == '__main__':
-  tf.app.run()

+ 37 - 52
slim/datasets/download_and_convert_flowers.py

@@ -14,17 +14,13 @@
 # ==============================================================================
 r"""Downloads and converts Flowers data to TFRecords of TF-Example protos.
 
-This script downloads the Flowers data, uncompresses it, reads the files
+This module downloads the Flowers data, uncompresses it, reads the files
 that make up the Flowers data and creates two TFRecord datasets: one for train
 and one for test. Each TFRecord dataset is comprised of a set of TF-Example
 protocol buffers, each of which contain a single image and label.
 
 The script should take about a minute to run.
 
-Usage:
-
-$ bazel build slim:download_and_convert_flowers
-$ .bazel-bin/slim/download_and_convert_flowers --dataset_dir=[DIRECTORY]
 """
 
 from __future__ import absolute_import
@@ -35,19 +31,10 @@ import math
 import os
 import random
 import sys
-import tarfile
 
-from six.moves import urllib
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
-
-tf.app.flags.DEFINE_string(
-    'dataset_dir',
-    None,
-    'The directory where the output TFRecords and temporary files are saved.')
-
-FLAGS = tf.app.flags.FLAGS
+from datasets import dataset_utils
 
 # The URL where the Flowers data can be downloaded.
 _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
@@ -82,27 +69,6 @@ class ImageReader(object):
     return image
 
 
-def _download_dataset(dataset_dir):
-  """Downloads the flowers data and uncompresses it locally.
-
-  Args:
-    dataset_dir: The directory where the temporary files are stored.
-  """
-  filename = _DATA_URL.split('/')[-1]
-  filepath = os.path.join(dataset_dir, filename)
-
-  if not os.path.exists(filepath):
-    def _progress(count, block_size, total_size):
-      sys.stdout.write('\r>> Downloading %s %.1f%%' % (
-          filename, float(count * block_size) / float(total_size) * 100.0))
-      sys.stdout.flush()
-    filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress)
-    print()
-    statinfo = os.stat(filepath)
-    print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
-    tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
-
-
 def _get_filenames_and_classes(dataset_dir):
   """Returns a list of filenames and inferred class names.
 
@@ -132,6 +98,12 @@ def _get_filenames_and_classes(dataset_dir):
   return photo_filenames, sorted(class_names)
 
 
+def _get_dataset_filename(dataset_dir, split_name, shard_id):
+  output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
+      split_name, shard_id, _NUM_SHARDS)
+  return os.path.join(dataset_dir, output_filename)
+
+
 def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
   """Converts the given filenames to a TFRecord dataset.
 
@@ -152,9 +124,8 @@ def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
     with tf.Session('') as sess:
 
       for shard_id in range(_NUM_SHARDS):
-        output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
-            split_name, shard_id, _NUM_SHARDS)
-        output_filename = os.path.join(dataset_dir, output_filename)
+        output_filename = _get_dataset_filename(
+            dataset_dir, split_name, shard_id)
 
         with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
           start_ndx = shard_id * num_per_shard
@@ -193,15 +164,31 @@ def _clean_up_temporary_files(dataset_dir):
   tf.gfile.DeleteRecursively(tmp_dir)
 
 
-def main(_):
-  if not FLAGS.dataset_dir:
-    raise ValueError('You must supply the dataset directory with --dataset_dir')
+def _dataset_exists(dataset_dir):
+  for split_name in ['train', 'validation']:
+    for shard_id in range(_NUM_SHARDS):
+      output_filename = _get_dataset_filename(
+          dataset_dir, split_name, shard_id)
+      if not tf.gfile.Exists(output_filename):
+        return False
+  return True
+
+
+def run(dataset_dir):
+  """Runs the download and conversion operation.
+
+  Args:
+    dataset_dir: The dataset directory where the dataset is stored.
+  """
+  if not tf.gfile.Exists(dataset_dir):
+    tf.gfile.MakeDirs(dataset_dir)
 
-  if not tf.gfile.Exists(FLAGS.dataset_dir):
-    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+  if _dataset_exists(dataset_dir):
+    print('Dataset files already exist. Exiting without re-creating them.')
+    return
 
-  _download_dataset(FLAGS.dataset_dir)
-  photo_filenames, class_names = _get_filenames_and_classes(FLAGS.dataset_dir)
+  dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
+  photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
   class_names_to_ids = dict(zip(class_names, range(len(class_names))))
 
   # Divide into train and test:
@@ -212,16 +199,14 @@ def main(_):
 
   # First, convert the training and validation sets.
   _convert_dataset('train', training_filenames, class_names_to_ids,
-                   FLAGS.dataset_dir)
+                   dataset_dir)
   _convert_dataset('validation', validation_filenames, class_names_to_ids,
-                   FLAGS.dataset_dir)
+                   dataset_dir)
 
   # Finally, write the labels file:
   labels_to_class_names = dict(zip(range(len(class_names)), class_names))
-  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
 
-  _clean_up_temporary_files(FLAGS.dataset_dir)
+  _clean_up_temporary_files(dataset_dir)
   print('\nFinished converting the Flowers dataset!')
 
-if __name__ == '__main__':
-  tf.app.run()

+ 28 - 34
slim/datasets/download_and_convert_mnist.py

@@ -14,17 +14,13 @@
 # ==============================================================================
 r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
 
-This script downloads the MNIST data, uncompresses it, reads the files
+This module downloads the MNIST data, uncompresses it, reads the files
 that make up the MNIST data and creates two TFRecord datasets: one for train
 and one for test. Each TFRecord dataset is comprised of a set of TF-Example
 protocol buffers, each of which contain a single image and label.
 
 The script should take about a minute to run.
 
-Usage:
-
-$ bazel build slim:download_and_convert_mnist
-$ .bazel-bin/slim/download_and_convert_mnist --dataset_dir=[DIRECTORY]
 """
 from __future__ import absolute_import
 from __future__ import division
@@ -38,14 +34,7 @@ import numpy as np
 from six.moves import urllib
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
-
-tf.app.flags.DEFINE_string(
-    'dataset_dir',
-    None,
-    'The directory where the output TFRecords and temporary files are saved.')
-
-FLAGS = tf.app.flags.FLAGS
+from datasets import dataset_utils
 
 # The URLs where the MNIST data can be downloaded.
 _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
@@ -140,16 +129,17 @@ def _add_to_tfrecord(data_filename, labels_filename, num_images,
         tfrecord_writer.write(example.SerializeToString())
 
 
-def _get_output_filename(split_name):
+def _get_output_filename(dataset_dir, split_name):
   """Creates the output filename.
 
   Args:
+    dataset_dir: The directory where the temporary files are stored.
     split_name: The name of the train/test split.
 
   Returns:
     An absolute file path.
   """
-  return '%s/mnist_%s.tfrecord' % (FLAGS.dataset_dir, split_name)
+  return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)
 
 
 def _download_dataset(dataset_dir):
@@ -193,35 +183,39 @@ def _clean_up_temporary_files(dataset_dir):
     tf.gfile.Remove(filepath)
 
 
-def main(_):
-  if not FLAGS.dataset_dir:
-    raise ValueError('You must supply the dataset directory with --dataset_dir')
+def run(dataset_dir):
+  """Runs the download and conversion operation.
 
-  if not tf.gfile.Exists(FLAGS.dataset_dir):
-    tf.gfile.MakeDirs(FLAGS.dataset_dir)
+  Args:
+    dataset_dir: The dataset directory where the dataset is stored.
+  """
+  if not tf.gfile.Exists(dataset_dir):
+    tf.gfile.MakeDirs(dataset_dir)
+
+  training_filename = _get_output_filename(dataset_dir, 'train')
+  testing_filename = _get_output_filename(dataset_dir, 'test')
 
-  _download_dataset(FLAGS.dataset_dir)
+  if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
+    print('Dataset files already exist. Exiting without re-creating them.')
+    return
+
+  _download_dataset(dataset_dir)
 
   # First, process the training data:
-  output_file = _get_output_filename('train')
-  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
-    data_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_DATA_FILENAME)
-    labels_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_LABELS_FILENAME)
+  with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
+    data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
+    labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
     _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
 
   # Next, process the testing data:
-  output_file = _get_output_filename('test')
-  with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
-    data_filename = os.path.join(FLAGS.dataset_dir, _TEST_DATA_FILENAME)
-    labels_filename = os.path.join(FLAGS.dataset_dir, _TEST_LABELS_FILENAME)
+  with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
+    data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
+    labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
     _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
 
   # Finally, write the labels file:
   labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
-  dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
+  dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
 
-  _clean_up_temporary_files(FLAGS.dataset_dir)
+  _clean_up_temporary_files(dataset_dir)
   print('\nFinished converting the MNIST dataset!')
-
-if __name__ == '__main__':
-  tf.app.run()

+ 1 - 1
slim/datasets/flowers.py

@@ -25,7 +25,7 @@ from __future__ import print_function
 import os
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
+from datasets import dataset_utils
 
 slim = tf.contrib.slim
 

+ 67 - 2
slim/datasets/imagenet.py

@@ -33,8 +33,11 @@ from __future__ import division
 from __future__ import print_function
 
 import os
+from six.moves import urllib
 import tensorflow as tf
 
+from datasets import dataset_utils
+
 slim = tf.contrib.slim
 
 # TODO(nsilberman): Add tfrecord file type once the script is updated.
@@ -55,7 +58,61 @@ _ITEMS_TO_DESCRIPTIONS = {
 
 _NUM_CLASSES = 1001
 
-# TODO(nsilberman): Add _LABELS_TO_NAMES
+
+def create_readable_names_for_imagenet_labels():
+  """Create a dict mapping label id to human readable string.
+
+  Returns:
+      labels_to_names: dictionary where keys are integers from to 1000
+      and values are human-readable names.
+
+  We retrieve a synset file, which contains a list of valid synset labels used
+  by ILSVRC competition. There is one synset one per line, eg.
+          #   n01440764
+          #   n01443537
+  We also retrieve a synset_to_human_file, which contains a mapping from synsets
+  to human-readable names for every synset in Imagenet. These are stored in a
+  tsv format, as follows:
+          #   n02119247    black fox
+          #   n02119359    silver fox
+  We assign each synset (in alphabetical order) an integer, starting from 1
+  (since 0 is reserved for the background class).
+
+  Code is based on
+  https://github.com/tensorflow/models/blob/master/inception/inception/data/build_imagenet_data.py#L463
+  """
+
+  # pylint: disable=g-line-too-long
+  base_url = 'https://raw.githubusercontent.com/tensorflow/models/master/inception/inception/data/'
+  synset_url = '{}/imagenet_lsvrc_2015_synsets.txt'.format(base_url)
+  synset_to_human_url = '{}/imagenet_metadata.txt'.format(base_url)
+
+  filename, _ = urllib.request.urlretrieve(synset_url)
+  synset_list = [s.strip() for s in open(filename).readlines()]
+  num_synsets_in_ilsvrc = len(synset_list)
+  assert num_synsets_in_ilsvrc == 1000
+
+  filename, _ = urllib.request.urlretrieve(synset_to_human_url)
+  synset_to_human_list = open(filename).readlines()
+  num_synsets_in_all_imagenet = len(synset_to_human_list)
+  assert num_synsets_in_all_imagenet == 21842
+
+  synset_to_human = {}
+  for s in synset_to_human_list:
+    parts = s.strip().split('\t')
+    assert len(parts) == 2
+    synset = parts[0]
+    human = parts[1]
+    synset_to_human[synset] = human
+
+  label_index = 1
+  labels_to_names = {0: 'background'}
+  for synset in synset_list:
+    name = synset_to_human[synset]
+    labels_to_names[label_index] = name
+    label_index += 1
+
+  return labels_to_names
 
 
 def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
@@ -119,10 +176,18 @@ def get_split(split_name, dataset_dir, file_pattern=None, reader=None):
   decoder = slim.tfexample_decoder.TFExampleDecoder(
       keys_to_features, items_to_handlers)
 
+  labels_to_names = None
+  if dataset_utils.has_labels(dataset_dir):
+    labels_to_names = dataset_utils.read_label_file(dataset_dir)
+  else:
+    labels_to_names = create_readable_names_for_imagenet_labels()
+    dataset_utils.write_label_file(labels_to_names, dataset_dir)
+
   return slim.dataset.Dataset(
       data_sources=file_pattern,
       reader=reader,
       decoder=decoder,
       num_samples=_SPLITS_TO_SIZES[split_name],
       items_to_descriptions=_ITEMS_TO_DESCRIPTIONS,
-      num_classes=_NUM_CLASSES)
+      num_classes=_NUM_CLASSES,
+      labels_to_names=labels_to_names)

+ 1 - 1
slim/datasets/mnist.py

@@ -25,7 +25,7 @@ from __future__ import print_function
 import os
 import tensorflow as tf
 
-from slim.datasets import dataset_utils
+from datasets import dataset_utils
 
 slim = tf.contrib.slim
 

+ 1 - 0
slim/deployment/__init__.py

@@ -0,0 +1 @@
+

+ 8 - 8
slim/models/model_deploy.py

@@ -30,7 +30,7 @@ Usage:
   g = tf.Graph()
 
   # Set up DeploymentConfig
-  config = slim.DeploymentConfig(num_clones=2, clone_on_cpu=True)
+  config = model_deploy.DeploymentConfig(num_clones=2, clone_on_cpu=True)
 
   # Create the global step on the device storing the variables.
   with tf.device(config.variables_device()):
@@ -51,7 +51,8 @@ Usage:
     predictions = CreateNetwork(images)
     slim.losses.log_loss(predictions, labels)
 
-  model_dp = slim.deploy(config, model_fn, [inputs_queue], optimizer=optimizer)
+  model_dp = model_deploy.deploy(config, model_fn, [inputs_queue],
+                                 optimizer=optimizer)
 
   # Run training.
   slim.learning.train(model_dp.train_op, my_log_dir,
@@ -240,7 +241,7 @@ def _gather_clone_loss(clone, num_clones, regularization_losses):
 
 
 def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
-                    kwargs=None):
+                    **kwargs):
   """Compute losses and gradients for a single clone.
 
   Args:
@@ -249,7 +250,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
     num_clones: The number of clones being deployed.
     regularization_losses: Possibly empty list of regularization_losses
       to add to the clone losses.
-    kwargs: Dict of kwarg to pass to compute_gradients().
+    **kwargs: Dict of kwarg to pass to compute_gradients().
 
   Returns:
     A tuple (clone_loss, clone_grads_and_vars).
@@ -267,7 +268,7 @@ def _optimize_clone(optimizer, clone, num_clones, regularization_losses,
 
 def optimize_clones(clones, optimizer,
                     regularization_losses=None,
-                    kwargs=None):
+                    **kwargs):
   """Compute clone losses and gradients for the given list of `Clones`.
 
   Note: The regularization_losses are added to the first clone losses.
@@ -278,7 +279,7 @@ def optimize_clones(clones, optimizer,
    regularization_losses: Optional list of regularization losses. If None it
      will gather them from tf.GraphKeys.REGULARIZATION_LOSSES. Pass `[]` to
      exclude them.
-   kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
+   **kwargs: Optional list of keyword arguments to pass to `compute_gradients`.
 
   Returns:
    A tuple (total_loss, grads_and_vars).
@@ -290,7 +291,6 @@ def optimize_clones(clones, optimizer,
   """
   grads_and_vars = []
   clones_losses = []
-  kwargs = kwargs or {}
   num_clones = len(clones)
   if regularization_losses is None:
     regularization_losses = tf.get_collection(
@@ -298,7 +298,7 @@ def optimize_clones(clones, optimizer,
   for clone in clones:
     with tf.name_scope(clone.scope):
       clone_loss, clone_grad = _optimize_clone(
-          optimizer, clone, num_clones, regularization_losses, kwargs)
+          optimizer, clone, num_clones, regularization_losses, **kwargs)
       if clone_loss is not None:
         clones_losses.append(clone_loss)
         grads_and_vars.append(clone_grad)

+ 1 - 1
slim/models/model_deploy_test.py

@@ -21,7 +21,7 @@ from __future__ import print_function
 import numpy as np
 import tensorflow as tf
 
-from slim.models import model_deploy
+from deployment import model_deploy
 
 slim = tf.contrib.slim
 

+ 74 - 0
slim/download_and_convert_data.py

@@ -0,0 +1,74 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+r"""Downloads and converts a particular dataset.
+
+Usage:
+```shell
+
+$ python download_and_convert_data.py \
+    --dataset_name=mnist \
+    --dataset_dir=/tmp/mnist
+
+$ python download_and_convert_data.py \
+    --dataset_name=cifar10 \
+    --dataset_dir=/tmp/cifar10
+
+$ python download_and_convert_data.py \
+    --dataset_name=flowers \
+    --dataset_dir=/tmp/flowers
+```
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from datasets import download_and_convert_cifar10
+from datasets import download_and_convert_flowers
+from datasets import download_and_convert_mnist
+
+FLAGS = tf.app.flags.FLAGS
+
+tf.app.flags.DEFINE_string(
+    'dataset_name',
+    None,
+    'The name of the dataset to convert, one of "cifar10", "flowers", "mnist".')
+
+tf.app.flags.DEFINE_string(
+    'dataset_dir',
+    None,
+    'The directory where the output TFRecords and temporary files are saved.')
+
+
+def main(_):
+  if not FLAGS.dataset_name:
+    raise ValueError('You must supply the dataset name with --dataset_name')
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  if FLAGS.dataset_name == 'cifar10':
+    download_and_convert_cifar10.run(FLAGS.dataset_dir)
+  elif FLAGS.dataset_name == 'flowers':
+    download_and_convert_flowers.run(FLAGS.dataset_dir)
+  elif FLAGS.dataset_name == 'mnist':
+    download_and_convert_mnist.run(FLAGS.dataset_dir)
+  else:
+    raise ValueError(
+        'dataset_name [%s] was not recognized.' % FLAGS.dataset_dir)
+
+if __name__ == '__main__':
+  tf.app.run()
+

+ 21 - 23
slim/eval.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Generic evaluation script that trains a given model a specified dataset."""
+"""Generic evaluation script that evaluates a model using a given dataset."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -21,9 +21,9 @@ from __future__ import print_function
 import math
 import tensorflow as tf
 
-from slim.datasets import dataset_factory
-from slim.models import model_factory
-from slim.models import preprocessing_factory
+from datasets import dataset_factory
+from nets import nets_factory
+from preprocessing import preprocessing_factory
 
 slim = tf.contrib.slim
 
@@ -42,11 +42,6 @@ tf.app.flags.DEFINE_string(
     'The directory where the model was written to or an absolute path to a '
     'checkpoint file.')
 
-tf.app.flags.DEFINE_bool(
-    'restore_global_step', True,
-    'Whether or not to restore the global step. When evaluating a model '
-    'checkpoint containing ONLY weights, set this flag to `False`.')
-
 tf.app.flags.DEFINE_string(
     'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
 
@@ -58,11 +53,10 @@ tf.app.flags.DEFINE_string(
     'dataset_name', 'imagenet', 'The name of the dataset to load.')
 
 tf.app.flags.DEFINE_string(
-    'dataset_split_name', 'train', 'The name of the train/test split.')
+    'dataset_split_name', 'test', 'The name of the train/test split.')
 
 tf.app.flags.DEFINE_string(
     'dataset_dir', None, 'The directory where the dataset files are stored.')
-tf.app.flags.MarkFlagAsRequired('dataset_dir')
 
 tf.app.flags.DEFINE_integer(
     'labels_offset', 0,
@@ -82,10 +76,17 @@ tf.app.flags.DEFINE_float(
     'The decay to use for the moving average.'
     'If left as None, then moving averages are not used.')
 
+tf.app.flags.DEFINE_integer(
+    'eval_image_size', None, 'Eval image size')
+
 FLAGS = tf.app.flags.FLAGS
 
 
 def main(_):
+  if not FLAGS.dataset_dir:
+    raise ValueError('You must supply the dataset directory with --dataset_dir')
+
+  tf.logging.set_verbosity(tf.logging.INFO)
   with tf.Graph().as_default():
     tf_global_step = slim.get_or_create_global_step()
 
@@ -98,7 +99,7 @@ def main(_):
     ####################
     # Select the model #
     ####################
-    model_fn = model_factory.get_model(
+    network_fn = nets_factory.get_network_fn(
         FLAGS.model_name,
         num_classes=(dataset.num_classes - FLAGS.labels_offset),
         is_training=False)
@@ -122,9 +123,9 @@ def main(_):
         preprocessing_name,
         is_training=False)
 
-    image = image_preprocessing_fn(image,
-                                   height=model_fn.default_image_size,
-                                   width=model_fn.default_image_size)
+    eval_image_size = FLAGS.eval_image_size or network_fn.default_image_size
+
+    image = image_preprocessing_fn(image, eval_image_size, eval_image_size)
 
     images, labels = tf.train.batch(
         [image, label],
@@ -135,19 +136,16 @@ def main(_):
     ####################
     # Define the model #
     ####################
-    logits, _ = model_fn(images)
+    logits, _ = network_fn(images)
 
     if FLAGS.moving_average_decay:
       variable_averages = tf.train.ExponentialMovingAverage(
           FLAGS.moving_average_decay, tf_global_step)
       variables_to_restore = variable_averages.variables_to_restore(
           slim.get_model_variables())
-
-      if FLAGS.restore_global_step:
-        variables_to_restore[tf_global_step.op.name] = tf_global_step
+      variables_to_restore[tf_global_step.op.name] = tf_global_step
     else:
-      exclude = None if FLAGS.restore_global_step else ['global_step']
-      variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
+      variables_to_restore = slim.get_variables_to_restore()
 
     predictions = tf.argmax(logits, 1)
     labels = tf.squeeze(labels)
@@ -181,8 +179,8 @@ def main(_):
     tf.logging.info('Evaluating %s' % checkpoint_path)
 
     slim.evaluation.evaluate_once(
-        FLAGS.master,
-        checkpoint_path,
+        master=FLAGS.master,
+        checkpoint_path=checkpoint_path,
         logdir=FLAGS.eval_dir,
         num_evals=num_batches,
         eval_op=names_to_updates.values(),

+ 0 - 140
slim/models/model_factory.py

@@ -1,140 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Contains a factory for building various models."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-from tensorflow.contrib.slim import nets
-from slim.nets import lenet
-
-slim = tf.contrib.slim
-
-
-def get_model(name, num_classes, weight_decay=0.0, is_training=False):
-  """Returns a model_fn such as `logits, end_points = model_fn(images)`.
-
-  Args:
-    name: The name of the model.
-    num_classes: The number of classes to use for classification.
-    weight_decay: The l2 coefficient for the model weights.
-    is_training: `True` if the model is being used for training and `False`
-      otherwise.
-
-  Returns:
-    model_fn: A function that applies the model to a batch of images. It has
-      the following signature:
-        logits, end_points = model_fn(images)
-  Raises:
-    ValueError: If model `name` is not recognized.
-  """
-  if name == 'inception_v1':
-    default_image_size = nets.inception.inception_v1.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.inception.inception_v1_arg_scope(
-          weight_decay=weight_decay)):
-        return nets.inception.inception_v1(images,
-                                           num_classes,
-                                           is_training=is_training)
-    model_fn = func
-  elif name == 'inception_v2':
-    default_image_size = nets.inception.inception_v2.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.inception.inception_v2_arg_scope(
-          weight_decay=weight_decay)):
-        return nets.inception.inception_v2(images,
-                                           num_classes=num_classes,
-                                           is_training=is_training)
-    model_fn = func
-  elif name == 'inception_v3':
-    default_image_size = nets.inception.inception_v3.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.inception.inception_v3_arg_scope(
-          weight_decay=weight_decay)):
-        return nets.inception.inception_v3(images,
-                                           num_classes=num_classes,
-                                           is_training=is_training)
-    model_fn = func
-  elif name == 'lenet':
-    default_image_size = lenet.lenet.default_image_size
-    def func(images):
-      with slim.arg_scope(lenet.lenet_arg_scope(weight_decay=weight_decay)):
-        return lenet.lenet(images,
-                           num_classes=num_classes,
-                           is_training=is_training)
-    model_fn = func
-  elif name == 'resnet_v1_50':
-    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
-          is_training, weight_decay=weight_decay)):
-        net, end_points = nets.resnet_v1.resnet_v1_50(
-            images, num_classes=num_classes)
-        net = tf.squeeze(net, squeeze_dims=[1, 2])
-        return net, end_points
-    model_fn = func
-  elif name == 'resnet_v1_101':
-    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
-          is_training, weight_decay=weight_decay)):
-        net, end_points = nets.resnet_v1.resnet_v1_101(
-            images, num_classes=num_classes)
-        net = tf.squeeze(net, squeeze_dims=[1, 2])
-        return net, end_points
-    model_fn = func
-  elif name == 'resnet_v1_152':
-    default_image_size = nets.resnet_v1.resnet_v1.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.resnet_v1.resnet_arg_scope(
-          is_training, weight_decay=weight_decay)):
-        net, end_points = nets.resnet_v1.resnet_v1_152(
-            images, num_classes=num_classes)
-        net = tf.squeeze(net, squeeze_dims=[1, 2])
-        return net, end_points
-    model_fn = func
-  elif name == 'vgg_a':
-    default_image_size = nets.vgg.vgg_a.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
-        return nets.vgg.vgg_a(images,
-                              num_classes=num_classes,
-                              is_training=is_training)
-    model_fn = func
-  elif name == 'vgg_16':
-    default_image_size = nets.vgg.vgg_16.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
-        return nets.vgg.vgg_16(images,
-                               num_classes=num_classes,
-                               is_training=is_training)
-    model_fn = func
-  elif name == 'vgg_19':
-    default_image_size = nets.vgg.vgg_19.default_image_size
-    def func(images):
-      with slim.arg_scope(nets.vgg.vgg_arg_scope(weight_decay)):
-        return nets.vgg.vgg_19(images,
-                               num_classes=num_classes,
-                               is_training=is_training)
-    model_fn = func
-  else:
-    raise ValueError('Model name [%s] was not recognized' % name)
-
-  model_fn.default_image_size = default_image_size
-
-  return model_fn

+ 0 - 200
slim/models/resnet_preprocessing.py

@@ -1,200 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Provides utilities to preprocess images for the ResNet networks."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import tensorflow as tf
-
-from tensorflow.contrib.slim import nets
-from tensorflow.python.ops import control_flow_ops
-
-slim = tf.contrib.slim
-
-_R_MEAN = 123.68
-_G_MEAN = 116.78
-_B_MEAN = 103.94
-
-_CROP_HEIGHT = nets.resnet_v1.resnet_v1.default_image_size
-_CROP_WIDTH = nets.resnet_v1.resnet_v1.default_image_size
-_RESIZE_SIDE = 256
-
-
-def _mean_image_subtraction(image, means):
-  """Subtracts the given means from each image channel.
-
-  For example:
-    means = [123.68, 116.779, 103.939]
-    image = _mean_image_subtraction(image, means)
-
-  Note that the rank of `image` must be known.
-
-  Args:
-    image: a tensor of size [height, width, C].
-    means: a C-vector of values to subtract from each channel.
-
-  Returns:
-    the centered image.
-
-  Raises:
-    ValueError: If the rank of `image` is unknown, if `image` has a rank other
-      than three or if the number of channels in `image` doesn't match the
-      number of values in `means`.
-  """
-  if image.get_shape().ndims != 3:
-    raise ValueError('Input must be of size [height, width, C>0]')
-  num_channels = image.get_shape().as_list()[-1]
-  if len(means) != num_channels:
-    raise ValueError('len(means) must match the number of channels')
-
-  channels = tf.split(2, num_channels, image)
-  for i in range(num_channels):
-    channels[i] -= means[i]
-  return tf.concat(2, channels)
-
-
-def _smallest_size_at_least(height, width, smallest_side):
-  """Computes new shape with the smallest side equal to `smallest_side`.
-
-  Computes new shape with the smallest side equal to `smallest_side` while
-  preserving the original aspect ratio.
-
-  Args:
-    height: an int32 scalar tensor indicating the current height.
-    width: an int32 scalar tensor indicating the current width.
-    smallest_side: an python integer indicating the smallest side of the new
-      shape.
-
-  Returns:
-    new_height: an int32 scalar tensor indicating the new height.
-    new_width: and int32 scalar tensor indicating the new width.
-  """
-  height = tf.to_float(height)
-  width = tf.to_float(width)
-  smallest_side = float(smallest_side)
-  scale = tf.cond(tf.greater(height, width),
-                  lambda: smallest_side / width,
-                  lambda: smallest_side / height)
-  new_height = tf.to_int32(height * scale)
-  new_width = tf.to_int32(width * scale)
-  return new_height, new_width
-
-
-def _aspect_preserving_resize(image, smallest_side):
-  """Resize images preserving the original aspect ratio.
-
-  Args:
-    image: a 3-D image tensor.
-    smallest_side: a python integer indicating the size of the smallest side
-      after resize.
-
-  Returns:
-    resized_image: a 3-D tensor containing the resized image.
-  """
-  shape = tf.shape(image)
-  height = shape[0]
-  width = shape[1]
-  new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
-  image = tf.expand_dims(image, 0)
-  resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
-                                           align_corners=False)
-  resized_image = tf.squeeze(resized_image)
-  resized_image.set_shape([None, None, 3])
-  return resized_image
-
-
-def _crop(image, offset_height, offset_width, crop_height, crop_width):
-  """Crops the given image using the provided offsets and sizes.
-
-  Note that the method doesn't assume we know the input image size but it does
-  assume we know the input image rank.
-
-  Args:
-    image: an image of shape [height, width, channels].
-    offset_height: a scalar tensor indicating the height offset.
-    offset_width: a scalar tensor indicating the width offset.
-    crop_height: the height of the cropped image.
-    crop_width: the width of the cropped image.
-
-  Returns:
-    the cropped (and resized) image.
-
-  Raises:
-    InvalidArgumentError: if the rank is not 3 or if the image dimensions are
-      less than the crop size.
-  """
-  original_shape = tf.shape(image)
-
-  rank_assertion = tf.Assert(
-      tf.equal(tf.rank(image), 3),
-      ['Rank of image must be equal to 3.'])
-  cropped_shape = control_flow_ops.with_dependencies(
-      [rank_assertion],
-      tf.pack([crop_height, crop_width, original_shape[2]]))
-
-  size_assertion = tf.Assert(
-      tf.logical_and(
-          tf.greater_equal(original_shape[0], crop_height),
-          tf.greater_equal(original_shape[1], crop_width)),
-      ['Crop size greater than the image size.'])
-
-  offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
-
-  # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
-  # define the crop size.
-  image = control_flow_ops.with_dependencies(
-      [size_assertion],
-      tf.slice(image, offsets, cropped_shape))
-  return tf.reshape(image, cropped_shape)
-
-
-def _central_crop(image_list, crop_height, crop_width):
-  """Performs central crops of the given image list.
-
-  Args:
-    image_list: a list of image tensors of the same dimension but possibly
-      varying channel.
-    crop_height: the height of the image following the crop.
-    crop_width: the width of the image following the crop.
-
-  Returns:
-    the list of cropped images.
-  """
-  outputs = []
-  for image in image_list:
-    image_height = tf.shape(image)[0]
-    image_width = tf.shape(image)[1]
-
-    offset_height = (image_height - crop_height) / 2
-    offset_width = (image_width - crop_width) / 2
-
-    outputs.append(_crop(image, offset_height, offset_width,
-                         crop_height, crop_width))
-  return outputs
-
-
-def preprocess_image(image,
-                     height=_CROP_HEIGHT,
-                     width=_CROP_WIDTH,
-                     is_training=False,  # pylint: disable=unused-argument
-                     resize_side=_RESIZE_SIDE):
-  image = _aspect_preserving_resize(image, resize_side)
-  image = _central_crop([image], height, width)[0]
-  image.set_shape([height, width, 3])
-  image = tf.to_float(image)
-  image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
-  return image

+ 1 - 0
slim/nets/__init__.py

@@ -0,0 +1 @@
+

+ 125 - 0
slim/nets/alexnet.py

@@ -0,0 +1,125 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a model definition for AlexNet.
+
+This work was first described in:
+  ImageNet Classification with Deep Convolutional Neural Networks
+  Alex Krizhevsky, Ilya Sutskever and Geoffrey E. Hinton
+
+and later refined in:
+  One weird trick for parallelizing convolutional neural networks
+  Alex Krizhevsky, 2014
+
+Here we provide the implementation proposed in "One weird trick" and not
+"ImageNet Classification", as per the paper, the LRN layers have been removed.
+
+Usage:
+  with slim.arg_scope(alexnet.alexnet_v2_arg_scope()):
+    outputs, end_points = alexnet.alexnet_v2(inputs)
+
+@@alexnet_v2
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def alexnet_v2_arg_scope(weight_decay=0.0005):
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      activation_fn=tf.nn.relu,
+                      biases_initializer=tf.constant_initializer(0.1),
+                      weights_regularizer=slim.l2_regularizer(weight_decay)):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+      with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
+        return arg_sc
+
+
+def alexnet_v2(inputs,
+               num_classes=1000,
+               is_training=True,
+               dropout_keep_prob=0.5,
+               spatial_squeeze=True,
+               scope='alexnet_v2'):
+  """AlexNet version 2.
+
+  Described in: http://arxiv.org/pdf/1404.5997v2.pdf
+  Parameters from:
+  github.com/akrizhevsky/cuda-convnet2/blob/master/layers/
+  layers-imagenet-1gpu.cfg
+
+  Note: All the fully_connected layers have been transformed to conv2d layers.
+        To use in classification mode, resize input to 224x224. To use in fully
+        convolutional mode, set spatial_squeeze to false.
+        The LRN layers have been removed and change the initializers from
+        random_normal_initializer to xavier_initializer.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether or not the model is being trained.
+    dropout_keep_prob: the probability that activations are kept in the dropout
+      layers during training.
+    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+      outputs. Useful to remove unnecessary dimensions for classification.
+    scope: Optional scope for the variables.
+
+  Returns:
+    the last op containing the log predictions and end_points dict.
+  """
+  with tf.variable_scope(scope, 'alexnet_v2', [inputs]) as sc:
+    end_points_collection = sc.name + '_end_points'
+    # Collect outputs for conv2d, fully_connected and max_pool2d.
+    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+                        outputs_collections=[end_points_collection]):
+      net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
+                        scope='conv1')
+      net = slim.max_pool2d(net, [3, 3], 2, scope='pool1')
+      net = slim.conv2d(net, 192, [5, 5], scope='conv2')
+      net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
+      net = slim.conv2d(net, 384, [3, 3], scope='conv3')
+      net = slim.conv2d(net, 384, [3, 3], scope='conv4')
+      net = slim.conv2d(net, 256, [3, 3], scope='conv5')
+      net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
+
+      # Use conv2d instead of fully_connected layers.
+      with slim.arg_scope([slim.conv2d],
+                          weights_initializer=trunc_normal(0.005),
+                          biases_initializer=tf.constant_initializer(0.1)):
+        net = slim.conv2d(net, 4096, [5, 5], padding='VALID',
+                          scope='fc6')
+        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                           scope='dropout6')
+        net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                           scope='dropout7')
+        net = slim.conv2d(net, num_classes, [1, 1],
+                          activation_fn=None,
+                          normalizer_fn=None,
+                          biases_initializer=tf.zeros_initializer,
+                          scope='fc8')
+
+      # Convert end_points_collection into a end_point dict.
+      end_points = dict(tf.get_collection(end_points_collection))
+      if spatial_squeeze:
+        net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+        end_points[sc.name + '/fc8'] = net
+      return net, end_points
+alexnet_v2.default_image_size = 224

+ 145 - 0
slim/nets/alexnet_test.py

@@ -0,0 +1,145 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.alexnet."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import alexnet
+
+slim = tf.contrib.slim
+
+
+class AlexnetV2Test(tf.test.TestCase):
+
+  def testBuild(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = alexnet.alexnet_v2(inputs, num_classes)
+      self.assertEquals(logits.op.name, 'alexnet_v2/fc8/squeezed')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testFullyConvolutional(self):
+    batch_size = 1
+    height, width = 300, 400
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = alexnet.alexnet_v2(inputs, num_classes, spatial_squeeze=False)
+      self.assertEquals(logits.op.name, 'alexnet_v2/fc8/BiasAdd')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, 4, 7, num_classes])
+
+  def testEndPoints(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = alexnet.alexnet_v2(inputs, num_classes)
+      expected_names = ['alexnet_v2/conv1',
+                        'alexnet_v2/pool1',
+                        'alexnet_v2/conv2',
+                        'alexnet_v2/pool2',
+                        'alexnet_v2/conv3',
+                        'alexnet_v2/conv4',
+                        'alexnet_v2/conv5',
+                        'alexnet_v2/pool5',
+                        'alexnet_v2/fc6',
+                        'alexnet_v2/fc7',
+                        'alexnet_v2/fc8'
+                       ]
+      self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+  def testModelVariables(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      alexnet.alexnet_v2(inputs, num_classes)
+      expected_names = ['alexnet_v2/conv1/weights',
+                        'alexnet_v2/conv1/biases',
+                        'alexnet_v2/conv2/weights',
+                        'alexnet_v2/conv2/biases',
+                        'alexnet_v2/conv3/weights',
+                        'alexnet_v2/conv3/biases',
+                        'alexnet_v2/conv4/weights',
+                        'alexnet_v2/conv4/biases',
+                        'alexnet_v2/conv5/weights',
+                        'alexnet_v2/conv5/biases',
+                        'alexnet_v2/fc6/weights',
+                        'alexnet_v2/fc6/biases',
+                        'alexnet_v2/fc7/weights',
+                        'alexnet_v2/fc7/biases',
+                        'alexnet_v2/fc8/weights',
+                        'alexnet_v2/fc8/biases',
+                       ]
+      model_variables = [v.op.name for v in slim.get_model_variables()]
+      self.assertSetEqual(set(model_variables), set(expected_names))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      predictions = tf.argmax(logits, 1)
+      self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 2
+    eval_batch_size = 1
+    train_height, train_width = 224, 224
+    eval_height, eval_width = 300, 400
+    num_classes = 1000
+    with self.test_session():
+      train_inputs = tf.random_uniform(
+          (train_batch_size, train_height, train_width, 3))
+      logits, _ = alexnet.alexnet_v2(train_inputs)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [train_batch_size, num_classes])
+      tf.get_variable_scope().reuse_variables()
+      eval_inputs = tf.random_uniform(
+          (eval_batch_size, eval_height, eval_width, 3))
+      logits, _ = alexnet.alexnet_v2(eval_inputs, is_training=False,
+                                     spatial_squeeze=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [eval_batch_size, 4, 7, num_classes])
+      logits = tf.reduce_mean(logits, [1, 2])
+      predictions = tf.argmax(logits, 1)
+      self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+  def testForward(self):
+    batch_size = 1
+    height, width = 224, 224
+    with self.test_session() as sess:
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = alexnet.alexnet_v2(inputs)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits)
+      self.assertTrue(output.any())
+
+if __name__ == '__main__':
+  tf.test.main()

+ 112 - 0
slim/nets/cifarnet.py

@@ -0,0 +1,112 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a variant of the CIFAR-10 model definition."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(stddev=stddev)
+
+
+def cifarnet(images, num_classes=10, is_training=False,
+             dropout_keep_prob=0.5,
+             prediction_fn=slim.softmax,
+             scope='CifarNet'):
+  """Creates a variant of the CifarNet model.
+
+  Note that since the output is a set of 'logits', the values fall in the
+  interval of (-infinity, infinity). Consequently, to convert the outputs to a
+  probability distribution over the characters, one will need to convert them
+  using the softmax function:
+
+        logits = cifarnet.cifarnet(images, is_training=False)
+        probabilities = tf.nn.softmax(logits)
+        predictions = tf.argmax(logits, 1)
+
+  Args:
+    images: A batch of `Tensors` of size [batch_size, height, width, channels].
+    num_classes: the number of classes in the dataset.
+    is_training: specifies whether or not we're currently training the model.
+      This variable will determine the behaviour of the dropout layer.
+    dropout_keep_prob: the percentage of activation values that are retained.
+    prediction_fn: a function to get predictions out of logits.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the pre-softmax activations, a tensor of size
+      [batch_size, `num_classes`]
+    end_points: a dictionary from components of the network to the corresponding
+      activation.
+  """
+  end_points = {}
+
+  with tf.variable_scope(scope, 'CifarNet', [images, num_classes]):
+    net = slim.conv2d(images, 64, [5, 5], scope='conv1')
+    end_points['conv1'] = net
+    net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
+    end_points['pool1'] = net
+    net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm1')
+    net = slim.conv2d(net, 64, [5, 5], scope='conv2')
+    end_points['conv2'] = net
+    net = tf.nn.lrn(net, 4, bias=1.0, alpha=0.001/9.0, beta=0.75, name='norm2')
+    net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
+    end_points['pool2'] = net
+    net = slim.flatten(net)
+    end_points['Flatten'] = net
+    net = slim.fully_connected(net, 384, scope='fc3')
+    end_points['fc3'] = net
+    net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                       scope='dropout3')
+    net = slim.fully_connected(net, 192, scope='fc4')
+    end_points['fc4'] = net
+    logits = slim.fully_connected(net, num_classes,
+                                  biases_initializer=tf.zeros_initializer,
+                                  weights_initializer=trunc_normal(1/192.0),
+                                  weights_regularizer=None,
+                                  activation_fn=None,
+                                  scope='logits')
+
+    end_points['Logits'] = logits
+    end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
+
+  return logits, end_points
+cifarnet.default_image_size = 32
+
+
+def cifarnet_arg_scope(weight_decay=0.004):
+  """Defines the default cifarnet argument scope.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+
+  Returns:
+    An `arg_scope` to use for the inception v3 model.
+  """
+  with slim.arg_scope(
+      [slim.conv2d],
+      weights_initializer=tf.truncated_normal_initializer(stddev=5e-2),
+      activation_fn=tf.nn.relu):
+    with slim.arg_scope(
+        [slim.fully_connected],
+        biases_initializer=tf.constant_initializer(0.1),
+        weights_initializer=trunc_normal(0.04),
+        weights_regularizer=slim.l2_regularizer(weight_decay),
+        activation_fn=tf.nn.relu) as sc:
+      return sc

+ 33 - 0
slim/nets/inception.py

@@ -0,0 +1,33 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Brings inception_v1, inception_v2 and inception_v3 under one namespace."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+from nets.inception_resnet_v2 import inception_resnet_v2
+from nets.inception_resnet_v2 import inception_resnet_v2_arg_scope
+from nets.inception_v1 import inception_v1
+from nets.inception_v1 import inception_v1_arg_scope
+from nets.inception_v1 import inception_v1_base
+from nets.inception_v2 import inception_v2
+from nets.inception_v2 import inception_v2_arg_scope
+from nets.inception_v2 import inception_v2_base
+from nets.inception_v3 import inception_v3
+from nets.inception_v3 import inception_v3_arg_scope
+from nets.inception_v3 import inception_v3_base
+# pylint: enable=unused-import

+ 280 - 0
slim/nets/inception_resnet_v2.py

@@ -0,0 +1,280 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition of the Inception Resnet V2 architecture.
+
+As described in http://arxiv.org/abs/1602.07261.
+
+  Inception-v4, Inception-ResNet and the Impact of Residual Connections
+    on Learning
+  Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def block35(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+  """Builds the 35x35 resnet block."""
+  with tf.variable_scope(scope, 'Block35', [net], reuse=reuse):
+    with tf.variable_scope('Branch_0'):
+      tower_conv = slim.conv2d(net, 32, 1, scope='Conv2d_1x1')
+    with tf.variable_scope('Branch_1'):
+      tower_conv1_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
+      tower_conv1_1 = slim.conv2d(tower_conv1_0, 32, 3, scope='Conv2d_0b_3x3')
+    with tf.variable_scope('Branch_2'):
+      tower_conv2_0 = slim.conv2d(net, 32, 1, scope='Conv2d_0a_1x1')
+      tower_conv2_1 = slim.conv2d(tower_conv2_0, 48, 3, scope='Conv2d_0b_3x3')
+      tower_conv2_2 = slim.conv2d(tower_conv2_1, 64, 3, scope='Conv2d_0c_3x3')
+    mixed = tf.concat(3, [tower_conv, tower_conv1_1, tower_conv2_2])
+    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
+                     activation_fn=None, scope='Conv2d_1x1')
+    net += scale * up
+    if activation_fn:
+      net = activation_fn(net)
+  return net
+
+
+def block17(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+  """Builds the 17x17 resnet block."""
+  with tf.variable_scope(scope, 'Block17', [net], reuse=reuse):
+    with tf.variable_scope('Branch_0'):
+      tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
+    with tf.variable_scope('Branch_1'):
+      tower_conv1_0 = slim.conv2d(net, 128, 1, scope='Conv2d_0a_1x1')
+      tower_conv1_1 = slim.conv2d(tower_conv1_0, 160, [1, 7],
+                                  scope='Conv2d_0b_1x7')
+      tower_conv1_2 = slim.conv2d(tower_conv1_1, 192, [7, 1],
+                                  scope='Conv2d_0c_7x1')
+    mixed = tf.concat(3, [tower_conv, tower_conv1_2])
+    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
+                     activation_fn=None, scope='Conv2d_1x1')
+    net += scale * up
+    if activation_fn:
+      net = activation_fn(net)
+  return net
+
+
+def block8(net, scale=1.0, activation_fn=tf.nn.relu, scope=None, reuse=None):
+  """Builds the 8x8 resnet block."""
+  with tf.variable_scope(scope, 'Block8', [net], reuse=reuse):
+    with tf.variable_scope('Branch_0'):
+      tower_conv = slim.conv2d(net, 192, 1, scope='Conv2d_1x1')
+    with tf.variable_scope('Branch_1'):
+      tower_conv1_0 = slim.conv2d(net, 192, 1, scope='Conv2d_0a_1x1')
+      tower_conv1_1 = slim.conv2d(tower_conv1_0, 224, [1, 3],
+                                  scope='Conv2d_0b_1x3')
+      tower_conv1_2 = slim.conv2d(tower_conv1_1, 256, [3, 1],
+                                  scope='Conv2d_0c_3x1')
+    mixed = tf.concat(3, [tower_conv, tower_conv1_2])
+    up = slim.conv2d(mixed, net.get_shape()[3], 1, normalizer_fn=None,
+                     activation_fn=None, scope='Conv2d_1x1')
+    net += scale * up
+    if activation_fn:
+      net = activation_fn(net)
+  return net
+
+
+def inception_resnet_v2(inputs, num_classes=1001, is_training=True,
+                        dropout_keep_prob=0.8,
+                        reuse=None,
+                        scope='InceptionResnetV2'):
+  """Creates the Inception Resnet V2 model.
+
+  Args:
+    inputs: a 4-D tensor of size [batch_size, height, width, 3].
+    num_classes: number of predicted classes.
+    is_training: whether is training or not.
+    dropout_keep_prob: float, the fraction to keep before final layer.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the logits outputs of the model.
+    end_points: the set of end_points from the inception model.
+  """
+  end_points = {}
+
+  with tf.variable_scope(scope, 'InceptionResnetV2', [inputs], reuse=reuse):
+    with slim.arg_scope([slim.batch_norm, slim.dropout],
+                        is_training=is_training):
+      with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                          stride=1, padding='SAME'):
+
+        # 149 x 149 x 32
+        net = slim.conv2d(inputs, 32, 3, stride=2, padding='VALID',
+                          scope='Conv2d_1a_3x3')
+        end_points['Conv2d_1a_3x3'] = net
+        # 147 x 147 x 32
+        net = slim.conv2d(net, 32, 3, padding='VALID',
+                          scope='Conv2d_2a_3x3')
+        end_points['Conv2d_2a_3x3'] = net
+        # 147 x 147 x 64
+        net = slim.conv2d(net, 64, 3, scope='Conv2d_2b_3x3')
+        end_points['Conv2d_2b_3x3'] = net
+        # 73 x 73 x 64
+        net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
+                              scope='MaxPool_3a_3x3')
+        end_points['MaxPool_3a_3x3'] = net
+        # 73 x 73 x 80
+        net = slim.conv2d(net, 80, 1, padding='VALID',
+                          scope='Conv2d_3b_1x1')
+        end_points['Conv2d_3b_1x1'] = net
+        # 71 x 71 x 192
+        net = slim.conv2d(net, 192, 3, padding='VALID',
+                          scope='Conv2d_4a_3x3')
+        end_points['Conv2d_4a_3x3'] = net
+        # 35 x 35 x 192
+        net = slim.max_pool2d(net, 3, stride=2, padding='VALID',
+                              scope='MaxPool_5a_3x3')
+        end_points['MaxPool_5a_3x3'] = net
+
+        # 35 x 35 x 320
+        with tf.variable_scope('Mixed_5b'):
+          with tf.variable_scope('Branch_0'):
+            tower_conv = slim.conv2d(net, 96, 1, scope='Conv2d_1x1')
+          with tf.variable_scope('Branch_1'):
+            tower_conv1_0 = slim.conv2d(net, 48, 1, scope='Conv2d_0a_1x1')
+            tower_conv1_1 = slim.conv2d(tower_conv1_0, 64, 5,
+                                        scope='Conv2d_0b_5x5')
+          with tf.variable_scope('Branch_2'):
+            tower_conv2_0 = slim.conv2d(net, 64, 1, scope='Conv2d_0a_1x1')
+            tower_conv2_1 = slim.conv2d(tower_conv2_0, 96, 3,
+                                        scope='Conv2d_0b_3x3')
+            tower_conv2_2 = slim.conv2d(tower_conv2_1, 96, 3,
+                                        scope='Conv2d_0c_3x3')
+          with tf.variable_scope('Branch_3'):
+            tower_pool = slim.avg_pool2d(net, 3, stride=1, padding='SAME',
+                                         scope='AvgPool_0a_3x3')
+            tower_pool_1 = slim.conv2d(tower_pool, 64, 1,
+                                       scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [tower_conv, tower_conv1_1,
+                              tower_conv2_2, tower_pool_1])
+
+        end_points['Mixed_5b'] = net
+        net = slim.repeat(net, 10, block35, scale=0.17)
+
+        # 17 x 17 x 1024
+        with tf.variable_scope('Mixed_6a'):
+          with tf.variable_scope('Branch_0'):
+            tower_conv = slim.conv2d(net, 384, 3, stride=2, padding='VALID',
+                                     scope='Conv2d_1a_3x3')
+          with tf.variable_scope('Branch_1'):
+            tower_conv1_0 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
+            tower_conv1_1 = slim.conv2d(tower_conv1_0, 256, 3,
+                                        scope='Conv2d_0b_3x3')
+            tower_conv1_2 = slim.conv2d(tower_conv1_1, 384, 3,
+                                        stride=2, padding='VALID',
+                                        scope='Conv2d_1a_3x3')
+          with tf.variable_scope('Branch_2'):
+            tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
+                                         scope='MaxPool_1a_3x3')
+          net = tf.concat(3, [tower_conv, tower_conv1_2, tower_pool])
+
+        end_points['Mixed_6a'] = net
+        net = slim.repeat(net, 20, block17, scale=0.10)
+
+        # Auxillary tower
+        with tf.variable_scope('AuxLogits'):
+          aux = slim.avg_pool2d(net, 5, stride=3, padding='VALID',
+                                scope='Conv2d_1a_3x3')
+          aux = slim.conv2d(aux, 128, 1, scope='Conv2d_1b_1x1')
+          aux = slim.conv2d(aux, 768, aux.get_shape()[1:3],
+                            padding='VALID', scope='Conv2d_2a_5x5')
+          aux = slim.flatten(aux)
+          aux = slim.fully_connected(aux, num_classes, activation_fn=None,
+                                     scope='Logits')
+          end_points['AuxLogits'] = aux
+
+        with tf.variable_scope('Mixed_7a'):
+          with tf.variable_scope('Branch_0'):
+            tower_conv = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
+            tower_conv_1 = slim.conv2d(tower_conv, 384, 3, stride=2,
+                                       padding='VALID', scope='Conv2d_1a_3x3')
+          with tf.variable_scope('Branch_1'):
+            tower_conv1 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
+            tower_conv1_1 = slim.conv2d(tower_conv1, 288, 3, stride=2,
+                                        padding='VALID', scope='Conv2d_1a_3x3')
+          with tf.variable_scope('Branch_2'):
+            tower_conv2 = slim.conv2d(net, 256, 1, scope='Conv2d_0a_1x1')
+            tower_conv2_1 = slim.conv2d(tower_conv2, 288, 3,
+                                        scope='Conv2d_0b_3x3')
+            tower_conv2_2 = slim.conv2d(tower_conv2_1, 320, 3, stride=2,
+                                        padding='VALID', scope='Conv2d_1a_3x3')
+          with tf.variable_scope('Branch_3'):
+            tower_pool = slim.max_pool2d(net, 3, stride=2, padding='VALID',
+                                         scope='MaxPool_1a_3x3')
+          net = tf.concat(3, [tower_conv_1, tower_conv1_1,
+                              tower_conv2_2, tower_pool])
+
+        end_points['Mixed_7a'] = net
+
+        net = slim.repeat(net, 9, block8, scale=0.20)
+        net = block8(net, activation_fn=None)
+
+        net = slim.conv2d(net, 1536, 1, scope='Conv2d_7b_1x1')
+        end_points['Conv2d_7b_1x1'] = net
+
+        with tf.variable_scope('Logits'):
+          end_points['PrePool'] = net
+          net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
+                                scope='AvgPool_1a_8x8')
+          net = slim.flatten(net)
+
+          net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                             scope='Dropout')
+
+          end_points['PreLogitsFlatten'] = net
+          logits = slim.fully_connected(net, num_classes, activation_fn=None,
+                                        scope='Logits')
+          end_points['Logits'] = logits
+          end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
+
+    return logits, end_points
+inception_resnet_v2.default_image_size = 299
+
+
+def inception_resnet_v2_arg_scope(weight_decay=0.00004,
+                                  batch_norm_decay=0.9997,
+                                  batch_norm_epsilon=0.001):
+  """Yields the scope with the default parameters for inception_resnet_v2.
+
+  Args:
+    weight_decay: the weight decay for weights variables.
+    batch_norm_decay: decay for the moving average of batch_norm momentums.
+    batch_norm_epsilon: small float added to variance to avoid dividing by zero.
+
+  Returns:
+    a arg_scope with the parameters needed for inception_resnet_v2.
+  """
+  # Set weight_decay for weights in conv2d and fully_connected layers.
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      weights_regularizer=slim.l2_regularizer(weight_decay),
+                      biases_regularizer=slim.l2_regularizer(weight_decay)):
+
+    batch_norm_params = {
+        'decay': batch_norm_decay,
+        'epsilon': batch_norm_epsilon,
+    }
+    # Set activation_fn and parameters for batch_norm.
+    with slim.arg_scope([slim.conv2d], activation_fn=tf.nn.relu,
+                        normalizer_fn=slim.batch_norm,
+                        normalizer_params=batch_norm_params) as scope:
+      return scope

+ 136 - 0
slim/nets/inception_resnet_v2_test.py

@@ -0,0 +1,136 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.inception_resnet_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import inception
+
+
+class InceptionTest(tf.test.TestCase):
+
+  def testBuildLogits(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = inception.inception_resnet_v2(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testBuildEndPoints(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = inception.inception_resnet_v2(inputs, num_classes)
+      self.assertTrue('Logits' in end_points)
+      logits = end_points['Logits']
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      self.assertTrue('AuxLogits' in end_points)
+      aux_logits = end_points['AuxLogits']
+      self.assertListEqual(aux_logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      pre_pool = end_points['PrePool']
+      self.assertListEqual(pre_pool.get_shape().as_list(),
+                           [batch_size, 8, 8, 1536])
+
+  def testVariablesSetDevice(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      # Force all Variables to reside on the device.
+      with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
+        inception.inception_resnet_v2(inputs, num_classes)
+      with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
+        inception.inception_resnet_v2(inputs, num_classes)
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+        self.assertDeviceEqual(v.device, '/cpu:0')
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+        self.assertDeviceEqual(v.device, '/gpu:0')
+
+  def testHalfSizeImages(self):
+    batch_size = 5
+    height, width = 150, 150
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, end_points = inception.inception_resnet_v2(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      pre_pool = end_points['PrePool']
+      self.assertListEqual(pre_pool.get_shape().as_list(),
+                           [batch_size, 3, 3, 1536])
+
+  def testUnknownBatchSize(self):
+    batch_size = 1
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session() as sess:
+      inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+      logits, _ = inception.inception_resnet_v2(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionResnetV2/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [None, num_classes])
+      images = tf.random_uniform((batch_size, height, width, 3))
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEquals(output.shape, (batch_size, num_classes))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session() as sess:
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = inception.inception_resnet_v2(eval_inputs,
+                                                num_classes,
+                                                is_training=False)
+      predictions = tf.argmax(logits, 1)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (batch_size,))
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 5
+    eval_batch_size = 2
+    height, width = 150, 150
+    num_classes = 1000
+    with self.test_session() as sess:
+      train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+      inception.inception_resnet_v2(train_inputs, num_classes)
+      eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+      logits, _ = inception.inception_resnet_v2(eval_inputs,
+                                                num_classes,
+                                                is_training=False,
+                                                reuse=True)
+      predictions = tf.argmax(logits, 1)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (eval_batch_size,))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 340 - 0
slim/nets/inception_v1.py

@@ -0,0 +1,340 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition for inception v1 classification network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def inception_v1_base(inputs,
+                      final_endpoint='Mixed_5c',
+                      scope='InceptionV1'):
+  """Defines the Inception V1 base architecture.
+
+  This architecture is defined in:
+    Going deeper with convolutions
+    Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
+    Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
+    http://arxiv.org/pdf/1409.4842v1.pdf.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    final_endpoint: specifies the endpoint to construct the network up to. It
+      can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
+      'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
+      'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
+      'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b', 'Mixed_5c']
+    scope: Optional variable_scope.
+
+  Returns:
+    A dictionary from components of the network to the corresponding activation.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values.
+  """
+  end_points = {}
+  with tf.variable_scope(scope, 'InceptionV1', [inputs]):
+    with slim.arg_scope(
+        [slim.conv2d, slim.fully_connected],
+        weights_initializer=trunc_normal(0.01)):
+      with slim.arg_scope([slim.conv2d, slim.max_pool2d],
+                          stride=1, padding='SAME'):
+        end_point = 'Conv2d_1a_7x7'
+        net = slim.conv2d(inputs, 64, [7, 7], stride=2, scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+        end_point = 'MaxPool_2a_3x3'
+        net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+        end_point = 'Conv2d_2b_1x1'
+        net = slim.conv2d(net, 64, [1, 1], scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+        end_point = 'Conv2d_2c_3x3'
+        net = slim.conv2d(net, 192, [3, 3], scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+        end_point = 'MaxPool_3a_3x3'
+        net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_3b'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 128, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 32, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 32, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_3c'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 192, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'MaxPool_4a_3x3'
+        net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_4b'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 96, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 208, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 16, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 48, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_4c'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_4d'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 128, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 256, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 24, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_4e'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 112, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 144, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 288, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 64, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 64, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_4f'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'MaxPool_5a_2x2'
+        net = slim.max_pool2d(net, [2, 2], stride=2, scope=end_point)
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_5b'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 256, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 160, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 320, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 32, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0a_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+
+        end_point = 'Mixed_5c'
+        with tf.variable_scope(end_point):
+          with tf.variable_scope('Branch_0'):
+            branch_0 = slim.conv2d(net, 384, [1, 1], scope='Conv2d_0a_1x1')
+          with tf.variable_scope('Branch_1'):
+            branch_1 = slim.conv2d(net, 192, [1, 1], scope='Conv2d_0a_1x1')
+            branch_1 = slim.conv2d(branch_1, 384, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_2'):
+            branch_2 = slim.conv2d(net, 48, [1, 1], scope='Conv2d_0a_1x1')
+            branch_2 = slim.conv2d(branch_2, 128, [3, 3], scope='Conv2d_0b_3x3')
+          with tf.variable_scope('Branch_3'):
+            branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+            branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
+          net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if final_endpoint == end_point: return net, end_points
+    raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+
+def inception_v1(inputs,
+                 num_classes=1000,
+                 is_training=True,
+                 dropout_keep_prob=0.8,
+                 prediction_fn=slim.softmax,
+                 spatial_squeeze=True,
+                 reuse=None,
+                 scope='InceptionV1'):
+  """Defines the Inception V1 architecture.
+
+  This architecture is defined in:
+
+    Going deeper with convolutions
+    Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
+    Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
+    http://arxiv.org/pdf/1409.4842v1.pdf.
+
+  The default image size used to train this network is 224x224.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether is training or not.
+    dropout_keep_prob: the percentage of activation values that are retained.
+    prediction_fn: a function to get predictions out of logits.
+    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the pre-softmax activations, a tensor of size
+      [batch_size, num_classes]
+    end_points: a dictionary from components of the network to the corresponding
+      activation.
+  """
+  # Final pooling and prediction
+  with tf.variable_scope(scope, 'InceptionV1', [inputs, num_classes],
+                         reuse=reuse) as scope:
+    with slim.arg_scope([slim.batch_norm, slim.dropout],
+                        is_training=is_training):
+      net, end_points = inception_v1_base(inputs, scope=scope)
+      with tf.variable_scope('Logits'):
+        net = slim.avg_pool2d(net, [7, 7], stride=1, scope='MaxPool_0a_7x7')
+        net = slim.dropout(net,
+                           dropout_keep_prob, scope='Dropout_0b')
+        logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+                             normalizer_fn=None, scope='Conv2d_0c_1x1')
+        if spatial_squeeze:
+          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
+
+        end_points['Logits'] = logits
+        end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
+  return logits, end_points
+inception_v1.default_image_size = 224
+
+
+def inception_v1_arg_scope(weight_decay=0.00004,
+                           use_batch_norm=True):
+  """Defines the default InceptionV1 arg scope.
+
+  Note: Althougth the original paper didn't use batch_norm we found it useful.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+    use_batch_norm: "If `True`, batch_norm is applied after each convolution.
+
+  Returns:
+    An `arg_scope` to use for the inception v3 model.
+  """
+  batch_norm_params = {
+      # Decay for the moving averages.
+      'decay': 0.9997,
+      # epsilon to prevent 0s in variance.
+      'epsilon': 0.001,
+      # collection containing update_ops.
+      'updates_collections': tf.GraphKeys.UPDATE_OPS,
+  }
+  if use_batch_norm:
+    normalizer_fn = slim.batch_norm
+    normalizer_params = batch_norm_params
+  else:
+    normalizer_fn = None
+    normalizer_params = {}
+  # Set weight_decay for weights in Conv and FC layers.
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      weights_regularizer=slim.l2_regularizer(weight_decay)):
+    with slim.arg_scope(
+        [slim.conv2d],
+        weights_initializer=slim.variance_scaling_initializer(),
+        activation_fn=tf.nn.relu,
+        normalizer_fn=normalizer_fn,
+        normalizer_params=normalizer_params) as sc:
+      return sc

+ 210 - 0
slim/nets/inception_v1_test.py

@@ -0,0 +1,210 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nets.inception_v1."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from nets import inception
+
+slim = tf.contrib.slim
+
+
+class InceptionV1Test(tf.test.TestCase):
+
+  def testBuildClassificationNetwork(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v1(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue('Predictions' in end_points)
+    self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
+                         [batch_size, num_classes])
+
+  def testBuildBaseNetwork(self):
+    batch_size = 5
+    height, width = 224, 224
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    mixed_6c, end_points = inception.inception_v1_base(inputs)
+    self.assertTrue(mixed_6c.op.name.startswith('InceptionV1/Mixed_5c'))
+    self.assertListEqual(mixed_6c.get_shape().as_list(),
+                         [batch_size, 7, 7, 1024])
+    expected_endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
+                          'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b',
+                          'Mixed_3c', 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c',
+                          'Mixed_4d', 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2',
+                          'Mixed_5b', 'Mixed_5c']
+    self.assertItemsEqual(end_points.keys(), expected_endpoints)
+
+  def testBuildOnlyUptoFinalEndpoint(self):
+    batch_size = 5
+    height, width = 224, 224
+    endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
+                 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
+                 'MaxPool_4a_3x3', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d',
+                 'Mixed_4e', 'Mixed_4f', 'MaxPool_5a_2x2', 'Mixed_5b',
+                 'Mixed_5c']
+    for index, endpoint in enumerate(endpoints):
+      with tf.Graph().as_default():
+        inputs = tf.random_uniform((batch_size, height, width, 3))
+        out_tensor, end_points = inception.inception_v1_base(
+            inputs, final_endpoint=endpoint)
+        self.assertTrue(out_tensor.op.name.startswith(
+            'InceptionV1/' + endpoint))
+        self.assertItemsEqual(endpoints[:index+1], end_points)
+
+  def testBuildAndCheckAllEndPointsUptoMixed5c(self):
+    batch_size = 5
+    height, width = 224, 224
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v1_base(inputs,
+                                                final_endpoint='Mixed_5c')
+    endpoints_shapes = {'Conv2d_1a_7x7': [5, 112, 112, 64],
+                        'MaxPool_2a_3x3': [5, 56, 56, 64],
+                        'Conv2d_2b_1x1': [5, 56, 56, 64],
+                        'Conv2d_2c_3x3': [5, 56, 56, 192],
+                        'MaxPool_3a_3x3': [5, 28, 28, 192],
+                        'Mixed_3b': [5, 28, 28, 256],
+                        'Mixed_3c': [5, 28, 28, 480],
+                        'MaxPool_4a_3x3': [5, 14, 14, 480],
+                        'Mixed_4b': [5, 14, 14, 512],
+                        'Mixed_4c': [5, 14, 14, 512],
+                        'Mixed_4d': [5, 14, 14, 512],
+                        'Mixed_4e': [5, 14, 14, 528],
+                        'Mixed_4f': [5, 14, 14, 832],
+                        'MaxPool_5a_2x2': [5, 7, 7, 832],
+                        'Mixed_5b': [5, 7, 7, 832],
+                        'Mixed_5c': [5, 7, 7, 1024]}
+
+    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
+    for endpoint_name in endpoints_shapes:
+      expected_shape = endpoints_shapes[endpoint_name]
+      self.assertTrue(endpoint_name in end_points)
+      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
+                           expected_shape)
+
+  def testModelHasExpectedNumberOfParameters(self):
+    batch_size = 5
+    height, width = 224, 224
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    with slim.arg_scope(inception.inception_v1_arg_scope()):
+      inception.inception_v1_base(inputs)
+    total_params, _ = slim.model_analyzer.analyze_vars(
+        slim.get_model_variables())
+    self.assertAlmostEqual(5607184, total_params)
+
+  def testHalfSizeImages(self):
+    batch_size = 5
+    height, width = 112, 112
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    mixed_5c, _ = inception.inception_v1_base(inputs)
+    self.assertTrue(mixed_5c.op.name.startswith('InceptionV1/Mixed_5c'))
+    self.assertListEqual(mixed_5c.get_shape().as_list(),
+                         [batch_size, 4, 4, 1024])
+
+  def testUnknownImageShape(self):
+    tf.reset_default_graph()
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
+    with self.test_session() as sess:
+      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
+      logits, end_points = inception.inception_v1(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      pre_pool = end_points['Mixed_5c']
+      feed_dict = {inputs: input_np}
+      tf.initialize_all_variables().run()
+      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
+      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
+
+  def testUnknowBatchSize(self):
+    batch_size = 1
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+    logits, _ = inception.inception_v1(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV1/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [None, num_classes])
+    images = tf.random_uniform((batch_size, height, width, 3))
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEquals(output.shape, (batch_size, num_classes))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+
+    eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, _ = inception.inception_v1(eval_inputs, num_classes,
+                                       is_training=False)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (batch_size,))
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 5
+    eval_batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+
+    train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+    inception.inception_v1(train_inputs, num_classes)
+    eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+    logits, _ = inception.inception_v1(eval_inputs, num_classes, reuse=True)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (eval_batch_size,))
+
+  def testLogitsNotSqueezed(self):
+    num_classes = 25
+    images = tf.random_uniform([1, 224, 224, 3])
+    logits, _ = inception.inception_v1(images,
+                                       num_classes=num_classes,
+                                       spatial_squeeze=False)
+
+    with self.test_session() as sess:
+      tf.initialize_all_variables().run()
+      logits_out = sess.run(logits)
+      self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 545 - 0
slim/nets/inception_v2.py

@@ -0,0 +1,545 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition for inception v2 classification network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def inception_v2_base(inputs,
+                      final_endpoint='Mixed_5c',
+                      min_depth=16,
+                      depth_multiplier=1.0,
+                      scope=None):
+  """Inception v2 (6a2).
+
+  Constructs an Inception v2 network from inputs to the given final endpoint.
+  This method can construct the network up to the layer inception(5b) as
+  described in http://arxiv.org/abs/1502.03167.
+
+  Args:
+    inputs: a tensor of shape [batch_size, height, width, channels].
+    final_endpoint: specifies the endpoint to construct the network up to. It
+      can be one of ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
+      'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c', 'Mixed_4a',
+      'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a', 'Mixed_5b',
+      'Mixed_5c'].
+    min_depth: Minimum depth value (number of channels) for all convolution ops.
+      Enforced when depth_multiplier < 1, and not an active constraint when
+      depth_multiplier >= 1.
+    depth_multiplier: Float multiplier for the depth (number of channels)
+      for all convolution ops. The value must be greater than zero. Typical
+      usage will be to set this value in (0, 1) to reduce the number of
+      parameters or computation cost of the model.
+    scope: Optional variable_scope.
+
+  Returns:
+    tensor_out: output tensor corresponding to the final_endpoint.
+    end_points: a set of activations for external use, for example summaries or
+                losses.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values,
+                or depth_multiplier <= 0
+  """
+
+  # end_points will collect relevant activations for external use, for example
+  # summaries or losses.
+  end_points = {}
+
+  # Used to find thinned depths for each layer.
+  if depth_multiplier <= 0:
+    raise ValueError('depth_multiplier is not greater than zero.')
+  depth = lambda d: max(int(d * depth_multiplier), min_depth)
+
+  with tf.variable_scope(scope, 'InceptionV2', [inputs]):
+    with slim.arg_scope(
+        [slim.conv2d, slim.max_pool2d, slim.avg_pool2d, slim.separable_conv2d],
+        stride=1, padding='SAME'):
+
+      # Note that sizes in the comments below assume an input spatial size of
+      # 224x224, however, the inputs can be of any size greater 32x32.
+
+      # 224 x 224 x 3
+      end_point = 'Conv2d_1a_7x7'
+      # depthwise_multiplier here is different from depth_multiplier.
+      # depthwise_multiplier determines the output channels of the initial
+      # depthwise conv (see docs for tf.nn.separable_conv2d), while
+      # depth_multiplier controls the # channels of the subsequent 1x1
+      # convolution. Must have
+      #   in_channels * depthwise_multipler <= out_channels
+      # so that the separable convolution is not overparameterized.
+      depthwise_multiplier = min(int(depth(64) / 3), 8)
+      net = slim.separable_conv2d(
+          inputs, depth(64), [7, 7], depth_multiplier=depthwise_multiplier,
+          stride=2, weights_initializer=trunc_normal(1.0),
+          scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 112 x 112 x 64
+      end_point = 'MaxPool_2a_3x3'
+      net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 56 x 56 x 64
+      end_point = 'Conv2d_2b_1x1'
+      net = slim.conv2d(net, depth(64), [1, 1], scope=end_point,
+                        weights_initializer=trunc_normal(0.1))
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 56 x 56 x 64
+      end_point = 'Conv2d_2c_3x3'
+      net = slim.conv2d(net, depth(192), [3, 3], scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 56 x 56 x 192
+      end_point = 'MaxPool_3a_3x3'
+      net = slim.max_pool2d(net, [3, 3], scope=end_point, stride=2)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 28 x 28 x 192
+      # Inception module.
+      end_point = 'Mixed_3b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(64), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(32), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 28 x 28 x 256
+      end_point = 'Mixed_3c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 28 x 28 x 320
+      end_point = 'Mixed_4a'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(
+              net, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_0 = slim.conv2d(branch_0, depth(160), [3, 3], stride=2,
+                                 scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(
+              branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
+          branch_1 = slim.conv2d(
+              branch_1, depth(96), [3, 3], stride=2, scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.max_pool2d(
+              net, [3, 3], stride=2, scope='MaxPool_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1, branch_2])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 14 x 14 x 576
+      end_point = 'Mixed_4b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(224), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(64), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(
+              branch_1, depth(96), [3, 3], scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(96), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 14 x 14 x 576
+      end_point = 'Mixed_4c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(96), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(128), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(96), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(128), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 14 x 14 x 576
+      end_point = 'Mixed_4d'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(160), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(160), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(160), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(96), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+
+      # 14 x 14 x 576
+      end_point = 'Mixed_4e'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(96), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(192), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(160), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(192), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(96), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 14 x 14 x 576
+      end_point = 'Mixed_5a'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(
+              net, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_0 = slim.conv2d(branch_0, depth(192), [3, 3], stride=2,
+                                 scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(192), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(256), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_1 = slim.conv2d(branch_1, depth(256), [3, 3], stride=2,
+                                 scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.max_pool2d(net, [3, 3], stride=2,
+                                     scope='MaxPool_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1, branch_2])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+      # 7 x 7 x 1024
+      end_point = 'Mixed_5b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(192), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(320), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(160), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+
+      # 7 x 7 x 1024
+      end_point = 'Mixed_5c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(352), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(
+              net, depth(192), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(320), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(
+              net, depth(192), [1, 1],
+              weights_initializer=trunc_normal(0.09),
+              scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(224), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.max_pool2d(net, [3, 3], scope='MaxPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(128), [1, 1],
+              weights_initializer=trunc_normal(0.1),
+              scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+        end_points[end_point] = net
+        if end_point == final_endpoint: return net, end_points
+    raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+
+def inception_v2(inputs,
+                 num_classes=1000,
+                 is_training=True,
+                 dropout_keep_prob=0.8,
+                 min_depth=16,
+                 depth_multiplier=1.0,
+                 prediction_fn=slim.softmax,
+                 spatial_squeeze=True,
+                 reuse=None,
+                 scope='InceptionV2'):
+  """Inception v2 model for classification.
+
+  Constructs an Inception v2 network for classification as described in
+  http://arxiv.org/abs/1502.03167.
+
+  The default image size used to train this network is 224x224.
+
+  Args:
+    inputs: a tensor of shape [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether is training or not.
+    dropout_keep_prob: the percentage of activation values that are retained.
+    min_depth: Minimum depth value (number of channels) for all convolution ops.
+      Enforced when depth_multiplier < 1, and not an active constraint when
+      depth_multiplier >= 1.
+    depth_multiplier: Float multiplier for the depth (number of channels)
+      for all convolution ops. The value must be greater than zero. Typical
+      usage will be to set this value in (0, 1) to reduce the number of
+      parameters or computation cost of the model.
+    prediction_fn: a function to get predictions out of logits.
+    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the pre-softmax activations, a tensor of size
+      [batch_size, num_classes]
+    end_points: a dictionary from components of the network to the corresponding
+      activation.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values,
+                or depth_multiplier <= 0
+  """
+  if depth_multiplier <= 0:
+    raise ValueError('depth_multiplier is not greater than zero.')
+
+  # Final pooling and prediction
+  with tf.variable_scope(scope, 'InceptionV2', [inputs, num_classes],
+                         reuse=reuse) as scope:
+    with slim.arg_scope([slim.batch_norm, slim.dropout],
+                        is_training=is_training):
+      net, end_points = inception_v2_base(
+          inputs, scope=scope, min_depth=min_depth,
+          depth_multiplier=depth_multiplier)
+      with tf.variable_scope('Logits'):
+        kernel_size = _reduced_kernel_size_for_small_input(net, [7, 7])
+        net = slim.avg_pool2d(net, kernel_size, padding='VALID',
+                              scope='AvgPool_1a_{}x{}'.format(*kernel_size))
+        # 1 x 1 x 1024
+        net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
+        logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+                             normalizer_fn=None, scope='Conv2d_1c_1x1')
+        if spatial_squeeze:
+          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
+      end_points['Logits'] = logits
+      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
+  return logits, end_points
+inception_v2.default_image_size = 224
+
+
+def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
+  """Define kernel size which is automatically reduced for small input.
+
+  If the shape of the input images is unknown at graph construction time this
+  function assumes that the input images are is large enough.
+
+  Args:
+    input_tensor: input tensor of size [batch_size, height, width, channels].
+    kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
+
+  Returns:
+    a tensor with the kernel size.
+
+  TODO(jrru): Make this function work with unknown shapes. Theoretically, this
+  can be done with the code below. Problems are two-fold: (1) If the shape was
+  known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot
+  handle tensors that define the kernel size.
+      shape = tf.shape(input_tensor)
+      return = tf.pack([tf.minimum(shape[1], kernel_size[0]),
+                        tf.minimum(shape[2], kernel_size[1])])
+
+  """
+  shape = input_tensor.get_shape().as_list()
+  if shape[1] is None or shape[2] is None:
+    kernel_size_out = kernel_size
+  else:
+    kernel_size_out = [min(shape[1], kernel_size[0]),
+                       min(shape[2], kernel_size[1])]
+  return kernel_size_out
+
+
+def inception_v2_arg_scope(weight_decay=0.00004):
+  """Defines the default InceptionV2 arg scope.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+
+  Returns:
+    An `arg_scope` to use for the inception v3 model.
+  """
+  batch_norm_params = {
+      # Decay for the moving averages.
+      'decay': 0.9997,
+      # epsilon to prevent 0s in variance.
+      'epsilon': 0.001,
+      # collection containing update_ops.
+      'updates_collections': tf.GraphKeys.UPDATE_OPS,
+  }
+
+  # Set weight_decay for weights in Conv and FC layers.
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      weights_regularizer=slim.l2_regularizer(weight_decay)):
+    with slim.arg_scope(
+        [slim.conv2d],
+        weights_initializer=slim.variance_scaling_initializer(),
+        activation_fn=tf.nn.relu,
+        normalizer_fn=slim.batch_norm,
+        normalizer_params=batch_norm_params) as sc:
+      return sc

+ 262 - 0
slim/nets/inception_v2_test.py

@@ -0,0 +1,262 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nets.inception_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from nets import inception
+
+slim = tf.contrib.slim
+
+
+class InceptionV2Test(tf.test.TestCase):
+
+  def testBuildClassificationNetwork(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v2(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue('Predictions' in end_points)
+    self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
+                         [batch_size, num_classes])
+
+  def testBuildBaseNetwork(self):
+    batch_size = 5
+    height, width = 224, 224
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    mixed_5c, end_points = inception.inception_v2_base(inputs)
+    self.assertTrue(mixed_5c.op.name.startswith('InceptionV2/Mixed_5c'))
+    self.assertListEqual(mixed_5c.get_shape().as_list(),
+                         [batch_size, 7, 7, 1024])
+    expected_endpoints = ['Mixed_3b', 'Mixed_3c', 'Mixed_4a', 'Mixed_4b',
+                          'Mixed_4c', 'Mixed_4d', 'Mixed_4e', 'Mixed_5a',
+                          'Mixed_5b', 'Mixed_5c', 'Conv2d_1a_7x7',
+                          'MaxPool_2a_3x3', 'Conv2d_2b_1x1', 'Conv2d_2c_3x3',
+                          'MaxPool_3a_3x3']
+    self.assertItemsEqual(end_points.keys(), expected_endpoints)
+
+  def testBuildOnlyUptoFinalEndpoint(self):
+    batch_size = 5
+    height, width = 224, 224
+    endpoints = ['Conv2d_1a_7x7', 'MaxPool_2a_3x3', 'Conv2d_2b_1x1',
+                 'Conv2d_2c_3x3', 'MaxPool_3a_3x3', 'Mixed_3b', 'Mixed_3c',
+                 'Mixed_4a', 'Mixed_4b', 'Mixed_4c', 'Mixed_4d', 'Mixed_4e',
+                 'Mixed_5a', 'Mixed_5b', 'Mixed_5c']
+    for index, endpoint in enumerate(endpoints):
+      with tf.Graph().as_default():
+        inputs = tf.random_uniform((batch_size, height, width, 3))
+        out_tensor, end_points = inception.inception_v2_base(
+            inputs, final_endpoint=endpoint)
+        self.assertTrue(out_tensor.op.name.startswith(
+            'InceptionV2/' + endpoint))
+        self.assertItemsEqual(endpoints[:index+1], end_points)
+
+  def testBuildAndCheckAllEndPointsUptoMixed5c(self):
+    batch_size = 5
+    height, width = 224, 224
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v2_base(inputs,
+                                                final_endpoint='Mixed_5c')
+    endpoints_shapes = {'Mixed_3b': [batch_size, 28, 28, 256],
+                        'Mixed_3c': [batch_size, 28, 28, 320],
+                        'Mixed_4a': [batch_size, 14, 14, 576],
+                        'Mixed_4b': [batch_size, 14, 14, 576],
+                        'Mixed_4c': [batch_size, 14, 14, 576],
+                        'Mixed_4d': [batch_size, 14, 14, 576],
+                        'Mixed_4e': [batch_size, 14, 14, 576],
+                        'Mixed_5a': [batch_size, 7, 7, 1024],
+                        'Mixed_5b': [batch_size, 7, 7, 1024],
+                        'Mixed_5c': [batch_size, 7, 7, 1024],
+                        'Conv2d_1a_7x7': [batch_size, 112, 112, 64],
+                        'MaxPool_2a_3x3': [batch_size, 56, 56, 64],
+                        'Conv2d_2b_1x1': [batch_size, 56, 56, 64],
+                        'Conv2d_2c_3x3': [batch_size, 56, 56, 192],
+                        'MaxPool_3a_3x3': [batch_size, 28, 28, 192]}
+    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
+    for endpoint_name in endpoints_shapes:
+      expected_shape = endpoints_shapes[endpoint_name]
+      self.assertTrue(endpoint_name in end_points)
+      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
+                           expected_shape)
+
+  def testModelHasExpectedNumberOfParameters(self):
+    batch_size = 5
+    height, width = 224, 224
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    with slim.arg_scope(inception.inception_v2_arg_scope()):
+      inception.inception_v2_base(inputs)
+    total_params, _ = slim.model_analyzer.analyze_vars(
+        slim.get_model_variables())
+    self.assertAlmostEqual(10173112, total_params)
+
+  def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v2(inputs, num_classes)
+
+    endpoint_keys = [key for key in end_points.keys()
+                     if key.startswith('Mixed') or key.startswith('Conv')]
+
+    _, end_points_with_multiplier = inception.inception_v2(
+        inputs, num_classes, scope='depth_multiplied_net',
+        depth_multiplier=0.5)
+
+    for key in endpoint_keys:
+      original_depth = end_points[key].get_shape().as_list()[3]
+      new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
+      self.assertEqual(0.5 * original_depth, new_depth)
+
+  def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v2(inputs, num_classes)
+
+    endpoint_keys = [key for key in end_points.keys()
+                     if key.startswith('Mixed') or key.startswith('Conv')]
+
+    _, end_points_with_multiplier = inception.inception_v2(
+        inputs, num_classes, scope='depth_multiplied_net',
+        depth_multiplier=2.0)
+
+    for key in endpoint_keys:
+      original_depth = end_points[key].get_shape().as_list()[3]
+      new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
+      self.assertEqual(2.0 * original_depth, new_depth)
+
+  def testRaiseValueErrorWithInvalidDepthMultiplier(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    with self.assertRaises(ValueError):
+      _ = inception.inception_v2(inputs, num_classes, depth_multiplier=-0.1)
+    with self.assertRaises(ValueError):
+      _ = inception.inception_v2(inputs, num_classes, depth_multiplier=0.0)
+
+  def testHalfSizeImages(self):
+    batch_size = 5
+    height, width = 112, 112
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v2(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    pre_pool = end_points['Mixed_5c']
+    self.assertListEqual(pre_pool.get_shape().as_list(),
+                         [batch_size, 4, 4, 1024])
+
+  def testUnknownImageShape(self):
+    tf.reset_default_graph()
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
+    with self.test_session() as sess:
+      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
+      logits, end_points = inception.inception_v2(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      pre_pool = end_points['Mixed_5c']
+      feed_dict = {inputs: input_np}
+      tf.initialize_all_variables().run()
+      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
+      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 7, 7, 1024])
+
+  def testUnknowBatchSize(self):
+    batch_size = 1
+    height, width = 224, 224
+    num_classes = 1000
+
+    inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+    logits, _ = inception.inception_v2(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV2/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [None, num_classes])
+    images = tf.random_uniform((batch_size, height, width, 3))
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEquals(output.shape, (batch_size, num_classes))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+
+    eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, _ = inception.inception_v2(eval_inputs, num_classes,
+                                       is_training=False)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (batch_size,))
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 5
+    eval_batch_size = 2
+    height, width = 150, 150
+    num_classes = 1000
+
+    train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+    inception.inception_v2(train_inputs, num_classes)
+    eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+    logits, _ = inception.inception_v2(eval_inputs, num_classes, reuse=True)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (eval_batch_size,))
+
+  def testLogitsNotSqueezed(self):
+    num_classes = 25
+    images = tf.random_uniform([1, 224, 224, 3])
+    logits, _ = inception.inception_v2(images,
+                                       num_classes=num_classes,
+                                       spatial_squeeze=False)
+
+    with self.test_session() as sess:
+      tf.initialize_all_variables().run()
+      logits_out = sess.run(logits)
+      self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 587 - 0
slim/nets/inception_v3.py

@@ -0,0 +1,587 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition for inception v3 classification network."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def inception_v3_base(inputs,
+                      final_endpoint='Mixed_7c',
+                      min_depth=16,
+                      depth_multiplier=1.0,
+                      scope=None):
+  """Inception model from http://arxiv.org/abs/1512.00567.
+
+  Constructs an Inception v3 network from inputs to the given final endpoint.
+  This method can construct the network up to the final inception block
+  Mixed_7c.
+
+  Note that the names of the layers in the paper do not correspond to the names
+  of the endpoints registered by this function although they build the same
+  network.
+
+  Here is a mapping from the old_names to the new names:
+  Old name          | New name
+  =======================================
+  conv0             | Conv2d_1a_3x3
+  conv1             | Conv2d_2a_3x3
+  conv2             | Conv2d_2b_3x3
+  pool1             | MaxPool_3a_3x3
+  conv3             | Conv2d_3b_1x1
+  conv4             | Conv2d_4a_3x3
+  pool2             | MaxPool_5a_3x3
+  mixed_35x35x256a  | Mixed_5b
+  mixed_35x35x288a  | Mixed_5c
+  mixed_35x35x288b  | Mixed_5d
+  mixed_17x17x768a  | Mixed_6a
+  mixed_17x17x768b  | Mixed_6b
+  mixed_17x17x768c  | Mixed_6c
+  mixed_17x17x768d  | Mixed_6d
+  mixed_17x17x768e  | Mixed_6e
+  mixed_8x8x1280a   | Mixed_7a
+  mixed_8x8x2048a   | Mixed_7b
+  mixed_8x8x2048b   | Mixed_7c
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    final_endpoint: specifies the endpoint to construct the network up to. It
+      can be one of ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
+      'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3', 'MaxPool_5a_3x3',
+      'Mixed_5b', 'Mixed_5c', 'Mixed_5d', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c',
+      'Mixed_6d', 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c'].
+    min_depth: Minimum depth value (number of channels) for all convolution ops.
+      Enforced when depth_multiplier < 1, and not an active constraint when
+      depth_multiplier >= 1.
+    depth_multiplier: Float multiplier for the depth (number of channels)
+      for all convolution ops. The value must be greater than zero. Typical
+      usage will be to set this value in (0, 1) to reduce the number of
+      parameters or computation cost of the model.
+    scope: Optional variable_scope.
+
+  Returns:
+    tensor_out: output tensor corresponding to the final_endpoint.
+    end_points: a set of activations for external use, for example summaries or
+                losses.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values,
+                or depth_multiplier <= 0
+  """
+  # end_points will collect relevant activations for external use, for example
+  # summaries or losses.
+  end_points = {}
+
+  if depth_multiplier <= 0:
+    raise ValueError('depth_multiplier is not greater than zero.')
+  depth = lambda d: max(int(d * depth_multiplier), min_depth)
+
+  with tf.variable_scope(scope, 'InceptionV3', [inputs]):
+    with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                        stride=1, padding='VALID'):
+      # 299 x 299 x 3
+      end_point = 'Conv2d_1a_3x3'
+      net = slim.conv2d(inputs, depth(32), [3, 3], stride=2, scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 149 x 149 x 32
+      end_point = 'Conv2d_2a_3x3'
+      net = slim.conv2d(net, depth(32), [3, 3], scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 147 x 147 x 32
+      end_point = 'Conv2d_2b_3x3'
+      net = slim.conv2d(net, depth(64), [3, 3], padding='SAME', scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 147 x 147 x 64
+      end_point = 'MaxPool_3a_3x3'
+      net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 73 x 73 x 64
+      end_point = 'Conv2d_3b_1x1'
+      net = slim.conv2d(net, depth(80), [1, 1], scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 73 x 73 x 80.
+      end_point = 'Conv2d_4a_3x3'
+      net = slim.conv2d(net, depth(192), [3, 3], scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 71 x 71 x 192.
+      end_point = 'MaxPool_5a_3x3'
+      net = slim.max_pool2d(net, [3, 3], stride=2, scope=end_point)
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # 35 x 35 x 192.
+
+    # Inception blocks
+    with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                        stride=1, padding='SAME'):
+      # mixed: 35 x 35 x 256.
+      end_point = 'Mixed_5b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(64), [5, 5],
+                                 scope='Conv2d_0b_5x5')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(32), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_1: 35 x 35 x 288.
+      end_point = 'Mixed_5c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0b_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(64), [5, 5],
+                                 scope='Conv_1_0c_5x5')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(64), [1, 1],
+                                 scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(64), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_2: 35 x 35 x 288.
+      end_point = 'Mixed_5d'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(48), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(64), [5, 5],
+                                 scope='Conv2d_0b_5x5')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_2 = slim.conv2d(branch_2, depth(96), [3, 3],
+                                 scope='Conv2d_0c_3x3')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(64), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_3: 17 x 17 x 768.
+      end_point = 'Mixed_6a'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(384), [3, 3], stride=2,
+                                 padding='VALID', scope='Conv2d_1a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(64), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3],
+                                 scope='Conv2d_0b_3x3')
+          branch_1 = slim.conv2d(branch_1, depth(96), [3, 3], stride=2,
+                                 padding='VALID', scope='Conv2d_1a_1x1')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
+                                     scope='MaxPool_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1, branch_2])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed4: 17 x 17 x 768.
+      end_point = 'Mixed_6b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(128), [1, 7],
+                                 scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, depth(192), [7, 1],
+                                 scope='Conv2d_0c_7x1')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(128), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(128), [7, 1],
+                                 scope='Conv2d_0b_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(128), [1, 7],
+                                 scope='Conv2d_0c_1x7')
+          branch_2 = slim.conv2d(branch_2, depth(128), [7, 1],
+                                 scope='Conv2d_0d_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [1, 7],
+                                 scope='Conv2d_0e_1x7')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(192), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_5: 17 x 17 x 768.
+      end_point = 'Mixed_6c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(160), [1, 7],
+                                 scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, depth(192), [7, 1],
+                                 scope='Conv2d_0c_7x1')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(160), [7, 1],
+                                 scope='Conv2d_0b_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(160), [1, 7],
+                                 scope='Conv2d_0c_1x7')
+          branch_2 = slim.conv2d(branch_2, depth(160), [7, 1],
+                                 scope='Conv2d_0d_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [1, 7],
+                                 scope='Conv2d_0e_1x7')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(192), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # mixed_6: 17 x 17 x 768.
+      end_point = 'Mixed_6d'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(160), [1, 7],
+                                 scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, depth(192), [7, 1],
+                                 scope='Conv2d_0c_7x1')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(160), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(160), [7, 1],
+                                 scope='Conv2d_0b_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(160), [1, 7],
+                                 scope='Conv2d_0c_1x7')
+          branch_2 = slim.conv2d(branch_2, depth(160), [7, 1],
+                                 scope='Conv2d_0d_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [1, 7],
+                                 scope='Conv2d_0e_1x7')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(192), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_7: 17 x 17 x 768.
+      end_point = 'Mixed_6e'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(192), [1, 7],
+                                 scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, depth(192), [7, 1],
+                                 scope='Conv2d_0c_7x1')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [7, 1],
+                                 scope='Conv2d_0b_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [1, 7],
+                                 scope='Conv2d_0c_1x7')
+          branch_2 = slim.conv2d(branch_2, depth(192), [7, 1],
+                                 scope='Conv2d_0d_7x1')
+          branch_2 = slim.conv2d(branch_2, depth(192), [1, 7],
+                                 scope='Conv2d_0e_1x7')
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(branch_3, depth(192), [1, 1],
+                                 scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_8: 8 x 8 x 1280.
+      end_point = 'Mixed_7a'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+          branch_0 = slim.conv2d(branch_0, depth(320), [3, 3], stride=2,
+                                 padding='VALID', scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(192), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, depth(192), [1, 7],
+                                 scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, depth(192), [7, 1],
+                                 scope='Conv2d_0c_7x1')
+          branch_1 = slim.conv2d(branch_1, depth(192), [3, 3], stride=2,
+                                 padding='VALID', scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
+                                     scope='MaxPool_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1, branch_2])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+      # mixed_9: 8 x 8 x 2048.
+      end_point = 'Mixed_7b'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = tf.concat(3, [
+              slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'),
+              slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0b_3x1')])
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(
+              branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3')
+          branch_2 = tf.concat(3, [
+              slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'),
+              slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')])
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+
+      # mixed_10: 8 x 8 x 2048.
+      end_point = 'Mixed_7c'
+      with tf.variable_scope(end_point):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, depth(320), [1, 1], scope='Conv2d_0a_1x1')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, depth(384), [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = tf.concat(3, [
+              slim.conv2d(branch_1, depth(384), [1, 3], scope='Conv2d_0b_1x3'),
+              slim.conv2d(branch_1, depth(384), [3, 1], scope='Conv2d_0c_3x1')])
+        with tf.variable_scope('Branch_2'):
+          branch_2 = slim.conv2d(net, depth(448), [1, 1], scope='Conv2d_0a_1x1')
+          branch_2 = slim.conv2d(
+              branch_2, depth(384), [3, 3], scope='Conv2d_0b_3x3')
+          branch_2 = tf.concat(3, [
+              slim.conv2d(branch_2, depth(384), [1, 3], scope='Conv2d_0c_1x3'),
+              slim.conv2d(branch_2, depth(384), [3, 1], scope='Conv2d_0d_3x1')])
+        with tf.variable_scope('Branch_3'):
+          branch_3 = slim.avg_pool2d(net, [3, 3], scope='AvgPool_0a_3x3')
+          branch_3 = slim.conv2d(
+              branch_3, depth(192), [1, 1], scope='Conv2d_0b_1x1')
+        net = tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+      end_points[end_point] = net
+      if end_point == final_endpoint: return net, end_points
+    raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+
+def inception_v3(inputs,
+                 num_classes=1000,
+                 is_training=True,
+                 dropout_keep_prob=0.8,
+                 min_depth=16,
+                 depth_multiplier=1.0,
+                 prediction_fn=slim.softmax,
+                 spatial_squeeze=True,
+                 reuse=None,
+                 scope='InceptionV3'):
+  """Inception model from http://arxiv.org/abs/1512.00567.
+
+  "Rethinking the Inception Architecture for Computer Vision"
+
+  Christian Szegedy, Vincent Vanhoucke, Sergey Ioffe, Jonathon Shlens,
+  Zbigniew Wojna.
+
+  With the default arguments this method constructs the exact model defined in
+  the paper. However, one can experiment with variations of the inception_v3
+  network by changing arguments dropout_keep_prob, min_depth and
+  depth_multiplier.
+
+  The default image size used to train this network is 299x299.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether is training or not.
+    dropout_keep_prob: the percentage of activation values that are retained.
+    min_depth: Minimum depth value (number of channels) for all convolution ops.
+      Enforced when depth_multiplier < 1, and not an active constraint when
+      depth_multiplier >= 1.
+    depth_multiplier: Float multiplier for the depth (number of channels)
+      for all convolution ops. The value must be greater than zero. Typical
+      usage will be to set this value in (0, 1) to reduce the number of
+      parameters or computation cost of the model.
+    prediction_fn: a function to get predictions out of logits.
+    spatial_squeeze: if True, logits is of shape is [B, C], if false logits is
+        of shape [B, 1, 1, C], where B is batch_size and C is number of classes.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the pre-softmax activations, a tensor of size
+      [batch_size, num_classes]
+    end_points: a dictionary from components of the network to the corresponding
+      activation.
+
+  Raises:
+    ValueError: if 'depth_multiplier' is less than or equal to zero.
+  """
+  if depth_multiplier <= 0:
+    raise ValueError('depth_multiplier is not greater than zero.')
+  depth = lambda d: max(int(d * depth_multiplier), min_depth)
+
+  with tf.variable_scope(scope, 'InceptionV3', [inputs, num_classes],
+                         reuse=reuse) as scope:
+    with slim.arg_scope([slim.batch_norm, slim.dropout],
+                        is_training=is_training):
+      net, end_points = inception_v3_base(
+          inputs, scope=scope, min_depth=min_depth,
+          depth_multiplier=depth_multiplier)
+
+      # Auxiliary Head logits
+      with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                          stride=1, padding='SAME'):
+        aux_logits = end_points['Mixed_6e']
+        with tf.variable_scope('AuxLogits'):
+          aux_logits = slim.avg_pool2d(
+              aux_logits, [5, 5], stride=3, padding='VALID',
+              scope='AvgPool_1a_5x5')
+          aux_logits = slim.conv2d(aux_logits, depth(128), [1, 1],
+                                   scope='Conv2d_1b_1x1')
+
+          # Shape of feature map before the final layer.
+          kernel_size = _reduced_kernel_size_for_small_input(
+              aux_logits, [5, 5])
+          aux_logits = slim.conv2d(
+              aux_logits, depth(768), kernel_size,
+              weights_initializer=trunc_normal(0.01),
+              padding='VALID', scope='Conv2d_2a_{}x{}'.format(*kernel_size))
+          aux_logits = slim.conv2d(
+              aux_logits, num_classes, [1, 1], activation_fn=None,
+              normalizer_fn=None, weights_initializer=trunc_normal(0.001),
+              scope='Conv2d_2b_1x1')
+          if spatial_squeeze:
+            aux_logits = tf.squeeze(aux_logits, [1, 2], name='SpatialSqueeze')
+          end_points['AuxLogits'] = aux_logits
+
+      # Final pooling and prediction
+      with tf.variable_scope('Logits'):
+        kernel_size = _reduced_kernel_size_for_small_input(net, [8, 8])
+        net = slim.avg_pool2d(net, kernel_size, padding='VALID',
+                              scope='AvgPool_1a_{}x{}'.format(*kernel_size))
+        # 1 x 1 x 2048
+        net = slim.dropout(net, keep_prob=dropout_keep_prob, scope='Dropout_1b')
+        end_points['PreLogits'] = net
+        # 2048
+        logits = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+                             normalizer_fn=None, scope='Conv2d_1c_1x1')
+        if spatial_squeeze:
+          logits = tf.squeeze(logits, [1, 2], name='SpatialSqueeze')
+        # 1000
+      end_points['Logits'] = logits
+      end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
+  return logits, end_points
+inception_v3.default_image_size = 299
+
+
+def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
+  """Define kernel size which is automatically reduced for small input.
+
+  If the shape of the input images is unknown at graph construction time this
+  function assumes that the input images are is large enough.
+
+  Args:
+    input_tensor: input tensor of size [batch_size, height, width, channels].
+    kernel_size: desired kernel size of length 2: [kernel_height, kernel_width]
+
+  Returns:
+    a tensor with the kernel size.
+
+  TODO(jrru): Make this function work with unknown shapes. Theoretically, this
+  can be done with the code below. Problems are two-fold: (1) If the shape was
+  known, it will be lost. (2) inception.slim.ops._two_element_tuple cannot
+  handle tensors that define the kernel size.
+      shape = tf.shape(input_tensor)
+      return = tf.pack([tf.minimum(shape[1], kernel_size[0]),
+                        tf.minimum(shape[2], kernel_size[1])])
+
+  """
+  shape = input_tensor.get_shape().as_list()
+  if shape[1] is None or shape[2] is None:
+    kernel_size_out = kernel_size
+  else:
+    kernel_size_out = [min(shape[1], kernel_size[0]),
+                       min(shape[2], kernel_size[1])]
+  return kernel_size_out
+
+
+def inception_v3_arg_scope(weight_decay=0.00004,
+                           stddev=0.1):
+  """Defines the default InceptionV3 arg scope.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+    stddev: The standard deviation of the trunctated normal weight initializer.
+
+  Returns:
+    An `arg_scope` to use for the inception v3 model.
+  """
+  batch_norm_params = {
+      # Decay for the moving averages.
+      'decay': 0.9997,
+      # epsilon to prevent 0s in variance.
+      'epsilon': 0.001,
+      # collection containing update_ops.
+      'updates_collections': tf.GraphKeys.UPDATE_OPS,
+  }
+
+  # Set weight_decay for weights in Conv and FC layers.
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      weights_regularizer=slim.l2_regularizer(weight_decay)):
+    with slim.arg_scope(
+        [slim.conv2d],
+        weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
+        activation_fn=tf.nn.relu,
+        normalizer_fn=slim.batch_norm,
+        normalizer_params=batch_norm_params) as sc:
+      return sc

+ 292 - 0
slim/nets/inception_v3_test.py

@@ -0,0 +1,292 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for nets.inception_v1."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from nets import inception
+
+slim = tf.contrib.slim
+
+
+class InceptionV3Test(tf.test.TestCase):
+
+  def testBuildClassificationNetwork(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v3(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue('Predictions' in end_points)
+    self.assertListEqual(end_points['Predictions'].get_shape().as_list(),
+                         [batch_size, num_classes])
+
+  def testBuildBaseNetwork(self):
+    batch_size = 5
+    height, width = 299, 299
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    final_endpoint, end_points = inception.inception_v3_base(inputs)
+    self.assertTrue(final_endpoint.op.name.startswith(
+        'InceptionV3/Mixed_7c'))
+    self.assertListEqual(final_endpoint.get_shape().as_list(),
+                         [batch_size, 8, 8, 2048])
+    expected_endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
+                          'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
+                          'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
+                          'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
+                          'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
+    self.assertItemsEqual(end_points.keys(), expected_endpoints)
+
+  def testBuildOnlyUptoFinalEndpoint(self):
+    batch_size = 5
+    height, width = 299, 299
+    endpoints = ['Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
+                 'MaxPool_3a_3x3', 'Conv2d_3b_1x1', 'Conv2d_4a_3x3',
+                 'MaxPool_5a_3x3', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
+                 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
+                 'Mixed_6e', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c']
+
+    for index, endpoint in enumerate(endpoints):
+      with tf.Graph().as_default():
+        inputs = tf.random_uniform((batch_size, height, width, 3))
+        out_tensor, end_points = inception.inception_v3_base(
+            inputs, final_endpoint=endpoint)
+        self.assertTrue(out_tensor.op.name.startswith(
+            'InceptionV3/' + endpoint))
+        self.assertItemsEqual(endpoints[:index+1], end_points)
+
+  def testBuildAndCheckAllEndPointsUptoMixed7c(self):
+    batch_size = 5
+    height, width = 299, 299
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v3_base(
+        inputs, final_endpoint='Mixed_7c')
+    endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
+                        'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
+                        'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
+                        'MaxPool_3a_3x3': [batch_size, 73, 73, 64],
+                        'Conv2d_3b_1x1': [batch_size, 73, 73, 80],
+                        'Conv2d_4a_3x3': [batch_size, 71, 71, 192],
+                        'MaxPool_5a_3x3': [batch_size, 35, 35, 192],
+                        'Mixed_5b': [batch_size, 35, 35, 256],
+                        'Mixed_5c': [batch_size, 35, 35, 288],
+                        'Mixed_5d': [batch_size, 35, 35, 288],
+                        'Mixed_6a': [batch_size, 17, 17, 768],
+                        'Mixed_6b': [batch_size, 17, 17, 768],
+                        'Mixed_6c': [batch_size, 17, 17, 768],
+                        'Mixed_6d': [batch_size, 17, 17, 768],
+                        'Mixed_6e': [batch_size, 17, 17, 768],
+                        'Mixed_7a': [batch_size, 8, 8, 1280],
+                        'Mixed_7b': [batch_size, 8, 8, 2048],
+                        'Mixed_7c': [batch_size, 8, 8, 2048]}
+    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
+    for endpoint_name in endpoints_shapes:
+      expected_shape = endpoints_shapes[endpoint_name]
+      self.assertTrue(endpoint_name in end_points)
+      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
+                           expected_shape)
+
+  def testModelHasExpectedNumberOfParameters(self):
+    batch_size = 5
+    height, width = 299, 299
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    with slim.arg_scope(inception.inception_v3_arg_scope()):
+      inception.inception_v3_base(inputs)
+    total_params, _ = slim.model_analyzer.analyze_vars(
+        slim.get_model_variables())
+    self.assertAlmostEqual(21802784, total_params)
+
+  def testBuildEndPoints(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v3(inputs, num_classes)
+    self.assertTrue('Logits' in end_points)
+    logits = end_points['Logits']
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue('AuxLogits' in end_points)
+    aux_logits = end_points['AuxLogits']
+    self.assertListEqual(aux_logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue('Mixed_7c' in end_points)
+    pre_pool = end_points['Mixed_7c']
+    self.assertListEqual(pre_pool.get_shape().as_list(),
+                         [batch_size, 8, 8, 2048])
+    self.assertTrue('PreLogits' in end_points)
+    pre_logits = end_points['PreLogits']
+    self.assertListEqual(pre_logits.get_shape().as_list(),
+                         [batch_size, 1, 1, 2048])
+
+  def testBuildEndPointsWithDepthMultiplierLessThanOne(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v3(inputs, num_classes)
+
+    endpoint_keys = [key for key in end_points.keys()
+                     if key.startswith('Mixed') or key.startswith('Conv')]
+
+    _, end_points_with_multiplier = inception.inception_v3(
+        inputs, num_classes, scope='depth_multiplied_net',
+        depth_multiplier=0.5)
+
+    for key in endpoint_keys:
+      original_depth = end_points[key].get_shape().as_list()[3]
+      new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
+      self.assertEqual(0.5 * original_depth, new_depth)
+
+  def testBuildEndPointsWithDepthMultiplierGreaterThanOne(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v3(inputs, num_classes)
+
+    endpoint_keys = [key for key in end_points.keys()
+                     if key.startswith('Mixed') or key.startswith('Conv')]
+
+    _, end_points_with_multiplier = inception.inception_v3(
+        inputs, num_classes, scope='depth_multiplied_net',
+        depth_multiplier=2.0)
+
+    for key in endpoint_keys:
+      original_depth = end_points[key].get_shape().as_list()[3]
+      new_depth = end_points_with_multiplier[key].get_shape().as_list()[3]
+      self.assertEqual(2.0 * original_depth, new_depth)
+
+  def testRaiseValueErrorWithInvalidDepthMultiplier(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    with self.assertRaises(ValueError):
+      _ = inception.inception_v3(inputs, num_classes, depth_multiplier=-0.1)
+    with self.assertRaises(ValueError):
+      _ = inception.inception_v3(inputs, num_classes, depth_multiplier=0.0)
+
+  def testHalfSizeImages(self):
+    batch_size = 5
+    height, width = 150, 150
+    num_classes = 1000
+
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v3(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    pre_pool = end_points['Mixed_7c']
+    self.assertListEqual(pre_pool.get_shape().as_list(),
+                         [batch_size, 3, 3, 2048])
+
+  def testUnknownImageShape(self):
+    tf.reset_default_graph()
+    batch_size = 2
+    height, width = 299, 299
+    num_classes = 1000
+    input_np = np.random.uniform(0, 1, (batch_size, height, width, 3))
+    with self.test_session() as sess:
+      inputs = tf.placeholder(tf.float32, shape=(batch_size, None, None, 3))
+      logits, end_points = inception.inception_v3(inputs, num_classes)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      pre_pool = end_points['Mixed_7c']
+      feed_dict = {inputs: input_np}
+      tf.initialize_all_variables().run()
+      pre_pool_out = sess.run(pre_pool, feed_dict=feed_dict)
+      self.assertListEqual(list(pre_pool_out.shape), [batch_size, 8, 8, 2048])
+
+  def testUnknowBatchSize(self):
+    batch_size = 1
+    height, width = 299, 299
+    num_classes = 1000
+
+    inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+    logits, _ = inception.inception_v3(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV3/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [None, num_classes])
+    images = tf.random_uniform((batch_size, height, width, 3))
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEquals(output.shape, (batch_size, num_classes))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 299, 299
+    num_classes = 1000
+
+    eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, _ = inception.inception_v3(eval_inputs, num_classes,
+                                       is_training=False)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (batch_size,))
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 5
+    eval_batch_size = 2
+    height, width = 150, 150
+    num_classes = 1000
+
+    train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+    inception.inception_v3(train_inputs, num_classes)
+    eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+    logits, _ = inception.inception_v3(eval_inputs, num_classes,
+                                       is_training=False, reuse=True)
+    predictions = tf.argmax(logits, 1)
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (eval_batch_size,))
+
+  def testLogitsNotSqueezed(self):
+    num_classes = 25
+    images = tf.random_uniform([1, 299, 299, 3])
+    logits, _ = inception.inception_v3(images,
+                                       num_classes=num_classes,
+                                       spatial_squeeze=False)
+
+    with self.test_session() as sess:
+      tf.initialize_all_variables().run()
+      logits_out = sess.run(logits)
+      self.assertListEqual(list(logits_out.shape), [1, 1, 1, num_classes])
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 2 - 1
slim/nets/lenet.py

@@ -33,7 +33,8 @@ def lenet(images, num_classes=10, is_training=False,
   interval of (-infinity, infinity). Consequently, to convert the outputs to a
   probability distribution over the characters, one will need to convert them
   using the softmax function:
-        logits = mnist.Mnist(images, is_training=False)
+
+        logits = lenet.lenet(images, is_training=False)
         probabilities = tf.nn.softmax(logits)
         predictions = tf.argmax(logits, 1)
 

+ 107 - 0
slim/nets/nets_factory.py

@@ -0,0 +1,107 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains a factory for building various models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import functools
+
+import tensorflow as tf
+
+from nets import alexnet
+from nets import cifarnet
+from nets import inception
+from nets import lenet
+from nets import overfeat
+from nets import resnet_v1
+from nets import resnet_v2
+from nets import vgg
+
+slim = tf.contrib.slim
+
+networks_map = {'alexnet_v2': alexnet.alexnet_v2,
+                'cifarnet': cifarnet.cifarnet,
+                'overfeat': overfeat.overfeat,
+                'vgg_a': vgg.vgg_a,
+                'vgg_16': vgg.vgg_16,
+                'vgg_19': vgg.vgg_19,
+                'inception_v1': inception.inception_v1,
+                'inception_v2': inception.inception_v2,
+                'inception_v3': inception.inception_v3,
+                'inception_resnet_v2': inception.inception_resnet_v2,
+                'lenet': lenet.lenet,
+                'resnet_v1_50': resnet_v1.resnet_v1_50,
+                'resnet_v1_101': resnet_v1.resnet_v1_101,
+                'resnet_v1_152': resnet_v1.resnet_v1_152,
+                'resnet_v1_200': resnet_v1.resnet_v1_200,
+                'resnet_v2_50': resnet_v2.resnet_v2_50,
+                'resnet_v2_101': resnet_v2.resnet_v2_101,
+                'resnet_v2_152': resnet_v2.resnet_v2_152,
+                'resnet_v2_200': resnet_v2.resnet_v2_200,
+               }
+
+arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
+                  'cifarnet': cifarnet.cifarnet_arg_scope,
+                  'overfeat': overfeat.overfeat_arg_scope,
+                  'vgg_a': vgg.vgg_arg_scope,
+                  'vgg_16': vgg.vgg_arg_scope,
+                  'vgg_19': vgg.vgg_arg_scope,
+                  'inception_v1': inception.inception_v3_arg_scope,
+                  'inception_v2': inception.inception_v3_arg_scope,
+                  'inception_v3': inception.inception_v3_arg_scope,
+                  'inception_resnet_v2':
+                  inception.inception_resnet_v2_arg_scope,
+                  'lenet': lenet.lenet_arg_scope,
+                  'resnet_v1_50': resnet_v1.resnet_arg_scope,
+                  'resnet_v1_101': resnet_v1.resnet_arg_scope,
+                  'resnet_v1_152': resnet_v1.resnet_arg_scope,
+                  'resnet_v1_200': resnet_v1.resnet_arg_scope,
+                  'resnet_v2_50': resnet_v2.resnet_arg_scope,
+                  'resnet_v2_101': resnet_v2.resnet_arg_scope,
+                  'resnet_v2_152': resnet_v2.resnet_arg_scope,
+                  'resnet_v2_200': resnet_v2.resnet_arg_scope,
+                 }
+
+
+def get_network_fn(name, num_classes, weight_decay=0.0, is_training=False):
+  """Returns a network_fn such as `logits, end_points = network_fn(images)`.
+
+  Args:
+    name: The name of the network.
+    num_classes: The number of classes to use for classification.
+    weight_decay: The l2 coefficient for the model weights.
+    is_training: `True` if the model is being used for training and `False`
+      otherwise.
+
+  Returns:
+    network_fn: A function that applies the model to a batch of images. It has
+      the following signature:
+        logits, end_points = network_fn(images)
+  Raises:
+    ValueError: If network `name` is not recognized.
+  """
+  if name not in networks_map:
+    raise ValueError('Name of network unknown %s' % name)
+  arg_scope = arg_scopes_map[name](weight_decay=weight_decay)
+  func = networks_map[name]
+  @functools.wraps(func)
+  def network_fn(images):
+    with slim.arg_scope(arg_scope):
+      return func(images, num_classes, is_training=is_training)
+  if hasattr(func, 'default_image_size'):
+    network_fn.default_image_size = func.default_image_size
+
+  return network_fn

+ 46 - 0
slim/nets/nets_factory_test.py

@@ -0,0 +1,46 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for slim.inception."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+import tensorflow as tf
+
+from nets import nets_factory
+
+
+class NetworksTest(tf.test.TestCase):
+
+  def testGetNetworkFn(self):
+    batch_size = 5
+    num_classes = 1000
+    for net in nets_factory.networks_map:
+      with self.test_session():
+        net_fn = nets_factory.get_network_fn(net, num_classes)
+        # Most networks use 224 as their default_image_size
+        image_size = getattr(net_fn, 'default_image_size', 224)
+        inputs = tf.random_uniform((batch_size, image_size, image_size, 3))
+        logits, end_points = net_fn(inputs)
+        self.assertTrue(isinstance(logits, tf.Tensor))
+        self.assertTrue(isinstance(end_points, dict))
+        self.assertEqual(logits.get_shape().as_list()[0], batch_size)
+        self.assertEqual(logits.get_shape().as_list()[-1], num_classes)
+
+if __name__ == '__main__':
+  tf.test.main()

+ 118 - 0
slim/nets/overfeat.py

@@ -0,0 +1,118 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the model definition for the OverFeat network.
+
+The definition for the network was obtained from:
+  OverFeat: Integrated Recognition, Localization and Detection using
+  Convolutional Networks
+  Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
+  Yann LeCun, 2014
+  http://arxiv.org/abs/1312.6229
+
+Usage:
+  with slim.arg_scope(overfeat.overfeat_arg_scope()):
+    outputs, end_points = overfeat.overfeat(inputs)
+
+@@overfeat
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
+
+
+def overfeat_arg_scope(weight_decay=0.0005):
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      activation_fn=tf.nn.relu,
+                      weights_regularizer=slim.l2_regularizer(weight_decay),
+                      biases_initializer=tf.zeros_initializer):
+    with slim.arg_scope([slim.conv2d], padding='SAME'):
+      with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
+        return arg_sc
+
+
+def overfeat(inputs,
+             num_classes=1000,
+             is_training=True,
+             dropout_keep_prob=0.5,
+             spatial_squeeze=True,
+             scope='overfeat'):
+  """Contains the model definition for the OverFeat network.
+
+  The definition for the network was obtained from:
+    OverFeat: Integrated Recognition, Localization and Detection using
+    Convolutional Networks
+    Pierre Sermanet, David Eigen, Xiang Zhang, Michael Mathieu, Rob Fergus and
+    Yann LeCun, 2014
+    http://arxiv.org/abs/1312.6229
+
+  Note: All the fully_connected layers have been transformed to conv2d layers.
+        To use in classification mode, resize input to 231x231. To use in fully
+        convolutional mode, set spatial_squeeze to false.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether or not the model is being trained.
+    dropout_keep_prob: the probability that activations are kept in the dropout
+      layers during training.
+    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+      outputs. Useful to remove unnecessary dimensions for classification.
+    scope: Optional scope for the variables.
+
+  Returns:
+    the last op containing the log predictions and end_points dict.
+
+  """
+  with tf.variable_scope(scope, 'overfeat', [inputs]) as sc:
+    end_points_collection = sc.name + '_end_points'
+    # Collect outputs for conv2d, fully_connected and max_pool2d
+    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+                        outputs_collections=end_points_collection):
+      net = slim.conv2d(inputs, 64, [11, 11], 4, padding='VALID',
+                        scope='conv1')
+      net = slim.max_pool2d(net, [2, 2], scope='pool1')
+      net = slim.conv2d(net, 256, [5, 5], padding='VALID', scope='conv2')
+      net = slim.max_pool2d(net, [2, 2], scope='pool2')
+      net = slim.conv2d(net, 512, [3, 3], scope='conv3')
+      net = slim.conv2d(net, 1024, [3, 3], scope='conv4')
+      net = slim.conv2d(net, 1024, [3, 3], scope='conv5')
+      net = slim.max_pool2d(net, [2, 2], scope='pool5')
+      with slim.arg_scope([slim.conv2d],
+                          weights_initializer=trunc_normal(0.005),
+                          biases_initializer=tf.constant_initializer(0.1)):
+        # Use conv2d instead of fully_connected layers.
+        net = slim.conv2d(net, 3072, [6, 6], padding='VALID', scope='fc6')
+        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                           scope='dropout6')
+        net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+        net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                           scope='dropout7')
+        net = slim.conv2d(net, num_classes, [1, 1],
+                          activation_fn=None,
+                          normalizer_fn=None,
+                          biases_initializer=tf.zeros_initializer,
+                          scope='fc8')
+      # Convert end_points_collection into a end_point dict.
+      end_points = dict(tf.get_collection(end_points_collection))
+      if spatial_squeeze:
+        net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+        end_points[sc.name + '/fc8'] = net
+      return net, end_points
+overfeat.default_image_size = 231

+ 145 - 0
slim/nets/overfeat_test.py

@@ -0,0 +1,145 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.overfeat."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import overfeat
+
+slim = tf.contrib.slim
+
+
+class OverFeatTest(tf.test.TestCase):
+
+  def testBuild(self):
+    batch_size = 5
+    height, width = 231, 231
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = overfeat.overfeat(inputs, num_classes)
+      self.assertEquals(logits.op.name, 'overfeat/fc8/squeezed')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testFullyConvolutional(self):
+    batch_size = 1
+    height, width = 281, 281
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = overfeat.overfeat(inputs, num_classes, spatial_squeeze=False)
+      self.assertEquals(logits.op.name, 'overfeat/fc8/BiasAdd')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, 2, 2, num_classes])
+
+  def testEndPoints(self):
+    batch_size = 5
+    height, width = 231, 231
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = overfeat.overfeat(inputs, num_classes)
+      expected_names = ['overfeat/conv1',
+                        'overfeat/pool1',
+                        'overfeat/conv2',
+                        'overfeat/pool2',
+                        'overfeat/conv3',
+                        'overfeat/conv4',
+                        'overfeat/conv5',
+                        'overfeat/pool5',
+                        'overfeat/fc6',
+                        'overfeat/fc7',
+                        'overfeat/fc8'
+                       ]
+      self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+  def testModelVariables(self):
+    batch_size = 5
+    height, width = 231, 231
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      overfeat.overfeat(inputs, num_classes)
+      expected_names = ['overfeat/conv1/weights',
+                        'overfeat/conv1/biases',
+                        'overfeat/conv2/weights',
+                        'overfeat/conv2/biases',
+                        'overfeat/conv3/weights',
+                        'overfeat/conv3/biases',
+                        'overfeat/conv4/weights',
+                        'overfeat/conv4/biases',
+                        'overfeat/conv5/weights',
+                        'overfeat/conv5/biases',
+                        'overfeat/fc6/weights',
+                        'overfeat/fc6/biases',
+                        'overfeat/fc7/weights',
+                        'overfeat/fc7/biases',
+                        'overfeat/fc8/weights',
+                        'overfeat/fc8/biases',
+                       ]
+      model_variables = [v.op.name for v in slim.get_model_variables()]
+      self.assertSetEqual(set(model_variables), set(expected_names))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 231, 231
+    num_classes = 1000
+    with self.test_session():
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = overfeat.overfeat(eval_inputs, is_training=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      predictions = tf.argmax(logits, 1)
+      self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 2
+    eval_batch_size = 1
+    train_height, train_width = 231, 231
+    eval_height, eval_width = 281, 281
+    num_classes = 1000
+    with self.test_session():
+      train_inputs = tf.random_uniform(
+          (train_batch_size, train_height, train_width, 3))
+      logits, _ = overfeat.overfeat(train_inputs)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [train_batch_size, num_classes])
+      tf.get_variable_scope().reuse_variables()
+      eval_inputs = tf.random_uniform(
+          (eval_batch_size, eval_height, eval_width, 3))
+      logits, _ = overfeat.overfeat(eval_inputs, is_training=False,
+                                    spatial_squeeze=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [eval_batch_size, 2, 2, num_classes])
+      logits = tf.reduce_mean(logits, [1, 2])
+      predictions = tf.argmax(logits, 1)
+      self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+  def testForward(self):
+    batch_size = 1
+    height, width = 231, 231
+    with self.test_session() as sess:
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = overfeat.overfeat(inputs)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits)
+      self.assertTrue(output.any())
+
+if __name__ == '__main__':
+  tf.test.main()

+ 254 - 0
slim/nets/resnet_utils.py

@@ -0,0 +1,254 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains building blocks for various versions of Residual Networks.
+
+Residual networks (ResNets) were proposed in:
+  Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+  Deep Residual Learning for Image Recognition. arXiv:1512.03385, 2015
+
+More variants were introduced in:
+  Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+  Identity Mappings in Deep Residual Networks. arXiv: 1603.05027, 2016
+
+We can obtain different ResNet variants by changing the network depth, width,
+and form of residual unit. This module implements the infrastructure for
+building them. Concrete ResNet units and full ResNet networks are implemented in
+the accompanying resnet_v1.py and resnet_v2.py modules.
+
+Compared to https://github.com/KaimingHe/deep-residual-networks, in the current
+implementation we subsample the output activations in the last residual unit of
+each block, instead of subsampling the input activations in the first residual
+unit of each block. The two implementations give identical results but our
+implementation is more memory efficient.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+class Block(collections.namedtuple('Block', ['scope', 'unit_fn', 'args'])):
+  """A named tuple describing a ResNet block.
+
+  Its parts are:
+    scope: The scope of the `Block`.
+    unit_fn: The ResNet unit function which takes as input a `Tensor` and
+      returns another `Tensor` with the output of the ResNet unit.
+    args: A list of length equal to the number of units in the `Block`. The list
+      contains one (depth, depth_bottleneck, stride) tuple for each unit in the
+      block to serve as argument to unit_fn.
+  """
+
+
+def subsample(inputs, factor, scope=None):
+  """Subsamples the input along the spatial dimensions.
+
+  Args:
+    inputs: A `Tensor` of size [batch, height_in, width_in, channels].
+    factor: The subsampling factor.
+    scope: Optional variable_scope.
+
+  Returns:
+    output: A `Tensor` of size [batch, height_out, width_out, channels] with the
+      input, either intact (if factor == 1) or subsampled (if factor > 1).
+  """
+  if factor == 1:
+    return inputs
+  else:
+    return slim.max_pool2d(inputs, [1, 1], stride=factor, scope=scope)
+
+
+def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
+  """Strided 2-D convolution with 'SAME' padding.
+
+  When stride > 1, then we do explicit zero-padding, followed by conv2d with
+  'VALID' padding.
+
+  Note that
+
+     net = conv2d_same(inputs, num_outputs, 3, stride=stride)
+
+  is equivalent to
+
+     net = slim.conv2d(inputs, num_outputs, 3, stride=1, padding='SAME')
+     net = subsample(net, factor=stride)
+
+  whereas
+
+     net = slim.conv2d(inputs, num_outputs, 3, stride=stride, padding='SAME')
+
+  is different when the input's height or width is even, which is why we add the
+  current function. For more details, see ResnetUtilsTest.testConv2DSameEven().
+
+  Args:
+    inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
+    num_outputs: An integer, the number of output filters.
+    kernel_size: An int with the kernel_size of the filters.
+    stride: An integer, the output stride.
+    rate: An integer, rate for atrous convolution.
+    scope: Scope.
+
+  Returns:
+    output: A 4-D tensor of size [batch, height_out, width_out, channels] with
+      the convolution output.
+  """
+  if stride == 1:
+    return slim.conv2d(inputs, num_outputs, kernel_size, stride=1, rate=rate,
+                       padding='SAME', scope=scope)
+  else:
+    kernel_size_effective = kernel_size + (kernel_size - 1) * (rate - 1)
+    pad_total = kernel_size_effective - 1
+    pad_beg = pad_total // 2
+    pad_end = pad_total - pad_beg
+    inputs = tf.pad(inputs,
+                    [[0, 0], [pad_beg, pad_end], [pad_beg, pad_end], [0, 0]])
+    return slim.conv2d(inputs, num_outputs, kernel_size, stride=stride,
+                       rate=rate, padding='VALID', scope=scope)
+
+
+@slim.add_arg_scope
+def stack_blocks_dense(net, blocks, output_stride=None,
+                       outputs_collections=None):
+  """Stacks ResNet `Blocks` and controls output feature density.
+
+  First, this function creates scopes for the ResNet in the form of
+  'block_name/unit_1', 'block_name/unit_2', etc.
+
+  Second, this function allows the user to explicitly control the ResNet
+  output_stride, which is the ratio of the input to output spatial resolution.
+  This is useful for dense prediction tasks such as semantic segmentation or
+  object detection.
+
+  Most ResNets consist of 4 ResNet blocks and subsample the activations by a
+  factor of 2 when transitioning between consecutive ResNet blocks. This results
+  to a nominal ResNet output_stride equal to 8. If we set the output_stride to
+  half the nominal network stride (e.g., output_stride=4), then we compute
+  responses twice.
+
+  Control of the output feature density is implemented by atrous convolution.
+
+  Args:
+    net: A `Tensor` of size [batch, height, width, channels].
+    blocks: A list of length equal to the number of ResNet `Blocks`. Each
+      element is a ResNet `Block` object describing the units in the `Block`.
+    output_stride: If `None`, then the output will be computed at the nominal
+      network stride. If output_stride is not `None`, it specifies the requested
+      ratio of input to output spatial resolution, which needs to be equal to
+      the product of unit strides from the start up to some level of the ResNet.
+      For example, if the ResNet employs units with strides 1, 2, 1, 3, 4, 1,
+      then valid values for the output_stride are 1, 2, 6, 24 or None (which
+      is equivalent to output_stride=24).
+    outputs_collections: Collection to add the ResNet block outputs.
+
+  Returns:
+    net: Output tensor with stride equal to the specified output_stride.
+
+  Raises:
+    ValueError: If the target output_stride is not valid.
+  """
+  # The current_stride variable keeps track of the effective stride of the
+  # activations. This allows us to invoke atrous convolution whenever applying
+  # the next residual unit would result in the activations having stride larger
+  # than the target output_stride.
+  current_stride = 1
+
+  # The atrous convolution rate parameter.
+  rate = 1
+
+  for block in blocks:
+    with tf.variable_scope(block.scope, 'block', [net]) as sc:
+      for i, unit in enumerate(block.args):
+        if output_stride is not None and current_stride > output_stride:
+          raise ValueError('The target output_stride cannot be reached.')
+
+        with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
+          unit_depth, unit_depth_bottleneck, unit_stride = unit
+
+          # If we have reached the target output_stride, then we need to employ
+          # atrous convolution with stride=1 and multiply the atrous rate by the
+          # current unit's stride for use in subsequent layers.
+          if output_stride is not None and current_stride == output_stride:
+            net = block.unit_fn(net,
+                                depth=unit_depth,
+                                depth_bottleneck=unit_depth_bottleneck,
+                                stride=1,
+                                rate=rate)
+            rate *= unit_stride
+
+          else:
+            net = block.unit_fn(net,
+                                depth=unit_depth,
+                                depth_bottleneck=unit_depth_bottleneck,
+                                stride=unit_stride,
+                                rate=1)
+            current_stride *= unit_stride
+      net = slim.utils.collect_named_outputs(outputs_collections, sc.name, net)
+
+  if output_stride is not None and current_stride != output_stride:
+    raise ValueError('The target output_stride cannot be reached.')
+
+  return net
+
+
+def resnet_arg_scope(weight_decay=0.0001,
+                     batch_norm_decay=0.997,
+                     batch_norm_epsilon=1e-5,
+                     batch_norm_scale=True):
+  """Defines the default ResNet arg scope.
+
+  TODO(gpapan): The batch-normalization related default values above are
+    appropriate for use in conjunction with the reference ResNet models
+    released at https://github.com/KaimingHe/deep-residual-networks. When
+    training ResNets from scratch, they might need to be tuned.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+    batch_norm_decay: The moving average decay when estimating layer activation
+      statistics in batch normalization.
+    batch_norm_epsilon: Small constant to prevent division by zero when
+      normalizing activations by their variance in batch normalization.
+    batch_norm_scale: If True, uses an explicit `gamma` multiplier to scale the
+      activations in the batch normalization layer.
+
+  Returns:
+    An `arg_scope` to use for the resnet models.
+  """
+  batch_norm_params = {
+      'decay': batch_norm_decay,
+      'epsilon': batch_norm_epsilon,
+      'scale': batch_norm_scale,
+      'updates_collections': tf.GraphKeys.UPDATE_OPS,
+  }
+
+  with slim.arg_scope(
+      [slim.conv2d],
+      weights_regularizer=slim.l2_regularizer(weight_decay),
+      weights_initializer=slim.variance_scaling_initializer(),
+      activation_fn=tf.nn.relu,
+      normalizer_fn=slim.batch_norm,
+      normalizer_params=batch_norm_params):
+    with slim.arg_scope([slim.batch_norm], **batch_norm_params):
+      # The following implies padding='SAME' for pool1, which makes feature
+      # alignment easier for dense prediction tasks. This is also used in
+      # https://github.com/facebook/fb.resnet.torch. However the accompanying
+      # code of 'Deep Residual Learning for Image Recognition' uses
+      # padding='VALID' for pool1. You can switch to that choice by setting
+      # slim.arg_scope([slim.max_pool2d], padding='VALID').
+      with slim.arg_scope([slim.max_pool2d], padding='SAME') as arg_sc:
+        return arg_sc

+ 295 - 0
slim/nets/resnet_v1.py

@@ -0,0 +1,295 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains definitions for the original form of Residual Networks.
+
+The 'v1' residual networks (ResNets) implemented in this module were proposed
+by:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Deep Residual Learning for Image Recognition. arXiv:1512.03385
+
+Other variants were introduced in:
+[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
+
+The networks defined in this module utilize the bottleneck building block of
+[1] with projection shortcuts only for increasing depths. They employ batch
+normalization *after* every weight layer. This is the architecture used by
+MSRA in the Imagenet and MSCOCO 2016 competition models ResNet-101 and
+ResNet-152. See [2; Fig. 1a] for a comparison between the current 'v1'
+architecture and the alternative 'v2' architecture of [2] which uses batch
+normalization *before* every weight layer in the so-called full pre-activation
+units.
+
+Typical use:
+
+   from tensorflow.contrib.slim.nets import resnet_v1
+
+ResNet-101 for image classification into 1000 classes:
+
+   # inputs has shape [batch, 224, 224, 3]
+   with slim.arg_scope(resnet_v1.resnet_arg_scope()):
+      net, end_points = resnet_v1.resnet_v1_101(inputs, 1000, is_training=False)
+
+ResNet-101 for semantic segmentation into 21 classes:
+
+   # inputs has shape [batch, 513, 513, 3]
+   with slim.arg_scope(resnet_v1.resnet_arg_scope()):
+      net, end_points = resnet_v1.resnet_v1_101(inputs,
+                                                21,
+                                                is_training=False,
+                                                global_pool=False,
+                                                output_stride=16)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import resnet_utils
+
+
+resnet_arg_scope = resnet_utils.resnet_arg_scope
+slim = tf.contrib.slim
+
+
+@slim.add_arg_scope
+def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
+               outputs_collections=None, scope=None):
+  """Bottleneck residual unit variant with BN after convolutions.
+
+  This is the original residual unit proposed in [1]. See Fig. 1(a) of [2] for
+  its definition. Note that we use here the bottleneck variant which has an
+  extra bottleneck layer.
+
+  When putting together two consecutive ResNet blocks that use this unit, one
+  should use stride = 2 in the last unit of the first block.
+
+  Args:
+    inputs: A tensor of size [batch, height, width, channels].
+    depth: The depth of the ResNet unit output.
+    depth_bottleneck: The depth of the bottleneck layers.
+    stride: The ResNet unit's stride. Determines the amount of downsampling of
+      the units output compared to its input.
+    rate: An integer, rate for atrous convolution.
+    outputs_collections: Collection to add the ResNet unit output.
+    scope: Optional variable_scope.
+
+  Returns:
+    The ResNet unit's output.
+  """
+  with tf.variable_scope(scope, 'bottleneck_v1', [inputs]) as sc:
+    depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
+    if depth == depth_in:
+      shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
+    else:
+      shortcut = slim.conv2d(inputs, depth, [1, 1], stride=stride,
+                             activation_fn=None, scope='shortcut')
+
+    residual = slim.conv2d(inputs, depth_bottleneck, [1, 1], stride=1,
+                           scope='conv1')
+    residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
+                                        rate=rate, scope='conv2')
+    residual = slim.conv2d(residual, depth, [1, 1], stride=1,
+                           activation_fn=None, scope='conv3')
+
+    output = tf.nn.relu(shortcut + residual)
+
+    return slim.utils.collect_named_outputs(outputs_collections,
+                                            sc.original_name_scope,
+                                            output)
+
+
+def resnet_v1(inputs,
+              blocks,
+              num_classes=None,
+              is_training=True,
+              global_pool=True,
+              output_stride=None,
+              include_root_block=True,
+              reuse=None,
+              scope=None):
+  """Generator for v1 ResNet models.
+
+  This function generates a family of ResNet v1 models. See the resnet_v1_*()
+  methods for specific model instantiations, obtained by selecting different
+  block instantiations that produce ResNets of various depths.
+
+  Training for image classification on Imagenet is usually done with [224, 224]
+  inputs, resulting in [7, 7] feature maps at the output of the last ResNet
+  block for the ResNets defined in [1] that have nominal stride equal to 32.
+  However, for dense prediction tasks we advise that one uses inputs with
+  spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
+  this case the feature maps at the ResNet output will have spatial shape
+  [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
+  and corners exactly aligned with the input image corners, which greatly
+  facilitates alignment of the features to the image. Using as input [225, 225]
+  images results in [8, 8] feature maps at the output of the last ResNet block.
+
+  For dense prediction tasks, the ResNet needs to run in fully-convolutional
+  (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
+  have nominal stride equal to 32 and a good choice in FCN mode is to use
+  output_stride=16 in order to increase the density of the computed features at
+  small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
+
+  Args:
+    inputs: A tensor of size [batch, height_in, width_in, channels].
+    blocks: A list of length equal to the number of ResNet blocks. Each element
+      is a resnet_utils.Block object describing the units in the block.
+    num_classes: Number of predicted classes for classification tasks. If None
+      we return the features before the logit layer.
+    is_training: whether is training or not.
+    global_pool: If True, we perform global average pooling before computing the
+      logits. Set to True for image classification, False for dense prediction.
+    output_stride: If None, then the output will be computed at the nominal
+      network stride. If output_stride is not None, it specifies the requested
+      ratio of input to output spatial resolution.
+    include_root_block: If True, include the initial convolution followed by
+      max-pooling, if False excludes it.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+  Returns:
+    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+      If global_pool is False, then height_out and width_out are reduced by a
+      factor of output_stride compared to the respective height_in and width_in,
+      else both height_out and width_out equal one. If num_classes is None, then
+      net is the output of the last ResNet block, potentially after global
+      average pooling. If num_classes is not None, net contains the pre-softmax
+      activations.
+    end_points: A dictionary from components of the network to the corresponding
+      activation.
+
+  Raises:
+    ValueError: If the target output_stride is not valid.
+  """
+  with tf.variable_scope(scope, 'resnet_v1', [inputs], reuse=reuse) as sc:
+    end_points_collection = sc.name + '_end_points'
+    with slim.arg_scope([slim.conv2d, bottleneck,
+                         resnet_utils.stack_blocks_dense],
+                        outputs_collections=end_points_collection):
+      with slim.arg_scope([slim.batch_norm], is_training=is_training):
+        net = inputs
+        if include_root_block:
+          if output_stride is not None:
+            if output_stride % 4 != 0:
+              raise ValueError('The output_stride needs to be a multiple of 4.')
+            output_stride /= 4
+          net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
+          net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
+        net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
+        if global_pool:
+          # Global average pooling.
+          net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
+        if num_classes is not None:
+          net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+                            normalizer_fn=None, scope='logits')
+        # Convert end_points_collection into a dictionary of end_points.
+        end_points = dict(tf.get_collection(end_points_collection))
+        if num_classes is not None:
+          end_points['predictions'] = slim.softmax(net, scope='predictions')
+        return net, end_points
+resnet_v1.default_image_size = 224
+
+
+def resnet_v1_50(inputs,
+                 num_classes=None,
+                 is_training=True,
+                 global_pool=True,
+                 output_stride=None,
+                 reuse=None,
+                 scope='resnet_v1_50'):
+  """ResNet-50 model of [1]. See resnet_v1() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)
+  ]
+  return resnet_v1(inputs, blocks, num_classes, is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v1_101(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v1_101'):
+  """ResNet-101 model of [1]. See resnet_v1() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)
+  ]
+  return resnet_v1(inputs, blocks, num_classes, is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v1_152(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v1_152'):
+  """ResNet-152 model of [1]. See resnet_v1() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v1(inputs, blocks, num_classes, is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v1_200(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v1_200'):
+  """ResNet-200 model of [2]. See resnet_v1() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v1(inputs, blocks, num_classes, is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)

+ 450 - 0
slim/nets/resnet_v1_test.py

@@ -0,0 +1,450 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.resnet_v1."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from nets import resnet_utils
+from nets import resnet_v1
+
+slim = tf.contrib.slim
+
+
+def create_test_input(batch_size, height, width, channels):
+  """Create test input tensor.
+
+  Args:
+    batch_size: The number of images per batch or `None` if unknown.
+    height: The height of each image or `None` if unknown.
+    width: The width of each image or `None` if unknown.
+    channels: The number of channels per image or `None` if unknown.
+
+  Returns:
+    Either a placeholder `Tensor` of dimension
+      [batch_size, height, width, channels] if any of the inputs are `None` or a
+    constant `Tensor` with the mesh grid values along the spatial dimensions.
+  """
+  if None in [batch_size, height, width, channels]:
+    return tf.placeholder(tf.float32, (batch_size, height, width, channels))
+  else:
+    return tf.to_float(
+        np.tile(
+            np.reshape(
+                np.reshape(np.arange(height), [height, 1]) +
+                np.reshape(np.arange(width), [1, width]),
+                [1, height, width, 1]),
+            [batch_size, 1, 1, channels]))
+
+
+class ResnetUtilsTest(tf.test.TestCase):
+
+  def testSubsampleThreeByThree(self):
+    x = tf.reshape(tf.to_float(tf.range(9)), [1, 3, 3, 1])
+    x = resnet_utils.subsample(x, 2)
+    expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
+    with self.test_session():
+      self.assertAllClose(x.eval(), expected.eval())
+
+  def testSubsampleFourByFour(self):
+    x = tf.reshape(tf.to_float(tf.range(16)), [1, 4, 4, 1])
+    x = resnet_utils.subsample(x, 2)
+    expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
+    with self.test_session():
+      self.assertAllClose(x.eval(), expected.eval())
+
+  def testConv2DSameEven(self):
+    n, n2 = 4, 2
+
+    # Input image.
+    x = create_test_input(1, n, n, 1)
+
+    # Convolution kernel.
+    w = create_test_input(1, 3, 3, 1)
+    w = tf.reshape(w, [3, 3, 1, 1])
+
+    tf.get_variable('Conv/weights', initializer=w)
+    tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+    tf.get_variable_scope().reuse_variables()
+
+    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
+    y1_expected = tf.to_float([[14, 28, 43, 26],
+                               [28, 48, 66, 37],
+                               [43, 66, 84, 46],
+                               [26, 37, 46, 22]])
+    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+    y2 = resnet_utils.subsample(y1, 2)
+    y2_expected = tf.to_float([[14, 43],
+                               [43, 84]])
+    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
+    y3_expected = y2_expected
+
+    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
+    y4_expected = tf.to_float([[48, 37],
+                               [37, 22]])
+    y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      self.assertAllClose(y1.eval(), y1_expected.eval())
+      self.assertAllClose(y2.eval(), y2_expected.eval())
+      self.assertAllClose(y3.eval(), y3_expected.eval())
+      self.assertAllClose(y4.eval(), y4_expected.eval())
+
+  def testConv2DSameOdd(self):
+    n, n2 = 5, 3
+
+    # Input image.
+    x = create_test_input(1, n, n, 1)
+
+    # Convolution kernel.
+    w = create_test_input(1, 3, 3, 1)
+    w = tf.reshape(w, [3, 3, 1, 1])
+
+    tf.get_variable('Conv/weights', initializer=w)
+    tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+    tf.get_variable_scope().reuse_variables()
+
+    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
+    y1_expected = tf.to_float([[14, 28, 43, 58, 34],
+                               [28, 48, 66, 84, 46],
+                               [43, 66, 84, 102, 55],
+                               [58, 84, 102, 120, 64],
+                               [34, 46, 55, 64, 30]])
+    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+    y2 = resnet_utils.subsample(y1, 2)
+    y2_expected = tf.to_float([[14, 43, 34],
+                               [43, 84, 55],
+                               [34, 55, 30]])
+    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
+    y3_expected = y2_expected
+
+    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
+    y4_expected = y2_expected
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      self.assertAllClose(y1.eval(), y1_expected.eval())
+      self.assertAllClose(y2.eval(), y2_expected.eval())
+      self.assertAllClose(y3.eval(), y3_expected.eval())
+      self.assertAllClose(y4.eval(), y4_expected.eval())
+
+  def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
+    """A plain ResNet without extra layers before or after the ResNet blocks."""
+    with tf.variable_scope(scope, values=[inputs]):
+      with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
+        net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
+        end_points = dict(tf.get_collection('end_points'))
+        return net, end_points
+
+  def testEndPointsV1(self):
+    """Test the end points of a tiny v1 bottleneck network."""
+    bottleneck = resnet_v1.bottleneck
+    blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
+              resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
+    inputs = create_test_input(2, 32, 16, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
+    expected = [
+        'tiny/block1/unit_1/bottleneck_v1/shortcut',
+        'tiny/block1/unit_1/bottleneck_v1/conv1',
+        'tiny/block1/unit_1/bottleneck_v1/conv2',
+        'tiny/block1/unit_1/bottleneck_v1/conv3',
+        'tiny/block1/unit_2/bottleneck_v1/conv1',
+        'tiny/block1/unit_2/bottleneck_v1/conv2',
+        'tiny/block1/unit_2/bottleneck_v1/conv3',
+        'tiny/block2/unit_1/bottleneck_v1/shortcut',
+        'tiny/block2/unit_1/bottleneck_v1/conv1',
+        'tiny/block2/unit_1/bottleneck_v1/conv2',
+        'tiny/block2/unit_1/bottleneck_v1/conv3',
+        'tiny/block2/unit_2/bottleneck_v1/conv1',
+        'tiny/block2/unit_2/bottleneck_v1/conv2',
+        'tiny/block2/unit_2/bottleneck_v1/conv3']
+    self.assertItemsEqual(expected, end_points)
+
+  def _stack_blocks_nondense(self, net, blocks):
+    """A simplified ResNet Block stacker without output stride control."""
+    for block in blocks:
+      with tf.variable_scope(block.scope, 'block', [net]):
+        for i, unit in enumerate(block.args):
+          depth, depth_bottleneck, stride = unit
+          with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
+            net = block.unit_fn(net,
+                                depth=depth,
+                                depth_bottleneck=depth_bottleneck,
+                                stride=stride,
+                                rate=1)
+    return net
+
+  def _atrousValues(self, bottleneck):
+    """Verify the values of dense feature extraction by atrous convolution.
+
+    Make sure that dense feature extraction by stack_blocks_dense() followed by
+    subsampling gives identical results to feature extraction at the nominal
+    network output stride using the simple self._stack_blocks_nondense() above.
+
+    Args:
+      bottleneck: The bottleneck function.
+    """
+    blocks = [
+        resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
+        resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
+        resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
+        resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
+    ]
+    nominal_stride = 8
+
+    # Test both odd and even input dimensions.
+    height = 30
+    width = 31
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      with slim.arg_scope([slim.batch_norm], is_training=False):
+        for output_stride in [1, 2, 4, 8, None]:
+          with tf.Graph().as_default():
+            with self.test_session() as sess:
+              tf.set_random_seed(0)
+              inputs = create_test_input(1, height, width, 3)
+              # Dense feature extraction followed by subsampling.
+              output = resnet_utils.stack_blocks_dense(inputs,
+                                                       blocks,
+                                                       output_stride)
+              if output_stride is None:
+                factor = 1
+              else:
+                factor = nominal_stride // output_stride
+
+              output = resnet_utils.subsample(output, factor)
+              # Make the two networks use the same weights.
+              tf.get_variable_scope().reuse_variables()
+              # Feature extraction at the nominal network rate.
+              expected = self._stack_blocks_nondense(inputs, blocks)
+              sess.run(tf.initialize_all_variables())
+              output, expected = sess.run([output, expected])
+              self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
+
+  def testAtrousValuesBottleneck(self):
+    self._atrousValues(resnet_v1.bottleneck)
+
+
+class ResnetCompleteNetworkTest(tf.test.TestCase):
+  """Tests with complete small ResNet v1 networks."""
+
+  def _resnet_small(self,
+                    inputs,
+                    num_classes=None,
+                    is_training=True,
+                    global_pool=True,
+                    output_stride=None,
+                    include_root_block=True,
+                    reuse=None,
+                    scope='resnet_v1_small'):
+    """A shallow and thin ResNet v1 for faster tests."""
+    bottleneck = resnet_v1.bottleneck
+    blocks = [
+        resnet_utils.Block(
+            'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
+        resnet_utils.Block(
+            'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
+        resnet_utils.Block(
+            'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
+        resnet_utils.Block(
+            'block4', bottleneck, [(32, 8, 1)] * 2)]
+    return resnet_v1.resnet_v1(inputs, blocks, num_classes,
+                               is_training=is_training,
+                               global_pool=global_pool,
+                               output_stride=output_stride,
+                               include_root_block=include_root_block,
+                               reuse=reuse,
+                               scope=scope)
+
+  def testClassificationEndPoints(self):
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(2, 224, 224, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      logits, end_points = self._resnet_small(inputs, num_classes,
+                                              global_pool=global_pool,
+                                              scope='resnet')
+    self.assertTrue(logits.op.name.startswith('resnet/logits'))
+    self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+    self.assertTrue('predictions' in end_points)
+    self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+                         [2, 1, 1, num_classes])
+
+  def testClassificationShapes(self):
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(2, 224, 224, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 28, 28, 4],
+          'resnet/block2': [2, 14, 14, 8],
+          'resnet/block3': [2, 7, 7, 16],
+          'resnet/block4': [2, 7, 7, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    inputs = create_test_input(2, 321, 321, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 41, 41, 4],
+          'resnet/block2': [2, 21, 21, 8],
+          'resnet/block3': [2, 11, 11, 16],
+          'resnet/block4': [2, 11, 11, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testRootlessFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    inputs = create_test_input(2, 128, 128, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         include_root_block=False,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 64, 64, 4],
+          'resnet/block2': [2, 32, 32, 8],
+          'resnet/block3': [2, 16, 16, 16],
+          'resnet/block4': [2, 16, 16, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testAtrousFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    output_stride = 8
+    inputs = create_test_input(2, 321, 321, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs,
+                                         num_classes,
+                                         global_pool=global_pool,
+                                         output_stride=output_stride,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 41, 41, 4],
+          'resnet/block2': [2, 41, 41, 8],
+          'resnet/block3': [2, 41, 41, 16],
+          'resnet/block4': [2, 41, 41, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testAtrousFullyConvolutionalValues(self):
+    """Verify dense feature extraction with atrous convolution."""
+    nominal_stride = 32
+    for output_stride in [4, 8, 16, 32, None]:
+      with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+        with tf.Graph().as_default():
+          with self.test_session() as sess:
+            tf.set_random_seed(0)
+            inputs = create_test_input(2, 81, 81, 3)
+            # Dense feature extraction followed by subsampling.
+            output, _ = self._resnet_small(inputs, None, is_training=False,
+                                           global_pool=False,
+                                           output_stride=output_stride)
+            if output_stride is None:
+              factor = 1
+            else:
+              factor = nominal_stride // output_stride
+            output = resnet_utils.subsample(output, factor)
+            # Make the two networks use the same weights.
+            tf.get_variable_scope().reuse_variables()
+            # Feature extraction at the nominal network rate.
+            expected, _ = self._resnet_small(inputs, None, is_training=False,
+                                             global_pool=False)
+            sess.run(tf.initialize_all_variables())
+            self.assertAllClose(output.eval(), expected.eval(),
+                                atol=1e-4, rtol=1e-4)
+
+  def testUnknownBatchSize(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(None, height, width, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      logits, _ = self._resnet_small(inputs, num_classes,
+                                     global_pool=global_pool,
+                                     scope='resnet')
+    self.assertTrue(logits.op.name.startswith('resnet/logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [None, 1, 1, num_classes])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 1, 1, num_classes))
+
+  def testFullyConvolutionalUnknownHeightWidth(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = False
+    inputs = create_test_input(batch, None, None, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      output, _ = self._resnet_small(inputs, None, global_pool=global_pool)
+    self.assertListEqual(output.get_shape().as_list(),
+                         [batch, None, None, 32])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(output, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 3, 3, 32))
+
+  def testAtrousFullyConvolutionalUnknownHeightWidth(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = False
+    output_stride = 8
+    inputs = create_test_input(batch, None, None, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      output, _ = self._resnet_small(inputs,
+                                     None,
+                                     global_pool=global_pool,
+                                     output_stride=output_stride)
+    self.assertListEqual(output.get_shape().as_list(),
+                         [batch, None, None, 32])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(output, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 9, 9, 32))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 302 - 0
slim/nets/resnet_v2.py

@@ -0,0 +1,302 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains definitions for the preactivation form of Residual Networks.
+
+Residual networks (ResNets) were originally proposed in:
+[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Deep Residual Learning for Image Recognition. arXiv:1512.03385
+
+The full preactivation 'v2' ResNet variant implemented in this module was
+introduced by:
+[2] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+    Identity Mappings in Deep Residual Networks. arXiv: 1603.05027
+
+The key difference of the full preactivation 'v2' variant compared to the
+'v1' variant in [1] is the use of batch normalization before every weight layer.
+Another difference is that 'v2' ResNets do not include an activation function in
+the main pathway. Also see [2; Fig. 4e].
+
+Typical use:
+
+   from tensorflow.contrib.slim.nets import resnet_v2
+
+ResNet-101 for image classification into 1000 classes:
+
+   # inputs has shape [batch, 224, 224, 3]
+   with slim.arg_scope(resnet_v2.resnet_arg_scope()):
+      net, end_points = resnet_v2.resnet_v2_101(inputs, 1000, is_training=False)
+
+ResNet-101 for semantic segmentation into 21 classes:
+
+   # inputs has shape [batch, 513, 513, 3]
+   with slim.arg_scope(resnet_v2.resnet_arg_scope(is_training)):
+      net, end_points = resnet_v2.resnet_v2_101(inputs,
+                                                21,
+                                                is_training=False,
+                                                global_pool=False,
+                                                output_stride=16)
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import resnet_utils
+
+slim = tf.contrib.slim
+resnet_arg_scope = resnet_utils.resnet_arg_scope
+
+
+@slim.add_arg_scope
+def bottleneck(inputs, depth, depth_bottleneck, stride, rate=1,
+               outputs_collections=None, scope=None):
+  """Bottleneck residual unit variant with BN before convolutions.
+
+  This is the full preactivation residual unit variant proposed in [2]. See
+  Fig. 1(b) of [2] for its definition. Note that we use here the bottleneck
+  variant which has an extra bottleneck layer.
+
+  When putting together two consecutive ResNet blocks that use this unit, one
+  should use stride = 2 in the last unit of the first block.
+
+  Args:
+    inputs: A tensor of size [batch, height, width, channels].
+    depth: The depth of the ResNet unit output.
+    depth_bottleneck: The depth of the bottleneck layers.
+    stride: The ResNet unit's stride. Determines the amount of downsampling of
+      the units output compared to its input.
+    rate: An integer, rate for atrous convolution.
+    outputs_collections: Collection to add the ResNet unit output.
+    scope: Optional variable_scope.
+
+  Returns:
+    The ResNet unit's output.
+  """
+  with tf.variable_scope(scope, 'bottleneck_v2', [inputs]) as sc:
+    depth_in = slim.utils.last_dimension(inputs.get_shape(), min_rank=4)
+    preact = slim.batch_norm(inputs, activation_fn=tf.nn.relu, scope='preact')
+    if depth == depth_in:
+      shortcut = resnet_utils.subsample(inputs, stride, 'shortcut')
+    else:
+      shortcut = slim.conv2d(preact, depth, [1, 1], stride=stride,
+                             normalizer_fn=None, activation_fn=None,
+                             scope='shortcut')
+
+    residual = slim.conv2d(preact, depth_bottleneck, [1, 1], stride=1,
+                           scope='conv1')
+    residual = resnet_utils.conv2d_same(residual, depth_bottleneck, 3, stride,
+                                        rate=rate, scope='conv2')
+    residual = slim.conv2d(residual, depth, [1, 1], stride=1,
+                           normalizer_fn=None, activation_fn=None,
+                           scope='conv3')
+
+    output = shortcut + residual
+
+    return slim.utils.collect_named_outputs(outputs_collections,
+                                            sc.original_name_scope,
+                                            output)
+
+
+def resnet_v2(inputs,
+              blocks,
+              num_classes=None,
+              is_training=True,
+              global_pool=True,
+              output_stride=None,
+              include_root_block=True,
+              reuse=None,
+              scope=None):
+  """Generator for v2 (preactivation) ResNet models.
+
+  This function generates a family of ResNet v2 models. See the resnet_v2_*()
+  methods for specific model instantiations, obtained by selecting different
+  block instantiations that produce ResNets of various depths.
+
+  Training for image classification on Imagenet is usually done with [224, 224]
+  inputs, resulting in [7, 7] feature maps at the output of the last ResNet
+  block for the ResNets defined in [1] that have nominal stride equal to 32.
+  However, for dense prediction tasks we advise that one uses inputs with
+  spatial dimensions that are multiples of 32 plus 1, e.g., [321, 321]. In
+  this case the feature maps at the ResNet output will have spatial shape
+  [(height - 1) / output_stride + 1, (width - 1) / output_stride + 1]
+  and corners exactly aligned with the input image corners, which greatly
+  facilitates alignment of the features to the image. Using as input [225, 225]
+  images results in [8, 8] feature maps at the output of the last ResNet block.
+
+  For dense prediction tasks, the ResNet needs to run in fully-convolutional
+  (FCN) mode and global_pool needs to be set to False. The ResNets in [1, 2] all
+  have nominal stride equal to 32 and a good choice in FCN mode is to use
+  output_stride=16 in order to increase the density of the computed features at
+  small computational and memory overhead, cf. http://arxiv.org/abs/1606.00915.
+
+  Args:
+    inputs: A tensor of size [batch, height_in, width_in, channels].
+    blocks: A list of length equal to the number of ResNet blocks. Each element
+      is a resnet_utils.Block object describing the units in the block.
+    num_classes: Number of predicted classes for classification tasks. If None
+      we return the features before the logit layer.
+    is_training: whether is training or not.
+    global_pool: If True, we perform global average pooling before computing the
+      logits. Set to True for image classification, False for dense prediction.
+    output_stride: If None, then the output will be computed at the nominal
+      network stride. If output_stride is not None, it specifies the requested
+      ratio of input to output spatial resolution.
+    include_root_block: If True, include the initial convolution followed by
+      max-pooling, if False excludes it. If excluded, `inputs` should be the
+      results of an activation-less convolution.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+
+
+  Returns:
+    net: A rank-4 tensor of size [batch, height_out, width_out, channels_out].
+      If global_pool is False, then height_out and width_out are reduced by a
+      factor of output_stride compared to the respective height_in and width_in,
+      else both height_out and width_out equal one. If num_classes is None, then
+      net is the output of the last ResNet block, potentially after global
+      average pooling. If num_classes is not None, net contains the pre-softmax
+      activations.
+    end_points: A dictionary from components of the network to the corresponding
+      activation.
+
+  Raises:
+    ValueError: If the target output_stride is not valid.
+  """
+  with tf.variable_scope(scope, 'resnet_v2', [inputs], reuse=reuse) as sc:
+    end_points_collection = sc.name + '_end_points'
+    with slim.arg_scope([slim.conv2d, bottleneck,
+                         resnet_utils.stack_blocks_dense],
+                        outputs_collections=end_points_collection):
+      with slim.arg_scope([slim.batch_norm], is_training=is_training):
+        net = inputs
+        if include_root_block:
+          if output_stride is not None:
+            if output_stride % 4 != 0:
+              raise ValueError('The output_stride needs to be a multiple of 4.')
+            output_stride /= 4
+          # We do not include batch normalization or activation functions in
+          # conv1 because the first ResNet unit will perform these. Cf.
+          # Appendix of [2].
+          with slim.arg_scope([slim.conv2d],
+                              activation_fn=None, normalizer_fn=None):
+            net = resnet_utils.conv2d_same(net, 64, 7, stride=2, scope='conv1')
+          net = slim.max_pool2d(net, [3, 3], stride=2, scope='pool1')
+        net = resnet_utils.stack_blocks_dense(net, blocks, output_stride)
+        # This is needed because the pre-activation variant does not have batch
+        # normalization or activation functions in the residual unit output. See
+        # Appendix of [2].
+        net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='postnorm')
+        if global_pool:
+          # Global average pooling.
+          net = tf.reduce_mean(net, [1, 2], name='pool5', keep_dims=True)
+        if num_classes is not None:
+          net = slim.conv2d(net, num_classes, [1, 1], activation_fn=None,
+                            normalizer_fn=None, scope='logits')
+        # Convert end_points_collection into a dictionary of end_points.
+        end_points = dict(tf.get_collection(end_points_collection))
+        if num_classes is not None:
+          end_points['predictions'] = slim.softmax(net, scope='predictions')
+        return net, end_points
+resnet_v2.default_image_size = 224
+
+
+def resnet_v2_50(inputs,
+                 num_classes=None,
+                 is_training=True,
+                 global_pool=True,
+                 output_stride=None,
+                 reuse=None,
+                 scope='resnet_v2_50'):
+  """ResNet-50 model of [1]. See resnet_v2() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 5 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v2_101(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v2_101'):
+  """ResNet-101 model of [1]. See resnet_v2() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 3 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 22 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v2_152(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v2_152'):
+  """ResNet-152 model of [1]. See resnet_v2() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 7 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)
+
+
+def resnet_v2_200(inputs,
+                  num_classes=None,
+                  is_training=True,
+                  global_pool=True,
+                  output_stride=None,
+                  reuse=None,
+                  scope='resnet_v2_200'):
+  """ResNet-200 model of [2]. See resnet_v2() for arg and return description."""
+  blocks = [
+      resnet_utils.Block(
+          'block1', bottleneck, [(256, 64, 1)] * 2 + [(256, 64, 2)]),
+      resnet_utils.Block(
+          'block2', bottleneck, [(512, 128, 1)] * 23 + [(512, 128, 2)]),
+      resnet_utils.Block(
+          'block3', bottleneck, [(1024, 256, 1)] * 35 + [(1024, 256, 2)]),
+      resnet_utils.Block(
+          'block4', bottleneck, [(2048, 512, 1)] * 3)]
+  return resnet_v2(inputs, blocks, num_classes, is_training=is_training,
+                   global_pool=global_pool, output_stride=output_stride,
+                   include_root_block=True, reuse=reuse, scope=scope)

+ 453 - 0
slim/nets/resnet_v2_test.py

@@ -0,0 +1,453 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.resnet_v2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from nets import resnet_utils
+from nets import resnet_v2
+
+slim = tf.contrib.slim
+
+
+def create_test_input(batch_size, height, width, channels):
+  """Create test input tensor.
+
+  Args:
+    batch_size: The number of images per batch or `None` if unknown.
+    height: The height of each image or `None` if unknown.
+    width: The width of each image or `None` if unknown.
+    channels: The number of channels per image or `None` if unknown.
+
+  Returns:
+    Either a placeholder `Tensor` of dimension
+      [batch_size, height, width, channels] if any of the inputs are `None` or a
+    constant `Tensor` with the mesh grid values along the spatial dimensions.
+  """
+  if None in [batch_size, height, width, channels]:
+    return tf.placeholder(tf.float32, (batch_size, height, width, channels))
+  else:
+    return tf.to_float(
+        np.tile(
+            np.reshape(
+                np.reshape(np.arange(height), [height, 1]) +
+                np.reshape(np.arange(width), [1, width]),
+                [1, height, width, 1]),
+            [batch_size, 1, 1, channels]))
+
+
+class ResnetUtilsTest(tf.test.TestCase):
+
+  def testSubsampleThreeByThree(self):
+    x = tf.reshape(tf.to_float(tf.range(9)), [1, 3, 3, 1])
+    x = resnet_utils.subsample(x, 2)
+    expected = tf.reshape(tf.constant([0, 2, 6, 8]), [1, 2, 2, 1])
+    with self.test_session():
+      self.assertAllClose(x.eval(), expected.eval())
+
+  def testSubsampleFourByFour(self):
+    x = tf.reshape(tf.to_float(tf.range(16)), [1, 4, 4, 1])
+    x = resnet_utils.subsample(x, 2)
+    expected = tf.reshape(tf.constant([0, 2, 8, 10]), [1, 2, 2, 1])
+    with self.test_session():
+      self.assertAllClose(x.eval(), expected.eval())
+
+  def testConv2DSameEven(self):
+    n, n2 = 4, 2
+
+    # Input image.
+    x = create_test_input(1, n, n, 1)
+
+    # Convolution kernel.
+    w = create_test_input(1, 3, 3, 1)
+    w = tf.reshape(w, [3, 3, 1, 1])
+
+    tf.get_variable('Conv/weights', initializer=w)
+    tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+    tf.get_variable_scope().reuse_variables()
+
+    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
+    y1_expected = tf.to_float([[14, 28, 43, 26],
+                               [28, 48, 66, 37],
+                               [43, 66, 84, 46],
+                               [26, 37, 46, 22]])
+    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+    y2 = resnet_utils.subsample(y1, 2)
+    y2_expected = tf.to_float([[14, 43],
+                               [43, 84]])
+    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
+    y3_expected = y2_expected
+
+    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
+    y4_expected = tf.to_float([[48, 37],
+                               [37, 22]])
+    y4_expected = tf.reshape(y4_expected, [1, n2, n2, 1])
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      self.assertAllClose(y1.eval(), y1_expected.eval())
+      self.assertAllClose(y2.eval(), y2_expected.eval())
+      self.assertAllClose(y3.eval(), y3_expected.eval())
+      self.assertAllClose(y4.eval(), y4_expected.eval())
+
+  def testConv2DSameOdd(self):
+    n, n2 = 5, 3
+
+    # Input image.
+    x = create_test_input(1, n, n, 1)
+
+    # Convolution kernel.
+    w = create_test_input(1, 3, 3, 1)
+    w = tf.reshape(w, [3, 3, 1, 1])
+
+    tf.get_variable('Conv/weights', initializer=w)
+    tf.get_variable('Conv/biases', initializer=tf.zeros([1]))
+    tf.get_variable_scope().reuse_variables()
+
+    y1 = slim.conv2d(x, 1, [3, 3], stride=1, scope='Conv')
+    y1_expected = tf.to_float([[14, 28, 43, 58, 34],
+                               [28, 48, 66, 84, 46],
+                               [43, 66, 84, 102, 55],
+                               [58, 84, 102, 120, 64],
+                               [34, 46, 55, 64, 30]])
+    y1_expected = tf.reshape(y1_expected, [1, n, n, 1])
+
+    y2 = resnet_utils.subsample(y1, 2)
+    y2_expected = tf.to_float([[14, 43, 34],
+                               [43, 84, 55],
+                               [34, 55, 30]])
+    y2_expected = tf.reshape(y2_expected, [1, n2, n2, 1])
+
+    y3 = resnet_utils.conv2d_same(x, 1, 3, stride=2, scope='Conv')
+    y3_expected = y2_expected
+
+    y4 = slim.conv2d(x, 1, [3, 3], stride=2, scope='Conv')
+    y4_expected = y2_expected
+
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      self.assertAllClose(y1.eval(), y1_expected.eval())
+      self.assertAllClose(y2.eval(), y2_expected.eval())
+      self.assertAllClose(y3.eval(), y3_expected.eval())
+      self.assertAllClose(y4.eval(), y4_expected.eval())
+
+  def _resnet_plain(self, inputs, blocks, output_stride=None, scope=None):
+    """A plain ResNet without extra layers before or after the ResNet blocks."""
+    with tf.variable_scope(scope, values=[inputs]):
+      with slim.arg_scope([slim.conv2d], outputs_collections='end_points'):
+        net = resnet_utils.stack_blocks_dense(inputs, blocks, output_stride)
+        end_points = dict(tf.get_collection('end_points'))
+        return net, end_points
+
+  def testEndPointsV2(self):
+    """Test the end points of a tiny v2 bottleneck network."""
+    bottleneck = resnet_v2.bottleneck
+    blocks = [resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
+              resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 1)])]
+    inputs = create_test_input(2, 32, 16, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_plain(inputs, blocks, scope='tiny')
+    expected = [
+        'tiny/block1/unit_1/bottleneck_v2/shortcut',
+        'tiny/block1/unit_1/bottleneck_v2/conv1',
+        'tiny/block1/unit_1/bottleneck_v2/conv2',
+        'tiny/block1/unit_1/bottleneck_v2/conv3',
+        'tiny/block1/unit_2/bottleneck_v2/conv1',
+        'tiny/block1/unit_2/bottleneck_v2/conv2',
+        'tiny/block1/unit_2/bottleneck_v2/conv3',
+        'tiny/block2/unit_1/bottleneck_v2/shortcut',
+        'tiny/block2/unit_1/bottleneck_v2/conv1',
+        'tiny/block2/unit_1/bottleneck_v2/conv2',
+        'tiny/block2/unit_1/bottleneck_v2/conv3',
+        'tiny/block2/unit_2/bottleneck_v2/conv1',
+        'tiny/block2/unit_2/bottleneck_v2/conv2',
+        'tiny/block2/unit_2/bottleneck_v2/conv3']
+    self.assertItemsEqual(expected, end_points)
+
+  def _stack_blocks_nondense(self, net, blocks):
+    """A simplified ResNet Block stacker without output stride control."""
+    for block in blocks:
+      with tf.variable_scope(block.scope, 'block', [net]):
+        for i, unit in enumerate(block.args):
+          depth, depth_bottleneck, stride = unit
+          with tf.variable_scope('unit_%d' % (i + 1), values=[net]):
+            net = block.unit_fn(net,
+                                depth=depth,
+                                depth_bottleneck=depth_bottleneck,
+                                stride=stride,
+                                rate=1)
+    return net
+
+  def _atrousValues(self, bottleneck):
+    """Verify the values of dense feature extraction by atrous convolution.
+
+    Make sure that dense feature extraction by stack_blocks_dense() followed by
+    subsampling gives identical results to feature extraction at the nominal
+    network output stride using the simple self._stack_blocks_nondense() above.
+
+    Args:
+      bottleneck: The bottleneck function.
+    """
+    blocks = [
+        resnet_utils.Block('block1', bottleneck, [(4, 1, 1), (4, 1, 2)]),
+        resnet_utils.Block('block2', bottleneck, [(8, 2, 1), (8, 2, 2)]),
+        resnet_utils.Block('block3', bottleneck, [(16, 4, 1), (16, 4, 2)]),
+        resnet_utils.Block('block4', bottleneck, [(32, 8, 1), (32, 8, 1)])
+    ]
+    nominal_stride = 8
+
+    # Test both odd and even input dimensions.
+    height = 30
+    width = 31
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      with slim.arg_scope([slim.batch_norm], is_training=False):
+        for output_stride in [1, 2, 4, 8, None]:
+          with tf.Graph().as_default():
+            with self.test_session() as sess:
+              tf.set_random_seed(0)
+              inputs = create_test_input(1, height, width, 3)
+              # Dense feature extraction followed by subsampling.
+              output = resnet_utils.stack_blocks_dense(inputs,
+                                                       blocks,
+                                                       output_stride)
+              if output_stride is None:
+                factor = 1
+              else:
+                factor = nominal_stride // output_stride
+
+              output = resnet_utils.subsample(output, factor)
+              # Make the two networks use the same weights.
+              tf.get_variable_scope().reuse_variables()
+              # Feature extraction at the nominal network rate.
+              expected = self._stack_blocks_nondense(inputs, blocks)
+              sess.run(tf.initialize_all_variables())
+              output, expected = sess.run([output, expected])
+              self.assertAllClose(output, expected, atol=1e-4, rtol=1e-4)
+
+  def testAtrousValuesBottleneck(self):
+    self._atrousValues(resnet_v2.bottleneck)
+
+
+class ResnetCompleteNetworkTest(tf.test.TestCase):
+  """Tests with complete small ResNet v2 networks."""
+
+  def _resnet_small(self,
+                    inputs,
+                    num_classes=None,
+                    is_training=True,
+                    global_pool=True,
+                    output_stride=None,
+                    include_root_block=True,
+                    reuse=None,
+                    scope='resnet_v2_small'):
+    """A shallow and thin ResNet v2 for faster tests."""
+    bottleneck = resnet_v2.bottleneck
+    blocks = [
+        resnet_utils.Block(
+            'block1', bottleneck, [(4, 1, 1)] * 2 + [(4, 1, 2)]),
+        resnet_utils.Block(
+            'block2', bottleneck, [(8, 2, 1)] * 2 + [(8, 2, 2)]),
+        resnet_utils.Block(
+            'block3', bottleneck, [(16, 4, 1)] * 2 + [(16, 4, 2)]),
+        resnet_utils.Block(
+            'block4', bottleneck, [(32, 8, 1)] * 2)]
+    return resnet_v2.resnet_v2(inputs, blocks, num_classes,
+                               is_training=is_training,
+                               global_pool=global_pool,
+                               output_stride=output_stride,
+                               include_root_block=include_root_block,
+                               reuse=reuse,
+                               scope=scope)
+
+  def testClassificationEndPoints(self):
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(2, 224, 224, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      logits, end_points = self._resnet_small(inputs, num_classes,
+                                              global_pool=global_pool,
+                                              scope='resnet')
+    self.assertTrue(logits.op.name.startswith('resnet/logits'))
+    self.assertListEqual(logits.get_shape().as_list(), [2, 1, 1, num_classes])
+    self.assertTrue('predictions' in end_points)
+    self.assertListEqual(end_points['predictions'].get_shape().as_list(),
+                         [2, 1, 1, num_classes])
+
+  def testClassificationShapes(self):
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(2, 224, 224, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 28, 28, 4],
+          'resnet/block2': [2, 14, 14, 8],
+          'resnet/block3': [2, 7, 7, 16],
+          'resnet/block4': [2, 7, 7, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    inputs = create_test_input(2, 321, 321, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 41, 41, 4],
+          'resnet/block2': [2, 21, 21, 8],
+          'resnet/block3': [2, 11, 11, 16],
+          'resnet/block4': [2, 11, 11, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testRootlessFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    inputs = create_test_input(2, 128, 128, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs, num_classes,
+                                         global_pool=global_pool,
+                                         include_root_block=False,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 64, 64, 4],
+          'resnet/block2': [2, 32, 32, 8],
+          'resnet/block3': [2, 16, 16, 16],
+          'resnet/block4': [2, 16, 16, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testAtrousFullyConvolutionalEndpointShapes(self):
+    global_pool = False
+    num_classes = 10
+    output_stride = 8
+    inputs = create_test_input(2, 321, 321, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      _, end_points = self._resnet_small(inputs,
+                                         num_classes,
+                                         global_pool=global_pool,
+                                         output_stride=output_stride,
+                                         scope='resnet')
+      endpoint_to_shape = {
+          'resnet/block1': [2, 41, 41, 4],
+          'resnet/block2': [2, 41, 41, 8],
+          'resnet/block3': [2, 41, 41, 16],
+          'resnet/block4': [2, 41, 41, 32]}
+      for endpoint in endpoint_to_shape:
+        shape = endpoint_to_shape[endpoint]
+        self.assertListEqual(end_points[endpoint].get_shape().as_list(), shape)
+
+  def testAtrousFullyConvolutionalValues(self):
+    """Verify dense feature extraction with atrous convolution."""
+    nominal_stride = 32
+    for output_stride in [4, 8, 16, 32, None]:
+      with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+        with tf.Graph().as_default():
+          with self.test_session() as sess:
+            tf.set_random_seed(0)
+            inputs = create_test_input(2, 81, 81, 3)
+            # Dense feature extraction followed by subsampling.
+            output, _ = self._resnet_small(inputs, None,
+                                           is_training=False,
+                                           global_pool=False,
+                                           output_stride=output_stride)
+            if output_stride is None:
+              factor = 1
+            else:
+              factor = nominal_stride // output_stride
+            output = resnet_utils.subsample(output, factor)
+            # Make the two networks use the same weights.
+            tf.get_variable_scope().reuse_variables()
+            # Feature extraction at the nominal network rate.
+            expected, _ = self._resnet_small(inputs, None,
+                                             is_training=False,
+                                             global_pool=False)
+            sess.run(tf.initialize_all_variables())
+            self.assertAllClose(output.eval(), expected.eval(),
+                                atol=1e-4, rtol=1e-4)
+
+  def testUnknownBatchSize(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = True
+    num_classes = 10
+    inputs = create_test_input(None, height, width, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      logits, _ = self._resnet_small(inputs, num_classes,
+                                     global_pool=global_pool,
+                                     scope='resnet')
+    self.assertTrue(logits.op.name.startswith('resnet/logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [None, 1, 1, num_classes])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 1, 1, num_classes))
+
+  def testFullyConvolutionalUnknownHeightWidth(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = False
+    inputs = create_test_input(batch, None, None, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      output, _ = self._resnet_small(inputs, None,
+                                     global_pool=global_pool)
+    self.assertListEqual(output.get_shape().as_list(),
+                         [batch, None, None, 32])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(output, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 3, 3, 32))
+
+  def testAtrousFullyConvolutionalUnknownHeightWidth(self):
+    batch = 2
+    height, width = 65, 65
+    global_pool = False
+    output_stride = 8
+    inputs = create_test_input(batch, None, None, 3)
+    with slim.arg_scope(resnet_utils.resnet_arg_scope()):
+      output, _ = self._resnet_small(inputs,
+                                     None,
+                                     global_pool=global_pool,
+                                     output_stride=output_stride)
+    self.assertListEqual(output.get_shape().as_list(),
+                         [batch, None, None, 32])
+    images = create_test_input(batch, height, width, 3)
+    with self.test_session() as sess:
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(output, {inputs: images.eval()})
+      self.assertEqual(output.shape, (batch, 9, 9, 32))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 244 - 0
slim/nets/vgg.py

@@ -0,0 +1,244 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains model definitions for versions of the Oxford VGG network.
+
+These model definitions were introduced in the following technical report:
+
+  Very Deep Convolutional Networks For Large-Scale Image Recognition
+  Karen Simonyan and Andrew Zisserman
+  arXiv technical report, 2015
+  PDF: http://arxiv.org/pdf/1409.1556.pdf
+  ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
+  CC-BY-4.0
+
+More information can be obtained from the VGG website:
+www.robots.ox.ac.uk/~vgg/research/very_deep/
+
+Usage:
+  with slim.arg_scope(vgg.vgg_arg_scope()):
+    outputs, end_points = vgg.vgg_a(inputs)
+
+  with slim.arg_scope(vgg.vgg_arg_scope()):
+    outputs, end_points = vgg.vgg_16(inputs)
+
+@@vgg_a
+@@vgg_16
+@@vgg_19
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def vgg_arg_scope(weight_decay=0.0005):
+  """Defines the VGG arg scope.
+
+  Args:
+    weight_decay: The l2 regularization coefficient.
+
+  Returns:
+    An arg_scope.
+  """
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      activation_fn=tf.nn.relu,
+                      weights_regularizer=slim.l2_regularizer(weight_decay),
+                      biases_initializer=tf.zeros_initializer):
+    with slim.arg_scope([slim.conv2d], padding='SAME') as arg_sc:
+      return arg_sc
+
+
+def vgg_a(inputs,
+          num_classes=1000,
+          is_training=True,
+          dropout_keep_prob=0.5,
+          spatial_squeeze=True,
+          scope='vgg_a'):
+  """Oxford Net VGG 11-Layers version A Example.
+
+  Note: All the fully_connected layers have been transformed to conv2d layers.
+        To use in classification mode, resize input to 224x224.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether or not the model is being trained.
+    dropout_keep_prob: the probability that activations are kept in the dropout
+      layers during training.
+    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+      outputs. Useful to remove unnecessary dimensions for classification.
+    scope: Optional scope for the variables.
+
+  Returns:
+    the last op containing the log predictions and end_points dict.
+  """
+  with tf.variable_scope(scope, 'vgg_a', [inputs]) as sc:
+    end_points_collection = sc.name + '_end_points'
+    # Collect outputs for conv2d, fully_connected and max_pool2d.
+    with slim.arg_scope([slim.conv2d, slim.max_pool2d],
+                        outputs_collections=end_points_collection):
+      net = slim.repeat(inputs, 1, slim.conv2d, 64, [3, 3], scope='conv1')
+      net = slim.max_pool2d(net, [2, 2], scope='pool1')
+      net = slim.repeat(net, 1, slim.conv2d, 128, [3, 3], scope='conv2')
+      net = slim.max_pool2d(net, [2, 2], scope='pool2')
+      net = slim.repeat(net, 2, slim.conv2d, 256, [3, 3], scope='conv3')
+      net = slim.max_pool2d(net, [2, 2], scope='pool3')
+      net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv4')
+      net = slim.max_pool2d(net, [2, 2], scope='pool4')
+      net = slim.repeat(net, 2, slim.conv2d, 512, [3, 3], scope='conv5')
+      net = slim.max_pool2d(net, [2, 2], scope='pool5')
+      # Use conv2d instead of fully_connected layers.
+      net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout6')
+      net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout7')
+      net = slim.conv2d(net, num_classes, [1, 1],
+                        activation_fn=None,
+                        normalizer_fn=None,
+                        scope='fc8')
+      # Convert end_points_collection into a end_point dict.
+      end_points = dict(tf.get_collection(end_points_collection))
+      if spatial_squeeze:
+        net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+        end_points[sc.name + '/fc8'] = net
+      return net, end_points
+vgg_a.default_image_size = 224
+
+
+def vgg_16(inputs,
+           num_classes=1000,
+           is_training=True,
+           dropout_keep_prob=0.5,
+           spatial_squeeze=True,
+           scope='vgg_16'):
+  """Oxford Net VGG 16-Layers version D Example.
+
+  Note: All the fully_connected layers have been transformed to conv2d layers.
+        To use in classification mode, resize input to 224x224.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether or not the model is being trained.
+    dropout_keep_prob: the probability that activations are kept in the dropout
+      layers during training.
+    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+      outputs. Useful to remove unnecessary dimensions for classification.
+    scope: Optional scope for the variables.
+
+  Returns:
+    the last op containing the log predictions and end_points dict.
+  """
+  with tf.variable_scope(scope, 'vgg_16', [inputs]) as sc:
+    end_points_collection = sc.name + '_end_points'
+    # Collect outputs for conv2d, fully_connected and max_pool2d.
+    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+                        outputs_collections=end_points_collection):
+      net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
+      net = slim.max_pool2d(net, [2, 2], scope='pool1')
+      net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
+      net = slim.max_pool2d(net, [2, 2], scope='pool2')
+      net = slim.repeat(net, 3, slim.conv2d, 256, [3, 3], scope='conv3')
+      net = slim.max_pool2d(net, [2, 2], scope='pool3')
+      net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv4')
+      net = slim.max_pool2d(net, [2, 2], scope='pool4')
+      net = slim.repeat(net, 3, slim.conv2d, 512, [3, 3], scope='conv5')
+      net = slim.max_pool2d(net, [2, 2], scope='pool5')
+      # Use conv2d instead of fully_connected layers.
+      net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout6')
+      net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout7')
+      net = slim.conv2d(net, num_classes, [1, 1],
+                        activation_fn=None,
+                        normalizer_fn=None,
+                        scope='fc8')
+      # Convert end_points_collection into a end_point dict.
+      end_points = dict(tf.get_collection(end_points_collection))
+      if spatial_squeeze:
+        net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+        end_points[sc.name + '/fc8'] = net
+      return net, end_points
+vgg_16.default_image_size = 224
+
+
+def vgg_19(inputs,
+           num_classes=1000,
+           is_training=True,
+           dropout_keep_prob=0.5,
+           spatial_squeeze=True,
+           scope='vgg_19'):
+  """Oxford Net VGG 19-Layers version E Example.
+
+  Note: All the fully_connected layers have been transformed to conv2d layers.
+        To use in classification mode, resize input to 224x224.
+
+  Args:
+    inputs: a tensor of size [batch_size, height, width, channels].
+    num_classes: number of predicted classes.
+    is_training: whether or not the model is being trained.
+    dropout_keep_prob: the probability that activations are kept in the dropout
+      layers during training.
+    spatial_squeeze: whether or not should squeeze the spatial dimensions of the
+      outputs. Useful to remove unnecessary dimensions for classification.
+    scope: Optional scope for the variables.
+
+  Returns:
+    the last op containing the log predictions and end_points dict.
+  """
+  with tf.variable_scope(scope, 'vgg_19', [inputs]) as sc:
+    end_points_collection = sc.name + '_end_points'
+    # Collect outputs for conv2d, fully_connected and max_pool2d.
+    with slim.arg_scope([slim.conv2d, slim.fully_connected, slim.max_pool2d],
+                        outputs_collections=end_points_collection):
+      net = slim.repeat(inputs, 2, slim.conv2d, 64, [3, 3], scope='conv1')
+      net = slim.max_pool2d(net, [2, 2], scope='pool1')
+      net = slim.repeat(net, 2, slim.conv2d, 128, [3, 3], scope='conv2')
+      net = slim.max_pool2d(net, [2, 2], scope='pool2')
+      net = slim.repeat(net, 4, slim.conv2d, 256, [3, 3], scope='conv3')
+      net = slim.max_pool2d(net, [2, 2], scope='pool3')
+      net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv4')
+      net = slim.max_pool2d(net, [2, 2], scope='pool4')
+      net = slim.repeat(net, 4, slim.conv2d, 512, [3, 3], scope='conv5')
+      net = slim.max_pool2d(net, [2, 2], scope='pool5')
+      # Use conv2d instead of fully_connected layers.
+      net = slim.conv2d(net, 4096, [7, 7], padding='VALID', scope='fc6')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout6')
+      net = slim.conv2d(net, 4096, [1, 1], scope='fc7')
+      net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
+                         scope='dropout7')
+      net = slim.conv2d(net, num_classes, [1, 1],
+                        activation_fn=None,
+                        normalizer_fn=None,
+                        scope='fc8')
+      # Convert end_points_collection into a end_point dict.
+      end_points = dict(tf.get_collection(end_points_collection))
+      if spatial_squeeze:
+        net = tf.squeeze(net, [1, 2], name='fc8/squeezed')
+        end_points[sc.name + '/fc8'] = net
+      return net, end_points
+vgg_19.default_image_size = 224
+
+# Alias
+vgg_d = vgg_16
+vgg_e = vgg_19

+ 455 - 0
slim/nets/vgg_test.py

@@ -0,0 +1,455 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.nets.vgg."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import vgg
+
+slim = tf.contrib.slim
+
+
+class VGGATest(tf.test.TestCase):
+
+  def testBuild(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_a(inputs, num_classes)
+      self.assertEquals(logits.op.name, 'vgg_a/fc8/squeezed')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testFullyConvolutional(self):
+    batch_size = 1
+    height, width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_a(inputs, num_classes, spatial_squeeze=False)
+      self.assertEquals(logits.op.name, 'vgg_a/fc8/BiasAdd')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, 2, 2, num_classes])
+
+  def testEndPoints(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = vgg.vgg_a(inputs, num_classes)
+      expected_names = ['vgg_a/conv1/conv1_1',
+                        'vgg_a/pool1',
+                        'vgg_a/conv2/conv2_1',
+                        'vgg_a/pool2',
+                        'vgg_a/conv3/conv3_1',
+                        'vgg_a/conv3/conv3_2',
+                        'vgg_a/pool3',
+                        'vgg_a/conv4/conv4_1',
+                        'vgg_a/conv4/conv4_2',
+                        'vgg_a/pool4',
+                        'vgg_a/conv5/conv5_1',
+                        'vgg_a/conv5/conv5_2',
+                        'vgg_a/pool5',
+                        'vgg_a/fc6',
+                        'vgg_a/fc7',
+                        'vgg_a/fc8'
+                       ]
+      self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+  def testModelVariables(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      vgg.vgg_a(inputs, num_classes)
+      expected_names = ['vgg_a/conv1/conv1_1/weights',
+                        'vgg_a/conv1/conv1_1/biases',
+                        'vgg_a/conv2/conv2_1/weights',
+                        'vgg_a/conv2/conv2_1/biases',
+                        'vgg_a/conv3/conv3_1/weights',
+                        'vgg_a/conv3/conv3_1/biases',
+                        'vgg_a/conv3/conv3_2/weights',
+                        'vgg_a/conv3/conv3_2/biases',
+                        'vgg_a/conv4/conv4_1/weights',
+                        'vgg_a/conv4/conv4_1/biases',
+                        'vgg_a/conv4/conv4_2/weights',
+                        'vgg_a/conv4/conv4_2/biases',
+                        'vgg_a/conv5/conv5_1/weights',
+                        'vgg_a/conv5/conv5_1/biases',
+                        'vgg_a/conv5/conv5_2/weights',
+                        'vgg_a/conv5/conv5_2/biases',
+                        'vgg_a/fc6/weights',
+                        'vgg_a/fc6/biases',
+                        'vgg_a/fc7/weights',
+                        'vgg_a/fc7/biases',
+                        'vgg_a/fc8/weights',
+                        'vgg_a/fc8/biases',
+                       ]
+      model_variables = [v.op.name for v in slim.get_model_variables()]
+      self.assertSetEqual(set(model_variables), set(expected_names))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_a(eval_inputs, is_training=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      predictions = tf.argmax(logits, 1)
+      self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 2
+    eval_batch_size = 1
+    train_height, train_width = 224, 224
+    eval_height, eval_width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      train_inputs = tf.random_uniform(
+          (train_batch_size, train_height, train_width, 3))
+      logits, _ = vgg.vgg_a(train_inputs)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [train_batch_size, num_classes])
+      tf.get_variable_scope().reuse_variables()
+      eval_inputs = tf.random_uniform(
+          (eval_batch_size, eval_height, eval_width, 3))
+      logits, _ = vgg.vgg_a(eval_inputs, is_training=False,
+                            spatial_squeeze=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [eval_batch_size, 2, 2, num_classes])
+      logits = tf.reduce_mean(logits, [1, 2])
+      predictions = tf.argmax(logits, 1)
+      self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+  def testForward(self):
+    batch_size = 1
+    height, width = 224, 224
+    with self.test_session() as sess:
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_a(inputs)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits)
+      self.assertTrue(output.any())
+
+
+class VGG16Test(tf.test.TestCase):
+
+  def testBuild(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_16(inputs, num_classes)
+      self.assertEquals(logits.op.name, 'vgg_16/fc8/squeezed')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testFullyConvolutional(self):
+    batch_size = 1
+    height, width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_16(inputs, num_classes, spatial_squeeze=False)
+      self.assertEquals(logits.op.name, 'vgg_16/fc8/BiasAdd')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, 2, 2, num_classes])
+
+  def testEndPoints(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = vgg.vgg_16(inputs, num_classes)
+      expected_names = ['vgg_16/conv1/conv1_1',
+                        'vgg_16/conv1/conv1_2',
+                        'vgg_16/pool1',
+                        'vgg_16/conv2/conv2_1',
+                        'vgg_16/conv2/conv2_2',
+                        'vgg_16/pool2',
+                        'vgg_16/conv3/conv3_1',
+                        'vgg_16/conv3/conv3_2',
+                        'vgg_16/conv3/conv3_3',
+                        'vgg_16/pool3',
+                        'vgg_16/conv4/conv4_1',
+                        'vgg_16/conv4/conv4_2',
+                        'vgg_16/conv4/conv4_3',
+                        'vgg_16/pool4',
+                        'vgg_16/conv5/conv5_1',
+                        'vgg_16/conv5/conv5_2',
+                        'vgg_16/conv5/conv5_3',
+                        'vgg_16/pool5',
+                        'vgg_16/fc6',
+                        'vgg_16/fc7',
+                        'vgg_16/fc8'
+                       ]
+      self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+  def testModelVariables(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      vgg.vgg_16(inputs, num_classes)
+      expected_names = ['vgg_16/conv1/conv1_1/weights',
+                        'vgg_16/conv1/conv1_1/biases',
+                        'vgg_16/conv1/conv1_2/weights',
+                        'vgg_16/conv1/conv1_2/biases',
+                        'vgg_16/conv2/conv2_1/weights',
+                        'vgg_16/conv2/conv2_1/biases',
+                        'vgg_16/conv2/conv2_2/weights',
+                        'vgg_16/conv2/conv2_2/biases',
+                        'vgg_16/conv3/conv3_1/weights',
+                        'vgg_16/conv3/conv3_1/biases',
+                        'vgg_16/conv3/conv3_2/weights',
+                        'vgg_16/conv3/conv3_2/biases',
+                        'vgg_16/conv3/conv3_3/weights',
+                        'vgg_16/conv3/conv3_3/biases',
+                        'vgg_16/conv4/conv4_1/weights',
+                        'vgg_16/conv4/conv4_1/biases',
+                        'vgg_16/conv4/conv4_2/weights',
+                        'vgg_16/conv4/conv4_2/biases',
+                        'vgg_16/conv4/conv4_3/weights',
+                        'vgg_16/conv4/conv4_3/biases',
+                        'vgg_16/conv5/conv5_1/weights',
+                        'vgg_16/conv5/conv5_1/biases',
+                        'vgg_16/conv5/conv5_2/weights',
+                        'vgg_16/conv5/conv5_2/biases',
+                        'vgg_16/conv5/conv5_3/weights',
+                        'vgg_16/conv5/conv5_3/biases',
+                        'vgg_16/fc6/weights',
+                        'vgg_16/fc6/biases',
+                        'vgg_16/fc7/weights',
+                        'vgg_16/fc7/biases',
+                        'vgg_16/fc8/weights',
+                        'vgg_16/fc8/biases',
+                       ]
+      model_variables = [v.op.name for v in slim.get_model_variables()]
+      self.assertSetEqual(set(model_variables), set(expected_names))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_16(eval_inputs, is_training=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      predictions = tf.argmax(logits, 1)
+      self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 2
+    eval_batch_size = 1
+    train_height, train_width = 224, 224
+    eval_height, eval_width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      train_inputs = tf.random_uniform(
+          (train_batch_size, train_height, train_width, 3))
+      logits, _ = vgg.vgg_16(train_inputs)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [train_batch_size, num_classes])
+      tf.get_variable_scope().reuse_variables()
+      eval_inputs = tf.random_uniform(
+          (eval_batch_size, eval_height, eval_width, 3))
+      logits, _ = vgg.vgg_16(eval_inputs, is_training=False,
+                             spatial_squeeze=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [eval_batch_size, 2, 2, num_classes])
+      logits = tf.reduce_mean(logits, [1, 2])
+      predictions = tf.argmax(logits, 1)
+      self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+  def testForward(self):
+    batch_size = 1
+    height, width = 224, 224
+    with self.test_session() as sess:
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_16(inputs)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits)
+      self.assertTrue(output.any())
+
+
+class VGG19Test(tf.test.TestCase):
+
+  def testBuild(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_19(inputs, num_classes)
+      self.assertEquals(logits.op.name, 'vgg_19/fc8/squeezed')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+
+  def testFullyConvolutional(self):
+    batch_size = 1
+    height, width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_19(inputs, num_classes, spatial_squeeze=False)
+      self.assertEquals(logits.op.name, 'vgg_19/fc8/BiasAdd')
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, 2, 2, num_classes])
+
+  def testEndPoints(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      _, end_points = vgg.vgg_19(inputs, num_classes)
+      expected_names = [
+          'vgg_19/conv1/conv1_1',
+          'vgg_19/conv1/conv1_2',
+          'vgg_19/pool1',
+          'vgg_19/conv2/conv2_1',
+          'vgg_19/conv2/conv2_2',
+          'vgg_19/pool2',
+          'vgg_19/conv3/conv3_1',
+          'vgg_19/conv3/conv3_2',
+          'vgg_19/conv3/conv3_3',
+          'vgg_19/conv3/conv3_4',
+          'vgg_19/pool3',
+          'vgg_19/conv4/conv4_1',
+          'vgg_19/conv4/conv4_2',
+          'vgg_19/conv4/conv4_3',
+          'vgg_19/conv4/conv4_4',
+          'vgg_19/pool4',
+          'vgg_19/conv5/conv5_1',
+          'vgg_19/conv5/conv5_2',
+          'vgg_19/conv5/conv5_3',
+          'vgg_19/conv5/conv5_4',
+          'vgg_19/pool5',
+          'vgg_19/fc6',
+          'vgg_19/fc7',
+          'vgg_19/fc8'
+      ]
+      self.assertSetEqual(set(end_points.keys()), set(expected_names))
+
+  def testModelVariables(self):
+    batch_size = 5
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      vgg.vgg_19(inputs, num_classes)
+      expected_names = [
+          'vgg_19/conv1/conv1_1/weights',
+          'vgg_19/conv1/conv1_1/biases',
+          'vgg_19/conv1/conv1_2/weights',
+          'vgg_19/conv1/conv1_2/biases',
+          'vgg_19/conv2/conv2_1/weights',
+          'vgg_19/conv2/conv2_1/biases',
+          'vgg_19/conv2/conv2_2/weights',
+          'vgg_19/conv2/conv2_2/biases',
+          'vgg_19/conv3/conv3_1/weights',
+          'vgg_19/conv3/conv3_1/biases',
+          'vgg_19/conv3/conv3_2/weights',
+          'vgg_19/conv3/conv3_2/biases',
+          'vgg_19/conv3/conv3_3/weights',
+          'vgg_19/conv3/conv3_3/biases',
+          'vgg_19/conv3/conv3_4/weights',
+          'vgg_19/conv3/conv3_4/biases',
+          'vgg_19/conv4/conv4_1/weights',
+          'vgg_19/conv4/conv4_1/biases',
+          'vgg_19/conv4/conv4_2/weights',
+          'vgg_19/conv4/conv4_2/biases',
+          'vgg_19/conv4/conv4_3/weights',
+          'vgg_19/conv4/conv4_3/biases',
+          'vgg_19/conv4/conv4_4/weights',
+          'vgg_19/conv4/conv4_4/biases',
+          'vgg_19/conv5/conv5_1/weights',
+          'vgg_19/conv5/conv5_1/biases',
+          'vgg_19/conv5/conv5_2/weights',
+          'vgg_19/conv5/conv5_2/biases',
+          'vgg_19/conv5/conv5_3/weights',
+          'vgg_19/conv5/conv5_3/biases',
+          'vgg_19/conv5/conv5_4/weights',
+          'vgg_19/conv5/conv5_4/biases',
+          'vgg_19/fc6/weights',
+          'vgg_19/fc6/biases',
+          'vgg_19/fc7/weights',
+          'vgg_19/fc7/biases',
+          'vgg_19/fc8/weights',
+          'vgg_19/fc8/biases',
+      ]
+      model_variables = [v.op.name for v in slim.get_model_variables()]
+      self.assertSetEqual(set(model_variables), set(expected_names))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 224, 224
+    num_classes = 1000
+    with self.test_session():
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_19(eval_inputs, is_training=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [batch_size, num_classes])
+      predictions = tf.argmax(logits, 1)
+      self.assertListEqual(predictions.get_shape().as_list(), [batch_size])
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 2
+    eval_batch_size = 1
+    train_height, train_width = 224, 224
+    eval_height, eval_width = 256, 256
+    num_classes = 1000
+    with self.test_session():
+      train_inputs = tf.random_uniform(
+          (train_batch_size, train_height, train_width, 3))
+      logits, _ = vgg.vgg_19(train_inputs)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [train_batch_size, num_classes])
+      tf.get_variable_scope().reuse_variables()
+      eval_inputs = tf.random_uniform(
+          (eval_batch_size, eval_height, eval_width, 3))
+      logits, _ = vgg.vgg_19(eval_inputs, is_training=False,
+                             spatial_squeeze=False)
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [eval_batch_size, 2, 2, num_classes])
+      logits = tf.reduce_mean(logits, [1, 2])
+      predictions = tf.argmax(logits, 1)
+      self.assertEquals(predictions.get_shape().as_list(), [eval_batch_size])
+
+  def testForward(self):
+    batch_size = 1
+    height, width = 224, 224
+    with self.test_session() as sess:
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = vgg.vgg_19(inputs)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits)
+      self.assertTrue(output.any())
+
+if __name__ == '__main__':
+  tf.test.main()

+ 1 - 0
slim/preprocessing/__init__.py

@@ -0,0 +1 @@
+

+ 17 - 17
slim/models/cifar10_preprocessing.py

@@ -12,20 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Provides utilities to preprocess images.
+"""Provides utilities to preprocess images in CIFAR-10.
 
-The preprocessing steps for VGG were introduced in the following technical
-report:
-
-  Very Deep Convolutional Networks For Large-Scale Image Recognition
-  Karen Simonyan and Andrew Zisserman
-  arXiv technical report, 2015
-  PDF: http://arxiv.org/pdf/1409.1556.pdf
-  ILSVRC 2014 Slides: http://www.robots.ox.ac.uk/~karen/pdf/ILSVRC_2014.pdf
-  CC-BY-4.0
-
-More information can be obtained from the VGG website:
-www.robots.ox.ac.uk/~vgg/research/very_deep/
 """
 
 from __future__ import absolute_import
@@ -34,7 +22,7 @@ from __future__ import print_function
 
 import tensorflow as tf
 
-_PADDING = 2
+_PADDING = 4
 
 slim = tf.contrib.slim
 
@@ -57,21 +45,27 @@ def preprocess_for_train(image,
   Returns:
     A preprocessed image.
   """
-  padded_image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]])
+  tf.image_summary('image', tf.expand_dims(image, 0))
+
+  # Transform the image to floats.
+  image = tf.to_float(image)
+  if padding > 0:
+    image = tf.pad(image, [[padding, padding], [padding, padding], [0, 0]])
   # Randomly crop a [height, width] section of the image.
-  distorted_image = tf.random_crop(padded_image,
+  distorted_image = tf.random_crop(image,
                                    [output_height, output_width, 3])
 
   # Randomly flip the image horizontally.
   distorted_image = tf.image.random_flip_left_right(distorted_image)
 
+  tf.image_summary('distorted_image', tf.expand_dims(distorted_image, 0))
+
   # Because these operations are not commutative, consider randomizing
   # the order their operation.
   distorted_image = tf.image.random_brightness(distorted_image,
                                                max_delta=63)
   distorted_image = tf.image.random_contrast(distorted_image,
                                              lower=0.2, upper=1.8)
-
   # Subtract off the mean and divide by the variance of the pixels.
   return tf.image.per_image_whitening(distorted_image)
 
@@ -87,9 +81,15 @@ def preprocess_for_eval(image, output_height, output_width):
   Returns:
     A preprocessed image.
   """
+  tf.image_summary('image', tf.expand_dims(image, 0))
+  # Transform the image to floats.
+  image = tf.to_float(image)
+
+  # Resize and crop if needed.
   resized_image = tf.image.resize_image_with_crop_or_pad(image,
                                                          output_width,
                                                          output_height)
+  tf.image_summary('resized_image', tf.expand_dims(resized_image, 0))
 
   # Subtract off the mean and divide by the variance of the pixels.
   return tf.image.per_image_whitening(resized_image)

slim/models/inception_preprocessing.py → slim/preprocessing/inception_preprocessing.py


slim/models/lenet_preprocessing.py → slim/preprocessing/lenet_preprocessing.py


+ 6 - 5
slim/models/preprocessing_factory.py

@@ -20,10 +20,10 @@ from __future__ import print_function
 
 import tensorflow as tf
 
-from slim.models import cifar10_preprocessing
-from slim.models import inception_preprocessing
-from slim.models import lenet_preprocessing
-from slim.models import vgg_preprocessing
+from preprocessing import cifarnet_preprocessing
+from preprocessing import inception_preprocessing
+from preprocessing import lenet_preprocessing
+from preprocessing import vgg_preprocessing
 
 slim = tf.contrib.slim
 
@@ -45,11 +45,12 @@ def get_preprocessing(name, is_training=False):
     ValueError: If Preprocessing `name` is not recognized.
   """
   preprocessing_fn_map = {
-      'cifar10': cifar10_preprocessing,
+      'cifarnet': cifarnet_preprocessing,
       'inception': inception_preprocessing,
       'inception_v1': inception_preprocessing,
       'inception_v2': inception_preprocessing,
       'inception_v3': inception_preprocessing,
+      'inception_resnet_v2': inception_preprocessing,
       'lenet': lenet_preprocessing,
       'resnet_v1_50': vgg_preprocessing,
       'resnet_v1_101': vgg_preprocessing,

slim/models/vgg_preprocessing.py → slim/preprocessing/vgg_preprocessing.py


+ 89 - 0
slim/scripts/finetune_inception_v1_on_flowers.sh

@@ -0,0 +1,89 @@
+#!/bin/bash
+#
+# This script performs the following operations:
+# 1. Downloads the Flowers dataset
+# 2. Fine-tunes an InceptionV1 model on the Flowers training set.
+# 3. Evaluates the model on the Flowers validation set.
+#
+# Usage:
+# cd slim
+# ./slim/scripts/finetune_inception_v1_on_flowers.sh
+
+# Where the pre-trained InceptionV1 checkpoint is saved to.
+PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
+
+# Where the training (fine-tuned) checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/flowers-models/inception_v1
+
+# Where the dataset is saved to.
+DATASET_DIR=/tmp/flowers
+
+# Download the pre-trained checkpoint.
+if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
+  mkdir ${PRETRAINED_CHECKPOINT_DIR}
+fi
+if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt ]; then
+  wget http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz
+  tar -xvf inception_v1_2016_08_28.tar.gz
+  mv inception_v1.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt
+  rm inception_v1_2016_08_28.tar.gz
+fi
+
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=flowers \
+  --dataset_dir=${DATASET_DIR}
+
+# Fine-tune only the new layers for 2000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v1 \
+  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v1.ckpt \
+  --checkpoint_exclude_scopes=InceptionV1/Logits \
+  --trainable_scopes=InceptionV1/Logits \
+  --max_number_of_steps=3000 \
+  --batch_size=32 \
+  --learning_rate=0.01 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v1
+
+# Fine-tune all the new layers for 1000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --checkpoint_path=${TRAIN_DIR} \
+  --model_name=inception_v1 \
+  --max_number_of_steps=1000 \
+  --batch_size=32 \
+  --learning_rate=0.001 \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR}/all \
+  --eval_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v1

+ 91 - 0
slim/scripts/finetune_inception_v3_on_flowers.sh

@@ -0,0 +1,91 @@
+#!/bin/bash
+#
+# This script performs the following operations:
+# 1. Downloads the Flowers dataset
+# 2. Fine-tunes an InceptionV3 model on the Flowers training set.
+# 3. Evaluates the model on the Flowers validation set.
+#
+# Usage:
+# cd slim
+# ./slim/scripts/finetune_inceptionv3_on_flowers.sh
+
+# Where the pre-trained InceptionV3 checkpoint is saved to.
+PRETRAINED_CHECKPOINT_DIR=/tmp/checkpoints
+
+# Where the training (fine-tuned) checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/flowers-models/inception_v3
+
+# Where the dataset is saved to.
+DATASET_DIR=/tmp/flowers
+
+# Download the pre-trained checkpoint.
+if [ ! -d "$PRETRAINED_CHECKPOINT_DIR" ]; then
+  mkdir ${PRETRAINED_CHECKPOINT_DIR}
+fi
+if [ ! -f ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt ]; then
+  wget http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz
+  tar -xvf inception_v3_2016_08_28.tar.gz
+  mv inception_v3.ckpt ${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt
+  rm inception_v3_2016_08_28.tar.gz
+fi
+
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=flowers \
+  --dataset_dir=${DATASET_DIR}
+
+# Fine-tune only the new layers for 1000 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v3 \
+  --checkpoint_path=${PRETRAINED_CHECKPOINT_DIR}/inception_v3.ckpt \
+  --checkpoint_exclude_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
+  --trainable_scopes=InceptionV3/Logits,InceptionV3/AuxLogits \
+  --max_number_of_steps=1000 \
+  --batch_size=32 \
+  --learning_rate=0.01 \
+  --learning_rate_decay_type=fixed \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v3
+
+# Fine-tune all the new layers for 500 steps.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v3 \
+  --checkpoint_path=${TRAIN_DIR} \
+  --max_number_of_steps=500 \
+  --batch_size=32 \
+  --learning_rate=0.0001 \
+  --learning_rate_decay_type=fixed \
+  --save_interval_secs=60 \
+  --save_summaries_secs=60 \
+  --log_every_n_steps=10 \
+  --optimizer=rmsprop \
+  --weight_decay=0.00004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR}/all \
+  --eval_dir=${TRAIN_DIR}/all \
+  --dataset_name=flowers \
+  --dataset_split_name=validation \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=inception_v3

+ 49 - 0
slim/scripts/train_cifarnet_on_cifar10.sh

@@ -0,0 +1,49 @@
+#!/bin/bash
+#
+# This script performs the following operations:
+# 1. Downloads the Cifar10 dataset
+# 2. Trains a CifarNet model on the Cifar10 training set.
+# 3. Evaluates the model on the Cifar10 testing set.
+#
+# Usage:
+# cd slim
+# ./scripts/train_cifar_net_on_mnist.sh
+
+# Where the checkpoint and logs will be saved to.
+TRAIN_DIR=/tmp/cifarnet-model
+
+# Where the dataset is saved to.
+DATASET_DIR=/tmp/cifar10
+
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=cifar10 \
+  --dataset_dir=${DATASET_DIR}
+
+# Run training.
+python train_image_classifier.py \
+  --train_dir=${TRAIN_DIR} \
+  --dataset_name=cifar10 \
+  --dataset_split_name=train \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=cifarnet \
+  --preprocessing_name=cifarnet \
+  --max_number_of_steps=100000 \
+  --batch_size=128 \
+  --save_interval_secs=120 \
+  --save_summaries_secs=120 \
+  --log_every_n_steps=100 \
+  --optimizer=sgd \
+  --learning_rate=0.1 \
+  --learning_rate_decay_factor=0.1 \
+  --num_epochs_per_decay=200 \
+  --weight_decay=0.004
+
+# Run evaluation.
+python eval_image_classifier.py \
+  --checkpoint_path=${TRAIN_DIR} \
+  --eval_dir=${TRAIN_DIR} \
+  --dataset_name=cifar10 \
+  --dataset_split_name=test \
+  --dataset_dir=${DATASET_DIR} \
+  --model_name=cifarnet

+ 16 - 11
slim/scripts/train_lenet_on_mnist.sh

@@ -1,24 +1,27 @@
 #!/bin/bash
 #
-# Before running this script, make sure you've followed the instructions for
-# downloading and converting the MNIST dataset.
-# See slim/datasets/download_and_convert_mnist.py.
+# This script performs the following operations:
+# 1. Downloads the MNIST dataset
+# 2. Trains a LeNet model on the MNIST training set.
+# 3. Evaluates the model on the MNIST testing set.
 #
 # Usage:
+# cd slim
 # ./slim/scripts/train_lenet_on_mnist.sh
 
-# Compile the training and evaluation binaries
-bazel build slim:train
-bazel build slim:eval
-
 # Where the checkpoint and logs will be saved to.
 TRAIN_DIR=/tmp/lenet-model
 
-# Where the dataset was saved to.
+# Where the dataset is saved to.
 DATASET_DIR=/tmp/mnist
 
+# Download the dataset
+python download_and_convert_data.py \
+  --dataset_name=mnist \
+  --dataset_dir=${DATASET_DIR}
+
 # Run training.
-./bazel-bin/slim/train \
+python train_image_classifier.py \
   --train_dir=${TRAIN_DIR} \
   --dataset_name=mnist \
   --dataset_split_name=train \
@@ -26,15 +29,17 @@ DATASET_DIR=/tmp/mnist
   --model_name=lenet \
   --preprocessing_name=lenet \
   --max_number_of_steps=20000 \
+  --batch_size=50 \
   --learning_rate=0.01 \
   --save_interval_secs=60 \
   --save_summaries_secs=60 \
+  --log_every_n_steps=100 \
   --optimizer=sgd \
-  --learning_rate_decay_factor=1.0
+  --learning_rate_decay_type=fixed \
   --weight_decay=0
 
 # Run evaluation.
-./blaze-bin/slim/eval \
+python eval_image_classifier.py \
   --checkpoint_path=${TRAIN_DIR} \
   --eval_dir=${TRAIN_DIR} \
   --dataset_name=mnist \

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 1058 - 0
slim/slim_walkthough.ipynb


+ 69 - 24
slim/train.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Generic training script that trains a given model a specified dataset."""
+"""Generic training script that trains a model using a given dataset."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -21,10 +21,10 @@ from __future__ import print_function
 import tensorflow as tf
 
 from tensorflow.python.ops import control_flow_ops
-from slim.datasets import dataset_factory
-from slim.models import model_deploy
-from slim.models import model_factory
-from slim.models import preprocessing_factory
+from datasets import dataset_factory
+from deployment import model_deploy
+from nets import nets_factory
+from preprocessing import preprocessing_factory
 
 slim = tf.contrib.slim
 
@@ -57,7 +57,7 @@ tf.app.flags.DEFINE_integer(
     'The number of threads used to create the batches.')
 
 tf.app.flags.DEFINE_integer(
-    'log_every_n_steps', 5,
+    'log_every_n_steps', 10,
     'The frequency with which logs are print.')
 
 tf.app.flags.DEFINE_integer(
@@ -161,8 +161,6 @@ tf.app.flags.DEFINE_float(
     'The decay to use for the moving average.'
     'If left as None, then moving averages are not used.')
 
-
-
 #######################
 # Dataset Flags #
 #######################
@@ -208,9 +206,18 @@ tf.app.flags.DEFINE_string(
 
 tf.app.flags.DEFINE_string(
     'checkpoint_exclude_scopes', None,
-    'Comma-separated list of scopes to include when fine-tuning '
+    'Comma-separated list of scopes of variables to exclude when restoring '
     'from a checkpoint.')
 
+tf.app.flags.DEFINE_string(
+    'trainable_scopes', None,
+    'Comma-separated list of scopes to filter the set of variables to train.'
+    'By default, None would train all the variables.')
+
+tf.app.flags.DEFINE_boolean(
+    'ignore_missing_vars', False,
+    'When restoring a checkpoint would ignore missing variables.')
+
 FLAGS = tf.app.flags.FLAGS
 
 
@@ -350,15 +357,42 @@ def _get_init_fn():
     if not excluded:
       variables_to_restore.append(var)
 
+  if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
+    checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
+  else:
+    checkpoint_path = FLAGS.checkpoint_path
+
+  tf.logging.info('Fine-tuning from %s' % checkpoint_path)
+
   return slim.assign_from_checkpoint_fn(
-      FLAGS.checkpoint_path,
-      variables_to_restore)
+      checkpoint_path,
+      variables_to_restore,
+      ignore_missing_vars=FLAGS.ignore_missing_vars)
+
+
+def _get_variables_to_train():
+  """Returns a list of variables to train.
+
+  Returns:
+    A list of variables to train by the optimizer.
+  """
+  if FLAGS.trainable_scopes is None:
+    return tf.trainable_variables()
+  else:
+    scopes = [scope.strip() for scope in FLAGS.trainable_scopes.split(',')]
+
+  variables_to_train = []
+  for scope in scopes:
+    variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
+    variables_to_train.extend(variables)
+  return variables_to_train
 
 
 def main(_):
   if not FLAGS.dataset_dir:
     raise ValueError('You must supply the dataset directory with --dataset_dir')
 
+  tf.logging.set_verbosity(tf.logging.INFO)
   with tf.Graph().as_default():
     ######################
     # Config model_deploy#
@@ -381,9 +415,9 @@ def main(_):
         FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
 
     ####################
-    # Select the model #
+    # Select the network #
     ####################
-    model_fn = model_factory.get_model(
+    network_fn = nets_factory.get_network_fn(
         FLAGS.model_name,
         num_classes=(dataset.num_classes - FLAGS.labels_offset),
         weight_decay=FLAGS.weight_decay,
@@ -409,10 +443,7 @@ def main(_):
       [image, label] = provider.get(['image', 'label'])
       label -= FLAGS.labels_offset
 
-      if FLAGS.train_image_size is None:
-        train_image_size = model_fn.default_image_size
-      else:
-        train_image_size = FLAGS.train_image_size
+      train_image_size = FLAGS.train_image_size or network_fn.default_image_size
 
       image = image_preprocessing_fn(image, train_image_size, train_image_size)
 
@@ -430,9 +461,9 @@ def main(_):
     # Define the model #
     ####################
     def clone_fn(batch_queue):
-      """Allows data parallelism by creating multiple clones of the model_fn."""
+      """Allows data parallelism by creating multiple clones of network_fn."""
       images, labels = batch_queue.dequeue()
-      logits, end_points = model_fn(images)
+      logits, end_points = network_fn(images)
 
       #############################
       # Specify the loss function #
@@ -443,6 +474,7 @@ def main(_):
             label_smoothing=FLAGS.label_smoothing, weight=0.4, scope='aux_loss')
       slim.losses.softmax_cross_entropy(
           logits, labels, label_smoothing=FLAGS.label_smoothing, weight=1.0)
+      return end_points
 
     # Gather initial summaries.
     summaries = set(tf.get_collection(tf.GraphKeys.SUMMARIES))
@@ -450,12 +482,20 @@ def main(_):
     clones = model_deploy.create_clones(deploy_config, clone_fn, [batch_queue])
     first_clone_scope = deploy_config.clone_scope(0)
     # Gather update_ops from the first clone. These contain, for example,
-    # the updates for the batch_norm variables created by model_fn.
+    # the updates for the batch_norm variables created by network_fn.
     update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS, first_clone_scope)
 
+    # Add summaries for end_points.
+    end_points = clones[0].outputs
+    for end_point in end_points:
+      x = end_points[end_point]
+      summaries.add(tf.histogram_summary('activations/' + end_point, x))
+      summaries.add(tf.scalar_summary('sparsity/' + end_point,
+                                      tf.nn.zero_fraction(x)))
+
     # Add summaries for losses.
     for loss in tf.get_collection(tf.GraphKeys.LOSSES, first_clone_scope):
-      tf.scalar_summary('losses/%s' % loss.op.name, loss)
+      summaries.add(tf.scalar_summary('losses/%s' % loss.op.name, loss))
 
     # Add summaries for variables.
     for variable in slim.get_model_variables():
@@ -494,10 +534,14 @@ def main(_):
       # Update ops executed locally by trainer.
       update_ops.append(variable_averages.apply(moving_average_variables))
 
-    # TODO(sguada) Refactor into function that takes the clones and optimizer
+    # Variables to train.
+    variables_to_train = _get_variables_to_train()
+
     #  and returns a train_tensor and summary_op
-    total_loss, clones_gradients = model_deploy.optimize_clones(clones,
-                                                                optimizer)
+    total_loss, clones_gradients = model_deploy.optimize_clones(
+        clones,
+        optimizer,
+        var_list=variables_to_train)
     # Add total_loss to summary.
     summaries.add(tf.scalar_summary('total_loss', total_loss,
                                     name='total_loss'))
@@ -519,6 +563,7 @@ def main(_):
     # Merge all summaries together.
     summary_op = tf.merge_summary(list(summaries), name='summary_op')
 
+
     ###########################
     # Kicks off the training. #
     ###########################