Forráskód Böngészése

update inception slim.

Jianmin Chen 9 éve
szülő
commit
1a8c712118

+ 9 - 0
inception/inception/slim/BUILD

@@ -101,3 +101,12 @@ py_library(
         ":variables",
     ],
 )
+
+py_test(
+    name = "collections_test",
+    size = "small",
+    srcs = ["collections_test.py"],
+    deps = [
+        ":slim",
+    ],
+)

+ 126 - 142
inception/inception/slim/README.md

@@ -1,21 +1,20 @@
 # TensorFlow-Slim
 
-TF-Slim is a lightweight library for defining, training and evaluating models
-in TensorFlow. It enables defining complex networks quickly and concisely while
+TF-Slim is a lightweight library for defining, training and evaluating models in
+TensorFlow. It enables defining complex networks quickly and concisely while
 keeping a model's architecture transparent and its hyperparameters explicit.
 
-
 [TOC]
 
 ## Teaser
 
-As a demonstration of the simplicity of using TF-Slim, compare the simplicity
-of the code necessary for defining the entire
-[VGG](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) network using TF-Slim
-to the lengthy and verbose nature of defining just the first three layers (out
-of 16) using native tensorflow:
+As a demonstration of the simplicity of using TF-Slim, compare the simplicity of
+the code necessary for defining the entire [VGG]
+(http://www.robots.ox.ac.uk/~vgg/research/very_deep/) network using TF-Slim to
+the lengthy and verbose nature of defining just the first three layers (out of
+16) using native tensorflow:
 
-```python{.good}
+```python {.good}
 # VGG16 in TF-Slim.
 def vgg16(inputs):
   with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
@@ -38,7 +37,7 @@ def vgg16(inputs):
   return net
 ```
 
-```python{.bad}
+```python {.bad}
 # Layers 1-3 (out of 16) of VGG16 in native tensorflow.
 def vgg16(inputs):
   with tf.name_scope('conv1_1') as scope:
@@ -61,47 +60,42 @@ def vgg16(inputs):
 
 TF-Slim offers several advantages over just the built-in tensorflow libraries:
 
-* Allows one to define models much more compactly by eliminating
-boilerplate code. This is accomplished through the use of
-[argument scoping](scopes.py)
-and numerous high level
-[operations](ops.py).
-These tools increase readability and maintainability, reduce the likelihood
-of an error from copy-and-pasting hyperparameter values and simplifies
-hyperparameter tuning.
-* Makes developing models simple by providing commonly used
-[loss functions](losses.py)
-* Provides a concise
-[definition](inception.py)
-of [Inception v3](http://arxiv.org/abs/1512.00567) network architecture
-ready to be used out-of-the-box or subsumed into new models.
+*   Allows one to define models much more compactly by eliminating boilerplate
+    code. This is accomplished through the use of [argument scoping](scopes.py)
+    and numerous high level [operations](ops.py). These tools increase
+    readability and maintainability, reduce the likelihood of an error from
+    copy-and-pasting hyperparameter values and simplifies hyperparameter tuning.
+*   Makes developing models simple by providing commonly used [loss functions]
+    (losses.py)
+*   Provides a concise [definition](inception.py) of [Inception v3]
+    (http://arxiv.org/abs/1512.00567) network architecture ready to be used
+    out-of-the-box or subsumed into new models.
 
 Additionally TF-Slim was designed with several principles in mind:
 
-* The various modules of TF-Slim (scopes, variables, ops, losses) are
-independent. This flexibility allows users to pick and choose
-components of TF-Slim completely à la carte.
-* TF-Slim is written using a Functional Programming style. That means it's
-super-lightweight and can be used right alongside any of TensorFlow's native
-operations.
-* Makes re-using network architectures easy. This allows users to build new
-networks on top of existing ones as well as fine-tuning pre-trained models on
-new tasks.
+*   The various modules of TF-Slim (scopes, variables, ops, losses) are
+    independent. This flexibility allows users to pick and choose components of
+    TF-Slim completely à la carte.
+*   TF-Slim is written using a Functional Programming style. That means it's
+    super-lightweight and can be used right alongside any of TensorFlow's native
+    operations.
+*   Makes re-using network architectures easy. This allows users to build new
+    networks on top of existing ones as well as fine-tuning pre-trained models
+    on new tasks.
 
 ## What are the various components of TF-Slim?
 
 TF-Slim is composed of several parts which were designed to exist independently.
 These include:
 
-* [scopes.py](./scopes.py):
-provides a new scope named `arg_scope` that allows a user to define default
-arguments for specific operations within that scope.
-* [variables.py](./variables.py):
-provides convenience wrappers for variable creation and manipulation.
-* [ops.py](./ops.py):
-provides high level operations for building models using tensorflow.
-* [losses.py](./losses.py):
-contains commonly used loss functions.
+*   [scopes.py](./scopes.py): provides a new scope named `arg_scope` that allows
+    a user to define default arguments for specific operations within that
+    scope.
+*   [variables.py](./variables.py): provides convenience wrappers for variable
+    creation and manipulation.
+*   [ops.py](./ops.py): provides high level operations for building models using
+    tensorflow.
+*   [losses.py](./losses.py): contains commonly used loss functions.
 
 ## Defining Models
 
@@ -110,16 +104,14 @@ operations and scopes. Each of these elements are defined below.
 
 ### Variables
 
-Creating
-[`Variables`](https://www.tensorflow.org/how_tos/variables/index.html)
+Creating [`Variables`](https://www.tensorflow.org/how_tos/variables/index.html)
 in native tensorflow requires either a predefined value or an initialization
-mechanism
-(random, normally distributed). Furthermore, if a variable needs to be created
-on a specific device, such as a GPU, the specification must be
-[made explicit](https://www.tensorflow.org/how_tos/using_gpu/index.html).
-To alleviate the code required for variable creation, TF-Slim provides a set
-of thin wrapper functions in [variables.py](./variables.py)
-which allow callers to easily define variables.
+mechanism (random, normally distributed). Furthermore, if a variable needs to be
+created on a specific device, such as a GPU, the specification must be [made
+explicit](https://www.tensorflow.org/how_tos/using_gpu/index.html). To alleviate
+the code required for variable creation, TF-Slim provides a set of thin wrapper
+functions in [variables.py](./variables.py) which allow callers to easily define
+variables.
 
 For example, to create a `weight` variable, initialize it using a truncated
 normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
@@ -159,21 +151,20 @@ weights = variables.variable('weights',
 
 ### Operations (Layers)
 
-While the set of TensorFlow operations is quite extensive, builders of
-neural networks typically think of models in terms of "layers". A layer,
-such as a Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer
-are more abstract than a single TensorFlow operation and typically involve
-many such operations. For example, a Convolutional Layer in a neural network
-is built using several steps:
+While the set of TensorFlow operations is quite extensive, builders of neural
+networks typically think of models in terms of "layers". A layer, such as a
+Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer are more
+abstract than a single TensorFlow operation and typically involve many such
+operations. For example, a Convolutional Layer in a neural network is built
+using several steps:
 
-1. Creating the weight variables
-2. Creating the bias variables
-3. Convolving the weights with the input from the previous layer
-4. Adding the biases to the result of the convolution.
+1.  Creating the weight variables
+2.  Creating the bias variables
+3.  Convolving the weights with the input from the previous layer
+4.  Adding the biases to the result of the convolution.
 
 In python code this can be rather laborious:
 
-
 ```python
 input = ...
 with tf.name_scope('conv1_1') as scope:
@@ -187,9 +178,9 @@ with tf.name_scope('conv1_1') as scope:
 ```
 
 To alleviate the need to duplicate this code repeatedly, TF-Slim provides a
-number of convenient operations defined at the (more abstract) level of
-neural network layers. For example, compare the code above to an invocation
-of the TF-Slim code:
+number of convenient operations defined at the (more abstract) level of neural
+network layers. For example, compare the code above to an invocation of the
+TF-Slim code:
 
 ```python
 input = ...
@@ -199,22 +190,21 @@ net = slim.ops.conv2d(input, [3, 3], 128, scope='conv1_1')
 TF-Slim provides numerous operations used in building neural networks which
 roughly correspond to such layers. These include:
 
-Layer | TF-Slim Op
-------- | --------
-Convolutional Layer | [ops.conv2d](ops.py)
+Layer                 | TF-Slim Op
+--------------------- | ------------------------
+Convolutional Layer   | [ops.conv2d](ops.py)
 Fully Connected Layer | [ops.fc](ops.py)
-BatchNorm layer | [ops.batch_norm](ops.py)
-Max Pooling Layer | [ops.max_pool](ops.py)
-Avg Pooling Layer | [ops.avg_pool](ops.py)
-Dropout Layer | [ops.dropout](ops.py)
+BatchNorm layer       | [ops.batch_norm](ops.py)
+Max Pooling Layer     | [ops.max_pool](ops.py)
+Avg Pooling Layer     | [ops.avg_pool](ops.py)
+Dropout Layer         | [ops.dropout](ops.py)
 
-[ops.py](./ops.py)
-also includes operations that are not really "layers" per se, but are
-often used to manipulate hidden unit representations during inference:
+[ops.py](./ops.py) also includes operations that are not really "layers" per se,
+but are often used to manipulate hidden unit representations during inference:
 
 Operation | TF-Slim Op
-------- | --------
-Flatten | [ops.flatten](ops.py)
+--------- | ---------------------
+Flatten   | [ops.flatten](ops.py)
 
 TF-Slim also provides a meta-operation called `repeat_op` that allows one to
 repeatedly perform the same operation. Consider the following snippet from the
@@ -238,36 +228,35 @@ for i in range(3):
 net = slim.ops.max_pool(net, [2, 2], scope='pool3')
 ```
 
-While this does reduce the amount of duplication, it can be made even cleaner
-by using the `RepeatOp`:
+While this does reduce the amount of duplication, it can be made even cleaner by
+using the `RepeatOp`:
 
 ```python
 net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
 net = slim.ops.max_pool(net, [2, 2], scope='pool2')
 ```
 
-Notice that the RepeatOp not only applies the same argument in-line, it also
-is smart enough to unroll the scopes such that the scopes assigned to each
+Notice that the RepeatOp not only applies the same argument in-line, it also is
+smart enough to unroll the scopes such that the scopes assigned to each
 subsequent call of `ops.conv2d` is appended with an underscore and iteration
 number. More concretely, the scopes in the example above would be 'conv3_1',
 'conv3_2' and 'conv3_3'.
 
-
 ### Scopes
 
-In addition to the types of scope mechanisms in TensorFlow
-([name_scope](https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
+In addition to the types of scope mechanisms in TensorFlow ([name_scope]
+(https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
 [op_scope](https://www.tensorflow.org/api_docs/python/framework.html#op_scope),
-[variable_scope](https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
-[variable_op_scope](https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
-TF-Slim adds a new scoping mechanism called "argument scope" or
-[arg_scope](scopes.py).
-This new scope allows a user to specify one or more operations and a set of
-arguments which will be passed to each of the operations defined in the
+[variable_scope]
+(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
+[variable_op_scope]
+(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
+TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope]
+(scopes.py). This new scope allows a user to specify one or more operations and
+a set of arguments which will be passed to each of the operations defined in the
 `arg_scope`. This functionality is best illustrated by example. Consider the
 following code snippet:
 
-
 ```python
 net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='SAME', stddev=0.01, weight_decay=0.0005, scope='conv1')
 net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', stddev=0.01, weight_decay=0.0005, scope='conv2')
@@ -278,8 +267,8 @@ It should be clear that these three Convolution layers share many of the same
 hyperparameters. Two have the same padding, all three have the same weight_decay
 and standard deviation of its weights. Not only do the duplicated values make
 the code more difficult to read, it also adds the addition burder to the writer
-of needing to doublecheck that all of the values are identical in each step.
-One solution would be to specify default values using variables:
+of needing to doublecheck that all of the values are identical in each step. One
+solution would be to specify default values using variables:
 
 ```python
 padding='SAME'
@@ -302,11 +291,11 @@ ensure that each layer uses the same values and simplify the code:
     net = slim.ops.conv2d(net, 256, [11, 11], scope='conv3')
 ```
 
-As the example illustrates, the use of arg_scope makes the code cleaner,
-simpler and easier to maintain. Notice that while argument values are specifed
-in the arg_scope, they can be overwritten locally. In particular, while
-the padding argument has been set to 'SAME', the second convolution overrides
-it with the value of 'VALID'.
+As the example illustrates, the use of arg_scope makes the code cleaner, simpler
+and easier to maintain. Notice that while argument values are specifed in the
+arg_scope, they can be overwritten locally. In particular, while the padding
+argument has been set to 'SAME', the second convolution overrides it with the
+value of 'VALID'.
 
 One can also nest `arg_scope`s and use multiple operations in the same scope.
 For example:
@@ -320,9 +309,9 @@ with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005)
     net = slim.ops.fc(net, 1000, activation=None, scope='fc')
 ```
 
-In this example, the first `arg_scope` applies the same `stddev` and `weight_decay`
-arguments to the `conv2d` and `fc` ops in its scope. In the second `arg_scope`,
-additional default arguments to `conv2d` only are specified.
+In this example, the first `arg_scope` applies the same `stddev` and
+`weight_decay` arguments to the `conv2d` and `fc` ops in its scope. In the
+second `arg_scope`, additional default arguments to `conv2d` only are specified.
 
 In addition to `arg_scope`, TF-Slim provides several decorators that wrap the
 use of tensorflow arg scopes. These include `@AddArgScope`, `@AddNameScope`,
@@ -349,8 +338,8 @@ with tf.variable_scope('layer1'):
   outputs = MyNewOp(inputs)
 ```
 
-As an alternative, one can use TF-Slim's decorators to decorate the function
-and simplify the call:
+As an alternative, one can use TF-Slim's decorators to decorate the function and
+simplify the call:
 
 ```python
 @AddVariableScope
@@ -364,8 +353,8 @@ outputs = MyNewOp('layer1')
 ```
 
 The `@AddVariableScope` decorater simply applies the `tf.variable_scope` scoping
-to the called function taking "layer1" as its argument. This allows the code
-to be written more concisely.
+to the called function taking "layer1" as its argument. This allows the code to
+be written more concisely.
 
 ### Losses
 
@@ -375,19 +364,16 @@ classification problems, this is typically the cross entropy between the true
 classes. For regression problems, this is often the sum-of-squares differences
 between the predicted and true values.
 
-Certain models, such as multi-task
-learning models, require the use of multiple loss functions simultaneously. In
-other words, the loss function ultimatey being minimized is the sum of various
-other loss functions. For example, consider a model that predicts both
-the type of scene in an image as well as the depth from the
-camera of each pixel. This model's loss function would be the sum of the
+Certain models, such as multi-task learning models, require the use of multiple
+loss functions simultaneously. In other words, the loss function ultimatey being
+minimized is the sum of various other loss functions. For example, consider a
+model that predicts both the type of scene in an image as well as the depth from
+the camera of each pixel. This model's loss function would be the sum of the
 classification loss and depth prediction loss.
 
-TF-Slim provides an easy-to-use mechanism for defining and keeping track of
-loss functions via the
-[losses.py](./losses.py)
-module. Consider the simple case where we want to train the VGG network:
-
+TF-Slim provides an easy-to-use mechanism for defining and keeping track of loss
+functions via the [losses.py](./losses.py) module. Consider the simple case
+where we want to train the VGG network:
 
 ```python
 # Load the images and labels.
@@ -401,9 +387,8 @@ loss = losses.ClassificationLoss(predictions, labels)
 ```
 
 In this example, we start by creating the model (using TF-Slim's VGG
-implementation), and add the standard classification loss. Now, lets turn
-to the case where we have a multi-task model that produces multiple outputs:
-
+implementation), and add the standard classification loss. Now, lets turn to the
+case where we have a multi-task model that produces multiple outputs:
 
 ```python
 # Load the images and labels.
@@ -424,16 +409,14 @@ total_loss2 = tf.get_collection(slim.losses.LOSSES_COLLECTION)
 In this example, we have two losses which we add by calling
 `losses.ClassificationLoss` and `losses.SumOfSquaresLoss`. We can obtain the
 total loss by adding them together (`total_loss1`) or by calling
-`losses.GetTotalLoss()`. How did this work?
-When you create a loss function via TF-Slim, TF-Slim adds the loss to a
-special TensorFlow collection of loss functions. This enables you to either
-manage the total loss manually, or allow TF-Slim to manage them for you.
+`losses.GetTotalLoss()`. How did this work? When you create a loss function via
+TF-Slim, TF-Slim adds the loss to a special TensorFlow collection of loss
+functions. This enables you to either manage the total loss manually, or allow
+TF-Slim to manage them for you.
 
 What if you want to let TF-Slim manage the losses for you but have a custom loss
-function?
-[losses.py](./losses.py)
-also has a function that adds this loss to TF-Slims collection. For example:
-
+function? [losses.py](./losses.py) also has a function that adds this loss to
+TF-Slims collection. For example:
 
 ```python
 # Load the images and labels.
@@ -452,15 +435,15 @@ tf.add_to_collection(slim.losses.LOSSES_COLLECTION, pose_loss) # Letting TF-Slim
 total_loss1 = classification_loss + sum_of_squares_loss + pose_loss
 total_loss2 = losses.GetTotalLoss()
 ```
-In this example, we can again either produce the total loss function manually
-or let TF-Slim know about the additional loss and let TF-Slim handle the losses.
 
+In this example, we can again either produce the total loss function manually or
+let TF-Slim know about the additional loss and let TF-Slim handle the losses.
 
 ## Putting the Pieces Together
 
 By combining TF-Slim Variables, Operations and scopes, we can write a normally
-very complex network with very few lines of code. For example, the entire
-[VGG](https://www.robots.ox.ac.uk/~vgg/research/very_deep/) architecture can be
+very complex network with very few lines of code. For example, the entire [VGG]
+(https://www.robots.ox.ac.uk/~vgg/research/very_deep/) architecture can be
 defined with just the following snippet:
 
 ```python
@@ -490,8 +473,8 @@ return net
 
 After a model has been trained, it can be restored using `tf.train.Saver()`
 which restores `Variables` from a given checkpoint. For many cases,
-`tf.train.Saver()` provides a simple mechanism to restore all or just a
-few variables.
+`tf.train.Saver()` provides a simple mechanism to restore all or just a few
+variables.
 
 ```python
 # Create some variables.
@@ -514,19 +497,21 @@ with tf.Session() as sess:
   ...
 ```
 
-See [Restoring Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#restoring-variables)
-and
-[Choosing which Variables to Save and Restore](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#choosing-which-variables-to-save-and-restore)
-sections of the [Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html)
-page for more details.
+See [Restoring Variables]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#restoring-variables)
+and [Choosing which Variables to Save and Restore]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#choosing-which-variables-to-save-and-restore)
+sections of the [Variables]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html) page for
+more details.
 
 ### Using slim.variables to Track which Variables need to be Restored
 
 It is often desirable to fine-tune a pre-trained model on an entirely new
 dataset or even a new task. In these situations, one must specify which layers
-of the model should be reused (and consequently loaded from a checkpoint)
-and which layers are new. Indicating which variables or layers should be
-restored is a process that quickly becomes cumbersome when done manually.
+of the model should be reused (and consequently loaded from a checkpoint) and
+which layers are new. Indicating which variables or layers should be restored is
+a process that quickly becomes cumbersome when done manually.
 
 To help keep track of which variables to restore, `slim.variables` provides a
 `restore` argument when creating each Variable. By default, all variables are
@@ -554,7 +539,6 @@ Additionally, every layer in `slim.ops` that creates slim.variables (such as
 argument which controls whether the variables created by that layer should be
 restored or not.
 
-
 ```python
 # Create a small network.
 net = slim.ops.conv2d(images, 32, [7, 7], stride=2, scope='conv1')

+ 31 - 4
inception/inception/slim/inception_model.py

@@ -43,7 +43,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 from inception.slim import ops
@@ -98,10 +97,10 @@ def inception_v3(inputs,
         # 73 x 73 x 64
         end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
                                          scope='conv3')
-        # 71 x 71 x 80.
+        # 73 x 73 x 80.
         end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
                                          scope='conv4')
-        # 69 x 69 x 192.
+        # 71 x 71 x 192.
         end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
                                            stride=2, scope='pool2')
         # 35 x 35 x 192.
@@ -260,7 +259,10 @@ def inception_v3(inputs,
           aux_logits = ops.fc(aux_logits, num_classes, activation=None,
                               stddev=0.001, restore=restore_logits)
           end_points['aux_logits'] = aux_logits
-        # mixed_8: 17 x 17 x 1280.
+        # mixed_8: 8 x 8 x 1280.
+        # Note that the scope below is not changed to not void previous
+        # checkpoints.
+        # (TODO) Fix the scope when appropriate.
         with tf.variable_scope('mixed_17x17x1280a'):
           with tf.variable_scope('branch3x3'):
             branch3x3 = ops.conv2d(net, 192, [1, 1])
@@ -327,3 +329,28 @@ def inception_v3(inputs,
           end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
       return logits, end_points
 
+
+def inception_v3_parameters(weight_decay=0.00004, stddev=0.1,
+                            batch_norm_decay=0.9997, batch_norm_epsilon=0.001):
+  """Yields the scope with the default parameters for inception_v3.
+
+  Args:
+    weight_decay: the weight decay for weights variables.
+    stddev: standard deviation of the truncated guassian weight distribution.
+    batch_norm_decay: decay for the moving average of batch_norm momentums.
+    batch_norm_epsilon: small float added to variance to avoid dividing by zero.
+
+  Yields:
+    a arg_scope with the parameters needed for inception_v3.
+  """
+  # Set weight_decay for weights in Conv and FC layers.
+  with scopes.arg_scope([ops.conv2d, ops.fc],
+                        weight_decay=weight_decay):
+    # Set stddev, activation and parameters for batch_norm.
+    with scopes.arg_scope([ops.conv2d],
+                          stddev=stddev,
+                          activation=tf.nn.relu,
+                          batch_norm_params={
+                              'decay': batch_norm_decay,
+                              'epsilon': batch_norm_epsilon}) as arg_scope:
+      yield arg_scope

+ 16 - 1
inception/inception/slim/inception_test.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 from inception.slim import inception_model as inception
@@ -55,6 +54,22 @@ class InceptionTest(tf.test.TestCase):
       self.assertListEqual(pre_pool.get_shape().as_list(),
                            [batch_size, 8, 8, 2048])
 
+  def testVariablesSetDevice(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      # Force all Variables to reside on the device.
+      with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
+        inception.inception_v3(inputs, num_classes)
+      with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
+        inception.inception_v3(inputs, num_classes)
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+        self.assertDeviceEqual(v.device, '/cpu:0')
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+        self.assertDeviceEqual(v.device, '/gpu:0')
+
   def testHalfSizeImages(self):
     batch_size = 5
     height, width = 150, 150

+ 65 - 1
inception/inception/slim/losses.py

@@ -26,7 +26,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 # In order to gather all losses in a network, the user should use this
@@ -35,6 +34,71 @@ import tensorflow as tf
 LOSSES_COLLECTION = '_losses'
 
 
+def l1_regularizer(weight=1.0, scope=None):
+  """Define a L1 regularizer.
+
+  Args:
+    weight: scale the loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L1Regularizer'):
+      l1_weight = tf.convert_to_tensor(weight,
+                                       dtype=tensor.dtype.base_dtype,
+                                       name='weight')
+      return tf.mul(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+  return regularizer
+
+
+def l2_regularizer(weight=1.0, scope=None):
+  """Define a L2 regularizer.
+
+  Args:
+    weight: scale the loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L2Regularizer'):
+      l2_weight = tf.convert_to_tensor(weight,
+                                       dtype=tensor.dtype.base_dtype,
+                                       name='weight')
+      return tf.mul(l2_weight, tf.nn.l2_loss(tensor), name='value')
+  return regularizer
+
+
+def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
+  """Define a L1L2 regularizer.
+
+  Args:
+    weight_l1: scale the L1 loss by this factor.
+    weight_l2: scale the L2 loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L1L2Regularizer'):
+      weight_l1_t = tf.convert_to_tensor(weight_l1,
+                                         dtype=tensor.dtype.base_dtype,
+                                         name='weight_l1')
+      weight_l2_t = tf.convert_to_tensor(weight_l2,
+                                         dtype=tensor.dtype.base_dtype,
+                                         name='weight_l2')
+      reg_l1 = tf.mul(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
+                      name='value_l1')
+      reg_l2 = tf.mul(weight_l2_t, tf.nn.l2_loss(tensor),
+                      name='value_l2')
+      return tf.add(reg_l1, reg_l2, name='value')
+  return regularizer
+
+
 def l1_loss(tensor, weight=1.0, scope=None):
   """Define a L1Loss, useful for regularize, i.e. lasso.
 

+ 89 - 1
inception/inception/slim/losses_test.py

@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 
 from inception.slim import losses
@@ -47,6 +46,95 @@ class LossesTest(tf.test.TestCase):
       self.assertAlmostEqual(loss.eval(), num_elem * wd / 2, 5)
 
 
+class RegularizersTest(tf.test.TestCase):
+
+  def testL1Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L1Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+  def testL1RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_regularizer(scope='L1')(tensor)
+      self.assertEquals(loss.op.name, 'L1/value')
+      self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+  def testL1RegularizerWithWeight(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight = 0.01
+      loss = losses.l1_regularizer(weight)(tensor)
+      self.assertEquals(loss.op.name, 'L1Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem * weight, 5)
+
+  def testL2Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l2_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+  def testL2RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l2_regularizer(scope='L2')(tensor)
+      self.assertEquals(loss.op.name, 'L2/value')
+      self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+  def testL2RegularizerWithWeight(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight = 0.01
+      loss = losses.l2_regularizer(weight)(tensor)
+      self.assertEquals(loss.op.name, 'L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem * weight / 2, 5)
+
+  def testL1L2Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_l2_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+  def testL1L2RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_l2_regularizer(scope='L1L2')(tensor)
+      self.assertEquals(loss.op.name, 'L1L2/value')
+      self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+  def testL1L2RegularizerWithWeights(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight_l1 = 0.01
+      weight_l2 = 0.05
+      loss = losses.l1_l2_regularizer(weight_l1, weight_l2)(tensor)
+      self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(),
+                             num_elem * weight_l1 + num_elem * weight_l2 / 2, 5)
+
+
 class CrossEntropyLossTest(tf.test.TestCase):
 
   def testCrossEntropyLossAllCorrect(self):

+ 81 - 32
inception/inception/slim/ops.py

@@ -27,7 +27,6 @@ from __future__ import division
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 
 from tensorflow.python.training import moving_averages
@@ -50,7 +49,8 @@ def batch_norm(inputs,
                is_training=True,
                trainable=True,
                restore=True,
-               scope=None):
+               scope=None,
+               reuse=None):
   """Adds a Batch Normalization layer.
 
   Args:
@@ -67,13 +67,15 @@ def batch_norm(inputs,
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
 
   Returns:
     a tensor representing the output of the operation.
 
   """
   inputs_shape = inputs.get_shape()
-  with tf.variable_op_scope([inputs], scope, 'BatchNorm'):
+  with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
     axis = range(len(inputs_shape) - 1)
     params_shape = inputs_shape[-1:]
     with scopes.arg_scope([variables.variable], restore=restore):
@@ -124,6 +126,37 @@ def batch_norm(inputs,
     return outputs
 
 
+def _two_element_tuple(int_or_tuple):
+  """Converts `int_or_tuple` to height, width.
+
+  Several of the functions that follow accept arguments as either
+  a tuple of 2 integers or a single integer.  A single integer
+  indicates that the 2 values of the tuple are the same.
+
+  This functions normalizes the input value by always returning a tuple.
+
+  Args:
+    int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
+
+  Returns:
+    A tuple with 2 values.
+
+  Raises:
+    ValueError: If `int_or_tuple` it not well formed.
+  """
+  if isinstance(int_or_tuple, (list, tuple)):
+    if len(int_or_tuple) != 2:
+      raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
+    return int(int_or_tuple[0]), int(int_or_tuple[1])
+  if isinstance(int_or_tuple, int):
+    return int(int_or_tuple), int(int_or_tuple)
+  if isinstance(int_or_tuple, tf.TensorShape):
+    if len(int_or_tuple) == 2:
+      return int_or_tuple[0], int_or_tuple[1]
+  raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
+                   'length 2')
+
+
 @scopes.add_arg_scope
 def conv2d(inputs,
            num_filters_out,
@@ -138,7 +171,8 @@ def conv2d(inputs,
            is_training=True,
            trainable=True,
            restore=True,
-           scope=None):
+           scope=None,
+           reuse=None):
   """Adds a 2D convolution followed by an optional batch_norm layer.
 
   conv2d creates a variable called 'weights', representing the convolutional
@@ -149,8 +183,11 @@ def conv2d(inputs,
   Args:
     inputs: a tensor of size [batch_size, height, width, channels].
     num_filters_out: the number of output filters.
-    kernel_size: a 2-D list comprising of the height and width of the filters.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of
+      of the filters. Can be an int if both values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: one of 'VALID' or 'SAME'.
     activation: activation function.
     stddev: standard deviation of the truncated guassian weight distribution.
@@ -161,28 +198,29 @@ def conv2d(inputs,
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
-
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
   Returns:
     a tensor representing the output of the operation.
 
-  Raises:
-    ValueError: if 'kernel_size' is not a 2-D list.
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
-  with tf.variable_op_scope([inputs], scope, 'Conv'):
+  with tf.variable_op_scope([inputs], scope, 'Conv', reuse=reuse):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     num_filters_in = inputs.get_shape()[-1]
-    weights_shape = [kernel_size[0], kernel_size[1],
+    weights_shape = [kernel_h, kernel_w,
                      num_filters_in, num_filters_out]
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
-    l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
+    l2_regularizer = None
+    if weight_decay and weight_decay > 0:
+      l2_regularizer = losses.l2_regularizer(weight_decay)
     weights = variables.variable('weights',
                                  shape=weights_shape,
                                  initializer=weights_initializer,
                                  regularizer=l2_regularizer,
                                  trainable=trainable,
                                  restore=restore)
-    conv = tf.nn.conv2d(inputs, weights, [1, stride, stride, 1],
+    conv = tf.nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
                         padding=padding)
     if batch_norm_params is not None:
       with scopes.arg_scope([batch_norm], is_training=is_training,
@@ -213,7 +251,8 @@ def fc(inputs,
        is_training=True,
        trainable=True,
        restore=True,
-       scope=None):
+       scope=None,
+       reuse=None):
   """Adds a fully connected layer followed by an optional batch_norm layer.
 
   FC creates a variable called 'weights', representing the fully connected
@@ -234,15 +273,19 @@ def fc(inputs,
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
 
   Returns:
      the tensor variable representing the result of the series of operations.
   """
-  with tf.variable_op_scope([inputs], scope, 'FC'):
+  with tf.variable_op_scope([inputs], scope, 'FC', reuse=reuse):
     num_units_in = inputs.get_shape()[1]
     weights_shape = [num_units_in, num_units_out]
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
-    l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
+    l2_regularizer = None
+    if weight_decay and weight_decay > 0:
+      l2_regularizer = losses.l2_regularizer(weight_decay)
     weights = variables.variable('weights',
                                  shape=weights_shape,
                                  initializer=weights_initializer,
@@ -298,8 +341,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
 
   Args:
     inputs: a tensor of size [batch_size, height, width, depth].
-    kernel_size: the size of the pooling kernel over which the op is computed.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+      pooling kernel over which the op is computed. Can be an int if both
+      values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     scope: Optional scope for op_scope.
 
@@ -308,12 +355,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
   Raises:
     ValueError: if 'kernel_size' is not a 2-D list
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
   with tf.op_scope([inputs], scope, 'MaxPool'):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.max_pool(inputs,
-                          ksize=[1, kernel_size[0], kernel_size[1], 1],
-                          strides=[1, stride, stride, 1],
+                          ksize=[1, kernel_h, kernel_w, 1],
+                          strides=[1, stride_h, stride_w, 1],
                           padding=padding)
 
 
@@ -326,22 +373,24 @@ def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
 
   Args:
     inputs: a tensor of size [batch_size, height, width, depth].
-    kernel_size: the size of the pooling kernel over which the op is computed.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+      pooling kernel over which the op is computed. Can be an int if both
+      values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     scope: Optional scope for op_scope.
 
   Returns:
     a tensor representing the results of the pooling operation.
-  Raises:
-    ValueError: if 'kernel_size' is not a 2-D list
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
   with tf.op_scope([inputs], scope, 'AvgPool'):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.avg_pool(inputs,
-                          ksize=[1, kernel_size[0], kernel_size[1], 1],
-                          strides=[1, stride, stride, 1],
+                          ksize=[1, kernel_h, kernel_w, 1],
+                          strides=[1, stride_h, stride_w, 1],
                           padding=padding)
 
 

+ 173 - 33
inception/inception/slim/ops_test.py

@@ -18,13 +18,11 @@ from __future__ import division
 from __future__ import print_function
 
 
-
 import numpy as np
 import tensorflow as tf
 
 from tensorflow.python.ops import control_flow_ops
 
-from inception.slim import losses
 from inception.slim import ops
 from inception.slim import scopes
 from inception.slim import variables
@@ -40,6 +38,57 @@ class ConvTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'Conv/Relu')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
 
+  def testCreateSquareConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, 3)
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+  def testCreateConvWithTensorShape(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, images.get_shape()[1:3])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+  def testCreateFullyConv(self):
+    height, width = 6, 6
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      output = ops.conv2d(images, 64, images.get_shape()[1:3], padding='VALID')
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
+
+  def testCreateVerticalConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [3, 1])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height, width, 32])
+
+  def testCreateHorizontalConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [1, 3])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height, width, 32])
+
+  def testCreateConvWithStride(self):
+    height, width = 6, 6
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [3, 3], stride=2)
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height/2, width/2, 32])
+
   def testCreateConvCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     images = tf.random_uniform((5, height, width, 3), seed=1)
@@ -76,31 +125,73 @@ class ConvTest(tf.test.TestCase):
     with self.test_session() as sess:
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
-      wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
-      self.assertEquals(wd.op.name, 'Conv/weights/Regularizer/L2Loss/value')
+      wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+      self.assertEquals(wd.op.name,
+                        'Conv/weights/Regularizer/L2Regularizer/value')
       sess.run(tf.initialize_all_variables())
       self.assertTrue(sess.run(wd) <= 0.01)
 
+  def testCreateConvWithoutWD(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3], weight_decay=0)
+      self.assertEquals(
+          tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
+  def testReuseVars(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3], scope='conv1')
+      self.assertEquals(len(variables.get_variables()), 2)
+      ops.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+
+  def testNonReuseVars(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 2)
+      ops.conv2d(images, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 4)
+
   def testReuseConvWithWD(self):
     height, width = 3, 3
     with self.test_session():
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
-      tf.get_variable_scope().reuse_variables()
-      ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+      ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1',
+                 reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
 
   def testConvWithBatchNorm(self):
     height, width = 3, 3
     with self.test_session():
-      images = tf.random_uniform((5, height, width, 3), seed=1)
-      with scopes.arg_scope([ops.conv2d], batch_norm_params={}):
-        net = ops.conv2d(images, 32, [3, 3], scope='conv1')
-        net = ops.conv2d(net, 32, [3, 3], scope='conv2')
-      self.assertEquals(len(tf.get_collection('moving_vars')), 4)
-      self.assertEquals(len(variables.get_variables('conv1/BatchNorm')), 3)
-      self.assertEquals(len(variables.get_variables('conv2/BatchNorm')), 3)
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+        net = ops.conv2d(images, 32, [3, 3])
+        net = ops.conv2d(net, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 8)
+      self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 3)
+
+  def testReuseConvWithBatchNorm(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+        net = ops.conv2d(images, 32, [3, 3], scope='Conv')
+        net = ops.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 4)
+      self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 0)
 
 
 class FCTest(tf.test.TestCase):
@@ -136,8 +227,7 @@ class FCTest(tf.test.TestCase):
     with self.test_session():
       ops.fc(inputs, 32, scope='fc1')
       self.assertEquals(len(variables.get_variables('fc1')), 2)
-      tf.get_variable_scope().reuse_variables()
-      ops.fc(inputs, 32, scope='fc1')
+      ops.fc(inputs, 32, scope='fc1', reuse=True)
       self.assertEquals(len(variables.get_variables('fc1')), 2)
 
   def testNonReuseVars(self):
@@ -161,31 +251,53 @@ class FCTest(tf.test.TestCase):
     with self.test_session() as sess:
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       ops.fc(inputs, 32, weight_decay=0.01)
-      wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
-      self.assertEquals(wd.op.name, 'FC/weights/Regularizer/L2Loss/value')
+      wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+      self.assertEquals(wd.op.name,
+                        'FC/weights/Regularizer/L2Regularizer/value')
       sess.run(tf.initialize_all_variables())
       self.assertTrue(sess.run(wd) <= 0.01)
 
+  def testCreateFCWithoutWD(self):
+    height, width = 3, 3
+    with self.test_session():
+      inputs = tf.random_uniform((5, height * width * 3), seed=1)
+      ops.fc(inputs, 32, weight_decay=0)
+      self.assertEquals(
+          tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
   def testReuseFCWithWD(self):
     height, width = 3, 3
     with self.test_session():
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
-      tf.get_variable_scope().reuse_variables()
-      ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+      ops.fc(inputs, 32, weight_decay=0.01, scope='fc', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
 
   def testFCWithBatchNorm(self):
     height, width = 3, 3
     with self.test_session():
       images = tf.random_uniform((5, height * width * 3), seed=1)
       with scopes.arg_scope([ops.fc], batch_norm_params={}):
-        net = ops.fc(images, 32, scope='fc1')
-        net = ops.fc(net, 32, scope='fc2')
-      self.assertEquals(len(tf.get_collection('moving_vars')), 4)
+        net = ops.fc(images, 27)
+        net = ops.fc(net, 27)
+      self.assertEquals(len(variables.get_variables()), 8)
+      self.assertEquals(len(variables.get_variables('FC/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('FC_1/BatchNorm')), 3)
+
+  def testReuseFCWithBatchNorm(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height * width * 3), seed=1)
+      with scopes.arg_scope([ops.fc], batch_norm_params={'decay': 0.9}):
+        net = ops.fc(images, 27, scope='fc1')
+        net = ops.fc(net, 27, scope='fc1', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 4)
       self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
-      self.assertEquals(len(variables.get_variables('fc2/BatchNorm')), 3)
 
 
 class MaxPoolTest(tf.test.TestCase):
@@ -198,6 +310,14 @@ class MaxPoolTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'MaxPool/MaxPool')
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
 
+  def testCreateSquareMaxPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.max_pool(images, 3)
+      self.assertEquals(output.op.name, 'MaxPool/MaxPool')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
   def testCreateMaxPoolWithScope(self):
     height, width = 3, 3
     with self.test_session():
@@ -219,6 +339,13 @@ class MaxPoolTest(tf.test.TestCase):
       output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
 
+  def testGlobalMaxPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.max_pool(images, images.get_shape()[1:3], stride=1)
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
 
 class AvgPoolTest(tf.test.TestCase):
 
@@ -230,6 +357,14 @@ class AvgPoolTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'AvgPool/AvgPool')
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
 
+  def testCreateSquareAvgPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.avg_pool(images, 3)
+      self.assertEquals(output.op.name, 'AvgPool/AvgPool')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
   def testCreateAvgPoolWithScope(self):
     height, width = 3, 3
     with self.test_session():
@@ -251,6 +386,13 @@ class AvgPoolTest(tf.test.TestCase):
       output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
 
+  def testGlobalAvgPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.avg_pool(images, images.get_shape()[1:3], stride=1)
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
 
 class OneHotEncodingTest(tf.test.TestCase):
 
@@ -342,8 +484,8 @@ class BatchNormTest(tf.test.TestCase):
       gamma = variables.get_variables_by_name('gamma')[0]
       self.assertEquals(beta.op.name, 'BatchNorm/beta')
       self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
-      moving_mean = tf.get_collection('moving_vars')[0]
-      moving_variance = tf.get_collection('moving_vars')[1]
+      moving_mean = tf.moving_average_variables()[0]
+      moving_variance = tf.moving_average_variables()[1]
       self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
       self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
 
@@ -375,8 +517,7 @@ class BatchNormTest(tf.test.TestCase):
     with self.test_session():
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.batch_norm(images, scale=True, scope='bn')
-      tf.get_variable_scope().reuse_variables()
-      ops.batch_norm(images, scale=True, scope='bn')
+      ops.batch_norm(images, scale=True, scope='bn', reuse=True)
       beta = variables.get_variables_by_name('beta')
       gamma = variables.get_variables_by_name('gamma')
       self.assertEquals(len(beta), 1)
@@ -390,8 +531,7 @@ class BatchNormTest(tf.test.TestCase):
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.batch_norm(images, scope='bn')
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
-      tf.get_variable_scope().reuse_variables()
-      ops.batch_norm(images, scope='bn')
+      ops.batch_norm(images, scope='bn', reuse=True)
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
 
   def testCreateMovingVars(self):

+ 49 - 23
inception/inception/slim/scopes.py

@@ -19,7 +19,7 @@
 
   Example of how to use scopes.arg_scope:
 
-  with slim.arg_scope(ops.conv2d, padding='SAME',
+  with scopes.arg_scope(ops.conv2d, padding='SAME',
                       stddev=0.01, weight_decay=0.0005):
     net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
     net = ops.conv2d(net, 256, [5, 5], scope='conv2')
@@ -32,6 +32,15 @@
     ops.conv2d(inputs, 256, [5, 5], padding='SAME',
                stddev=0.01, weight_decay=0.0005, scope='conv2')
 
+  Example of how to reuse an arg_scope:
+  with scopes.arg_scope(ops.conv2d, padding='SAME',
+                      stddev=0.01, weight_decay=0.0005) as conv2d_arg_scope:
+    net = ops.conv2d(net, 256, [5, 5], scope='conv1')
+    ....
+
+  with scopes.arg_scope(conv2d_arg_scope):
+    net = ops.conv2d(net, 256, [5, 5], scope='conv2')
+
   Example of how to use scopes.add_arg_scope:
 
   @scopes.add_arg_scope
@@ -44,7 +53,6 @@ from __future__ import print_function
 import contextlib
 import functools
 
-
 from tensorflow.python.framework import ops
 
 _ARGSTACK_KEY = ("__arg_stack",)
@@ -74,12 +82,16 @@ def _add_op(op):
 
 
 @contextlib.contextmanager
-def arg_scope(list_ops, **kwargs):
+def arg_scope(list_ops_or_scope, **kwargs):
   """Stores the default arguments for the given set of list_ops.
 
+  For usage, please see examples at top of the file.
+
   Args:
-    list_ops: List or tuple of operations to set argument scope for. Every op in
-              list_ops need to be decorated with @add_arg_scope to work.
+    list_ops_or_scope: List or tuple of operations to set argument scope for or
+      a dictionary containg the current scope. When list_ops_or_scope is a dict,
+      kwargs must be empty. When list_ops_or_scope is a list or tuple, then
+      every op in it need to be decorated with @add_arg_scope to work.
     **kwargs: keyword=value that will define the defaults for each op in
               list_ops. All the ops need to accept the given set of arguments.
 
@@ -89,24 +101,38 @@ def arg_scope(list_ops, **kwargs):
     TypeError: if list_ops is not a list or a tuple.
     ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
   """
-  if not isinstance(list_ops, (list, tuple)):
-    raise TypeError("list_ops is not a list or a tuple")
-  try:
-    current_scope = _current_arg_scope().copy()
-    for op in list_ops:
-      key_op = (op.__module__, op.__name__)
-      if not has_arg_scope(op):
-        raise ValueError("%s is not decorated with @add_arg_scope", key_op)
-      if key_op in current_scope:
-        current_kwargs = current_scope[key_op].copy()
-        current_kwargs.update(kwargs)
-        current_scope[key_op] = current_kwargs
-      else:
-        current_scope[key_op] = kwargs.copy()
-    _get_arg_stack().append(current_scope)
-    yield current_scope
-  finally:
-    _get_arg_stack().pop()
+  if isinstance(list_ops_or_scope, dict):
+    # Assumes that list_ops_or_scope is a scope that is being reused.
+    if kwargs:
+      raise ValueError("When attempting to re-use a scope by suppling a"
+                       "dictionary, kwargs must be empty.")
+    current_scope = list_ops_or_scope.copy()
+    try:
+      _get_arg_stack().append(current_scope)
+      yield current_scope
+    finally:
+      _get_arg_stack().pop()
+  else:
+    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
+    if not isinstance(list_ops_or_scope, (list, tuple)):
+      raise TypeError("list_ops_or_scope must either be a list/tuple or reused"
+                      "scope (i.e. dict)")
+    try:
+      current_scope = _current_arg_scope().copy()
+      for op in list_ops_or_scope:
+        key_op = (op.__module__, op.__name__)
+        if not has_arg_scope(op):
+          raise ValueError("%s is not decorated with @add_arg_scope", key_op)
+        if key_op in current_scope:
+          current_kwargs = current_scope[key_op].copy()
+          current_kwargs.update(kwargs)
+          current_scope[key_op] = current_kwargs
+        else:
+          current_scope[key_op] = kwargs.copy()
+      _get_arg_stack().append(current_scope)
+      yield current_scope
+    finally:
+      _get_arg_stack().pop()
 
 
 def add_arg_scope(func):

+ 45 - 1
inception/inception/slim/scopes_test.py

@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 from inception.slim import scopes
 
@@ -39,6 +38,51 @@ class ArgScopeTest(tf.test.TestCase):
     with self.test_session():
       self.assertEqual(scopes._current_arg_scope(), {})
 
+  def testCurrentArgScope(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    key_op = (func1.__module__, func1.__name__)
+    current_scope = {key_op: func1_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope:
+        self.assertDictEqual(scope, current_scope)
+
+  def testCurrentArgScopeNested(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    func2_kwargs = {'b': 2, 'd': [2]}
+    key = lambda f: (f.__module__, f.__name__)
+    current_scope = {key(func1): func1_kwargs.copy(),
+                     key(func2): func2_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+        with scopes.arg_scope([func2], b=2, d=[2]) as scope:
+          self.assertDictEqual(scope, current_scope)
+
+  def testReuseArgScope(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    key_op = (func1.__module__, func1.__name__)
+    current_scope = {key_op: func1_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+        pass
+      with scopes.arg_scope(scope1) as scope:
+        self.assertDictEqual(scope, current_scope)
+
+  def testReuseArgScopeNested(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    func2_kwargs = {'b': 2, 'd': [2]}
+    key = lambda f: (f.__module__, f.__name__)
+    current_scope1 = {key(func1): func1_kwargs.copy()}
+    current_scope2 = {key(func1): func1_kwargs.copy(),
+                      key(func2): func2_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+        with scopes.arg_scope([func2], b=2, d=[2]) as scope2:
+          pass
+      with scopes.arg_scope(scope1):
+        self.assertDictEqual(scopes._current_arg_scope(), current_scope1)
+      with scopes.arg_scope(scope2):
+        self.assertDictEqual(scopes._current_arg_scope(), current_scope2)
+
   def testSimpleArgScope(self):
     func1_args = (0,)
     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}

+ 54 - 60
inception/inception/slim/variables.py

@@ -12,7 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # ==============================================================================
-"""Contains convenience wrappers for creating Variables in TensorFlow.
+"""Contains convenience wrappers for creating variables in TF-Slim.
+
+The variables module is typically used for defining model variables from the
+ops routines (see slim.ops). Such variables are used for training, evaluation
+and inference of models.
+
+All the variables created through this module would be added to the
+MODEL_VARIABLES collection, if you create a model variable outside slim, it can
+be added with slim.variables.add_variable(external_variable, reuse).
 
 Usage:
   weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
@@ -24,15 +32,15 @@ Usage:
                                device='/cpu:0')
 
   biases = variables.variable('biases',
-                               shape=[100],
-                               initializer=tf.zeros_initializer,
-                               device='/cpu:0')
+                              shape=[100],
+                              initializer=tf.zeros_initializer,
+                              device='/cpu:0')
 
   # More complex example.
 
   net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
   net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
-  with slim.arg_scope(variables.Variables, restore=False):
+  with slim.arg_scope([variables.variable], restore=False):
     net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
 
   # Get all model variables from all the layers.
@@ -47,9 +55,9 @@ Usage:
   # Get all bias from all the layers.
   biases = slim.variables.get_variables_by_name('biases')
 
-  # Get all variables in the VARIABLES_TO_RESTORE collection
+  # Get all variables to restore.
   # (i.e. only those created by 'conv1' and 'conv2')
-  variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+  variables_to_restore = slim.variables.get_variables_to_restore()
 
 ************************************************
 * Initializing model variables from a checkpoint
@@ -60,7 +68,7 @@ v1 = slim.variables.variable(name="v1", ..., restore=False)
 v2 = slim.variables.variable(name="v2", ...) # By default restore=True
 ...
 # The list of variables to restore should only contain 'v2'.
-variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+variables_to_restore = slim.variables.get_variables_to_restore()
 restorer = tf.train.Saver(variables_to_restore)
 with tf.Session() as sess:
   # Restore variables from disk.
@@ -74,92 +82,71 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 from inception.slim import scopes
 
 # Collection containing all the variables created using slim.variables
-VARIABLES_COLLECTION = '_variables_'
+MODEL_VARIABLES = '_model_variables_'
 
-# Collection containing all the slim.variables that are marked to_restore
+# Collection containing the slim.variables that are created with restore=True.
 VARIABLES_TO_RESTORE = '_variables_to_restore_'
 
 
-def get_variable_given_name(var):
-  """Gets the variable given name without the scope.
-
-  Args:
-    var: a variable.
-
-  Returns:
-    the given name of the variable without the scope.
-  """
-  name = var.op.name
-  if '/' in name:
-    name = name.split('/')[-1]
-  return name
-
-
-def default_collections(given_name, restore):
-  """Define the set of default collections that variables should be added.
-
-  Args:
-    given_name: the given name of the variable.
-    restore: whether the variable should be added to the VARIABLES_TO_RESTORE
-      collection.
-
-  Returns:
-    a list of default collections.
-  """
-  defaults = [tf.GraphKeys.VARIABLES, VARIABLES_COLLECTION]
-  defaults += [VARIABLES_COLLECTION + given_name]
-  if restore:
-    defaults += [VARIABLES_TO_RESTORE]
-  return defaults
-
-
 def add_variable(var, restore=True):
-  """Adds a variable to the default set of collections.
+  """Adds a variable to the MODEL_VARIABLES collection.
 
+    Optionally it will add the variable to  the VARIABLES_TO_RESTORE collection.
   Args:
     var: a variable.
     restore: whether the variable should be added to the
       VARIABLES_TO_RESTORE collection.
+
   """
-  given_name = get_variable_given_name(var)
-  for collection in default_collections(given_name, restore):
+  collections = [MODEL_VARIABLES]
+  if restore:
+    collections.append(VARIABLES_TO_RESTORE)
+  for collection in collections:
     if var not in tf.get_collection(collection):
       tf.add_to_collection(collection, var)
 
 
-def get_variables(prefix=None, suffix=None):
-  """Gets the list of variables, filtered by prefix and/or suffix.
+def get_variables(scope=None, suffix=None):
+  """Gets the list of variables, filtered by scope and/or suffix.
 
   Args:
-    prefix: an optional prefix for filtering the variables to return.
+    scope: an optional scope for filtering the variables to return.
     suffix: an optional suffix for filtering the variables to return.
 
   Returns:
-    a list of variables with prefix and suffix.
+    a copied list of variables with scope and suffix.
   """
-  candidates = tf.get_collection(VARIABLES_COLLECTION, prefix)
+  candidates = tf.get_collection(MODEL_VARIABLES, scope)[:]
   if suffix is not None:
     candidates = [var for var in candidates if var.op.name.endswith(suffix)]
   return candidates
 
 
-def get_variables_by_name(given_name, prefix=None):
-  """Gets the list of variables were given that name.
+def get_variables_to_restore():
+  """Gets the list of variables to restore.
+
+  Returns:
+    a copied list of variables.
+  """
+  return tf.get_collection(VARIABLES_TO_RESTORE)[:]
+
+
+def get_variables_by_name(given_name, scope=None):
+  """Gets the list of variables that were given that name.
 
   Args:
     given_name: name given to the variable without scope.
-    prefix: an optional prefix for filtering the variables to return.
+    scope: an optional scope for filtering the variables to return.
 
   Returns:
-    a list of variables with prefix and suffix.
+    a copied list of variables with the given name and prefix.
   """
-  return tf.get_collection(VARIABLES_COLLECTION + given_name, prefix)
+  return get_variables(scope=scope, suffix=given_name)
 
 
 def get_unique_variable(name):
@@ -204,7 +191,7 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
     collections: A list of collection names to which the Variable will be added.
       Note that the variable is always also added to the tf.GraphKeys.VARIABLES
-      collection.
+      and MODEL_VARIABLES collections.
     device: Optional device to place the variable. It can be an string or a
       function that is called to get the device for the variable.
     restore: whether the variable should be added to the
@@ -216,8 +203,15 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
   # Instantiate the device for this variable if it is passed as a function.
   if device and callable(device):
     device = device()
-  collections = set(list(collections or []) + default_collections(name,
-                                                                  restore))
+  collections = list(collections or [])
+
+  # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
+  collections += [tf.GraphKeys.VARIABLES, MODEL_VARIABLES]
+  # Add to VARIABLES_TO_RESTORE if necessary
+  if restore:
+    collections.append(VARIABLES_TO_RESTORE)
+  # Remove duplicates
+  collections = set(collections)
   with tf.device(device):
     return tf.get_variable(name, shape=shape, dtype=dtype,
                            initializer=initializer, regularizer=regularizer,

+ 58 - 28
inception/inception/slim/variables_test.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 from inception.slim import scopes
@@ -33,29 +32,13 @@ class VariablesTest(tf.test.TestCase):
         self.assertEquals(a.op.name, 'A/a')
         self.assertListEqual(a.get_shape().as_list(), [5])
 
-  def testGetVariableGivenName(self):
-    with self.test_session():
-      with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
-      with tf.variable_scope('B'):
-        b = variables.variable('a', [5])
-      self.assertEquals('a', variables.get_variable_given_name(a))
-      self.assertEquals('a', variables.get_variable_given_name(b))
-
-  def testGetVariableGivenNameScoped(self):
-    with self.test_session():
-      with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
-        b = variables.variable('b', [5])
-        self.assertEquals([a], variables.get_variables_by_name('a'))
-        self.assertEquals([b], variables.get_variables_by_name('b'))
-
   def testGetVariables(self):
     with self.test_session():
       with tf.variable_scope('A'):
         a = variables.variable('a', [5])
       with tf.variable_scope('B'):
         b = variables.variable('a', [5])
+      self.assertEquals([a, b], variables.get_variables())
       self.assertEquals([a], variables.get_variables('A'))
       self.assertEquals([b], variables.get_variables('B'))
 
@@ -103,19 +86,28 @@ class VariablesTest(tf.test.TestCase):
       with tf.variable_scope('A'):
         a = variables.variable('a', [5])
       with tf.variable_scope('B'):
-        b = variables.variable('b', [5])
-      self.assertListEqual([a, b],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+        b = variables.variable('a', [5])
+      self.assertEquals([a, b], variables.get_variables_to_restore())
 
-  def testGetVariablesToRestorePartial(self):
+  def testNoneGetVariablesToRestore(self):
     with self.test_session():
       with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
+        a = variables.variable('a', [5], restore=False)
       with tf.variable_scope('B'):
+        b = variables.variable('a', [5], restore=False)
+      self.assertEquals([], variables.get_variables_to_restore())
+      self.assertEquals([a, b], variables.get_variables())
+
+  def testGetMixedVariablesToRestore(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
         b = variables.variable('b', [5], restore=False)
-      self.assertListEqual([a, b], variables.get_variables())
-      self.assertListEqual([a],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+      with tf.variable_scope('B'):
+        c = variables.variable('c', [5])
+        d = variables.variable('d', [5], restore=False)
+      self.assertEquals([a, b, c, d], variables.get_variables())
+      self.assertEquals([a, c], variables.get_variables_to_restore())
 
   def testReuseVariable(self):
     with self.test_session():
@@ -190,11 +182,49 @@ class VariablesTest(tf.test.TestCase):
                               collections=['A', 'B']):
           b = variables.variable('b', [])
         c = variables.variable('c', [])
-      self.assertListEqual([a, b, c],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+      self.assertListEqual([a, b, c], variables.get_variables_to_restore())
       self.assertListEqual([a, c], tf.trainable_variables())
       self.assertListEqual([b], tf.get_collection('A'))
       self.assertListEqual([b], tf.get_collection('B'))
 
+
+class GetVariablesByNameTest(tf.test.TestCase):
+
+  def testGetVariableGivenNameScoped(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
+        b = variables.variable('b', [5])
+        self.assertEquals([a], variables.get_variables_by_name('a'))
+        self.assertEquals([b], variables.get_variables_by_name('b'))
+
+  def testGetVariablesByNameReturnsByValueWithScope(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
+        matched_variables = variables.get_variables_by_name('a')
+
+        # If variables.get_variables_by_name returns the list by reference, the
+        # following append should persist, and be returned, in subsequent calls
+        # to variables.get_variables_by_name('a').
+        matched_variables.append(4)
+
+        matched_variables = variables.get_variables_by_name('a')
+        self.assertEquals([a], matched_variables)
+
+  def testGetVariablesByNameReturnsByValueWithoutScope(self):
+    with self.test_session():
+      a = variables.variable('a', [5])
+      matched_variables = variables.get_variables_by_name('a')
+
+      # If variables.get_variables_by_name returns the list by reference, the
+      # following append should persist, and be returned, in subsequent calls
+      # to variables.get_variables_by_name('a').
+      matched_variables.append(4)
+
+      matched_variables = variables.get_variables_by_name('a')
+      self.assertEquals([a], matched_variables)
+
+
 if __name__ == '__main__':
   tf.test.main()