浏览代码

Update inception model based on tf API changes: replace tf.op_scope with tf.name_scope and tf.variable_op_scope with tf.variable_scope; fix the order of arguments for tf.concat; replace tf.mul with tf.multiply.

Li Lao 8 年之前
父节点
当前提交
e5079c8390

+ 2 - 2
inception/inception/inception_model.py

@@ -147,8 +147,8 @@ def _activation_summary(x):
   # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
   # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
   # session. This helps the clarity of presentation on tensorboard.
   # session. This helps the clarity of presentation on tensorboard.
   tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
   tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name)
-  tf.histogram_summary(tensor_name + '/activations', x)
-  tf.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
+  tf.contrib.deprecated.histogram_summary(tensor_name + '/activations', x)
+  tf.contrib.deprecated.scalar_summary(tensor_name + '/sparsity', tf.nn.zero_fraction(x))
 
 
 
 
 def _activation_summaries(endpoints):
 def _activation_summaries(endpoints):

+ 0 - 3
inception/inception/slim/README.md

@@ -246,11 +246,8 @@ number. More concretely, the scopes in the example above would be 'conv3_1',
 
 
 In addition to the types of scope mechanisms in TensorFlow ([name_scope]
 In addition to the types of scope mechanisms in TensorFlow ([name_scope]
 (https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
 (https://www.tensorflow.org/api_docs/python/framework.html#name_scope),
-[op_scope](https://www.tensorflow.org/api_docs/python/framework.html#op_scope),
 [variable_scope]
 [variable_scope]
 (https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
 (https://www.tensorflow.org/api_docs/python/state_ops.html#variable_scope),
-[variable_op_scope]
-(https://www.tensorflow.org/api_docs/python/state_ops.html#variable_op_scope)),
 TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope]
 TF-Slim adds a new scoping mechanism called "argument scope" or [arg_scope]
 (scopes.py). This new scope allows a user to specify one or more operations and
 (scopes.py). This new scope allows a user to specify one or more operations and
 a set of arguments which will be passed to each of the operations defined in the
 a set of arguments which will be passed to each of the operations defined in the

+ 21 - 21
inception/inception/slim/inception_model.py

@@ -69,7 +69,7 @@ def inception_v3(inputs,
     is_training: whether is training or not.
     is_training: whether is training or not.
     restore_logits: whether or not the logits layers should be restored.
     restore_logits: whether or not the logits layers should be restored.
       Useful for fine-tuning a model with different num_classes.
       Useful for fine-tuning a model with different num_classes.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a list containing 'logits', 'aux_logits' Tensors.
     a list containing 'logits', 'aux_logits' Tensors.
@@ -77,7 +77,7 @@ def inception_v3(inputs,
   # end_points will collect relevant activations for external use, for example
   # end_points will collect relevant activations for external use, for example
   # summaries or losses.
   # summaries or losses.
   end_points = {}
   end_points = {}
-  with tf.op_scope([inputs], scope, 'inception_v3'):
+  with tf.name_scope(scope, 'inception_v3', [inputs]):
     with scopes.arg_scope([ops.conv2d, ops.fc, ops.batch_norm, ops.dropout],
     with scopes.arg_scope([ops.conv2d, ops.fc, ops.batch_norm, ops.dropout],
                           is_training=is_training):
                           is_training=is_training):
       with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
       with scopes.arg_scope([ops.conv2d, ops.max_pool, ops.avg_pool],
@@ -122,7 +122,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 32, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 32, [1, 1])
-          net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
+          net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
           end_points['mixed_35x35x256a'] = net
           end_points['mixed_35x35x256a'] = net
         # mixed_1: 35 x 35 x 288.
         # mixed_1: 35 x 35 x 288.
         with tf.variable_scope('mixed_35x35x288a'):
         with tf.variable_scope('mixed_35x35x288a'):
@@ -138,7 +138,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
-          net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
+          net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
           end_points['mixed_35x35x288a'] = net
           end_points['mixed_35x35x288a'] = net
         # mixed_2: 35 x 35 x 288.
         # mixed_2: 35 x 35 x 288.
         with tf.variable_scope('mixed_35x35x288b'):
         with tf.variable_scope('mixed_35x35x288b'):
@@ -154,7 +154,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 64, [1, 1])
-          net = tf.concat(3, [branch1x1, branch5x5, branch3x3dbl, branch_pool])
+          net = tf.concat([branch1x1, branch5x5, branch3x3dbl, branch_pool], 3)
           end_points['mixed_35x35x288b'] = net
           end_points['mixed_35x35x288b'] = net
         # mixed_3: 17 x 17 x 768.
         # mixed_3: 17 x 17 x 768.
         with tf.variable_scope('mixed_17x17x768a'):
         with tf.variable_scope('mixed_17x17x768a'):
@@ -167,7 +167,7 @@ def inception_v3(inputs,
                                       stride=2, padding='VALID')
                                       stride=2, padding='VALID')
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
             branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
-          net = tf.concat(3, [branch3x3, branch3x3dbl, branch_pool])
+          net = tf.concat([branch3x3, branch3x3dbl, branch_pool], 3)
           end_points['mixed_17x17x768a'] = net
           end_points['mixed_17x17x768a'] = net
         # mixed4: 17 x 17 x 768.
         # mixed4: 17 x 17 x 768.
         with tf.variable_scope('mixed_17x17x768b'):
         with tf.variable_scope('mixed_17x17x768b'):
@@ -186,7 +186,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
+          net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
           end_points['mixed_17x17x768b'] = net
           end_points['mixed_17x17x768b'] = net
         # mixed_5: 17 x 17 x 768.
         # mixed_5: 17 x 17 x 768.
         with tf.variable_scope('mixed_17x17x768c'):
         with tf.variable_scope('mixed_17x17x768c'):
@@ -205,7 +205,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
+          net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
           end_points['mixed_17x17x768c'] = net
           end_points['mixed_17x17x768c'] = net
         # mixed_6: 17 x 17 x 768.
         # mixed_6: 17 x 17 x 768.
         with tf.variable_scope('mixed_17x17x768d'):
         with tf.variable_scope('mixed_17x17x768d'):
@@ -224,7 +224,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
+          net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
           end_points['mixed_17x17x768d'] = net
           end_points['mixed_17x17x768d'] = net
         # mixed_7: 17 x 17 x 768.
         # mixed_7: 17 x 17 x 768.
         with tf.variable_scope('mixed_17x17x768e'):
         with tf.variable_scope('mixed_17x17x768e'):
@@ -243,7 +243,7 @@ def inception_v3(inputs,
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch7x7, branch7x7dbl, branch_pool])
+          net = tf.concat([branch1x1, branch7x7, branch7x7dbl, branch_pool], 3)
           end_points['mixed_17x17x768e'] = net
           end_points['mixed_17x17x768e'] = net
         # Auxiliary Head logits
         # Auxiliary Head logits
         aux_logits = tf.identity(end_points['mixed_17x17x768e'])
         aux_logits = tf.identity(end_points['mixed_17x17x768e'])
@@ -276,7 +276,7 @@ def inception_v3(inputs,
                                      stride=2, padding='VALID')
                                      stride=2, padding='VALID')
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
             branch_pool = ops.max_pool(net, [3, 3], stride=2, padding='VALID')
-          net = tf.concat(3, [branch3x3, branch7x7x3, branch_pool])
+          net = tf.concat([branch3x3, branch7x7x3, branch_pool], 3)
           end_points['mixed_17x17x1280a'] = net
           end_points['mixed_17x17x1280a'] = net
         # mixed_9: 8 x 8 x 2048.
         # mixed_9: 8 x 8 x 2048.
         with tf.variable_scope('mixed_8x8x2048a'):
         with tf.variable_scope('mixed_8x8x2048a'):
@@ -284,17 +284,17 @@ def inception_v3(inputs,
             branch1x1 = ops.conv2d(net, 320, [1, 1])
             branch1x1 = ops.conv2d(net, 320, [1, 1])
           with tf.variable_scope('branch3x3'):
           with tf.variable_scope('branch3x3'):
             branch3x3 = ops.conv2d(net, 384, [1, 1])
             branch3x3 = ops.conv2d(net, 384, [1, 1])
-            branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
-                                      ops.conv2d(branch3x3, 384, [3, 1])])
+            branch3x3 = tf.concat([ops.conv2d(branch3x3, 384, [1, 3]),
+                                   ops.conv2d(branch3x3, 384, [3, 1])], 3)
           with tf.variable_scope('branch3x3dbl'):
           with tf.variable_scope('branch3x3dbl'):
             branch3x3dbl = ops.conv2d(net, 448, [1, 1])
             branch3x3dbl = ops.conv2d(net, 448, [1, 1])
             branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
             branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
-            branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
-                                         ops.conv2d(branch3x3dbl, 384, [3, 1])])
+            branch3x3dbl = tf.concat([ops.conv2d(branch3x3dbl, 384, [1, 3]),
+                                      ops.conv2d(branch3x3dbl, 384, [3, 1])], 3)
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
+          net = tf.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 3)
           end_points['mixed_8x8x2048a'] = net
           end_points['mixed_8x8x2048a'] = net
         # mixed_10: 8 x 8 x 2048.
         # mixed_10: 8 x 8 x 2048.
         with tf.variable_scope('mixed_8x8x2048b'):
         with tf.variable_scope('mixed_8x8x2048b'):
@@ -302,17 +302,17 @@ def inception_v3(inputs,
             branch1x1 = ops.conv2d(net, 320, [1, 1])
             branch1x1 = ops.conv2d(net, 320, [1, 1])
           with tf.variable_scope('branch3x3'):
           with tf.variable_scope('branch3x3'):
             branch3x3 = ops.conv2d(net, 384, [1, 1])
             branch3x3 = ops.conv2d(net, 384, [1, 1])
-            branch3x3 = tf.concat(3, [ops.conv2d(branch3x3, 384, [1, 3]),
-                                      ops.conv2d(branch3x3, 384, [3, 1])])
+            branch3x3 = tf.concat([ops.conv2d(branch3x3, 384, [1, 3]),
+                                   ops.conv2d(branch3x3, 384, [3, 1])], 3)
           with tf.variable_scope('branch3x3dbl'):
           with tf.variable_scope('branch3x3dbl'):
             branch3x3dbl = ops.conv2d(net, 448, [1, 1])
             branch3x3dbl = ops.conv2d(net, 448, [1, 1])
             branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
             branch3x3dbl = ops.conv2d(branch3x3dbl, 384, [3, 3])
-            branch3x3dbl = tf.concat(3, [ops.conv2d(branch3x3dbl, 384, [1, 3]),
-                                         ops.conv2d(branch3x3dbl, 384, [3, 1])])
+            branch3x3dbl = tf.concat([ops.conv2d(branch3x3dbl, 384, [1, 3]),
+                                      ops.conv2d(branch3x3dbl, 384, [3, 1])], 3)
           with tf.variable_scope('branch_pool'):
           with tf.variable_scope('branch_pool'):
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.avg_pool(net, [3, 3])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
             branch_pool = ops.conv2d(branch_pool, 192, [1, 1])
-          net = tf.concat(3, [branch1x1, branch3x3, branch3x3dbl, branch_pool])
+          net = tf.concat([branch1x1, branch3x3, branch3x3dbl, branch_pool], 3)
           end_points['mixed_8x8x2048b'] = net
           end_points['mixed_8x8x2048b'] = net
         # Final pooling and prediction
         # Final pooling and prediction
         with tf.variable_scope('logits'):
         with tf.variable_scope('logits'):

+ 2 - 2
inception/inception/slim/inception_test.py

@@ -65,9 +65,9 @@ class InceptionTest(tf.test.TestCase):
         inception.inception_v3(inputs, num_classes)
         inception.inception_v3(inputs, num_classes)
       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
       with tf.variable_scope('on_gpu'), tf.device('/gpu:0'):
         inception.inception_v3(inputs, num_classes)
         inception.inception_v3(inputs, num_classes)
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_cpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_cpu'):
         self.assertDeviceEqual(v.device, '/cpu:0')
         self.assertDeviceEqual(v.device, '/cpu:0')
-      for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope='on_gpu'):
+      for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='on_gpu'):
         self.assertDeviceEqual(v.device, '/gpu:0')
         self.assertDeviceEqual(v.device, '/gpu:0')
 
 
   def testHalfSizeImages(self):
   def testHalfSizeImages(self):

+ 22 - 22
inception/inception/slim/losses.py

@@ -39,17 +39,17 @@ def l1_regularizer(weight=1.0, scope=None):
 
 
   Args:
   Args:
     weight: scale the loss by this factor.
     weight: scale the loss by this factor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a regularizer function.
     a regularizer function.
   """
   """
   def regularizer(tensor):
   def regularizer(tensor):
-    with tf.op_scope([tensor], scope, 'L1Regularizer'):
+    with tf.name_scope(scope, 'L1Regularizer', [tensor]):
       l1_weight = tf.convert_to_tensor(weight,
       l1_weight = tf.convert_to_tensor(weight,
                                        dtype=tensor.dtype.base_dtype,
                                        dtype=tensor.dtype.base_dtype,
                                        name='weight')
                                        name='weight')
-      return tf.mul(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+      return tf.multiply(l1_weight, tf.reduce_sum(tf.abs(tensor)), name='value')
   return regularizer
   return regularizer
 
 
 
 
@@ -58,17 +58,17 @@ def l2_regularizer(weight=1.0, scope=None):
 
 
   Args:
   Args:
     weight: scale the loss by this factor.
     weight: scale the loss by this factor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a regularizer function.
     a regularizer function.
   """
   """
   def regularizer(tensor):
   def regularizer(tensor):
-    with tf.op_scope([tensor], scope, 'L2Regularizer'):
+    with tf.name_scope(scope, 'L2Regularizer', [tensor]):
       l2_weight = tf.convert_to_tensor(weight,
       l2_weight = tf.convert_to_tensor(weight,
                                        dtype=tensor.dtype.base_dtype,
                                        dtype=tensor.dtype.base_dtype,
                                        name='weight')
                                        name='weight')
-      return tf.mul(l2_weight, tf.nn.l2_loss(tensor), name='value')
+      return tf.multiply(l2_weight, tf.nn.l2_loss(tensor), name='value')
   return regularizer
   return regularizer
 
 
 
 
@@ -78,22 +78,22 @@ def l1_l2_regularizer(weight_l1=1.0, weight_l2=1.0, scope=None):
   Args:
   Args:
     weight_l1: scale the L1 loss by this factor.
     weight_l1: scale the L1 loss by this factor.
     weight_l2: scale the L2 loss by this factor.
     weight_l2: scale the L2 loss by this factor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a regularizer function.
     a regularizer function.
   """
   """
   def regularizer(tensor):
   def regularizer(tensor):
-    with tf.op_scope([tensor], scope, 'L1L2Regularizer'):
+    with tf.name_scope(scope, 'L1L2Regularizer', [tensor]):
       weight_l1_t = tf.convert_to_tensor(weight_l1,
       weight_l1_t = tf.convert_to_tensor(weight_l1,
                                          dtype=tensor.dtype.base_dtype,
                                          dtype=tensor.dtype.base_dtype,
                                          name='weight_l1')
                                          name='weight_l1')
       weight_l2_t = tf.convert_to_tensor(weight_l2,
       weight_l2_t = tf.convert_to_tensor(weight_l2,
                                          dtype=tensor.dtype.base_dtype,
                                          dtype=tensor.dtype.base_dtype,
                                          name='weight_l2')
                                          name='weight_l2')
-      reg_l1 = tf.mul(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
+      reg_l1 = tf.multiply(weight_l1_t, tf.reduce_sum(tf.abs(tensor)),
                       name='value_l1')
                       name='value_l1')
-      reg_l2 = tf.mul(weight_l2_t, tf.nn.l2_loss(tensor),
+      reg_l2 = tf.multiply(weight_l2_t, tf.nn.l2_loss(tensor),
                       name='value_l2')
                       name='value_l2')
       return tf.add(reg_l1, reg_l2, name='value')
       return tf.add(reg_l1, reg_l2, name='value')
   return regularizer
   return regularizer
@@ -105,16 +105,16 @@ def l1_loss(tensor, weight=1.0, scope=None):
   Args:
   Args:
     tensor: tensor to regularize.
     tensor: tensor to regularize.
     weight: scale the loss by this factor.
     weight: scale the loss by this factor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     the L1 loss op.
     the L1 loss op.
   """
   """
-  with tf.op_scope([tensor], scope, 'L1Loss'):
+  with tf.name_scope(scope, 'L1Loss', [tensor]):
     weight = tf.convert_to_tensor(weight,
     weight = tf.convert_to_tensor(weight,
                                   dtype=tensor.dtype.base_dtype,
                                   dtype=tensor.dtype.base_dtype,
                                   name='loss_weight')
                                   name='loss_weight')
-    loss = tf.mul(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
+    loss = tf.multiply(weight, tf.reduce_sum(tf.abs(tensor)), name='value')
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     return loss
     return loss
 
 
@@ -125,16 +125,16 @@ def l2_loss(tensor, weight=1.0, scope=None):
   Args:
   Args:
     tensor: tensor to regularize.
     tensor: tensor to regularize.
     weight: an optional weight to modulate the loss.
     weight: an optional weight to modulate the loss.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     the L2 loss op.
     the L2 loss op.
   """
   """
-  with tf.op_scope([tensor], scope, 'L2Loss'):
+  with tf.name_scope(scope, 'L2Loss', [tensor]):
     weight = tf.convert_to_tensor(weight,
     weight = tf.convert_to_tensor(weight,
                                   dtype=tensor.dtype.base_dtype,
                                   dtype=tensor.dtype.base_dtype,
                                   name='loss_weight')
                                   name='loss_weight')
-    loss = tf.mul(weight, tf.nn.l2_loss(tensor), name='value')
+    loss = tf.multiply(weight, tf.nn.l2_loss(tensor), name='value')
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     return loss
     return loss
 
 
@@ -150,25 +150,25 @@ def cross_entropy_loss(logits, one_hot_labels, label_smoothing=0,
     one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
     one_hot_labels: [batch_size, num_classes] target one_hot_encoded labels.
     label_smoothing: if greater than 0 then smooth the labels.
     label_smoothing: if greater than 0 then smooth the labels.
     weight: scale the loss by this factor.
     weight: scale the loss by this factor.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     A tensor with the softmax_cross_entropy loss.
     A tensor with the softmax_cross_entropy loss.
   """
   """
   logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
   logits.get_shape().assert_is_compatible_with(one_hot_labels.get_shape())
-  with tf.op_scope([logits, one_hot_labels], scope, 'CrossEntropyLoss'):
+  with tf.name_scope(scope, 'CrossEntropyLoss', [logits, one_hot_labels]):
     num_classes = one_hot_labels.get_shape()[-1].value
     num_classes = one_hot_labels.get_shape()[-1].value
     one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
     one_hot_labels = tf.cast(one_hot_labels, logits.dtype)
     if label_smoothing > 0:
     if label_smoothing > 0:
       smooth_positives = 1.0 - label_smoothing
       smooth_positives = 1.0 - label_smoothing
       smooth_negatives = label_smoothing / num_classes
       smooth_negatives = label_smoothing / num_classes
       one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
       one_hot_labels = one_hot_labels * smooth_positives + smooth_negatives
-    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
-                                                            labels=one_hot_labels,
-                                                            name='xentropy')
+    cross_entropy = tf.contrib.nn.deprecated_flipped_softmax_cross_entropy_with_logits(
+        logits, one_hot_labels, name='xentropy')
+
     weight = tf.convert_to_tensor(weight,
     weight = tf.convert_to_tensor(weight,
                                   dtype=logits.dtype.base_dtype,
                                   dtype=logits.dtype.base_dtype,
                                   name='loss_weight')
                                   name='loss_weight')
-    loss = tf.mul(weight, tf.reduce_mean(cross_entropy), name='value')
+    loss = tf.multiply(weight, tf.reduce_mean(cross_entropy), name='value')
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     tf.add_to_collection(LOSSES_COLLECTION, loss)
     return loss
     return loss

+ 18 - 18
inception/inception/slim/ops.py

@@ -68,7 +68,7 @@ def batch_norm(inputs,
     is_training: whether or not the model is in training mode.
     is_training: whether or not the model is in training mode.
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
-    scope: Optional scope for variable_op_scope.
+    scope: Optional scope for variable_scope.
     reuse: whether or not the layer and its variables should be reused. To be
     reuse: whether or not the layer and its variables should be reused. To be
       able to reuse the layer scope must be given.
       able to reuse the layer scope must be given.
 
 
@@ -77,7 +77,7 @@ def batch_norm(inputs,
 
 
   """
   """
   inputs_shape = inputs.get_shape()
   inputs_shape = inputs.get_shape()
-  with tf.variable_op_scope([inputs], scope, 'BatchNorm', reuse=reuse):
+  with tf.variable_scope(scope, 'BatchNorm', [inputs], reuse=reuse):
     axis = list(range(len(inputs_shape) - 1))
     axis = list(range(len(inputs_shape) - 1))
     params_shape = inputs_shape[-1:]
     params_shape = inputs_shape[-1:]
     # Allocate parameters for the beta and gamma of the normalization.
     # Allocate parameters for the beta and gamma of the normalization.
@@ -203,14 +203,14 @@ def conv2d(inputs,
     is_training: whether or not the model is in training mode.
     is_training: whether or not the model is in training mode.
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
-    scope: Optional scope for variable_op_scope.
+    scope: Optional scope for variable_scope.
     reuse: whether or not the layer and its variables should be reused. To be
     reuse: whether or not the layer and its variables should be reused. To be
       able to reuse the layer scope must be given.
       able to reuse the layer scope must be given.
   Returns:
   Returns:
     a tensor representing the output of the operation.
     a tensor representing the output of the operation.
 
 
   """
   """
-  with tf.variable_op_scope([inputs], scope, 'Conv', reuse=reuse):
+  with tf.variable_scope(scope, 'Conv', [inputs], reuse=reuse):
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     stride_h, stride_w = _two_element_tuple(stride)
     stride_h, stride_w = _two_element_tuple(stride)
     num_filters_in = inputs.get_shape()[-1]
     num_filters_in = inputs.get_shape()[-1]
@@ -278,14 +278,14 @@ def fc(inputs,
     is_training: whether or not the model is in training mode.
     is_training: whether or not the model is in training mode.
     trainable: whether or not the variables should be trainable or not.
     trainable: whether or not the variables should be trainable or not.
     restore: whether or not the variables should be marked for restore.
     restore: whether or not the variables should be marked for restore.
-    scope: Optional scope for variable_op_scope.
+    scope: Optional scope for variable_scope.
     reuse: whether or not the layer and its variables should be reused. To be
     reuse: whether or not the layer and its variables should be reused. To be
       able to reuse the layer scope must be given.
       able to reuse the layer scope must be given.
 
 
   Returns:
   Returns:
      the tensor variable representing the result of the series of operations.
      the tensor variable representing the result of the series of operations.
   """
   """
-  with tf.variable_op_scope([inputs], scope, 'FC', reuse=reuse):
+  with tf.variable_scope(scope, 'FC', [inputs], reuse=reuse):
     num_units_in = inputs.get_shape()[1]
     num_units_in = inputs.get_shape()[1]
     weights_shape = [num_units_in, num_units_out]
     weights_shape = [num_units_in, num_units_out]
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
     weights_initializer = tf.truncated_normal_initializer(stddev=stddev)
@@ -323,15 +323,15 @@ def one_hot_encoding(labels, num_classes, scope=None):
   Args:
   Args:
     labels: [batch_size] target labels.
     labels: [batch_size] target labels.
     num_classes: total number of classes.
     num_classes: total number of classes.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
   Returns:
   Returns:
     one hot encoding of the labels.
     one hot encoding of the labels.
   """
   """
-  with tf.op_scope([labels], scope, 'OneHotEncoding'):
+  with tf.name_scope(scope, 'OneHotEncoding', [labels]):
     batch_size = labels.get_shape()[0]
     batch_size = labels.get_shape()[0]
     indices = tf.expand_dims(tf.range(0, batch_size), 1)
     indices = tf.expand_dims(tf.range(0, batch_size), 1)
     labels = tf.cast(tf.expand_dims(labels, 1), indices.dtype)
     labels = tf.cast(tf.expand_dims(labels, 1), indices.dtype)
-    concated = tf.concat(1, [indices, labels])
+    concated = tf.concat([indices, labels], 1)
     onehot_labels = tf.sparse_to_dense(
     onehot_labels = tf.sparse_to_dense(
         concated, tf.pack([batch_size, num_classes]), 1.0, 0.0)
         concated, tf.pack([batch_size, num_classes]), 1.0, 0.0)
     onehot_labels.set_shape([batch_size, num_classes])
     onehot_labels.set_shape([batch_size, num_classes])
@@ -354,14 +354,14 @@ def max_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
       Can be an int if both strides are the same.  Note that presently
       Can be an int if both strides are the same.  Note that presently
       both strides must have the same value.
       both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     padding: the padding method, either 'VALID' or 'SAME'.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a tensor representing the results of the pooling operation.
     a tensor representing the results of the pooling operation.
   Raises:
   Raises:
     ValueError: if 'kernel_size' is not a 2-D list
     ValueError: if 'kernel_size' is not a 2-D list
   """
   """
-  with tf.op_scope([inputs], scope, 'MaxPool'):
+  with tf.name_scope(scope, 'MaxPool', [inputs]):
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     stride_h, stride_w = _two_element_tuple(stride)
     stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.max_pool(inputs,
     return tf.nn.max_pool(inputs,
@@ -386,12 +386,12 @@ def avg_pool(inputs, kernel_size, stride=2, padding='VALID', scope=None):
       Can be an int if both strides are the same.  Note that presently
       Can be an int if both strides are the same.  Note that presently
       both strides must have the same value.
       both strides must have the same value.
     padding: the padding method, either 'VALID' or 'SAME'.
     padding: the padding method, either 'VALID' or 'SAME'.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a tensor representing the results of the pooling operation.
     a tensor representing the results of the pooling operation.
   """
   """
-  with tf.op_scope([inputs], scope, 'AvgPool'):
+  with tf.name_scope(scope, 'AvgPool', [inputs]):
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     kernel_h, kernel_w = _two_element_tuple(kernel_size)
     stride_h, stride_w = _two_element_tuple(stride)
     stride_h, stride_w = _two_element_tuple(stride)
     return tf.nn.avg_pool(inputs,
     return tf.nn.avg_pool(inputs,
@@ -409,13 +409,13 @@ def dropout(inputs, keep_prob=0.5, is_training=True, scope=None):
     keep_prob: the probability of keeping each input unit.
     keep_prob: the probability of keeping each input unit.
     is_training: whether or not the model is in training mode. If so, dropout is
     is_training: whether or not the model is in training mode. If so, dropout is
     applied and values scaled. Otherwise, inputs is returned.
     applied and values scaled. Otherwise, inputs is returned.
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a tensor representing the output of the operation.
     a tensor representing the output of the operation.
   """
   """
   if is_training and keep_prob > 0:
   if is_training and keep_prob > 0:
-    with tf.op_scope([inputs], scope, 'Dropout'):
+    with tf.name_scope(scope, 'Dropout', [inputs]):
       return tf.nn.dropout(inputs, keep_prob)
       return tf.nn.dropout(inputs, keep_prob)
   else:
   else:
     return inputs
     return inputs
@@ -428,7 +428,7 @@ def flatten(inputs, scope=None):
 
 
   Args:
   Args:
     inputs: a tensor of size [batch_size, ...].
     inputs: a tensor of size [batch_size, ...].
-    scope: Optional scope for op_scope.
+    scope: Optional scope for name_scope.
 
 
   Returns:
   Returns:
     a flattened tensor with shape [batch_size, k].
     a flattened tensor with shape [batch_size, k].
@@ -439,7 +439,7 @@ def flatten(inputs, scope=None):
     raise ValueError('Inputs must be have a least 2 dimensions')
     raise ValueError('Inputs must be have a least 2 dimensions')
   dims = inputs.get_shape()[1:]
   dims = inputs.get_shape()[1:]
   k = dims.num_elements()
   k = dims.num_elements()
-  with tf.op_scope([inputs], scope, 'Flatten'):
+  with tf.name_scope(scope, 'Flatten', [inputs]):
     return tf.reshape(inputs, [-1, k])
     return tf.reshape(inputs, [-1, k])
 
 
 
 
@@ -466,7 +466,7 @@ def repeat_op(repetitions, inputs, op, *args, **kwargs):
     ValueError: if the op is unknown or wrong.
     ValueError: if the op is unknown or wrong.
   """
   """
   scope = kwargs.pop('scope', None)
   scope = kwargs.pop('scope', None)
-  with tf.variable_op_scope([inputs], scope, 'RepeatOp'):
+  with tf.variable_scope(scope, 'RepeatOp', [inputs]):
     tower = inputs
     tower = inputs
     for _ in range(repetitions):
     for _ in range(repetitions):
       tower = op(tower, *args, **kwargs)
       tower = op(tower, *args, **kwargs)

+ 5 - 5
inception/inception/slim/variables.py

@@ -161,7 +161,7 @@ def get_unique_variable(name):
   Raises:
   Raises:
     ValueError: if no variable uniquely identified by the name exists.
     ValueError: if no variable uniquely identified by the name exists.
   """
   """
-  candidates = tf.get_collection(tf.GraphKeys.VARIABLES, name)
+  candidates = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, name)
   if not candidates:
   if not candidates:
     raise ValueError('Couldnt find variable %s' % name)
     raise ValueError('Couldnt find variable %s' % name)
 
 
@@ -234,7 +234,7 @@ def global_step(device=''):
   else:
   else:
     collections = [
     collections = [
         VARIABLES_TO_RESTORE,
         VARIABLES_TO_RESTORE,
-        tf.GraphKeys.VARIABLES,
+        tf.GraphKeys.GLOBAL_VARIABLES,
         tf.GraphKeys.GLOBAL_STEP,
         tf.GraphKeys.GLOBAL_STEP,
     ]
     ]
     # Get the device for the variable.
     # Get the device for the variable.
@@ -263,7 +263,7 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
     trainable: If `True` also add the variable to the graph collection
     trainable: If `True` also add the variable to the graph collection
       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
       `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable).
     collections: A list of collection names to which the Variable will be added.
     collections: A list of collection names to which the Variable will be added.
-      Note that the variable is always also added to the tf.GraphKeys.VARIABLES
+      Note that the variable is always also added to the tf.GraphKeys.GLOBAL_VARIABLES
       and MODEL_VARIABLES collections.
       and MODEL_VARIABLES collections.
     device: Optional device to place the variable. It can be an string or a
     device: Optional device to place the variable. It can be an string or a
       function that is called to get the device for the variable.
       function that is called to get the device for the variable.
@@ -275,8 +275,8 @@ def variable(name, shape=None, dtype=tf.float32, initializer=None,
   """
   """
   collections = list(collections or [])
   collections = list(collections or [])
 
 
-  # Make sure variables are added to tf.GraphKeys.VARIABLES and MODEL_VARIABLES
-  collections += [tf.GraphKeys.VARIABLES, MODEL_VARIABLES]
+  # Make sure variables are added to tf.GraphKeys.GLOBAL_VARIABLES and MODEL_VARIABLES
+  collections += [tf.GraphKeys.GLOBAL_VARIABLES, MODEL_VARIABLES]
   # Add to VARIABLES_TO_RESTORE if necessary
   # Add to VARIABLES_TO_RESTORE if necessary
   if restore:
   if restore:
     collections.append(VARIABLES_TO_RESTORE)
     collections.append(VARIABLES_TO_RESTORE)