Bladeren bron

Merge pull request #911 from wookayin/cifar10

Fix bugs and API usage on cifar10 and cifar10_multi_gpu_train
Neal Wu 8 jaren geleden
bovenliggende
commit
596c9e2367
1 gewijzigde bestanden met toevoegingen van 20 en 19 verwijderingen
  1. 20 19
      tutorials/image/cifar10/cifar10_multi_gpu_train.py

+ 20 - 19
tutorials/image/cifar10/cifar10_multi_gpu_train.py

@@ -162,25 +162,26 @@ def train():
 
     # Calculate the gradients for each model tower.
     tower_grads = []
-    for i in xrange(FLAGS.num_gpus):
-      with tf.device('/gpu:%d' % i):
-        with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
-          # Calculate the loss for one tower of the CIFAR model. This function
-          # constructs the entire CIFAR model but shares the variables across
-          # all towers.
-          loss = tower_loss(scope)
-
-          # Reuse variables for the next tower.
-          tf.get_variable_scope().reuse_variables()
-
-          # Retain the summaries from the final tower.
-          summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
-
-          # Calculate the gradients for the batch of data on this CIFAR tower.
-          grads = opt.compute_gradients(loss)
-
-          # Keep track of the gradients across all towers.
-          tower_grads.append(grads)
+    with tf.variable_scope(tf.get_variable_scope()):
+      for i in xrange(FLAGS.num_gpus):
+        with tf.device('/gpu:%d' % i):
+          with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
+            # Calculate the loss for one tower of the CIFAR model. This function
+            # constructs the entire CIFAR model but shares the variables across
+            # all towers.
+            loss = tower_loss(scope)
+
+            # Reuse variables for the next tower.
+            tf.get_variable_scope().reuse_variables()
+
+            # Retain the summaries from the final tower.
+            summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
+
+            # Calculate the gradients for the batch of data on this CIFAR tower.
+            grads = opt.compute_gradients(loss)
+
+            # Keep track of the gradients across all towers.
+            tower_grads.append(grads)
 
     # We must calculate the mean of each gradient. Note that this is the
     # synchronization point across all towers.