|
|
@@ -220,28 +220,23 @@ def train(dataset):
|
|
|
# Number of classes in the Dataset label set plus 1.
|
|
|
# Label 0 is reserved for an (unused) background class.
|
|
|
num_classes = dataset.num_classes() + 1
|
|
|
+
|
|
|
+ # Split the batch of images and labels for towers.
|
|
|
+ images_splits = tf.split(0, FLAGS.num_gpus, images)
|
|
|
+ labels_splits = tf.split(0, FLAGS.num_gpus, labels)
|
|
|
|
|
|
# Calculate the gradients for each model tower.
|
|
|
tower_grads = []
|
|
|
for i in xrange(FLAGS.num_gpus):
|
|
|
with tf.device('/gpu:%d' % i):
|
|
|
with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
|
|
|
- # Split the batch of images and labels.
|
|
|
- batch_start = split_batch_size * i
|
|
|
- images_batch = tf.slice(images,
|
|
|
- begin=[batch_start, 0, 0, 0],
|
|
|
- size=[split_batch_size, -1, -1, -1])
|
|
|
- labels_batch = tf.slice(labels,
|
|
|
- begin=[batch_start],
|
|
|
- size=[split_batch_size])
|
|
|
-
|
|
|
-
|
|
|
# Force all Variables to reside on the CPU.
|
|
|
with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
|
|
|
# Calculate the loss for one tower of the ImageNet model. This
|
|
|
# function constructs the entire ImageNet model but shares the
|
|
|
# variables across all towers.
|
|
|
- loss = _tower_loss(images_batch, labels_batch, num_classes, scope)
|
|
|
+ loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
|
|
|
+ scope)
|
|
|
|
|
|
# Reuse variables for the next tower.
|
|
|
tf.get_variable_scope().reuse_variables()
|