Przeglądaj źródła

Variables defined in ExponentialMovingAverage need not to be shared. (#778)

* Variables defined in ExponentialMovingAverage need not to be shared.

* Address comments.
Yuefeng Zhou 8 lat temu
rodzic
commit
91c7b91f83
1 zmienionych plików z 9 dodań i 7 usunięć
  1. 9 7
      inception/inception/inception_train.py

+ 9 - 7
inception/inception/inception_train.py

@@ -79,7 +79,7 @@ RMSPROP_MOMENTUM = 0.9             # Momentum in RMSProp.
 RMSPROP_EPSILON = 1.0              # Epsilon term for RMSProp.
 
 
-def _tower_loss(images, labels, num_classes, scope):
+def _tower_loss(images, labels, num_classes, scope, reuse_variables=None):
   """Calculate the total loss on a single tower running the ImageNet model.
 
   We perform 'batch splitting'. This means that we cut up a batch across
@@ -103,9 +103,10 @@ def _tower_loss(images, labels, num_classes, scope):
   restore_logits = not FLAGS.fine_tune
 
   # Build inference Graph.
-  logits = inception.inference(images, num_classes, for_training=True,
-                               restore_logits=restore_logits,
-                               scope=scope)
+  with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
+    logits = inception.inference(images, num_classes, for_training=True,
+                                 restore_logits=restore_logits,
+                                 scope=scope)
 
   # Build the portion of the Graph calculating the losses. Note that we will
   # assemble the total_loss using a custom function below.
@@ -220,13 +221,14 @@ def train(dataset):
     # Number of classes in the Dataset label set plus 1.
     # Label 0 is reserved for an (unused) background class.
     num_classes = dataset.num_classes() + 1
-    
+
      # Split the batch of images and labels for towers.
     images_splits = tf.split(0, FLAGS.num_gpus, images)
     labels_splits = tf.split(0, FLAGS.num_gpus, labels)
 
     # Calculate the gradients for each model tower.
     tower_grads = []
+    reuse_variables = None
     for i in xrange(FLAGS.num_gpus):
       with tf.device('/gpu:%d' % i):
         with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
@@ -236,10 +238,10 @@ def train(dataset):
             # function constructs the entire ImageNet model but shares the
             # variables across all towers.
             loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
-                               scope)
+                               scope, reuse_variables)
 
           # Reuse variables for the next tower.
-          tf.get_variable_scope().reuse_variables()
+          reuse_variables = True
 
           # Retain the summaries from the final tower.
           summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)