Переглянути джерело

Merge pull request #40 from jmchen-g/master

update inception slim.
Jianmin Chen 9 роки тому
батько
коміт
c3d18895ef

+ 9 - 0
inception/inception/slim/BUILD

@@ -101,3 +101,12 @@ py_library(
         ":variables",
         ":variables",
     ],
     ],
 )
 )
+
+py_test(
+    name = "collections_test",
+    size = "small",
+    srcs = ["collections_test.py"],
+    deps = [
+        ":slim",
+    ],
+)

+ 126 - 142
inception/inception/slim/README.md

@@ -1,21 +1,20 @@
 # TensorFlow-Slim
 # TensorFlow-Slim
 
 
-TF-Slim is a lightweight library for defining, training and evaluating models
-in TensorFlow. It enables defining complex networks quickly and concisely while
+TF-Slim is a lightweight library for defining, training and evaluating models in
+TensorFlow. It enables defining complex networks quickly and concisely while
 keeping a model's architecture transparent and its hyperparameters explicit.
 keeping a model's architecture transparent and its hyperparameters explicit.
 
 
-
 [TOC]
 [TOC]
 
 
 ## Teaser
 ## Teaser
 
 
-As a demonstration of the simplicity of using TF-Slim, compare the simplicity
-of the code necessary for defining the entire
-[VGG](http://www.robots.ox.ac.uk/~vgg/research/very_deep/) network using TF-Slim
-to the lengthy and verbose nature of defining just the first three layers (out
-of 16) using native tensorflow:
+As a demonstration of the simplicity of using TF-Slim, compare the simplicity of
+the code necessary for defining the entire [VGG]
+(http://www.robots.ox.ac.uk/~vgg/research/very_deep/) network using TF-Slim to
+the lengthy and verbose nature of defining just the first three layers (out of
+16) using native tensorflow:
 
 
-```python{.good}
+```python {.good}
 # VGG16 in TF-Slim.
 # VGG16 in TF-Slim.
 def vgg16(inputs):
 def vgg16(inputs):
   with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
   with slim.arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005):
@@ -38,7 +37,7 @@ def vgg16(inputs):
   return net
   return net
 ```
 ```
 
 
-```python{.bad}
+```python {.bad}
 # Layers 1-3 (out of 16) of VGG16 in native tensorflow.
 # Layers 1-3 (out of 16) of VGG16 in native tensorflow.
 def vgg16(inputs):
 def vgg16(inputs):
   with tf.name_scope('conv1_1') as scope:
   with tf.name_scope('conv1_1') as scope:
@@ -61,47 +60,42 @@ def vgg16(inputs):
 
 
 TF-Slim offers several advantages over just the built-in tensorflow libraries:
 TF-Slim offers several advantages over just the built-in tensorflow libraries:
 
 
-* Allows one to define models much more compactly by eliminating
-boilerplate code. This is accomplished through the use of
-[argument scoping](scopes.py)
-and numerous high level
-[operations](ops.py).
-These tools increase readability and maintainability, reduce the likelihood
-of an error from copy-and-pasting hyperparameter values and simplifies
-hyperparameter tuning.
-* Makes developing models simple by providing commonly used
-[loss functions](losses.py)
-* Provides a concise
-[definition](inception.py)
-of [Inception v3](http://arxiv.org/abs/1512.00567) network architecture
-ready to be used out-of-the-box or subsumed into new models.
+*   Allows one to define models much more compactly by eliminating boilerplate
+    code. This is accomplished through the use of [argument scoping](scopes.py)
+    and numerous high level [operations](ops.py). These tools increase
+    readability and maintainability, reduce the likelihood of an error from
+    copy-and-pasting hyperparameter values and simplifies hyperparameter tuning.
+*   Makes developing models simple by providing commonly used [loss functions]
+    (losses.py)
+*   Provides a concise [definition](inception.py) of [Inception v3]
+    (http://arxiv.org/abs/1512.00567) network architecture ready to be used
+    out-of-the-box or subsumed into new models.
 
 
 Additionally TF-Slim was designed with several principles in mind:
 Additionally TF-Slim was designed with several principles in mind:
 
 
-* The various modules of TF-Slim (scopes, variables, ops, losses) are
-independent. This flexibility allows users to pick and choose
-components of TF-Slim completely à la carte.
-* TF-Slim is written using a Functional Programming style. That means it's
-super-lightweight and can be used right alongside any of TensorFlow's native
-operations.
-* Makes re-using network architectures easy. This allows users to build new
-networks on top of existing ones as well as fine-tuning pre-trained models on
-new tasks.
+*   The various modules of TF-Slim (scopes, variables, ops, losses) are
+    independent. This flexibility allows users to pick and choose components of
+    TF-Slim completely à la carte.
+*   TF-Slim is written using a Functional Programming style. That means it's
+    super-lightweight and can be used right alongside any of TensorFlow's native
+    operations.
+*   Makes re-using network architectures easy. This allows users to build new
+    networks on top of existing ones as well as fine-tuning pre-trained models
+    on new tasks.
 
 
 ## What are the various components of TF-Slim?
 ## What are the various components of TF-Slim?
 
 
 TF-Slim is composed of several parts which were designed to exist independently.
 TF-Slim is composed of several parts which were designed to exist independently.
 These include:
 These include:
 
 
-* [scopes.py](./scopes.py):
-provides a new scope named `arg_scope` that allows a user to define default
-arguments for specific operations within that scope.
-* [variables.py](./variables.py):
-provides convenience wrappers for variable creation and manipulation.
-* [ops.py](./ops.py):
-provides high level operations for building models using tensorflow.
-* [losses.py](./losses.py):
-contains commonly used loss functions.
+*   [scopes.py](./scopes.py): provides a new scope named `arg_scope` that allows
+    a user to define default arguments for specific operations within that
+    scope.
+*   [variables.py](./variables.py): provides convenience wrappers for variable
+    creation and manipulation.
+*   [ops.py](./ops.py): provides high level operations for building models using
+    tensorflow.
+*   [losses.py](./losses.py): contains commonly used loss functions.
 
 
 ## Defining Models
 ## Defining Models
 
 
@@ -110,16 +104,14 @@ operations and scopes. Each of these elements are defined below.
 
 
 ### Variables
 ### Variables
 
 
-Creating
-[`Variables`](https://www.tensorflow.org/how_tos/variables/index.html)
+Creating [`Variables`](https://www.tensorflow.org/how_tos/variables/index.html)
 in native tensorflow requires either a predefined value or an initialization
 in native tensorflow requires either a predefined value or an initialization
-mechanism
-(random, normally distributed). Furthermore, if a variable needs to be created
-on a specific device, such as a GPU, the specification must be
-[made explicit](https://www.tensorflow.org/how_tos/using_gpu/index.html).
-To alleviate the code required for variable creation, TF-Slim provides a set
-of thin wrapper functions in [variables.py](./variables.py)
-which allow callers to easily define variables.
+mechanism (random, normally distributed). Furthermore, if a variable needs to be
+created on a specific device, such as a GPU, the specification must be [made
+explicit](https://www.tensorflow.org/how_tos/using_gpu/index.html). To alleviate
+the code required for variable creation, TF-Slim provides a set of thin wrapper
+functions in [variables.py](./variables.py) which allow callers to easily define
+variables.
 
 
 For example, to create a `weight` variable, initialize it using a truncated
 For example, to create a `weight` variable, initialize it using a truncated
 normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
 normal distribution, regularize it with an `l2_loss` and place it on the `CPU`,
@@ -159,21 +151,20 @@ weights = variables.variable('weights',
 
 
 ### Operations (Layers)
 ### Operations (Layers)
 
 
-While the set of TensorFlow operations is quite extensive, builders of
-neural networks typically think of models in terms of "layers". A layer,
-such as a Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer
-are more abstract than a single TensorFlow operation and typically involve
-many such operations. For example, a Convolutional Layer in a neural network
-is built using several steps:
+While the set of TensorFlow operations is quite extensive, builders of neural
+networks typically think of models in terms of "layers". A layer, such as a
+Convolutional Layer, a Fully Connected Layer or a BatchNorm Layer are more
+abstract than a single TensorFlow operation and typically involve many such
+operations. For example, a Convolutional Layer in a neural network is built
+using several steps:
 
 
-1. Creating the weight variables
-2. Creating the bias variables
-3. Convolving the weights with the input from the previous layer
-4. Adding the biases to the result of the convolution.
+1.  Creating the weight variables
+2.  Creating the bias variables
+3.  Convolving the weights with the input from the previous layer
+4.  Adding the biases to the result of the convolution.
 
 
 In python code this can be rather laborious:
 In python code this can be rather laborious:
 
 
-
 ```python
 ```python
 input = ...
 input = ...
 with tf.name_scope('conv1_1') as scope:
 with tf.name_scope('conv1_1') as scope:
@@ -187,9 +178,9 @@ with tf.name_scope('conv1_1') as scope:
 ```
 ```
 
 
 To alleviate the need to duplicate this code repeatedly, TF-Slim provides a
 To alleviate the need to duplicate this code repeatedly, TF-Slim provides a
-number of convenient operations defined at the (more abstract) level of
-neural network layers. For example, compare the code above to an invocation
-of the TF-Slim code:
+number of convenient operations defined at the (more abstract) level of neural
+network layers. For example, compare the code above to an invocation of the
+TF-Slim code:
 
 
 ```python
 ```python
 input = ...
 input = ...
@@ -199,22 +190,21 @@ net = slim.ops.conv2d(input, [3, 3], 128, scope='conv1_1')
 TF-Slim provides numerous operations used in building neural networks which
 TF-Slim provides numerous operations used in building neural networks which
 roughly correspond to such layers. These include:
 roughly correspond to such layers. These include:
 
 
-Layer | TF-Slim Op
-------- | --------
-Convolutional Layer | [ops.conv2d](ops.py)
+Layer                 | TF-Slim Op
+--------------------- | ------------------------
+Convolutional Layer   | [ops.conv2d](ops.py)
 Fully Connected Layer | [ops.fc](ops.py)
 Fully Connected Layer | [ops.fc](ops.py)
-BatchNorm layer | [ops.batch_norm](ops.py)
-Max Pooling Layer | [ops.max_pool](ops.py)
-Avg Pooling Layer | [ops.avg_pool](ops.py)
-Dropout Layer | [ops.dropout](ops.py)
+BatchNorm layer       | [ops.batch_norm](ops.py)
+Max Pooling Layer     | [ops.max_pool](ops.py)
+Avg Pooling Layer     | [ops.avg_pool](ops.py)
+Dropout Layer         | [ops.dropout](ops.py)
 
 
-[ops.py](./ops.py)
-also includes operations that are not really "layers" per se, but are
-often used to manipulate hidden unit representations during inference:
+[ops.py](./ops.py) also includes operations that are not really "layers" per se,
+but are often used to manipulate hidden unit representations during inference:
 
 
 Operation | TF-Slim Op
 Operation | TF-Slim Op
-------- | --------
-Flatten | [ops.flatten](ops.py)
+--------- | ---------------------
+Flatten   | [ops.flatten](ops.py)
 
 
 TF-Slim also provides a meta-operation called `repeat_op` that allows one to
 TF-Slim also provides a meta-operation called `repeat_op` that allows one to
 repeatedly perform the same operation. Consider the following snippet from the
 repeatedly perform the same operation. Consider the following snippet from the
@@ -238,36 +228,35 @@ for i in range(3):
 net = slim.ops.max_pool(net, [2, 2], scope='pool3')
 net = slim.ops.max_pool(net, [2, 2], scope='pool3')
 ```
 ```
 
 
-While this does reduce the amount of duplication, it can be made even cleaner
-by using the `RepeatOp`:
+While this does reduce the amount of duplication, it can be made even cleaner by
+using the `RepeatOp`:
 
 
 ```python
 ```python
 net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
 net = slim.ops.repeat_op(3, net, slim.ops.conv2d, 256, [3, 3], scope='conv3')
 net = slim.ops.max_pool(net, [2, 2], scope='pool2')
 net = slim.ops.max_pool(net, [2, 2], scope='pool2')
 ```
 ```
 
 
-Notice that the RepeatOp not only applies the same argument in-line, it also
-is smart enough to unroll the scopes such that the scopes assigned to each
+Notice that the RepeatOp not only applies the same argument in-line, it also is
+smart enough to unroll the scopes such that the scopes assigned to each
 subsequent call of `ops.conv2d` is appended with an underscore and iteration
 subsequent call of `ops.conv2d` is appended with an underscore and iteration
 number. More concretely, the scopes in the example above would be 'conv3_1',
 number. More concretely, the scopes in the example above would be 'conv3_1',
 'conv3_2' and 'conv3_3'.
 'conv3_2' and 'conv3_3'.
 
 
-
 ### Scopes
 ### Scopes
 
 
-In addition to the types of scope mechanisms in TensorFlow
-([name_scope](https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
+In addition to the types of scope mechanisms in TensorFlow ([name_scope]
+(https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
 [op_scope](https://www.tensorflow.org/api_docs/python/framework.html#op_scope),
 [op_scope](https://www.tensorflow.org/api_docs/python/framework.html#op_scope),
-[variable_scope](https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
-[variable_op_scope](https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
-TF-Slim adds a new scoping mechanism called "argument scope" or
-[arg_scope](scopes.py).
-This new scope allows a user to specify one or more operations and a set of
-arguments which will be passed to each of the operations defined in the
+[variable_scope]
+(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
+[variable_op_scope]
+(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
+TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope]
+(scopes.py). This new scope allows a user to specify one or more operations and
+a set of arguments which will be passed to each of the operations defined in the
 `arg_scope`. This functionality is best illustrated by example. Consider the
 `arg_scope`. This functionality is best illustrated by example. Consider the
 following code snippet:
 following code snippet:
 
 
-
 ```python
 ```python
 net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='SAME', stddev=0.01, weight_decay=0.0005, scope='conv1')
 net = slim.ops.conv2d(inputs, 64, [11, 11], 4, padding='SAME', stddev=0.01, weight_decay=0.0005, scope='conv1')
 net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', stddev=0.01, weight_decay=0.0005, scope='conv2')
 net = slim.ops.conv2d(net, 128, [11, 11], padding='VALID', stddev=0.01, weight_decay=0.0005, scope='conv2')
@@ -278,8 +267,8 @@ It should be clear that these three Convolution layers share many of the same
 hyperparameters. Two have the same padding, all three have the same weight_decay
 hyperparameters. Two have the same padding, all three have the same weight_decay
 and standard deviation of its weights. Not only do the duplicated values make
 and standard deviation of its weights. Not only do the duplicated values make
 the code more difficult to read, it also adds the addition burder to the writer
 the code more difficult to read, it also adds the addition burder to the writer
-of needing to doublecheck that all of the values are identical in each step.
-One solution would be to specify default values using variables:
+of needing to doublecheck that all of the values are identical in each step. One
+solution would be to specify default values using variables:
 
 
 ```python
 ```python
 padding='SAME'
 padding='SAME'
@@ -302,11 +291,11 @@ ensure that each layer uses the same values and simplify the code:
     net = slim.ops.conv2d(net, 256, [11, 11], scope='conv3')
     net = slim.ops.conv2d(net, 256, [11, 11], scope='conv3')
 ```
 ```
 
 
-As the example illustrates, the use of arg_scope makes the code cleaner,
-simpler and easier to maintain. Notice that while argument values are specifed
-in the arg_scope, they can be overwritten locally. In particular, while
-the padding argument has been set to 'SAME', the second convolution overrides
-it with the value of 'VALID'.
+As the example illustrates, the use of arg_scope makes the code cleaner, simpler
+and easier to maintain. Notice that while argument values are specifed in the
+arg_scope, they can be overwritten locally. In particular, while the padding
+argument has been set to 'SAME', the second convolution overrides it with the
+value of 'VALID'.
 
 
 One can also nest `arg_scope`s and use multiple operations in the same scope.
 One can also nest `arg_scope`s and use multiple operations in the same scope.
 For example:
 For example:
@@ -320,9 +309,9 @@ with arg_scope([slim.ops.conv2d, slim.ops.fc], stddev=0.01, weight_decay=0.0005)
     net = slim.ops.fc(net, 1000, activation=None, scope='fc')
     net = slim.ops.fc(net, 1000, activation=None, scope='fc')
 ```
 ```
 
 
-In this example, the first `arg_scope` applies the same `stddev` and `weight_decay`
-arguments to the `conv2d` and `fc` ops in its scope. In the second `arg_scope`,
-additional default arguments to `conv2d` only are specified.
+In this example, the first `arg_scope` applies the same `stddev` and
+`weight_decay` arguments to the `conv2d` and `fc` ops in its scope. In the
+second `arg_scope`, additional default arguments to `conv2d` only are specified.
 
 
 In addition to `arg_scope`, TF-Slim provides several decorators that wrap the
 In addition to `arg_scope`, TF-Slim provides several decorators that wrap the
 use of tensorflow arg scopes. These include `@AddArgScope`, `@AddNameScope`,
 use of tensorflow arg scopes. These include `@AddArgScope`, `@AddNameScope`,
@@ -349,8 +338,8 @@ with tf.variable_scope('layer1'):
   outputs = MyNewOp(inputs)
   outputs = MyNewOp(inputs)
 ```
 ```
 
 
-As an alternative, one can use TF-Slim's decorators to decorate the function
-and simplify the call:
+As an alternative, one can use TF-Slim's decorators to decorate the function and
+simplify the call:
 
 
 ```python
 ```python
 @AddVariableScope
 @AddVariableScope
@@ -364,8 +353,8 @@ outputs = MyNewOp('layer1')
 ```
 ```
 
 
 The `@AddVariableScope` decorater simply applies the `tf.variable_scope` scoping
 The `@AddVariableScope` decorater simply applies the `tf.variable_scope` scoping
-to the called function taking "layer1" as its argument. This allows the code
-to be written more concisely.
+to the called function taking "layer1" as its argument. This allows the code to
+be written more concisely.
 
 
 ### Losses
 ### Losses
 
 
@@ -375,19 +364,16 @@ classification problems, this is typically the cross entropy between the true
 classes. For regression problems, this is often the sum-of-squares differences
 classes. For regression problems, this is often the sum-of-squares differences
 between the predicted and true values.
 between the predicted and true values.
 
 
-Certain models, such as multi-task
-learning models, require the use of multiple loss functions simultaneously. In
-other words, the loss function ultimatey being minimized is the sum of various
-other loss functions. For example, consider a model that predicts both
-the type of scene in an image as well as the depth from the
-camera of each pixel. This model's loss function would be the sum of the
+Certain models, such as multi-task learning models, require the use of multiple
+loss functions simultaneously. In other words, the loss function ultimatey being
+minimized is the sum of various other loss functions. For example, consider a
+model that predicts both the type of scene in an image as well as the depth from
+the camera of each pixel. This model's loss function would be the sum of the
 classification loss and depth prediction loss.
 classification loss and depth prediction loss.
 
 
-TF-Slim provides an easy-to-use mechanism for defining and keeping track of
-loss functions via the
-[losses.py](./losses.py)
-module. Consider the simple case where we want to train the VGG network:
-
+TF-Slim provides an easy-to-use mechanism for defining and keeping track of loss
+functions via the [losses.py](./losses.py) module. Consider the simple case
+where we want to train the VGG network:
 
 
 ```python
 ```python
 # Load the images and labels.
 # Load the images and labels.
@@ -401,9 +387,8 @@ loss = losses.ClassificationLoss(predictions, labels)
 ```
 ```
 
 
 In this example, we start by creating the model (using TF-Slim's VGG
 In this example, we start by creating the model (using TF-Slim's VGG
-implementation), and add the standard classification loss. Now, lets turn
-to the case where we have a multi-task model that produces multiple outputs:
-
+implementation), and add the standard classification loss. Now, lets turn to the
+case where we have a multi-task model that produces multiple outputs:
 
 
 ```python
 ```python
 # Load the images and labels.
 # Load the images and labels.
@@ -424,16 +409,14 @@ total_loss2 = tf.get_collection(slim.losses.LOSSES_COLLECTION)
 In this example, we have two losses which we add by calling
 In this example, we have two losses which we add by calling
 `losses.ClassificationLoss` and `losses.SumOfSquaresLoss`. We can obtain the
 `losses.ClassificationLoss` and `losses.SumOfSquaresLoss`. We can obtain the
 total loss by adding them together (`total_loss1`) or by calling
 total loss by adding them together (`total_loss1`) or by calling
-`losses.GetTotalLoss()`. How did this work?
-When you create a loss function via TF-Slim, TF-Slim adds the loss to a
-special TensorFlow collection of loss functions. This enables you to either
-manage the total loss manually, or allow TF-Slim to manage them for you.
+`losses.GetTotalLoss()`. How did this work? When you create a loss function via
+TF-Slim, TF-Slim adds the loss to a special TensorFlow collection of loss
+functions. This enables you to either manage the total loss manually, or allow
+TF-Slim to manage them for you.
 
 
 What if you want to let TF-Slim manage the losses for you but have a custom loss
 What if you want to let TF-Slim manage the losses for you but have a custom loss
-function?
-[losses.py](./losses.py)
-also has a function that adds this loss to TF-Slims collection. For example:
-
+function? [losses.py](./losses.py) also has a function that adds this loss to
+TF-Slims collection. For example:
 
 
 ```python
 ```python
 # Load the images and labels.
 # Load the images and labels.
@@ -452,15 +435,15 @@ tf.add_to_collection(slim.losses.LOSSES_COLLECTION, pose_loss) # Letting TF-Slim
 total_loss1 = classification_loss + sum_of_squares_loss + pose_loss
 total_loss1 = classification_loss + sum_of_squares_loss + pose_loss
 total_loss2 = losses.GetTotalLoss()
 total_loss2 = losses.GetTotalLoss()
 ```
 ```
-In this example, we can again either produce the total loss function manually
-or let TF-Slim know about the additional loss and let TF-Slim handle the losses.
 
 
+In this example, we can again either produce the total loss function manually or
+let TF-Slim know about the additional loss and let TF-Slim handle the losses.
 
 
 ## Putting the Pieces Together
 ## Putting the Pieces Together
 
 
 By combining TF-Slim Variables, Operations and scopes, we can write a normally
 By combining TF-Slim Variables, Operations and scopes, we can write a normally
-very complex network with very few lines of code. For example, the entire
-[VGG](https://www.robots.ox.ac.uk/~vgg/research/very_deep/) architecture can be
+very complex network with very few lines of code. For example, the entire [VGG]
+(https://www.robots.ox.ac.uk/~vgg/research/very_deep/) architecture can be
 defined with just the following snippet:
 defined with just the following snippet:
 
 
 ```python
 ```python
@@ -490,8 +473,8 @@ return net
 
 
 After a model has been trained, it can be restored using `tf.train.Saver()`
 After a model has been trained, it can be restored using `tf.train.Saver()`
 which restores `Variables` from a given checkpoint. For many cases,
 which restores `Variables` from a given checkpoint. For many cases,
-`tf.train.Saver()` provides a simple mechanism to restore all or just a
-few variables.
+`tf.train.Saver()` provides a simple mechanism to restore all or just a few
+variables.
 
 
 ```python
 ```python
 # Create some variables.
 # Create some variables.
@@ -514,19 +497,21 @@ with tf.Session() as sess:
   ...
   ...
 ```
 ```
 
 
-See [Restoring Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#restoring-variables)
-and
-[Choosing which Variables to Save and Restore](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#choosing-which-variables-to-save-and-restore)
-sections of the [Variables](https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html)
-page for more details.
+See [Restoring Variables]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#restoring-variables)
+and [Choosing which Variables to Save and Restore]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html#choosing-which-variables-to-save-and-restore)
+sections of the [Variables]
+(https://www.tensorflow.org/versions/r0.7/how_tos/variables/index.html) page for
+more details.
 
 
 ### Using slim.variables to Track which Variables need to be Restored
 ### Using slim.variables to Track which Variables need to be Restored
 
 
 It is often desirable to fine-tune a pre-trained model on an entirely new
 It is often desirable to fine-tune a pre-trained model on an entirely new
 dataset or even a new task. In these situations, one must specify which layers
 dataset or even a new task. In these situations, one must specify which layers
-of the model should be reused (and consequently loaded from a checkpoint)
-and which layers are new. Indicating which variables or layers should be
-restored is a process that quickly becomes cumbersome when done manually.
+of the model should be reused (and consequently loaded from a checkpoint) and
+which layers are new. Indicating which variables or layers should be restored is
+a process that quickly becomes cumbersome when done manually.
 
 
 To help keep track of which variables to restore, `slim.variables` provides a
 To help keep track of which variables to restore, `slim.variables` provides a
 `restore` argument when creating each Variable. By default, all variables are
 `restore` argument when creating each Variable. By default, all variables are
@@ -554,7 +539,6 @@ Additionally, every layer in `slim.ops` that creates slim.variables (such as
 argument which controls whether the variables created by that layer should be
 argument which controls whether the variables created by that layer should be
 restored or not.
 restored or not.
 
 
-
 ```python
 ```python
 # Create a small network.
 # Create a small network.
 net = slim.ops.conv2d(images, 32, [7, 7], stride=2, scope='conv1')
 net = slim.ops.conv2d(images, 32, [7, 7], stride=2, scope='conv1')

+ 31 - 4
inception/inception/slim/inception_model.py

@@ -43,7 +43,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from inception.slim import ops
 from inception.slim import ops
@@ -98,10 +97,10 @@ def inception_v3(inputs,
         # 73 x 73 x 64
         # 73 x 73 x 64
         end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
         end_points['conv3'] = ops.conv2d(end_points['pool1'], 80, [1, 1],
                                          scope='conv3')
                                          scope='conv3')
-        # 71 x 71 x 80.
+        # 73 x 73 x 80.
         end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
         end_points['conv4'] = ops.conv2d(end_points['conv3'], 192, [3, 3],
                                          scope='conv4')
                                          scope='conv4')
-        # 69 x 69 x 192.
+        # 71 x 71 x 192.
         end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
         end_points['pool2'] = ops.max_pool(end_points['conv4'], [3, 3],
                                            stride=2, scope='pool2')
                                            stride=2, scope='pool2')
         # 35 x 35 x 192.
         # 35 x 35 x 192.
@@ -260,7 +259,10 @@ def inception_v3(inputs,
           aux_logits = ops.fc(aux_logits, num_classes, activation=None,
           aux_logits = ops.fc(aux_logits, num_classes, activation=None,
                               stddev=0.001, restore=restore_logits)
                               stddev=0.001, restore=restore_logits)
           end_points['aux_logits'] = aux_logits
           end_points['aux_logits'] = aux_logits
-        # mixed_8: 17 x 17 x 1280.
+        # mixed_8: 8 x 8 x 1280.
+        # Note that the scope below is not changed to not void previous
+        # checkpoints.
+        # (TODO) Fix the scope when appropriate.
         with tf.variable_scope('mixed_17x17x1280a'):
         with tf.variable_scope('mixed_17x17x1280a'):
           with tf.variable_scope('branch3x3'):
           with tf.variable_scope('branch3x3'):
             branch3x3 = ops.conv2d(net, 192, [1, 1])
             branch3x3 = ops.conv2d(net, 192, [1, 1])
@@ -327,3 +329,28 @@ def inception_v3(inputs,
           end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
           end_points['predictions'] = tf.nn.softmax(logits, name='predictions')
       return logits, end_points
       return logits, end_points
 
 
+
+def inception_v3_parameters(weight_decay=0.00004, stddev=0.1,
+                            batch_norm_decay=0.9997, batch_norm_epsilon=0.001):
+  """Yields the scope with the default parameters for inception_v3.
+
+  Args:
+    weight_decay: the weight decay for weights variables.
+    stddev: standard deviation of the truncated guassian weight distribution.
+    batch_norm_decay: decay for the moving average of batch_norm momentums.
+    batch_norm_epsilon: small float added to variance to avoid dividing by zero.
+
+  Yields:
+    a arg_scope with the parameters needed for inception_v3.
+  """
+  # Set weight_decay for weights in Conv and FC layers.
+  with scopes.arg_scope([ops.conv2d, ops.fc],
+                        weight_decay=weight_decay):
+    # Set stddev, activation and parameters for batch_norm.
+    with scopes.arg_scope([ops.conv2d],
+                          stddev=stddev,
+                          activation=tf.nn.relu,
+                          batch_norm_params={
+                              'decay': batch_norm_decay,
+                              'epsilon': batch_norm_epsilon}) as arg_scope:
+      yield arg_scope

+ 16 - 1
inception/inception/slim/inception_test.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from inception.slim import inception_model as inception
 from inception.slim import inception_model as inception
@@ -55,6 +54,22 @@ class InceptionTest(tf.test.TestCase):
       self.assertListEqual(pre_pool.get_shape().as_list(),
       self.assertListEqual(pre_pool.get_shape().as_list(),
                            [batch_size, 8, 8, 2048])
                            [batch_size, 8, 8, 2048])
 
 
+  def testVariablesSetDevice(self):
+    batch_size = 5
+    height, width = 299, 299
+    num_classes = 1000
+    with self.test_session():
+      inputs = tf.random_uniform((batch_size, height, width, 3))
+      # Force all Variables to reside on the device.
+      with tf.variable_scope('on_cpu'), tf.device('/cpu:0'):
+        inception.inception_v3(inputs, num_classes)
+      with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
+        inception.inception_v3(inputs, num_classes)
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+        self.assertDeviceEqual(v.device, '/cpu:0')
+      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+        self.assertDeviceEqual(v.device, '/gpu:0')
+
   def testHalfSizeImages(self):
   def testHalfSizeImages(self):
     batch_size = 5
     batch_size = 5
     height, width = 150, 150
     height, width = 150, 150

+ 65 - 1
inception/inception/slim/losses.py

@@ -26,7 +26,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 # In order to gather all losses in a network, the user should use this
 # In order to gather all losses in a network, the user should use this
@@ -35,6 +34,71 @@ import tensorflow as tf
 LOSSES_COLLECTION = '_losses'
 LOSSES_COLLECTION = '_losses'
 
 
 
 
+def l1_regularizer(weight=1.0, scope=None):
+  """Define a L1 regularizer.
+
+  Args:
+    weight: scale the loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L1Regularizer'):
+      l1_weight = tf.convert_to_tensor(weight,
+                                       dtype=tensor.dtype.base_dtype,
+                                       name='weight')
+      return tf.mul(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+  return regularizer
+
+
+def l2_regularizer(weight=1.0, scope=None):
+  """Define a L2 regularizer.
+
+  Args:
+    weight: scale the loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L2Regularizer'):
+      l2_weight = tf.convert_to_tensor(weight,
+                                       dtype=tensor.dtype.base_dtype,
+                                       name='weight')
+      return tf.mul(l2_weight, tf.nn.l2_loss(tensor), name='value')
+  return regularizer
+
+
+def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
+  """Define a L1L2 regularizer.
+
+  Args:
+    weight_l1: scale the L1 loss by this factor.
+    weight_l2: scale the L2 loss by this factor.
+    scope: Optional scope for op_scope.
+
+  Returns:
+    a regularizer function.
+  """
+  def regularizer(tensor):
+    with tf.op_scope([tensor], scope, 'L1L2Regularizer'):
+      weight_l1_t = tf.convert_to_tensor(weight_l1,
+                                         dtype=tensor.dtype.base_dtype,
+                                         name='weight_l1')
+      weight_l2_t = tf.convert_to_tensor(weight_l2,
+                                         dtype=tensor.dtype.base_dtype,
+                                         name='weight_l2')
+      reg_l1 = tf.mul(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
+                      name='value_l1')
+      reg_l2 = tf.mul(weight_l2_t, tf.nn.l2_loss(tensor),
+                      name='value_l2')
+      return tf.add(reg_l1, reg_l2, name='value')
+  return regularizer
+
+
 def l1_loss(tensor, weight=1.0, scope=None):
 def l1_loss(tensor, weight=1.0, scope=None):
   """Define a L1Loss, useful for regularize, i.e. lasso.
   """Define a L1Loss, useful for regularize, i.e. lasso.
 
 

+ 89 - 1
inception/inception/slim/losses_test.py

@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from inception.slim import losses
 from inception.slim import losses
@@ -47,6 +46,95 @@ class LossesTest(tf.test.TestCase):
       self.assertAlmostEqual(loss.eval(), num_elem * wd / 2, 5)
       self.assertAlmostEqual(loss.eval(), num_elem * wd / 2, 5)
 
 
 
 
+class RegularizersTest(tf.test.TestCase):
+
+  def testL1Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L1Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+  def testL1RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_regularizer(scope='L1')(tensor)
+      self.assertEquals(loss.op.name, 'L1/value')
+      self.assertAlmostEqual(loss.eval(), num_elem, 5)
+
+  def testL1RegularizerWithWeight(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight = 0.01
+      loss = losses.l1_regularizer(weight)(tensor)
+      self.assertEquals(loss.op.name, 'L1Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem * weight, 5)
+
+  def testL2Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l2_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+  def testL2RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l2_regularizer(scope='L2')(tensor)
+      self.assertEquals(loss.op.name, 'L2/value')
+      self.assertAlmostEqual(loss.eval(), num_elem / 2, 5)
+
+  def testL2RegularizerWithWeight(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight = 0.01
+      loss = losses.l2_regularizer(weight)(tensor)
+      self.assertEquals(loss.op.name, 'L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem * weight / 2, 5)
+
+  def testL1L2Regularizer(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_l2_regularizer()(tensor)
+      self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+  def testL1L2RegularizerWithScope(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      loss = losses.l1_l2_regularizer(scope='L1L2')(tensor)
+      self.assertEquals(loss.op.name, 'L1L2/value')
+      self.assertAlmostEqual(loss.eval(), num_elem + num_elem / 2, 5)
+
+  def testL1L2RegularizerWithWeights(self):
+    with self.test_session():
+      shape = [5, 5, 5]
+      num_elem = 5 * 5 * 5
+      tensor = tf.constant(1.0, shape=shape)
+      weight_l1 = 0.01
+      weight_l2 = 0.05
+      loss = losses.l1_l2_regularizer(weight_l1, weight_l2)(tensor)
+      self.assertEquals(loss.op.name, 'L1L2Regularizer/value')
+      self.assertAlmostEqual(loss.eval(),
+                             num_elem * weight_l1 + num_elem * weight_l2 / 2, 5)
+
+
 class CrossEntropyLossTest(tf.test.TestCase):
 class CrossEntropyLossTest(tf.test.TestCase):
 
 
   def testCrossEntropyLossAllCorrect(self):
   def testCrossEntropyLossAllCorrect(self):

+ 81 - 32
inception/inception/slim/ops.py

@@ -27,7 +27,6 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from tensorflow.python.training import moving_averages
 from tensorflow.python.training import moving_averages
@@ -50,7 +49,8 @@ def batch_norm(inputs,
                is_training=True,
                is_training=True,
                trainable=True,
                trainable=True,
                restore=True,
                restore=True,
-               scope=None):
+               scope=None,
+               reuse=None):
   """Adds a Batch Normalization layer.
   """Adds a Batch Normalization layer.
 
 
   Args:
   Args:
@@ -67,13 +67,15 @@ def batch_norm(inputs,
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
     scope: Optional scope for variable_op_scope.
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
 
 
   Returns:
   Returns:
     a tensor representing the output of the operation.
     a tensor representing the output of the operation.
 
 
   """
   """
   inputs_shape = inputs.get_shape()
   inputs_shape = inputs.get_shape()
-  with tf.variable_op_scope([inputs], scope, 'BatchNorm'):
+  with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
     axis = range(len(inputs_shape) - 1)
     axis = range(len(inputs_shape) - 1)
     params_shape = inputs_shape[-1:]
     params_shape = inputs_shape[-1:]
     with scopes.arg_scope([variables.variable], restore=restore):
     with scopes.arg_scope([variables.variable], restore=restore):
@@ -124,6 +126,37 @@ def batch_norm(inputs,
     return outputs
     return outputs
 
 
 
 
+def _two_element_tuple(int_or_tuple):
+  """Converts `int_or_tuple` to height, width.
+
+  Several of the functions that follow accept arguments as either
+  a tuple of 2 integers or a single integer.  A single integer
+  indicates that the 2 values of the tuple are the same.
+
+  This functions normalizes the input value by always returning a tuple.
+
+  Args:
+    int_or_tuple: A list of 2 ints, a single int or a tf.TensorShape.
+
+  Returns:
+    A tuple with 2 values.
+
+  Raises:
+    ValueError: If `int_or_tuple` it not well formed.
+  """
+  if isinstance(int_or_tuple, (list, tuple)):
+    if len(int_or_tuple) != 2:
+      raise ValueError('Must be a list with 2 elements: %s' % int_or_tuple)
+    return int(int_or_tuple[0]), int(int_or_tuple[1])
+  if isinstance(int_or_tuple, int):
+    return int(int_or_tuple), int(int_or_tuple)
+  if isinstance(int_or_tuple, tf.TensorShape):
+    if len(int_or_tuple) == 2:
+      return int_or_tuple[0], int_or_tuple[1]
+  raise ValueError('Must be an int, a list with 2 elements or a TensorShape of '
+                   'length 2')
+
+
 @scopes.add_arg_scope
 @scopes.add_arg_scope
 def conv2d(inputs,
 def conv2d(inputs,
            num_filters_out,
            num_filters_out,
@@ -138,7 +171,8 @@ def conv2d(inputs,
            is_training=True,
            is_training=True,
            trainable=True,
            trainable=True,
            restore=True,
            restore=True,
-           scope=None):
+           scope=None,
+           reuse=None):
   """Adds a 2D convolution followed by an optional batch_norm layer.
   """Adds a 2D convolution followed by an optional batch_norm layer.
 
 
   conv2d creates a variable called 'weights', representing the convolutional
   conv2d creates a variable called 'weights', representing the convolutional
@@ -149,8 +183,11 @@ def conv2d(inputs,
   Args:
   Args:
     inputs: a tensor of size [batch_size, height, width, channels].
     inputs: a tensor of size [batch_size, height, width, channels].
     num_filters_out: the number of output filters.
     num_filters_out: the number of output filters.
-    kernel_size: a 2-D list comprising of the height and width of the filters.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of
+      of the filters. Can be an int if both values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: one of 'VALID' or 'SAME'.
     padding: one of 'VALID' or 'SAME'.
     activation: activation function.
     activation: activation function.
     stddev: standard deviation of the truncated guassian weight distribution.
     stddev: standard deviation of the truncated guassian weight distribution.
@@ -161,28 +198,29 @@ def conv2d(inputs,
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
     scope: Optional scope for variable_op_scope.
-
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
   Returns:
   Returns:
     a tensor representing the output of the operation.
     a tensor representing the output of the operation.
 
 
-  Raises:
-    ValueError: if 'kernel_size' is not a 2-D list.
   """
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
-  with tf.variable_op_scope([inputs], scope, 'Conv'):
+  with tf.variable_op_scope([inputs], scope, 'Conv', reuse=reuse):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     num_filters_in = inputs.get_shape()[-1]
     num_filters_in = inputs.get_shape()[-1]
-    weights_shape = [kernel_size[0], kernel_size[1],
+    weights_shape = [kernel_h, kernel_w,
                      num_filters_in, num_filters_out]
                      num_filters_in, num_filters_out]
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
-    l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
+    l2_regularizer = None
+    if weight_decay and weight_decay > 0:
+      l2_regularizer = losses.l2_regularizer(weight_decay)
     weights = variables.variable('weights',
     weights = variables.variable('weights',
                                  shape=weights_shape,
                                  shape=weights_shape,
                                  initializer=weights_initializer,
                                  initializer=weights_initializer,
                                  regularizer=l2_regularizer,
                                  regularizer=l2_regularizer,
                                  trainable=trainable,
                                  trainable=trainable,
                                  restore=restore)
                                  restore=restore)
-    conv = tf.nn.conv2d(inputs, weights, [1, stride, stride, 1],
+    conv = tf.nn.conv2d(inputs, weights, [1, stride_h, stride_w, 1],
                         padding=padding)
                         padding=padding)
     if batch_norm_params is not None:
     if batch_norm_params is not None:
       with scopes.arg_scope([batch_norm], is_training=is_training,
       with scopes.arg_scope([batch_norm], is_training=is_training,
@@ -213,7 +251,8 @@ def fc(inputs,
        is_training=True,
        is_training=True,
        trainable=True,
        trainable=True,
        restore=True,
        restore=True,
-       scope=None):
+       scope=None,
+       reuse=None):
   """Adds a fully connected layer followed by an optional batch_norm layer.
   """Adds a fully connected layer followed by an optional batch_norm layer.
 
 
   FC creates a variable called 'weights', representing the fully connected
   FC creates a variable called 'weights', representing the fully connected
@@ -234,15 +273,19 @@ def fc(inputs,
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
     scope: Optional scope for variable_op_scope.
     scope: Optional scope for variable_op_scope.
+    reuse: whether or not the layer and its variables should be reused. To be
+      able to reuse the layer scope must be given.
 
 
   Returns:
   Returns:
      the tensor variable representing the result of the series of operations.
      the tensor variable representing the result of the series of operations.
   """
   """
-  with tf.variable_op_scope([inputs], scope, 'FC'):
+  with tf.variable_op_scope([inputs], scope, 'FC', reuse=reuse):
     num_units_in = inputs.get_shape()[1]
     num_units_in = inputs.get_shape()[1]
     weights_shape = [num_units_in, num_units_out]
     weights_shape = [num_units_in, num_units_out]
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
-    l2_regularizer = lambda t: losses.l2_loss(t, weight_decay)
+    l2_regularizer = None
+    if weight_decay and weight_decay > 0:
+      l2_regularizer = losses.l2_regularizer(weight_decay)
     weights = variables.variable('weights',
     weights = variables.variable('weights',
                                  shape=weights_shape,
                                  shape=weights_shape,
                                  initializer=weights_initializer,
                                  initializer=weights_initializer,
@@ -298,8 +341,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
 
 
   Args:
   Args:
     inputs: a tensor of size [batch_size, height, width, depth].
     inputs: a tensor of size [batch_size, height, width, depth].
-    kernel_size: the size of the pooling kernel over which the op is computed.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+      pooling kernel over which the op is computed. Can be an int if both
+      values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     padding: the padding method, either 'VALID' or 'SAME'.
     scope: Optional scope for op_scope.
     scope: Optional scope for op_scope.
 
 
@@ -308,12 +355,12 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
   Raises:
   Raises:
     ValueError: if 'kernel_size' is not a 2-D list
     ValueError: if 'kernel_size' is not a 2-D list
   """
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
   with tf.op_scope([inputs], scope, 'MaxPool'):
   with tf.op_scope([inputs], scope, 'MaxPool'):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.max_pool(inputs,
     return tf.nn.max_pool(inputs,
-                          ksize=[1, kernel_size[0], kernel_size[1], 1],
-                          strides=[1, stride, stride, 1],
+                          ksize=[1, kernel_h, kernel_w, 1],
+                          strides=[1, stride_h, stride_w, 1],
                           padding=padding)
                           padding=padding)
 
 
 
 
@@ -326,22 +373,24 @@ def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
 
 
   Args:
   Args:
     inputs: a tensor of size [batch_size, height, width, depth].
     inputs: a tensor of size [batch_size, height, width, depth].
-    kernel_size: the size of the pooling kernel over which the op is computed.
-    stride: the stride in height and width of the convolution.
+    kernel_size: a list of length 2: [kernel_height, kernel_width] of the
+      pooling kernel over which the op is computed. Can be an int if both
+      values are the same.
+    stride: a list of length 2: [stride_height, stride_width].
+      Can be an int if both strides are the same.  Note that presently
+      both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     padding: the padding method, either 'VALID' or 'SAME'.
     scope: Optional scope for op_scope.
     scope: Optional scope for op_scope.
 
 
   Returns:
   Returns:
     a tensor representing the results of the pooling operation.
     a tensor representing the results of the pooling operation.
-  Raises:
-    ValueError: if 'kernel_size' is not a 2-D list
   """
   """
-  if len(kernel_size) != 2:
-    raise ValueError('kernel_size must be a 2-D list.')
   with tf.op_scope([inputs], scope, 'AvgPool'):
   with tf.op_scope([inputs], scope, 'AvgPool'):
+    kernel_h, kernel_w = _two_element_tuple(kernel_size)
+    stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.avg_pool(inputs,
     return tf.nn.avg_pool(inputs,
-                          ksize=[1, kernel_size[0], kernel_size[1], 1],
-                          strides=[1, stride, stride, 1],
+                          ksize=[1, kernel_h, kernel_w, 1],
+                          strides=[1, stride_h, stride_w, 1],
                           padding=padding)
                           padding=padding)
 
 
 
 

+ 173 - 33
inception/inception/slim/ops_test.py

@@ -18,13 +18,11 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
 
 
-
 import numpy as np
 import numpy as np
 import tensorflow as tf
 import tensorflow as tf
 
 
 from tensorflow.python.ops import control_flow_ops
 from tensorflow.python.ops import control_flow_ops
 
 
-from inception.slim import losses
 from inception.slim import ops
 from inception.slim import ops
 from inception.slim import scopes
 from inception.slim import scopes
 from inception.slim import variables
 from inception.slim import variables
@@ -40,6 +38,57 @@ class ConvTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'Conv/Relu')
       self.assertEquals(output.op.name, 'Conv/Relu')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
 
 
+  def testCreateSquareConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, 3)
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+  def testCreateConvWithTensorShape(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, images.get_shape()[1:3])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, height, width, 32])
+
+  def testCreateFullyConv(self):
+    height, width = 6, 6
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      output = ops.conv2d(images, 64, images.get_shape()[1:3], padding='VALID')
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 64])
+
+  def testCreateVerticalConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [3, 1])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height, width, 32])
+
+  def testCreateHorizontalConv(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [1, 3])
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height, width, 32])
+
+  def testCreateConvWithStride(self):
+    height, width = 6, 6
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.conv2d(images, 32, [3, 3], stride=2)
+      self.assertEquals(output.op.name, 'Conv/Relu')
+      self.assertListEqual(output.get_shape().as_list(),
+                           [5, height/2, width/2, 32])
+
   def testCreateConvCreatesWeightsAndBiasesVars(self):
   def testCreateConvCreatesWeightsAndBiasesVars(self):
     height, width = 3, 3
     height, width = 3, 3
     images = tf.random_uniform((5, height, width, 3), seed=1)
     images = tf.random_uniform((5, height, width, 3), seed=1)
@@ -76,31 +125,73 @@ class ConvTest(tf.test.TestCase):
     with self.test_session() as sess:
     with self.test_session() as sess:
       images = tf.random_uniform((5, height, width, 3), seed=1)
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01)
-      wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
-      self.assertEquals(wd.op.name, 'Conv/weights/Regularizer/L2Loss/value')
+      wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+      self.assertEquals(wd.op.name,
+                        'Conv/weights/Regularizer/L2Regularizer/value')
       sess.run(tf.initialize_all_variables())
       sess.run(tf.initialize_all_variables())
       self.assertTrue(sess.run(wd) <= 0.01)
       self.assertTrue(sess.run(wd) <= 0.01)
 
 
+  def testCreateConvWithoutWD(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3], weight_decay=0)
+      self.assertEquals(
+          tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
+  def testReuseVars(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3], scope='conv1')
+      self.assertEquals(len(variables.get_variables()), 2)
+      ops.conv2d(images, 32, [3, 3], scope='conv1', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+
+  def testNonReuseVars(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      ops.conv2d(images, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 2)
+      ops.conv2d(images, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 4)
+
   def testReuseConvWithWD(self):
   def testReuseConvWithWD(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
       images = tf.random_uniform((5, height, width, 3), seed=1)
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
       ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
-      tf.get_variable_scope().reuse_variables()
-      ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+      ops.conv2d(images, 32, [3, 3], weight_decay=0.01, scope='conv1',
+                 reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
 
 
   def testConvWithBatchNorm(self):
   def testConvWithBatchNorm(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
-      images = tf.random_uniform((5, height, width, 3), seed=1)
-      with scopes.arg_scope([ops.conv2d], batch_norm_params={}):
-        net = ops.conv2d(images, 32, [3, 3], scope='conv1')
-        net = ops.conv2d(net, 32, [3, 3], scope='conv2')
-      self.assertEquals(len(tf.get_collection('moving_vars')), 4)
-      self.assertEquals(len(variables.get_variables('conv1/BatchNorm')), 3)
-      self.assertEquals(len(variables.get_variables('conv2/BatchNorm')), 3)
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+        net = ops.conv2d(images, 32, [3, 3])
+        net = ops.conv2d(net, 32, [3, 3])
+      self.assertEquals(len(variables.get_variables()), 8)
+      self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 3)
+
+  def testReuseConvWithBatchNorm(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 32), seed=1)
+      with scopes.arg_scope([ops.conv2d], batch_norm_params={'decay': 0.9}):
+        net = ops.conv2d(images, 32, [3, 3], scope='Conv')
+        net = ops.conv2d(net, 32, [3, 3], scope='Conv', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 4)
+      self.assertEquals(len(variables.get_variables('Conv/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('Conv_1/BatchNorm')), 0)
 
 
 
 
 class FCTest(tf.test.TestCase):
 class FCTest(tf.test.TestCase):
@@ -136,8 +227,7 @@ class FCTest(tf.test.TestCase):
     with self.test_session():
     with self.test_session():
       ops.fc(inputs, 32, scope='fc1')
       ops.fc(inputs, 32, scope='fc1')
       self.assertEquals(len(variables.get_variables('fc1')), 2)
       self.assertEquals(len(variables.get_variables('fc1')), 2)
-      tf.get_variable_scope().reuse_variables()
-      ops.fc(inputs, 32, scope='fc1')
+      ops.fc(inputs, 32, scope='fc1', reuse=True)
       self.assertEquals(len(variables.get_variables('fc1')), 2)
       self.assertEquals(len(variables.get_variables('fc1')), 2)
 
 
   def testNonReuseVars(self):
   def testNonReuseVars(self):
@@ -161,31 +251,53 @@ class FCTest(tf.test.TestCase):
     with self.test_session() as sess:
     with self.test_session() as sess:
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       ops.fc(inputs, 32, weight_decay=0.01)
       ops.fc(inputs, 32, weight_decay=0.01)
-      wd = tf.get_collection(losses.LOSSES_COLLECTION)[0]
-      self.assertEquals(wd.op.name, 'FC/weights/Regularizer/L2Loss/value')
+      wd = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)[0]
+      self.assertEquals(wd.op.name,
+                        'FC/weights/Regularizer/L2Regularizer/value')
       sess.run(tf.initialize_all_variables())
       sess.run(tf.initialize_all_variables())
       self.assertTrue(sess.run(wd) <= 0.01)
       self.assertTrue(sess.run(wd) <= 0.01)
 
 
+  def testCreateFCWithoutWD(self):
+    height, width = 3, 3
+    with self.test_session():
+      inputs = tf.random_uniform((5, height * width * 3), seed=1)
+      ops.fc(inputs, 32, weight_decay=0)
+      self.assertEquals(
+          tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES), [])
+
   def testReuseFCWithWD(self):
   def testReuseFCWithWD(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       inputs = tf.random_uniform((5, height * width * 3), seed=1)
       ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
       ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
-      tf.get_variable_scope().reuse_variables()
-      ops.fc(inputs, 32, weight_decay=0.01, scope='fc')
-      self.assertEquals(len(tf.get_collection(losses.LOSSES_COLLECTION)), 1)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
+      ops.fc(inputs, 32, weight_decay=0.01, scope='fc', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 2)
+      self.assertEquals(
+          len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)), 1)
 
 
   def testFCWithBatchNorm(self):
   def testFCWithBatchNorm(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
       images = tf.random_uniform((5, height * width * 3), seed=1)
       images = tf.random_uniform((5, height * width * 3), seed=1)
       with scopes.arg_scope([ops.fc], batch_norm_params={}):
       with scopes.arg_scope([ops.fc], batch_norm_params={}):
-        net = ops.fc(images, 32, scope='fc1')
-        net = ops.fc(net, 32, scope='fc2')
-      self.assertEquals(len(tf.get_collection('moving_vars')), 4)
+        net = ops.fc(images, 27)
+        net = ops.fc(net, 27)
+      self.assertEquals(len(variables.get_variables()), 8)
+      self.assertEquals(len(variables.get_variables('FC/BatchNorm')), 3)
+      self.assertEquals(len(variables.get_variables('FC_1/BatchNorm')), 3)
+
+  def testReuseFCWithBatchNorm(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height * width * 3), seed=1)
+      with scopes.arg_scope([ops.fc], batch_norm_params={'decay': 0.9}):
+        net = ops.fc(images, 27, scope='fc1')
+        net = ops.fc(net, 27, scope='fc1', reuse=True)
+      self.assertEquals(len(variables.get_variables()), 4)
       self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
       self.assertEquals(len(variables.get_variables('fc1/BatchNorm')), 3)
-      self.assertEquals(len(variables.get_variables('fc2/BatchNorm')), 3)
 
 
 
 
 class MaxPoolTest(tf.test.TestCase):
 class MaxPoolTest(tf.test.TestCase):
@@ -198,6 +310,14 @@ class MaxPoolTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'MaxPool/MaxPool')
       self.assertEquals(output.op.name, 'MaxPool/MaxPool')
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
 
 
+  def testCreateSquareMaxPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.max_pool(images, 3)
+      self.assertEquals(output.op.name, 'MaxPool/MaxPool')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
   def testCreateMaxPoolWithScope(self):
   def testCreateMaxPoolWithScope(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
@@ -219,6 +339,13 @@ class MaxPoolTest(tf.test.TestCase):
       output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
       output = ops.max_pool(images, [3, 3], stride=1, padding='SAME')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
 
 
+  def testGlobalMaxPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.max_pool(images, images.get_shape()[1:3], stride=1)
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
 
 
 class AvgPoolTest(tf.test.TestCase):
 class AvgPoolTest(tf.test.TestCase):
 
 
@@ -230,6 +357,14 @@ class AvgPoolTest(tf.test.TestCase):
       self.assertEquals(output.op.name, 'AvgPool/AvgPool')
       self.assertEquals(output.op.name, 'AvgPool/AvgPool')
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
       self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
 
 
+  def testCreateSquareAvgPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.avg_pool(images, 3)
+      self.assertEquals(output.op.name, 'AvgPool/AvgPool')
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
   def testCreateAvgPoolWithScope(self):
   def testCreateAvgPoolWithScope(self):
     height, width = 3, 3
     height, width = 3, 3
     with self.test_session():
     with self.test_session():
@@ -251,6 +386,13 @@ class AvgPoolTest(tf.test.TestCase):
       output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
       output = ops.avg_pool(images, [3, 3], stride=1, padding='SAME')
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
       self.assertListEqual(output.get_shape().as_list(), [5, height, width, 3])
 
 
+  def testGlobalAvgPool(self):
+    height, width = 3, 3
+    with self.test_session():
+      images = tf.random_uniform((5, height, width, 3), seed=1)
+      output = ops.avg_pool(images, images.get_shape()[1:3], stride=1)
+      self.assertListEqual(output.get_shape().as_list(), [5, 1, 1, 3])
+
 
 
 class OneHotEncodingTest(tf.test.TestCase):
 class OneHotEncodingTest(tf.test.TestCase):
 
 
@@ -342,8 +484,8 @@ class BatchNormTest(tf.test.TestCase):
       gamma = variables.get_variables_by_name('gamma')[0]
       gamma = variables.get_variables_by_name('gamma')[0]
       self.assertEquals(beta.op.name, 'BatchNorm/beta')
       self.assertEquals(beta.op.name, 'BatchNorm/beta')
       self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
       self.assertEquals(gamma.op.name, 'BatchNorm/gamma')
-      moving_mean = tf.get_collection('moving_vars')[0]
-      moving_variance = tf.get_collection('moving_vars')[1]
+      moving_mean = tf.moving_average_variables()[0]
+      moving_variance = tf.moving_average_variables()[1]
       self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
       self.assertEquals(moving_mean.op.name, 'BatchNorm/moving_mean')
       self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
       self.assertEquals(moving_variance.op.name, 'BatchNorm/moving_variance')
 
 
@@ -375,8 +517,7 @@ class BatchNormTest(tf.test.TestCase):
     with self.test_session():
     with self.test_session():
       images = tf.random_uniform((5, height, width, 3), seed=1)
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.batch_norm(images, scale=True, scope='bn')
       ops.batch_norm(images, scale=True, scope='bn')
-      tf.get_variable_scope().reuse_variables()
-      ops.batch_norm(images, scale=True, scope='bn')
+      ops.batch_norm(images, scale=True, scope='bn', reuse=True)
       beta = variables.get_variables_by_name('beta')
       beta = variables.get_variables_by_name('beta')
       gamma = variables.get_variables_by_name('gamma')
       gamma = variables.get_variables_by_name('gamma')
       self.assertEquals(len(beta), 1)
       self.assertEquals(len(beta), 1)
@@ -390,8 +531,7 @@ class BatchNormTest(tf.test.TestCase):
       images = tf.random_uniform((5, height, width, 3), seed=1)
       images = tf.random_uniform((5, height, width, 3), seed=1)
       ops.batch_norm(images, scope='bn')
       ops.batch_norm(images, scope='bn')
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 2)
-      tf.get_variable_scope().reuse_variables()
-      ops.batch_norm(images, scope='bn')
+      ops.batch_norm(images, scope='bn', reuse=True)
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
       self.assertEquals(len(tf.get_collection(ops.UPDATE_OPS_COLLECTION)), 4)
 
 
   def testCreateMovingVars(self):
   def testCreateMovingVars(self):

+ 49 - 23
inception/inception/slim/scopes.py

@@ -19,7 +19,7 @@
 
 
   Example of how to use scopes.arg_scope:
   Example of how to use scopes.arg_scope:
 
 
-  with slim.arg_scope(ops.conv2d, padding='SAME',
+  with scopes.arg_scope(ops.conv2d, padding='SAME',
                       stddev=0.01, weight_decay=0.0005):
                       stddev=0.01, weight_decay=0.0005):
     net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
     net = ops.conv2d(inputs, 64, [11, 11], 4, padding='VALID', scope='conv1')
     net = ops.conv2d(net, 256, [5, 5], scope='conv2')
     net = ops.conv2d(net, 256, [5, 5], scope='conv2')
@@ -32,6 +32,15 @@
     ops.conv2d(inputs, 256, [5, 5], padding='SAME',
     ops.conv2d(inputs, 256, [5, 5], padding='SAME',
                stddev=0.01, weight_decay=0.0005, scope='conv2')
                stddev=0.01, weight_decay=0.0005, scope='conv2')
 
 
+  Example of how to reuse an arg_scope:
+  with scopes.arg_scope(ops.conv2d, padding='SAME',
+                      stddev=0.01, weight_decay=0.0005) as conv2d_arg_scope:
+    net = ops.conv2d(net, 256, [5, 5], scope='conv1')
+    ....
+
+  with scopes.arg_scope(conv2d_arg_scope):
+    net = ops.conv2d(net, 256, [5, 5], scope='conv2')
+
   Example of how to use scopes.add_arg_scope:
   Example of how to use scopes.add_arg_scope:
 
 
   @scopes.add_arg_scope
   @scopes.add_arg_scope
@@ -44,7 +53,6 @@ from __future__ import print_function
 import contextlib
 import contextlib
 import functools
 import functools
 
 
-
 from tensorflow.python.framework import ops
 from tensorflow.python.framework import ops
 
 
 _ARGSTACK_KEY = ("__arg_stack",)
 _ARGSTACK_KEY = ("__arg_stack",)
@@ -74,12 +82,16 @@ def _add_op(op):
 
 
 
 
 @contextlib.contextmanager
 @contextlib.contextmanager
-def arg_scope(list_ops, **kwargs):
+def arg_scope(list_ops_or_scope, **kwargs):
   """Stores the default arguments for the given set of list_ops.
   """Stores the default arguments for the given set of list_ops.
 
 
+  For usage, please see examples at top of the file.
+
   Args:
   Args:
-    list_ops: List or tuple of operations to set argument scope for. Every op in
-              list_ops need to be decorated with @add_arg_scope to work.
+    list_ops_or_scope: List or tuple of operations to set argument scope for or
+      a dictionary containg the current scope. When list_ops_or_scope is a dict,
+      kwargs must be empty. When list_ops_or_scope is a list or tuple, then
+      every op in it need to be decorated with @add_arg_scope to work.
     **kwargs: keyword=value that will define the defaults for each op in
     **kwargs: keyword=value that will define the defaults for each op in
               list_ops. All the ops need to accept the given set of arguments.
               list_ops. All the ops need to accept the given set of arguments.
 
 
@@ -89,24 +101,38 @@ def arg_scope(list_ops, **kwargs):
     TypeError: if list_ops is not a list or a tuple.
     TypeError: if list_ops is not a list or a tuple.
     ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
     ValueError: if any op in list_ops has not be decorated with @add_arg_scope.
   """
   """
-  if not isinstance(list_ops, (list, tuple)):
-    raise TypeError("list_ops is not a list or a tuple")
-  try:
-    current_scope = _current_arg_scope().copy()
-    for op in list_ops:
-      key_op = (op.__module__, op.__name__)
-      if not has_arg_scope(op):
-        raise ValueError("%s is not decorated with @add_arg_scope", key_op)
-      if key_op in current_scope:
-        current_kwargs = current_scope[key_op].copy()
-        current_kwargs.update(kwargs)
-        current_scope[key_op] = current_kwargs
-      else:
-        current_scope[key_op] = kwargs.copy()
-    _get_arg_stack().append(current_scope)
-    yield current_scope
-  finally:
-    _get_arg_stack().pop()
+  if isinstance(list_ops_or_scope, dict):
+    # Assumes that list_ops_or_scope is a scope that is being reused.
+    if kwargs:
+      raise ValueError("When attempting to re-use a scope by suppling a"
+                       "dictionary, kwargs must be empty.")
+    current_scope = list_ops_or_scope.copy()
+    try:
+      _get_arg_stack().append(current_scope)
+      yield current_scope
+    finally:
+      _get_arg_stack().pop()
+  else:
+    # Assumes that list_ops_or_scope is a list/tuple of ops with kwargs.
+    if not isinstance(list_ops_or_scope, (list, tuple)):
+      raise TypeError("list_ops_or_scope must either be a list/tuple or reused"
+                      "scope (i.e. dict)")
+    try:
+      current_scope = _current_arg_scope().copy()
+      for op in list_ops_or_scope:
+        key_op = (op.__module__, op.__name__)
+        if not has_arg_scope(op):
+          raise ValueError("%s is not decorated with @add_arg_scope", key_op)
+        if key_op in current_scope:
+          current_kwargs = current_scope[key_op].copy()
+          current_kwargs.update(kwargs)
+          current_scope[key_op] = current_kwargs
+        else:
+          current_scope[key_op] = kwargs.copy()
+      _get_arg_stack().append(current_scope)
+      yield current_scope
+    finally:
+      _get_arg_stack().pop()
 
 
 
 
 def add_arg_scope(func):
 def add_arg_scope(func):

+ 45 - 1
inception/inception/slim/scopes_test.py

@@ -18,7 +18,6 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 from inception.slim import scopes
 from inception.slim import scopes
 
 
@@ -39,6 +38,51 @@ class ArgScopeTest(tf.test.TestCase):
     with self.test_session():
     with self.test_session():
       self.assertEqual(scopes._current_arg_scope(), {})
       self.assertEqual(scopes._current_arg_scope(), {})
 
 
+  def testCurrentArgScope(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    key_op = (func1.__module__, func1.__name__)
+    current_scope = {key_op: func1_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope:
+        self.assertDictEqual(scope, current_scope)
+
+  def testCurrentArgScopeNested(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    func2_kwargs = {'b': 2, 'd': [2]}
+    key = lambda f: (f.__module__, f.__name__)
+    current_scope = {key(func1): func1_kwargs.copy(),
+                     key(func2): func2_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]):
+        with scopes.arg_scope([func2], b=2, d=[2]) as scope:
+          self.assertDictEqual(scope, current_scope)
+
+  def testReuseArgScope(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    key_op = (func1.__module__, func1.__name__)
+    current_scope = {key_op: func1_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+        pass
+      with scopes.arg_scope(scope1) as scope:
+        self.assertDictEqual(scope, current_scope)
+
+  def testReuseArgScopeNested(self):
+    func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
+    func2_kwargs = {'b': 2, 'd': [2]}
+    key = lambda f: (f.__module__, f.__name__)
+    current_scope1 = {key(func1): func1_kwargs.copy()}
+    current_scope2 = {key(func1): func1_kwargs.copy(),
+                      key(func2): func2_kwargs.copy()}
+    with self.test_session():
+      with scopes.arg_scope([func1], a=1, b=None, c=[1]) as scope1:
+        with scopes.arg_scope([func2], b=2, d=[2]) as scope2:
+          pass
+      with scopes.arg_scope(scope1):
+        self.assertDictEqual(scopes._current_arg_scope(), current_scope1)
+      with scopes.arg_scope(scope2):
+        self.assertDictEqual(scopes._current_arg_scope(), current_scope2)
+
   def testSimpleArgScope(self):
   def testSimpleArgScope(self):
     func1_args = (0,)
     func1_args = (0,)
     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}
     func1_kwargs = {'a': 1, 'b': None, 'c': [1]}

+ 54 - 60
inception/inception/slim/variables.py

@@ -12,7 +12,15 @@
 # See the License for the specific language governing permissions and
 # See the License for the specific language governing permissions and
 # limitations under the License.
 # limitations under the License.
 # ==============================================================================
 # ==============================================================================
-"""Contains convenience wrappers for creating Variables in TensorFlow.
+"""Contains convenience wrappers for creating variables in TF-Slim.
+
+The variables module is typically used for defining model variables from the
+ops routines (see slim.ops). Such variables are used for training, evaluation
+and inference of models.
+
+All the variables created through this module would be added to the
+MODEL_VARIABLES collection, if you create a model variable outside slim, it can
+be added with slim.variables.add_variable(external_variable, reuse).
 
 
 Usage:
 Usage:
   weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
   weights_initializer = tf.truncated_normal_initializer(stddev=0.01)
@@ -24,15 +32,15 @@ Usage:
                                device='/cpu:0')
                                device='/cpu:0')
 
 
   biases = variables.variable('biases',
   biases = variables.variable('biases',
-                               shape=[100],
-                               initializer=tf.zeros_initializer,
-                               device='/cpu:0')
+                              shape=[100],
+                              initializer=tf.zeros_initializer,
+                              device='/cpu:0')
 
 
   # More complex example.
   # More complex example.
 
 
   net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
   net = slim.ops.conv2d(input, 32, [3, 3], scope='conv1')
   net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
   net = slim.ops.conv2d(net, 64, [3, 3], scope='conv2')
-  with slim.arg_scope(variables.Variables, restore=False):
+  with slim.arg_scope([variables.variable], restore=False):
     net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
     net = slim.ops.conv2d(net, 64, [3, 3], scope='conv3')
 
 
   # Get all model variables from all the layers.
   # Get all model variables from all the layers.
@@ -47,9 +55,9 @@ Usage:
   # Get all bias from all the layers.
   # Get all bias from all the layers.
   biases = slim.variables.get_variables_by_name('biases')
   biases = slim.variables.get_variables_by_name('biases')
 
 
-  # Get all variables in the VARIABLES_TO_RESTORE collection
+  # Get all variables to restore.
   # (i.e. only those created by 'conv1' and 'conv2')
   # (i.e. only those created by 'conv1' and 'conv2')
-  variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+  variables_to_restore = slim.variables.get_variables_to_restore()
 
 
 ************************************************
 ************************************************
 * Initializing model variables from a checkpoint
 * Initializing model variables from a checkpoint
@@ -60,7 +68,7 @@ v1 = slim.variables.variable(name="v1", ..., restore=False)
 v2 = slim.variables.variable(name="v2", ...) # By default restore=True
 v2 = slim.variables.variable(name="v2", ...) # By default restore=True
 ...
 ...
 # The list of variables to restore should only contain 'v2'.
 # The list of variables to restore should only contain 'v2'.
-variables_to_restore = tf.get_collection(slim.variables.VARIABLES_TO_RESTORE)
+variables_to_restore = slim.variables.get_variables_to_restore()
 restorer = tf.train.Saver(variables_to_restore)
 restorer = tf.train.Saver(variables_to_restore)
 with tf.Session() as sess:
 with tf.Session() as sess:
   # Restore variables from disk.
   # Restore variables from disk.
@@ -74,92 +82,71 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from inception.slim import scopes
 from inception.slim import scopes
 
 
 # Collection containing all the variables created using slim.variables
 # Collection containing all the variables created using slim.variables
-VARIABLES_COLLECTION = '_variables_'
+MODEL_VARIABLES = '_model_variables_'
 
 
-# Collection containing all the slim.variables that are marked to_restore
+# Collection containing the slim.variables that are created with restore=True.
 VARIABLES_TO_RESTORE = '_variables_to_restore_'
 VARIABLES_TO_RESTORE = '_variables_to_restore_'
 
 
 
 
-def get_variable_given_name(var):
-  """Gets the variable given name without the scope.
-
-  Args:
-    var: a variable.
-
-  Returns:
-    the given name of the variable without the scope.
-  """
-  name = var.op.name
-  if '/' in name:
-    name = name.split('/')[-1]
-  return name
-
-
-def default_collections(given_name, restore):
-  """Define the set of default collections that variables should be added.
-
-  Args:
-    given_name: the given name of the variable.
-    restore: whether the variable should be added to the VARIABLES_TO_RESTORE
-      collection.
-
-  Returns:
-    a list of default collections.
-  """
-  defaults = [tf.GraphKeys.VARIABLES, VARIABLES_COLLECTION]
-  defaults += [VARIABLES_COLLECTION + given_name]
-  if restore:
-    defaults += [VARIABLES_TO_RESTORE]
-  return defaults
-
-
 def add_variable(var, restore=True):
 def add_variable(var, restore=True):
-  """Adds a variable to the default set of collections.
+  """Adds a variable to the MODEL_VARIABLES collection.
 
 
+    Optionally it will add the variable to  the VARIABLES_TO_RESTORE collection.
   Args:
   Args:
     var: a variable.
     var: a variable.
     restore: whether the variable should be added to the
     restore: whether the variable should be added to the
       VARIABLES_TO_RESTORE collection.
       VARIABLES_TO_RESTORE collection.
+
   """
   """
-  given_name = get_variable_given_name(var)
-  for collection in default_collections(given_name, restore):
+  collections = [MODEL_VARIABLES]
+  if restore:
+    collections.append(VARIABLES_TO_RESTORE)
+  for collection in collections:
     if var not in tf.get_collection(collection):
     if var not in tf.get_collection(collection):
       tf.add_to_collection(collection, var)
       tf.add_to_collection(collection, var)
 
 
 
 
-def get_variables(prefix=None, suffix=None):
-  """Gets the list of variables, filtered by prefix and/or suffix.
+def get_variables(scope=None, suffix=None):
+  """Gets the list of variables, filtered by scope and/or suffix.
 
 
   Args:
   Args:
-    prefix: an optional prefix for filtering the variables to return.
+    scope: an optional scope for filtering the variables to return.
     suffix: an optional suffix for filtering the variables to return.
     suffix: an optional suffix for filtering the variables to return.
 
 
   Returns:
   Returns:
-    a list of variables with prefix and suffix.
+    a copied list of variables with scope and suffix.
   """
   """
-  candidates = tf.get_collection(VARIABLES_COLLECTION, prefix)
+  candidates = tf.get_collection(MODEL_VARIABLES, scope)[:]
   if suffix is not None:
   if suffix is not None:
     candidates = [var for var in candidates if var.op.name.endswith(suffix)]
     candidates = [var for var in candidates if var.op.name.endswith(suffix)]
   return candidates
   return candidates
 
 
 
 
-def get_variables_by_name(given_name, prefix=None):
-  """Gets the list of variables were given that name.
+def get_variables_to_restore():
+  """Gets the list of variables to restore.
+
+  Returns:
+    a copied list of variables.
+  """
+  return tf.get_collection(VARIABLES_TO_RESTORE)[:]
+
+
+def get_variables_by_name(given_name, scope=None):
+  """Gets the list of variables that were given that name.
 
 
   Args:
   Args:
     given_name: name given to the variable without scope.
     given_name: name given to the variable without scope.
-    prefix: an optional prefix for filtering the variables to return.
+    scope: an optional scope for filtering the variables to return.
 
 
   Returns:
   Returns:
-    a list of variables with prefix and suffix.
+    a copied list of variables with the given name and prefix.
   """
   """
-  return tf.get_collection(VARIABLES_COLLECTION + given_name, prefix)
+  return get_variables(scope=scope, suffix=given_name)
 
 
 
 
 def get_unique_variable(name):
 def get_unique_variable(name):
@@ -204,7 +191,7 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
     collections: A list of collection names to which the Variable will be added.
     collections: A list of collection names to which the Variable will be added.
       Note that the variable is always also added to the tf.GraphKeys.VARIABLES
       Note that the variable is always also added to the tf.GraphKeys.VARIABLES
-      collection.
+      and MODEL_VARIABLES collections.
     device: Optional device to place the variable. It can be an string or a
     device: Optional device to place the variable. It can be an string or a
       function that is called to get the device for the variable.
       function that is called to get the device for the variable.
     restore: whether the variable should be added to the
     restore: whether the variable should be added to the
@@ -216,8 +203,15 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
   # Instantiate the device for this variable if it is passed as a function.
   # Instantiate the device for this variable if it is passed as a function.
   if device and callable(device):
   if device and callable(device):
     device = device()
     device = device()
-  collections = set(list(collections or []) + default_collections(name,
-                                                                  restore))
+  collections = list(collections or [])
+
+  # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
+  collections += [tf.GraphKeys.VARIABLES, MODEL_VARIABLES]
+  # Add to VARIABLES_TO_RESTORE if necessary
+  if restore:
+    collections.append(VARIABLES_TO_RESTORE)
+  # Remove duplicates
+  collections = set(collections)
   with tf.device(device):
   with tf.device(device):
     return tf.get_variable(name, shape=shape, dtype=dtype,
     return tf.get_variable(name, shape=shape, dtype=dtype,
                            initializer=initializer, regularizer=regularizer,
                            initializer=initializer, regularizer=regularizer,

+ 58 - 28
inception/inception/slim/variables_test.py

@@ -17,7 +17,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import division
 from __future__ import print_function
 from __future__ import print_function
 
 
-
 import tensorflow as tf
 import tensorflow as tf
 
 
 from inception.slim import scopes
 from inception.slim import scopes
@@ -33,29 +32,13 @@ class VariablesTest(tf.test.TestCase):
         self.assertEquals(a.op.name, 'A/a')
         self.assertEquals(a.op.name, 'A/a')
         self.assertListEqual(a.get_shape().as_list(), [5])
         self.assertListEqual(a.get_shape().as_list(), [5])
 
 
-  def testGetVariableGivenName(self):
-    with self.test_session():
-      with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
-      with tf.variable_scope('B'):
-        b = variables.variable('a', [5])
-      self.assertEquals('a', variables.get_variable_given_name(a))
-      self.assertEquals('a', variables.get_variable_given_name(b))
-
-  def testGetVariableGivenNameScoped(self):
-    with self.test_session():
-      with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
-        b = variables.variable('b', [5])
-        self.assertEquals([a], variables.get_variables_by_name('a'))
-        self.assertEquals([b], variables.get_variables_by_name('b'))
-
   def testGetVariables(self):
   def testGetVariables(self):
     with self.test_session():
     with self.test_session():
       with tf.variable_scope('A'):
       with tf.variable_scope('A'):
         a = variables.variable('a', [5])
         a = variables.variable('a', [5])
       with tf.variable_scope('B'):
       with tf.variable_scope('B'):
         b = variables.variable('a', [5])
         b = variables.variable('a', [5])
+      self.assertEquals([a, b], variables.get_variables())
       self.assertEquals([a], variables.get_variables('A'))
       self.assertEquals([a], variables.get_variables('A'))
       self.assertEquals([b], variables.get_variables('B'))
       self.assertEquals([b], variables.get_variables('B'))
 
 
@@ -103,19 +86,28 @@ class VariablesTest(tf.test.TestCase):
       with tf.variable_scope('A'):
       with tf.variable_scope('A'):
         a = variables.variable('a', [5])
         a = variables.variable('a', [5])
       with tf.variable_scope('B'):
       with tf.variable_scope('B'):
-        b = variables.variable('b', [5])
-      self.assertListEqual([a, b],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+        b = variables.variable('a', [5])
+      self.assertEquals([a, b], variables.get_variables_to_restore())
 
 
-  def testGetVariablesToRestorePartial(self):
+  def testNoneGetVariablesToRestore(self):
     with self.test_session():
     with self.test_session():
       with tf.variable_scope('A'):
       with tf.variable_scope('A'):
-        a = variables.variable('a', [5])
+        a = variables.variable('a', [5], restore=False)
       with tf.variable_scope('B'):
       with tf.variable_scope('B'):
+        b = variables.variable('a', [5], restore=False)
+      self.assertEquals([], variables.get_variables_to_restore())
+      self.assertEquals([a, b], variables.get_variables())
+
+  def testGetMixedVariablesToRestore(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
         b = variables.variable('b', [5], restore=False)
         b = variables.variable('b', [5], restore=False)
-      self.assertListEqual([a, b], variables.get_variables())
-      self.assertListEqual([a],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+      with tf.variable_scope('B'):
+        c = variables.variable('c', [5])
+        d = variables.variable('d', [5], restore=False)
+      self.assertEquals([a, b, c, d], variables.get_variables())
+      self.assertEquals([a, c], variables.get_variables_to_restore())
 
 
   def testReuseVariable(self):
   def testReuseVariable(self):
     with self.test_session():
     with self.test_session():
@@ -190,11 +182,49 @@ class VariablesTest(tf.test.TestCase):
                               collections=['A', 'B']):
                               collections=['A', 'B']):
           b = variables.variable('b', [])
           b = variables.variable('b', [])
         c = variables.variable('c', [])
         c = variables.variable('c', [])
-      self.assertListEqual([a, b, c],
-                           tf.get_collection(variables.VARIABLES_TO_RESTORE))
+      self.assertListEqual([a, b, c], variables.get_variables_to_restore())
       self.assertListEqual([a, c], tf.trainable_variables())
       self.assertListEqual([a, c], tf.trainable_variables())
       self.assertListEqual([b], tf.get_collection('A'))
       self.assertListEqual([b], tf.get_collection('A'))
       self.assertListEqual([b], tf.get_collection('B'))
       self.assertListEqual([b], tf.get_collection('B'))
 
 
+
+class GetVariablesByNameTest(tf.test.TestCase):
+
+  def testGetVariableGivenNameScoped(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
+        b = variables.variable('b', [5])
+        self.assertEquals([a], variables.get_variables_by_name('a'))
+        self.assertEquals([b], variables.get_variables_by_name('b'))
+
+  def testGetVariablesByNameReturnsByValueWithScope(self):
+    with self.test_session():
+      with tf.variable_scope('A'):
+        a = variables.variable('a', [5])
+        matched_variables = variables.get_variables_by_name('a')
+
+        # If variables.get_variables_by_name returns the list by reference, the
+        # following append should persist, and be returned, in subsequent calls
+        # to variables.get_variables_by_name('a').
+        matched_variables.append(4)
+
+        matched_variables = variables.get_variables_by_name('a')
+        self.assertEquals([a], matched_variables)
+
+  def testGetVariablesByNameReturnsByValueWithoutScope(self):
+    with self.test_session():
+      a = variables.variable('a', [5])
+      matched_variables = variables.get_variables_by_name('a')
+
+      # If variables.get_variables_by_name returns the list by reference, the
+      # following append should persist, and be returned, in subsequent calls
+      # to variables.get_variables_by_name('a').
+      matched_variables.append(4)
+
+      matched_variables = variables.get_variables_by_name('a')
+      self.assertEquals([a], matched_variables)
+
+
 if __name__ == '__main__':
 if __name__ == '__main__':
   tf.test.main()
   tf.test.main()