瀏覽代碼

Wrap the cifar10 multigpu model construction part with a variable_scope

Without the new variable_scope, creating apply_gradient_op raises
an error that additional moving average or slot variables could not
be created. This is because of the 'leaky reuse' of variable scope,
so we correct the problem by explicitly introducing a new variable scope.

Related issues: tensorflow/models#901, tensorflow/tensorflow#6220
Jongwook Choi 8 年之前
父節點
當前提交
9d96e9fe2b
共有 1 個文件被更改,包括 20 次插入19 次删除
  1. 20 19
      tutorials/image/cifar10/cifar10_multi_gpu_train.py

+ 20 - 19
tutorials/image/cifar10/cifar10_multi_gpu_train.py

@@ -162,25 +162,26 @@ def train():
 
     # Calculate the gradients for each model tower.
     tower_grads = []
-    for i in xrange(FLAGS.num_gpus):
-      with tf.device('/gpu:%d' % i):
-        with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
-          # Calculate the loss for one tower of the CIFAR model. This function
-          # constructs the entire CIFAR model but shares the variables across
-          # all towers.
-          loss = tower_loss(scope)
-
-          # Reuse variables for the next tower.
-          tf.get_variable_scope().reuse_variables()
-
-          # Retain the summaries from the final tower.
-          summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
-
-          # Calculate the gradients for the batch of data on this CIFAR tower.
-          grads = opt.compute_gradients(loss)
-
-          # Keep track of the gradients across all towers.
-          tower_grads.append(grads)
+    with tf.variable_scope(tf.get_variable_scope()):
+      for i in xrange(FLAGS.num_gpus):
+        with tf.device('/gpu:%d' % i):
+          with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
+            # Calculate the loss for one tower of the CIFAR model. This function
+            # constructs the entire CIFAR model but shares the variables across
+            # all towers.
+            loss = tower_loss(scope)
+
+            # Reuse variables for the next tower.
+            tf.get_variable_scope().reuse_variables()
+
+            # Retain the summaries from the final tower.
+            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
+
+            # Calculate the gradients for the batch of data on this CIFAR tower.
+            grads = opt.compute_gradients(loss)
+
+            # Keep track of the gradients across all towers.
+            tower_grads.append(grads)
 
     # We must calculate the mean of each gradient. Note that this is the
     # synchronization point across all towers.