Explorar o código

Implementation of Inception V4

Alex Kurakin %!s(int64=8) %!d(string=hai) anos
pai
achega
804d90f75f

+ 34 - 0
slim/BUILD

@@ -164,25 +164,50 @@ py_library(
         ":inception_v1",
         ":inception_v2",
         ":inception_v3",
+        ":inception_v4",
     ],
 )
 
 py_library(
+    name = "inception_utils",
+    srcs = ["nets/inception_utils.py"],
+    srcs_version = "PY2AND3",
+)
+
+py_library(
     name = "inception_v1",
     srcs = ["nets/inception_v1.py"],
     srcs_version = "PY2AND3",
+    deps = [
+        ":inception_utils",
+    ],
 )
 
 py_library(
     name = "inception_v2",
     srcs = ["nets/inception_v2.py"],
     srcs_version = "PY2AND3",
+    deps = [
+        ":inception_utils",
+    ],
 )
 
 py_library(
     name = "inception_v3",
     srcs = ["nets/inception_v3.py"],
     srcs_version = "PY2AND3",
+    deps = [
+        ":inception_utils",
+    ],
+)
+
+py_library(
+    name = "inception_v4",
+    srcs = ["nets/inception_v4.py"],
+    srcs_version = "PY2AND3",
+    deps = [
+        ":inception_utils",
+    ],
 )
 
 py_library(
@@ -219,6 +244,15 @@ py_test(
 )
 
 py_test(
+    name = "inception_v4_test",
+    size = "large",
+    srcs = ["nets/inception_v4_test.py"],
+    shard_count = 3,
+    srcs_version = "PY2AND3",
+    deps = [":inception"],
+)
+
+py_test(
     name = "inception_resnet_v2_test",
     size = "large",
     srcs = ["nets/inception_resnet_v2_test.py"],

+ 4 - 3
slim/README.md

@@ -197,9 +197,10 @@ crops at multiple scales.
 
 Model | TF-Slim File | Checkpoint | Top-1 Accuracy| Top-5 Accuracy |
 :----:|:------------:|:----------:|:-------:|:--------:|
-[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v1.py)|[inception_v1.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz)|69.8|89.6|
-[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v2.py)|[inception_v2.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz)|73.9|91.8|
-[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/inception_v3.py)|[inception_v3.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)|78.0|93.9|
+[Inception V1](http://arxiv.org/abs/1409.4842v1)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v1.py)|[inception_v1_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v1_2016_08_28.tar.gz)|69.8|89.6|
+[Inception V2](http://arxiv.org/abs/1502.03167)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v2.py)|[inception_v2_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v2_2016_08_28.tar.gz)|73.9|91.8|
+[Inception V3](http://arxiv.org/abs/1512.00567)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v3.py)|[inception_v3_2016_08_28.tar.gz](http://download.tensorflow.org/models/inception_v3_2016_08_28.tar.gz)|78.0|93.9|
+[Inception V4](http://arxiv.org/abs/1602.07261)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_v4.py)|[inception_v4_2016_09_09.tar.gz](http://download.tensorflow.org/models/inception_v4_2016_09_09.tar.gz)|80.2|95.2|
 [Inception-ResNet-v2](http://arxiv.org/abs/1602.07261)|[Code](https://github.com/tensorflow/models/blob/master/slim/nets/inception_resnet_v2.py)|[inception_resnet_v2.tar.gz](http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz)|80.4|95.3|
 [ResNet 50](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_50.tar.gz](http://download.tensorflow.org/models/resnet_v1_50_2016_08_28.tar.gz)|75.2|92.2|
 [ResNet 101](https://arxiv.org/abs/1512.03385)|[Code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py)|[resnet_v1_101.tar.gz](http://download.tensorflow.org/models/resnet_v1_101_2016_08_28.tar.gz)|76.4|92.9|

+ 4 - 1
slim/nets/inception.py

@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Brings inception_v1, inception_v2 and inception_v3 under one namespace."""
+"""Brings all inception models under one namespace."""
 
 from __future__ import absolute_import
 from __future__ import division
@@ -30,4 +30,7 @@ from nets.inception_v2 import inception_v2_base
 from nets.inception_v3 import inception_v3
 from nets.inception_v3 import inception_v3_arg_scope
 from nets.inception_v3 import inception_v3_base
+from nets.inception_v4 import inception_v4
+from nets.inception_v4 import inception_v4_arg_scope
+from nets.inception_v4 import inception_v4_base
 # pylint: enable=unused-import

+ 71 - 0
slim/nets/inception_utils.py

@@ -0,0 +1,71 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains common code shared by all inception models.
+
+Usage of arg scope:
+  with slim.arg_scope(inception_arg_scope()):
+    logits, end_points = inception.inception_v3(images, num_classes,
+                                                is_training=is_training)
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+slim = tf.contrib.slim
+
+
+def inception_arg_scope(weight_decay=0.00004,
+                        use_batch_norm=True,
+                        batch_norm_decay=0.9997,
+                        batch_norm_epsilon=0.001):
+  """Defines the default arg scope for inception models.
+
+  Args:
+    weight_decay: The weight decay to use for regularizing the model.
+    use_batch_norm: "If `True`, batch_norm is applied after each convolution.
+    batch_norm_decay: Decay for batch norm moving average.
+    batch_norm_epsilon: Small float added to variance to avoid dividing by zero
+      in batch norm.
+
+  Returns:
+    An `arg_scope` to use for the inception models.
+  """
+  batch_norm_params = {
+      # Decay for the moving averages.
+      'decay': batch_norm_decay,
+      # epsilon to prevent 0s in variance.
+      'epsilon': batch_norm_epsilon,
+      # collection containing update_ops.
+      'updates_collections': tf.GraphKeys.UPDATE_OPS,
+  }
+  if use_batch_norm:
+    normalizer_fn = slim.batch_norm
+    normalizer_params = batch_norm_params
+  else:
+    normalizer_fn = None
+    normalizer_params = {}
+  # Set weight_decay for weights in Conv and FC layers.
+  with slim.arg_scope([slim.conv2d, slim.fully_connected],
+                      weights_regularizer=slim.l2_regularizer(weight_decay)):
+    with slim.arg_scope(
+        [slim.conv2d],
+        weights_initializer=slim.variance_scaling_initializer(),
+        activation_fn=tf.nn.relu,
+        normalizer_fn=normalizer_fn,
+        normalizer_params=normalizer_params) as sc:
+      return sc

+ 3 - 38
slim/nets/inception_v1.py

@@ -20,6 +20,8 @@ from __future__ import print_function
 
 import tensorflow as tf
 
+from nets import inception_utils
+
 slim = tf.contrib.slim
 trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
 
@@ -300,41 +302,4 @@ def inception_v1(inputs,
   return logits, end_points
 inception_v1.default_image_size = 224
 
-
-def inception_v1_arg_scope(weight_decay=0.00004,
-                           use_batch_norm=True):
-  """Defines the default InceptionV1 arg scope.
-
-  Note: Althougth the original paper didn't use batch_norm we found it useful.
-
-  Args:
-    weight_decay: The weight decay to use for regularizing the model.
-    use_batch_norm: "If `True`, batch_norm is applied after each convolution.
-
-  Returns:
-    An `arg_scope` to use for the inception v3 model.
-  """
-  batch_norm_params = {
-      # Decay for the moving averages.
-      'decay': 0.9997,
-      # epsilon to prevent 0s in variance.
-      'epsilon': 0.001,
-      # collection containing update_ops.
-      'updates_collections': tf.GraphKeys.UPDATE_OPS,
-  }
-  if use_batch_norm:
-    normalizer_fn = slim.batch_norm
-    normalizer_params = batch_norm_params
-  else:
-    normalizer_fn = None
-    normalizer_params = {}
-  # Set weight_decay for weights in Conv and FC layers.
-  with slim.arg_scope([slim.conv2d, slim.fully_connected],
-                      weights_regularizer=slim.l2_regularizer(weight_decay)):
-    with slim.arg_scope(
-        [slim.conv2d],
-        weights_initializer=slim.variance_scaling_initializer(),
-        activation_fn=tf.nn.relu,
-        normalizer_fn=normalizer_fn,
-        normalizer_params=normalizer_params) as sc:
-      return sc
+inception_v1_arg_scope = inception_utils.inception_arg_scope

+ 3 - 28
slim/nets/inception_v2.py

@@ -20,6 +20,8 @@ from __future__ import print_function
 
 import tensorflow as tf
 
+from nets import inception_utils
+
 slim = tf.contrib.slim
 trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
 
@@ -515,31 +517,4 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
   return kernel_size_out
 
 
-def inception_v2_arg_scope(weight_decay=0.00004):
-  """Defines the default InceptionV2 arg scope.
-
-  Args:
-    weight_decay: The weight decay to use for regularizing the model.
-
-  Returns:
-    An `arg_scope` to use for the inception v3 model.
-  """
-  batch_norm_params = {
-      # Decay for the moving averages.
-      'decay': 0.9997,
-      # epsilon to prevent 0s in variance.
-      'epsilon': 0.001,
-      # collection containing update_ops.
-      'updates_collections': tf.GraphKeys.UPDATE_OPS,
-  }
-
-  # Set weight_decay for weights in Conv and FC layers.
-  with slim.arg_scope([slim.conv2d, slim.fully_connected],
-                      weights_regularizer=slim.l2_regularizer(weight_decay)):
-    with slim.arg_scope(
-        [slim.conv2d],
-        weights_initializer=slim.variance_scaling_initializer(),
-        activation_fn=tf.nn.relu,
-        normalizer_fn=slim.batch_norm,
-        normalizer_params=batch_norm_params) as sc:
-      return sc
+inception_v2_arg_scope = inception_utils.inception_arg_scope

+ 3 - 30
slim/nets/inception_v3.py

@@ -20,6 +20,8 @@ from __future__ import print_function
 
 import tensorflow as tf
 
+from nets import inception_utils
+
 slim = tf.contrib.slim
 trunc_normal = lambda stddev: tf.truncated_normal_initializer(0.0, stddev)
 
@@ -555,33 +557,4 @@ def _reduced_kernel_size_for_small_input(input_tensor, kernel_size):
   return kernel_size_out
 
 
-def inception_v3_arg_scope(weight_decay=0.00004,
-                           stddev=0.1):
-  """Defines the default InceptionV3 arg scope.
-
-  Args:
-    weight_decay: The weight decay to use for regularizing the model.
-    stddev: The standard deviation of the trunctated normal weight initializer.
-
-  Returns:
-    An `arg_scope` to use for the inception v3 model.
-  """
-  batch_norm_params = {
-      # Decay for the moving averages.
-      'decay': 0.9997,
-      # epsilon to prevent 0s in variance.
-      'epsilon': 0.001,
-      # collection containing update_ops.
-      'updates_collections': tf.GraphKeys.UPDATE_OPS,
-  }
-
-  # Set weight_decay for weights in Conv and FC layers.
-  with slim.arg_scope([slim.conv2d, slim.fully_connected],
-                      weights_regularizer=slim.l2_regularizer(weight_decay)):
-    with slim.arg_scope(
-        [slim.conv2d],
-        weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
-        activation_fn=tf.nn.relu,
-        normalizer_fn=slim.batch_norm,
-        normalizer_params=batch_norm_params) as sc:
-      return sc
+inception_v3_arg_scope = inception_utils.inception_arg_scope

+ 323 - 0
slim/nets/inception_v4.py

@@ -0,0 +1,323 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains the definition of the Inception V4 architecture.
+
+As described in http://arxiv.org/abs/1602.07261.
+
+  Inception-v4, Inception-ResNet and the Impact of Residual Connections
+    on Learning
+  Christian Szegedy, Sergey Ioffe, Vincent Vanhoucke, Alex Alemi
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import inception_utils
+
+slim = tf.contrib.slim
+
+
+def block_inception_a(inputs, scope=None, reuse=None):
+  """Builds Inception-A block for Inception v4 network."""
+  # By default use stride=1 and SAME padding
+  with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
+                      stride=1, padding='SAME'):
+    with tf.variable_scope(scope, 'BlockInceptionA', [inputs], reuse=reuse):
+      with tf.variable_scope('Branch_0'):
+        branch_0 = slim.conv2d(inputs, 96, [1, 1], scope='Conv2d_0a_1x1')
+      with tf.variable_scope('Branch_1'):
+        branch_1 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1')
+        branch_1 = slim.conv2d(branch_1, 96, [3, 3], scope='Conv2d_0b_3x3')
+      with tf.variable_scope('Branch_2'):
+        branch_2 = slim.conv2d(inputs, 64, [1, 1], scope='Conv2d_0a_1x1')
+        branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0b_3x3')
+        branch_2 = slim.conv2d(branch_2, 96, [3, 3], scope='Conv2d_0c_3x3')
+      with tf.variable_scope('Branch_3'):
+        branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
+        branch_3 = slim.conv2d(branch_3, 96, [1, 1], scope='Conv2d_0b_1x1')
+      return tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+
+
+def block_reduction_a(inputs, scope=None, reuse=None):
+  """Builds Reduction-A block for Inception v4 network."""
+  # By default use stride=1 and SAME padding
+  with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
+                      stride=1, padding='SAME'):
+    with tf.variable_scope(scope, 'BlockReductionA', [inputs], reuse=reuse):
+      with tf.variable_scope('Branch_0'):
+        branch_0 = slim.conv2d(inputs, 384, [3, 3], stride=2, padding='VALID',
+                               scope='Conv2d_1a_3x3')
+      with tf.variable_scope('Branch_1'):
+        branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
+        branch_1 = slim.conv2d(branch_1, 224, [3, 3], scope='Conv2d_0b_3x3')
+        branch_1 = slim.conv2d(branch_1, 256, [3, 3], stride=2,
+                               padding='VALID', scope='Conv2d_1a_3x3')
+      with tf.variable_scope('Branch_2'):
+        branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID',
+                                   scope='MaxPool_1a_3x3')
+      return tf.concat(3, [branch_0, branch_1, branch_2])
+
+
+def block_inception_b(inputs, scope=None, reuse=None):
+  """Builds Inception-B block for Inception v4 network."""
+  # By default use stride=1 and SAME padding
+  with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
+                      stride=1, padding='SAME'):
+    with tf.variable_scope(scope, 'BlockInceptionB', [inputs], reuse=reuse):
+      with tf.variable_scope('Branch_0'):
+        branch_0 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
+      with tf.variable_scope('Branch_1'):
+        branch_1 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
+        branch_1 = slim.conv2d(branch_1, 224, [1, 7], scope='Conv2d_0b_1x7')
+        branch_1 = slim.conv2d(branch_1, 256, [7, 1], scope='Conv2d_0c_7x1')
+      with tf.variable_scope('Branch_2'):
+        branch_2 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
+        branch_2 = slim.conv2d(branch_2, 192, [7, 1], scope='Conv2d_0b_7x1')
+        branch_2 = slim.conv2d(branch_2, 224, [1, 7], scope='Conv2d_0c_1x7')
+        branch_2 = slim.conv2d(branch_2, 224, [7, 1], scope='Conv2d_0d_7x1')
+        branch_2 = slim.conv2d(branch_2, 256, [1, 7], scope='Conv2d_0e_1x7')
+      with tf.variable_scope('Branch_3'):
+        branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
+        branch_3 = slim.conv2d(branch_3, 128, [1, 1], scope='Conv2d_0b_1x1')
+      return tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+
+
+def block_reduction_b(inputs, scope=None, reuse=None):
+  """Builds Reduction-B block for Inception v4 network."""
+  # By default use stride=1 and SAME padding
+  with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
+                      stride=1, padding='SAME'):
+    with tf.variable_scope(scope, 'BlockReductionB', [inputs], reuse=reuse):
+      with tf.variable_scope('Branch_0'):
+        branch_0 = slim.conv2d(inputs, 192, [1, 1], scope='Conv2d_0a_1x1')
+        branch_0 = slim.conv2d(branch_0, 192, [3, 3], stride=2,
+                               padding='VALID', scope='Conv2d_1a_3x3')
+      with tf.variable_scope('Branch_1'):
+        branch_1 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1')
+        branch_1 = slim.conv2d(branch_1, 256, [1, 7], scope='Conv2d_0b_1x7')
+        branch_1 = slim.conv2d(branch_1, 320, [7, 1], scope='Conv2d_0c_7x1')
+        branch_1 = slim.conv2d(branch_1, 320, [3, 3], stride=2,
+                               padding='VALID', scope='Conv2d_1a_3x3')
+      with tf.variable_scope('Branch_2'):
+        branch_2 = slim.max_pool2d(inputs, [3, 3], stride=2, padding='VALID',
+                                   scope='MaxPool_1a_3x3')
+      return tf.concat(3, [branch_0, branch_1, branch_2])
+
+
+def block_inception_c(inputs, scope=None, reuse=None):
+  """Builds Inception-C block for Inception v4 network."""
+  # By default use stride=1 and SAME padding
+  with slim.arg_scope([slim.conv2d, slim.avg_pool2d, slim.max_pool2d],
+                      stride=1, padding='SAME'):
+    with tf.variable_scope(scope, 'BlockInceptionC', [inputs], reuse=reuse):
+      with tf.variable_scope('Branch_0'):
+        branch_0 = slim.conv2d(inputs, 256, [1, 1], scope='Conv2d_0a_1x1')
+      with tf.variable_scope('Branch_1'):
+        branch_1 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
+        branch_1 = tf.concat(3, [
+            slim.conv2d(branch_1, 256, [1, 3], scope='Conv2d_0b_1x3'),
+            slim.conv2d(branch_1, 256, [3, 1], scope='Conv2d_0c_3x1')])
+      with tf.variable_scope('Branch_2'):
+        branch_2 = slim.conv2d(inputs, 384, [1, 1], scope='Conv2d_0a_1x1')
+        branch_2 = slim.conv2d(branch_2, 448, [3, 1], scope='Conv2d_0b_3x1')
+        branch_2 = slim.conv2d(branch_2, 512, [1, 3], scope='Conv2d_0c_1x3')
+        branch_2 = tf.concat(3, [
+            slim.conv2d(branch_2, 256, [1, 3], scope='Conv2d_0d_1x3'),
+            slim.conv2d(branch_2, 256, [3, 1], scope='Conv2d_0e_3x1')])
+      with tf.variable_scope('Branch_3'):
+        branch_3 = slim.avg_pool2d(inputs, [3, 3], scope='AvgPool_0a_3x3')
+        branch_3 = slim.conv2d(branch_3, 256, [1, 1], scope='Conv2d_0b_1x1')
+      return tf.concat(3, [branch_0, branch_1, branch_2, branch_3])
+
+
+def inception_v4_base(inputs, final_endpoint='Mixed_7d', scope=None):
+  """Creates the Inception V4 network up to the given final endpoint.
+
+  Args:
+    inputs: a 4-D tensor of size [batch_size, height, width, 3].
+    final_endpoint: specifies the endpoint to construct the network up to.
+      It can be one of [ 'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3',
+      'Mixed_3a', 'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
+      'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d', 'Mixed_6e',
+      'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a', 'Mixed_7b', 'Mixed_7c',
+      'Mixed_7d']
+    scope: Optional variable_scope.
+
+  Returns:
+    logits: the logits outputs of the model.
+    end_points: the set of end_points from the inception model.
+
+  Raises:
+    ValueError: if final_endpoint is not set to one of the predefined values,
+  """
+  end_points = {}
+
+  def add_and_check_final(name, net):
+    end_points[name] = net
+    return name == final_endpoint
+
+  with tf.variable_scope(scope, 'InceptionV4', [inputs]):
+    with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                        stride=1, padding='SAME'):
+      # 299 x 299 x 3
+      net = slim.conv2d(inputs, 32, [3, 3], stride=2,
+                        padding='VALID', scope='Conv2d_1a_3x3')
+      if add_and_check_final('Conv2d_1a_3x3', net): return net, end_points
+      # 149 x 149 x 32
+      net = slim.conv2d(net, 32, [3, 3], padding='VALID',
+                        scope='Conv2d_2a_3x3')
+      if add_and_check_final('Conv2d_2a_3x3', net): return net, end_points
+      # 147 x 147 x 32
+      net = slim.conv2d(net, 64, [3, 3], scope='Conv2d_2b_3x3')
+      if add_and_check_final('Conv2d_2b_3x3', net): return net, end_points
+      # 147 x 147 x 64
+      with tf.variable_scope('Mixed_3a'):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
+                                     scope='MaxPool_0a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, 96, [3, 3], stride=2, padding='VALID',
+                                 scope='Conv2d_0a_3x3')
+        net = tf.concat(3, [branch_0, branch_1])
+        if add_and_check_final('Mixed_3a', net): return net, end_points
+
+      # 73 x 73 x 160
+      with tf.variable_scope('Mixed_4a'):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
+          branch_0 = slim.conv2d(branch_0, 96, [3, 3], padding='VALID',
+                                 scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.conv2d(net, 64, [1, 1], scope='Conv2d_0a_1x1')
+          branch_1 = slim.conv2d(branch_1, 64, [1, 7], scope='Conv2d_0b_1x7')
+          branch_1 = slim.conv2d(branch_1, 64, [7, 1], scope='Conv2d_0c_7x1')
+          branch_1 = slim.conv2d(branch_1, 96, [3, 3], padding='VALID',
+                                 scope='Conv2d_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1])
+        if add_and_check_final('Mixed_4a', net): return net, end_points
+
+      # 71 x 71 x 192
+      with tf.variable_scope('Mixed_5a'):
+        with tf.variable_scope('Branch_0'):
+          branch_0 = slim.conv2d(net, 192, [3, 3], stride=2, padding='VALID',
+                                 scope='Conv2d_1a_3x3')
+        with tf.variable_scope('Branch_1'):
+          branch_1 = slim.max_pool2d(net, [3, 3], stride=2, padding='VALID',
+                                     scope='MaxPool_1a_3x3')
+        net = tf.concat(3, [branch_0, branch_1])
+        if add_and_check_final('Mixed_5a', net): return net, end_points
+
+      # 35 x 35 x 384
+      # 4 x Inception-A blocks
+      for idx in xrange(4):
+        block_scope = 'Mixed_5' + chr(ord('b') + idx)
+        net = block_inception_a(net, block_scope)
+        if add_and_check_final(block_scope, net): return net, end_points
+
+      # 35 x 35 x 384
+      # Reduction-A block
+      net = block_reduction_a(net, 'Mixed_6a')
+      if add_and_check_final('Mixed_6a', net): return net, end_points
+
+      # 17 x 17 x 1024
+      # 7 x Inception-B blocks
+      for idx in xrange(7):
+        block_scope = 'Mixed_6' + chr(ord('b') + idx)
+        net = block_inception_b(net, block_scope)
+        if add_and_check_final(block_scope, net): return net, end_points
+
+      # 17 x 17 x 1024
+      # Reduction-B block
+      net = block_reduction_b(net, 'Mixed_7a')
+      if add_and_check_final('Mixed_7a', net): return net, end_points
+
+      # 8 x 8 x 1536
+      # 3 x Inception-C blocks
+      for idx in xrange(3):
+        block_scope = 'Mixed_7' + chr(ord('b') + idx)
+        net = block_inception_c(net, block_scope)
+        if add_and_check_final(block_scope, net): return net, end_points
+  raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+
+def inception_v4(inputs, num_classes=1001, is_training=True,
+                 dropout_keep_prob=0.8,
+                 reuse=None,
+                 scope='InceptionV4',
+                 create_aux_logits=True):
+  """Creates the Inception V4 model.
+
+  Args:
+    inputs: a 4-D tensor of size [batch_size, height, width, 3].
+    num_classes: number of predicted classes.
+    is_training: whether is training or not.
+    dropout_keep_prob: float, the fraction to keep before final layer.
+    reuse: whether or not the network and its variables should be reused. To be
+      able to reuse 'scope' must be given.
+    scope: Optional variable_scope.
+    create_aux_logits: Whether to include the auxilliary logits.
+
+  Returns:
+    logits: the logits outputs of the model.
+    end_points: the set of end_points from the inception model.
+  """
+  end_points = {}
+  with tf.variable_scope(scope, 'InceptionV4', [inputs], reuse=reuse) as scope:
+    with slim.arg_scope([slim.batch_norm, slim.dropout],
+                        is_training=is_training):
+      net, end_points = inception_v4_base(inputs, scope=scope)
+
+      with slim.arg_scope([slim.conv2d, slim.max_pool2d, slim.avg_pool2d],
+                          stride=1, padding='SAME'):
+        # Auxiliary Head logits
+        if create_aux_logits:
+          with tf.variable_scope('AuxLogits'):
+            # 17 x 17 x 1024
+            aux_logits = end_points['Mixed_6h']
+            aux_logits = slim.avg_pool2d(aux_logits, [5, 5], stride=3,
+                                         padding='VALID',
+                                         scope='AvgPool_1a_5x5')
+            aux_logits = slim.conv2d(aux_logits, 128, [1, 1],
+                                     scope='Conv2d_1b_1x1')
+            aux_logits = slim.conv2d(aux_logits, 768,
+                                     aux_logits.get_shape()[1:3],
+                                     padding='VALID', scope='Conv2d_2a')
+            aux_logits = slim.flatten(aux_logits)
+            aux_logits = slim.fully_connected(aux_logits, num_classes,
+                                              activation_fn=None,
+                                              scope='Aux_logits')
+            end_points['AuxLogits'] = aux_logits
+
+        # Final pooling and prediction
+        with tf.variable_scope('Logits'):
+          # 8 x 8 x 1536
+          net = slim.avg_pool2d(net, net.get_shape()[1:3], padding='VALID',
+                                scope='AvgPool_1a')
+          # 1 x 1 x 1536
+          net = slim.dropout(net, dropout_keep_prob, scope='Dropout_1b')
+          net = slim.flatten(net, scope='PreLogitsFlatten')
+          end_points['PreLogitsFlatten'] = net
+          # 1536
+          logits = slim.fully_connected(net, num_classes, activation_fn=None,
+                                        scope='Logits')
+          end_points['Logits'] = logits
+          end_points['Predictions'] = tf.nn.softmax(logits, name='Predictions')
+    return logits, end_points
+inception_v4.default_image_size = 299
+
+
+inception_v4_arg_scope = inception_utils.inception_arg_scope

+ 216 - 0
slim/nets/inception_v4_test.py

@@ -0,0 +1,216 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for slim.inception_v4."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from nets import inception
+
+
+class InceptionTest(tf.test.TestCase):
+
+  def testBuildLogits(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v4(inputs, num_classes)
+    auxlogits = end_points['AuxLogits']
+    predictions = end_points['Predictions']
+    self.assertTrue(auxlogits.op.name.startswith('InceptionV4/AuxLogits'))
+    self.assertListEqual(auxlogits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    self.assertTrue(predictions.op.name.startswith(
+        'InceptionV4/Logits/Predictions'))
+    self.assertListEqual(predictions.get_shape().as_list(),
+                         [batch_size, num_classes])
+
+  def testBuildWithoutAuxLogits(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, endpoints = inception.inception_v4(inputs, num_classes,
+                                               create_aux_logits=False)
+    self.assertFalse('AuxLogits' in endpoints)
+    self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+
+  def testAllEndPointsShapes(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    _, end_points = inception.inception_v4(inputs, num_classes)
+    endpoints_shapes = {'Conv2d_1a_3x3': [batch_size, 149, 149, 32],
+                        'Conv2d_2a_3x3': [batch_size, 147, 147, 32],
+                        'Conv2d_2b_3x3': [batch_size, 147, 147, 64],
+                        'Mixed_3a': [batch_size, 73, 73, 160],
+                        'Mixed_4a': [batch_size, 71, 71, 192],
+                        'Mixed_5a': [batch_size, 35, 35, 384],
+                        # 4 x Inception-A blocks
+                        'Mixed_5b': [batch_size, 35, 35, 384],
+                        'Mixed_5c': [batch_size, 35, 35, 384],
+                        'Mixed_5d': [batch_size, 35, 35, 384],
+                        'Mixed_5e': [batch_size, 35, 35, 384],
+                        # Reduction-A block
+                        'Mixed_6a': [batch_size, 17, 17, 1024],
+                        # 7 x Inception-B blocks
+                        'Mixed_6b': [batch_size, 17, 17, 1024],
+                        'Mixed_6c': [batch_size, 17, 17, 1024],
+                        'Mixed_6d': [batch_size, 17, 17, 1024],
+                        'Mixed_6e': [batch_size, 17, 17, 1024],
+                        'Mixed_6f': [batch_size, 17, 17, 1024],
+                        'Mixed_6g': [batch_size, 17, 17, 1024],
+                        'Mixed_6h': [batch_size, 17, 17, 1024],
+                        # Reduction-A block
+                        'Mixed_7a': [batch_size, 8, 8, 1536],
+                        # 3 x Inception-C blocks
+                        'Mixed_7b': [batch_size, 8, 8, 1536],
+                        'Mixed_7c': [batch_size, 8, 8, 1536],
+                        'Mixed_7d': [batch_size, 8, 8, 1536],
+                        # Logits and predictions
+                        'AuxLogits': [batch_size, num_classes],
+                        'PreLogitsFlatten': [batch_size, 1536],
+                        'Logits': [batch_size, num_classes],
+                        'Predictions': [batch_size, num_classes]}
+    self.assertItemsEqual(endpoints_shapes.keys(), end_points.keys())
+    for endpoint_name in endpoints_shapes:
+      expected_shape = endpoints_shapes[endpoint_name]
+      self.assertTrue(endpoint_name in end_points)
+      self.assertListEqual(end_points[endpoint_name].get_shape().as_list(),
+                           expected_shape)
+
+  def testBuildBaseNetwork(self):
+    batch_size = 5
+    height, width = 299, 299
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    net, end_points = inception.inception_v4_base(inputs)
+    self.assertTrue(net.op.name.startswith(
+        'InceptionV4/Mixed_7d'))
+    self.assertListEqual(net.get_shape().as_list(), [batch_size, 8, 8, 1536])
+    expected_endpoints = [
+        'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
+        'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
+        'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
+        'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a',
+        'Mixed_7b', 'Mixed_7c', 'Mixed_7d']
+    self.assertItemsEqual(end_points.keys(), expected_endpoints)
+    for name, op in end_points.iteritems():
+      self.assertTrue(op.name.startswith('InceptionV4/' + name))
+
+  def testBuildOnlyUpToFinalEndpoint(self):
+    batch_size = 5
+    height, width = 299, 299
+    all_endpoints = [
+        'Conv2d_1a_3x3', 'Conv2d_2a_3x3', 'Conv2d_2b_3x3', 'Mixed_3a',
+        'Mixed_4a', 'Mixed_5a', 'Mixed_5b', 'Mixed_5c', 'Mixed_5d',
+        'Mixed_5e', 'Mixed_6a', 'Mixed_6b', 'Mixed_6c', 'Mixed_6d',
+        'Mixed_6e', 'Mixed_6f', 'Mixed_6g', 'Mixed_6h', 'Mixed_7a',
+        'Mixed_7b', 'Mixed_7c', 'Mixed_7d']
+    for index, endpoint in enumerate(all_endpoints):
+      with tf.Graph().as_default():
+        inputs = tf.random_uniform((batch_size, height, width, 3))
+        out_tensor, end_points = inception.inception_v4_base(
+            inputs, final_endpoint=endpoint)
+        self.assertTrue(out_tensor.op.name.startswith(
+            'InceptionV4/' + endpoint))
+        self.assertItemsEqual(all_endpoints[:index+1], end_points)
+
+  def testVariablesSetDevice(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    # Force all Variables to reside on the device.
+    with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
+      inception.inception_v4(inputs, num_classes)
+    with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
+      inception.inception_v4(inputs, num_classes)
+    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+      self.assertDeviceEqual(v.device, '/cpu:0')
+    for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+      self.assertDeviceEqual(v.device, '/gpu:0')
+
+  def testHalfSizeImages(self):
+    batch_size = 5
+    height, width = 150, 150
+    num_classes = 1000
+    inputs = tf.random_uniform((batch_size, height, width, 3))
+    logits, end_points = inception.inception_v4(inputs, num_classes)
+    self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
+    self.assertListEqual(logits.get_shape().as_list(),
+                         [batch_size, num_classes])
+    pre_pool = end_points['Mixed_7d']
+    self.assertListEqual(pre_pool.get_shape().as_list(),
+                         [batch_size, 3, 3, 1536])
+
+  def testUnknownBatchSize(self):
+    batch_size = 1
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session() as sess:
+      inputs = tf.placeholder(tf.float32, (None, height, width, 3))
+      logits, _ = inception.inception_v4(inputs, num_classes)
+      self.assertTrue(logits.op.name.startswith('InceptionV4/Logits'))
+      self.assertListEqual(logits.get_shape().as_list(),
+                           [None, num_classes])
+      images = tf.random_uniform((batch_size, height, width, 3))
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(logits, {inputs: images.eval()})
+      self.assertEquals(output.shape, (batch_size, num_classes))
+
+  def testEvaluation(self):
+    batch_size = 2
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session() as sess:
+      eval_inputs = tf.random_uniform((batch_size, height, width, 3))
+      logits, _ = inception.inception_v4(eval_inputs,
+                                         num_classes,
+                                         is_training=False)
+      predictions = tf.argmax(logits, 1)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (batch_size,))
+
+  def testTrainEvalWithReuse(self):
+    train_batch_size = 5
+    eval_batch_size = 2
+    height, width = 150, 150
+    num_classes = 1000
+    with self.test_session() as sess:
+      train_inputs = tf.random_uniform((train_batch_size, height, width, 3))
+      inception.inception_v4(train_inputs, num_classes)
+      eval_inputs = tf.random_uniform((eval_batch_size, height, width, 3))
+      logits, _ = inception.inception_v4(eval_inputs,
+                                         num_classes,
+                                         is_training=False,
+                                         reuse=True)
+      predictions = tf.argmax(logits, 1)
+      sess.run(tf.initialize_all_variables())
+      output = sess.run(predictions)
+      self.assertEquals(output.shape, (eval_batch_size,))
+
+
+if __name__ == '__main__':
+  tf.test.main()

+ 2 - 0
slim/nets/nets_factory.py

@@ -41,6 +41,7 @@ networks_map = {'alexnet_v2': alexnet.alexnet_v2,
                 'inception_v1': inception.inception_v1,
                 'inception_v2': inception.inception_v2,
                 'inception_v3': inception.inception_v3,
+                'inception_v4': inception.inception_v4,
                 'inception_resnet_v2': inception.inception_resnet_v2,
                 'lenet': lenet.lenet,
                 'resnet_v1_50': resnet_v1.resnet_v1_50,
@@ -62,6 +63,7 @@ arg_scopes_map = {'alexnet_v2': alexnet.alexnet_v2_arg_scope,
                   'inception_v1': inception.inception_v3_arg_scope,
                   'inception_v2': inception.inception_v3_arg_scope,
                   'inception_v3': inception.inception_v3_arg_scope,
+                  'inception_v4': inception.inception_v4_arg_scope,
                   'inception_resnet_v2':
                   inception.inception_resnet_v2_arg_scope,
                   'lenet': lenet.lenet_arg_scope,

+ 1 - 0
slim/preprocessing/preprocessing_factory.py

@@ -50,6 +50,7 @@ def get_preprocessing(name, is_training=False):
       'inception_v1': inception_preprocessing,
       'inception_v2': inception_preprocessing,
       'inception_v3': inception_preprocessing,
+      'inception_v4': inception_preprocessing,
       'inception_resnet_v2': inception_preprocessing,
       'lenet': lenet_preprocessing,
       'resnet_v1_50': vgg_preprocessing,