|
@@ -79,7 +79,7 @@ RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
|
|
|
RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
|
|
|
|
|
|
|
|
|
-def _tower_loss(images, labels, num_classes, scope):
|
|
|
+def _tower_loss(images, labels, num_classes, scope, reuse_variables=None):
|
|
|
"""Calculate the total loss on a single tower running the ImageNet model.
|
|
|
|
|
|
We perform 'batch splitting'. This means that we cut up a batch across
|
|
@@ -103,9 +103,10 @@ def _tower_loss(images, labels, num_classes, scope):
|
|
|
restore_logits = not FLAGS.fine_tune
|
|
|
|
|
|
# Build inference Graph.
|
|
|
- logits = inception.inference(images, num_classes, for_training=True,
|
|
|
- restore_logits=restore_logits,
|
|
|
- scope=scope)
|
|
|
+ with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
|
|
|
+ logits = inception.inference(images, num_classes, for_training=True,
|
|
|
+ restore_logits=restore_logits,
|
|
|
+ scope=scope)
|
|
|
|
|
|
# Build the portion of the Graph calculating the losses. Note that we will
|
|
|
# assemble the total_loss using a custom function below.
|
|
@@ -220,13 +221,14 @@ def train(dataset):
|
|
|
# Number of classes in the Dataset label set plus 1.
|
|
|
# Label 0 is reserved for an (unused) background class.
|
|
|
num_classes = dataset.num_classes() + 1
|
|
|
-
|
|
|
+
|
|
|
# Split the batch of images and labels for towers.
|
|
|
images_splits = tf.split(0, FLAGS.num_gpus, images)
|
|
|
labels_splits = tf.split(0, FLAGS.num_gpus, labels)
|
|
|
|
|
|
# Calculate the gradients for each model tower.
|
|
|
tower_grads = []
|
|
|
+ reuse_variables = None
|
|
|
for i in xrange(FLAGS.num_gpus):
|
|
|
with tf.device('/gpu:%d' % i):
|
|
|
with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
|
|
@@ -236,10 +238,10 @@ def train(dataset):
|
|
|
# function constructs the entire ImageNet model but shares the
|
|
|
# variables across all towers.
|
|
|
loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
|
|
|
- scope)
|
|
|
+ scope, reuse_variables)
|
|
|
|
|
|
# Reuse variables for the next tower.
|
|
|
- tf.get_variable_scope().reuse_variables()
|
|
|
+ reuse_variables = True
|
|
|
|
|
|
# Retain the summaries from the final tower.
|
|
|
summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
|