Browse Source

Swivel: add multiple GPU support

This speeds up the training accordingly.
Vadim Markovtsev 8 years ago
parent
commit
ceee992a6e
1 changed files with 166 additions and 103 deletions
  1. 166 103
      swivel/swivel.py

+ 166 - 103
swivel/swivel.py

@@ -52,7 +52,6 @@ embeddings are stored in separate files.
 """
 
 from __future__ import print_function
-import argparse
 import glob
 import math
 import os
@@ -62,6 +61,7 @@ import threading
 
 import numpy as np
 import tensorflow as tf
+from tensorflow.python.client import device_lib
 
 flags = tf.app.flags
 
@@ -85,13 +85,26 @@ flags.DEFINE_float('confidence_base', 0.1, 'Base for l2 confidence function')
 flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate')
 flags.DEFINE_integer('num_concurrent_steps', 2,
                      'Number of threads to train with')
+flags.DEFINE_integer('num_readers', 4,
+                     'Number of threads to read the input data and feed it')
 flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for')
-flags.DEFINE_float('per_process_gpu_memory_fraction', 0.25,
-                   'Fraction of GPU memory to use')
+flags.DEFINE_float('per_process_gpu_memory_fraction', 0,
+                   'Fraction of GPU memory to use, 0 means allow_growth')
+flags.DEFINE_integer('num_gpus', 0,
+                     'Number of GPUs to use, 0 means all available')
 
 FLAGS = flags.FLAGS
 
 
+def log(message, *args, **kwargs):
+    tf.logging.info(message, *args, **kwargs)
+
+
+def get_available_gpus():
+    return [d.name for d in device_lib.list_local_devices()
+            if d.device_type == 'GPU']
+
+
 def embeddings_with_init(vocab_size, embedding_dim, name):
   """Creates and initializes the embedding tensors."""
   return tf.get_variable(name=name,
@@ -130,7 +143,7 @@ def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
   queued_global_row, queued_global_col, queued_count = tf.train.batch(
       [global_row, global_col, count],
       batch_size=1,
-      num_threads=4,
+      num_threads=FLAGS.num_readers,
       capacity=32)
 
   queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
@@ -164,16 +177,14 @@ def write_embeddings_to_disk(config, model, sess):
   # Row Embedding
   row_vocab_path = config.input_base_path + '/row_vocab.txt'
   row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
-  print('Writing row embeddings to:', row_embedding_output_path)
-  sys.stdout.flush()
+  log('Writing row embeddings to: %s', row_embedding_output_path)
   write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
                                  sess, model.row_embedding)
 
   # Column Embedding
   col_vocab_path = config.input_base_path + '/col_vocab.txt'
   col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
-  print('Writing column embeddings to:', col_embedding_output_path)
-  sys.stdout.flush()
+  log('Writing column embeddings to: %s', col_embedding_output_path)
   write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
                                  sess, model.col_embedding)
 
@@ -186,8 +197,7 @@ class SwivelModel(object):
     self._config = config
 
     # Create paths to input data files
-    print('Reading model from:', config.input_base_path)
-    sys.stdout.flush()
+    log('Reading model from: %s', config.input_base_path)
     count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
     row_sums_path = config.input_base_path + '/row_sums.txt'
     col_sums_path = config.input_base_path + '/col_sums.txt'
@@ -198,93 +208,129 @@ class SwivelModel(object):
 
     self.n_rows = len(row_sums)
     self.n_cols = len(col_sums)
-    print('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
-        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols))
-    sys.stdout.flush()
+    log('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d)',
+        self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
     self.n_submatrices = (self.n_rows * self.n_cols /
                           (config.submatrix_rows * config.submatrix_cols))
-    print('n_submatrices: %d' % (self.n_submatrices))
-    sys.stdout.flush()
-
-    # ===== CREATE VARIABLES ======
-    # embeddings
-    self.row_embedding = embeddings_with_init(
-      embedding_dim=config.embedding_size,
-      vocab_size=self.n_rows,
-      name='row_embedding')
-    self.col_embedding = embeddings_with_init(
-      embedding_dim=config.embedding_size,
-      vocab_size=self.n_cols,
-      name='col_embedding')
-    tf.summary.histogram('row_emb', self.row_embedding)
-    tf.summary.histogram('col_emb', self.col_embedding)
-
-    matrix_log_sum = math.log(np.sum(row_sums) + 1)
-    row_bias_init = [math.log(x + 1) for x in row_sums]
-    col_bias_init = [math.log(x + 1) for x in col_sums]
-    self.row_bias = tf.Variable(
-        row_bias_init, trainable=config.trainable_bias)
-    self.col_bias = tf.Variable(
-        col_bias_init, trainable=config.trainable_bias)
-    tf.summary.histogram('row_bias', self.row_bias)
-    tf.summary.histogram('col_bias', self.col_bias)
-
-    # ===== CREATE GRAPH =====
-
-    # Get input
-    global_row, global_col, count = count_matrix_input(
-      count_matrix_files, config.submatrix_rows, config.submatrix_cols)
-
-    # Fetch embeddings.
-    selected_row_embedding = tf.nn.embedding_lookup(
-        self.row_embedding, global_row)
-    selected_col_embedding = tf.nn.embedding_lookup(
-        self.col_embedding, global_col)
-
-    # Fetch biases.
-    selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
-    selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
-
-    # Multiply the row and column embeddings to generate predictions.
-    predictions = tf.matmul(
-        selected_row_embedding, selected_col_embedding, transpose_b=True)
-
-    # These binary masks separate zero from non-zero values.
-    count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
-    count_is_zero = 1 - tf.to_float(tf.cast(count, tf.bool))
-
-    objectives = count_is_nonzero * tf.log(count + 1e-30)
-    objectives -= tf.reshape(selected_row_bias, [config.submatrix_rows, 1])
-    objectives -= selected_col_bias
-    objectives += matrix_log_sum
-
-    err = predictions - objectives
-
-    # The confidence function scales the L2 loss based on the raw co-occurrence
-    # count.
-    l2_confidence = (config.confidence_base + config.confidence_scale * tf.pow(
-        count, config.confidence_exponent))
-
-    l2_loss = config.loss_multiplier * tf.reduce_sum(
-        0.5 * l2_confidence * err * err * count_is_nonzero)
-
-    sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
-        tf.nn.softplus(err) * count_is_zero)
-
-    self.loss = l2_loss + sigmoid_loss
-
-    tf.summary.scalar("l2_loss", l2_loss)
-    tf.summary.scalar("sigmoid_loss", sigmoid_loss)
-    tf.summary.scalar("loss", self.loss)
-
-    # Add optimizer.
-    self.global_step = tf.Variable(0, name='global_step')
-    opt = tf.train.AdagradOptimizer(config.learning_rate)
-    self.train_op = opt.minimize(self.loss, global_step=self.global_step)
-    self.saver = tf.train.Saver(sharded=True)
+    log('n_submatrices: %d', self.n_submatrices)
+
+    with tf.device('/cpu:0'):
+      # ===== CREATE VARIABLES ======
+      # Get input
+      global_row, global_col, count = count_matrix_input(
+        count_matrix_files, config.submatrix_rows, config.submatrix_cols)
+
+      # Embeddings
+      self.row_embedding = embeddings_with_init(
+        embedding_dim=config.embedding_size,
+        vocab_size=self.n_rows,
+        name='row_embedding')
+      self.col_embedding = embeddings_with_init(
+        embedding_dim=config.embedding_size,
+        vocab_size=self.n_cols,
+        name='col_embedding')
+      tf.summary.histogram('row_emb', self.row_embedding)
+      tf.summary.histogram('col_emb', self.col_embedding)
+
+      matrix_log_sum = math.log(np.sum(row_sums) + 1)
+      row_bias_init = [math.log(x + 1) for x in row_sums]
+      col_bias_init = [math.log(x + 1) for x in col_sums]
+      self.row_bias = tf.Variable(
+          row_bias_init, trainable=config.trainable_bias)
+      self.col_bias = tf.Variable(
+          col_bias_init, trainable=config.trainable_bias)
+      tf.summary.histogram('row_bias', self.row_bias)
+      tf.summary.histogram('col_bias', self.col_bias)
+
+      # Add optimizer
+      l2_losses = []
+      sigmoid_losses = []
+      self.global_step = tf.Variable(0, name='global_step')
+      opt = tf.train.AdagradOptimizer(config.learning_rate)
+
+      all_grads = []
+
+    devices = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)] \
+        if FLAGS.num_gpus > 0 else get_available_gpus()
+    self.devices_number = len(devices)
+    with tf.variable_scope(tf.get_variable_scope()):
+      for dev in devices:
+        with tf.device(dev):
+          with tf.name_scope(dev[1:].replace(':', '_')):
+            # ===== CREATE GRAPH =====
+            # Fetch embeddings.
+            selected_row_embedding = tf.nn.embedding_lookup(
+                self.row_embedding, global_row)
+            selected_col_embedding = tf.nn.embedding_lookup(
+                self.col_embedding, global_col)
+
+            # Fetch biases.
+            selected_row_bias = tf.nn.embedding_lookup(
+                [self.row_bias], global_row)
+            selected_col_bias = tf.nn.embedding_lookup(
+                [self.col_bias], global_col)
+
+            # Multiply the row and column embeddings to generate predictions.
+            predictions = tf.matmul(
+                selected_row_embedding, selected_col_embedding,
+                transpose_b=True)
+
+            # These binary masks separate zero from non-zero values.
+            count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
+            count_is_zero = 1 - count_is_nonzero
+
+            objectives = count_is_nonzero * tf.log(count + 1e-30)
+            objectives -= tf.reshape(
+                selected_row_bias, [config.submatrix_rows, 1])
+            objectives -= selected_col_bias
+            objectives += matrix_log_sum
+
+            err = predictions - objectives
+
+            # The confidence function scales the L2 loss based on the raw
+            # co-occurrence count.
+            l2_confidence = (config.confidence_base +
+                             config.confidence_scale * tf.pow(
+                                 count, config.confidence_exponent))
+
+            l2_loss = config.loss_multiplier * tf.reduce_sum(
+                0.5 * l2_confidence * err * err * count_is_nonzero)
+            l2_losses.append(tf.expand_dims(l2_loss, 0))
+
+            sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
+                tf.nn.softplus(err) * count_is_zero)
+            sigmoid_losses.append(tf.expand_dims(sigmoid_loss, 0))
+
+            loss = l2_loss + sigmoid_loss
+            grads = opt.compute_gradients(loss)
+            all_grads.append(grads)
+
+    with tf.device('/cpu:0'):
+      # ===== MERGE LOSSES =====
+      l2_loss = tf.reduce_mean(tf.concat(l2_losses, 0), 0, name="l2_loss")
+      sigmoid_loss = tf.reduce_mean(tf.concat(sigmoid_losses, 0), 0,
+                                    name="sigmoid_loss")
+      self.loss = l2_loss + sigmoid_loss
+      average = tf.train.ExponentialMovingAverage(0.8, self.global_step)
+      loss_average_op = average.apply((self.loss,))
+      tf.summary.scalar("l2_loss", l2_loss)
+      tf.summary.scalar("sigmoid_loss", sigmoid_loss)
+      tf.summary.scalar("loss", self.loss)
+
+      # Apply the gradients to adjust the shared variables.
+      apply_gradient_ops = []
+      for grads in all_grads:
+        apply_gradient_ops.append(opt.apply_gradients(
+            grads, global_step=self.global_step))
+
+      self.train_op = tf.group(loss_average_op, *apply_gradient_ops)
+      self.saver = tf.train.Saver(sharded=True)
 
 
 def main(_):
+  tf.logging.set_verbosity(tf.logging.INFO)
+  start_time = time.time()
+
   # Create the output path.  If this fails, it really ought to fail
   # now. :)
   if not os.path.isdir(FLAGS.output_base_path):
@@ -295,8 +341,13 @@ def main(_):
     model = SwivelModel(FLAGS)
 
     # Create a session for running Ops on the Graph.
-    gpu_options = tf.GPUOptions(
-        per_process_gpu_memory_fraction=FLAGS.per_process_gpu_memory_fraction)
+    gpu_opts = {}
+    if FLAGS.per_process_gpu_memory_fraction > 0:
+        gpu_opts["per_process_gpu_memory_fraction"] = \
+            FLAGS.per_process_gpu_memory_fraction
+    else:
+        gpu_opts["allow_growth"] = True
+    gpu_options = tf.GPUOptions(**gpu_opts)
     sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
 
     # Run the Op to initialize the variables.
@@ -309,21 +360,32 @@ def main(_):
     # Calculate how many steps each thread should run
     n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
         FLAGS.submatrix_rows * FLAGS.submatrix_cols)
-    n_steps_per_thread = n_total_steps / FLAGS.num_concurrent_steps
+    n_steps_per_thread = n_total_steps / (
+        FLAGS.num_concurrent_steps * model.devices_number)
     n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
     t0 = [time.time()]
+    n_steps_between_status_updates = 100
+    status_i = [0]
+    status_lock = threading.Lock()
+    msg = ('%%%dd/%%d submatrices trained (%%.1f%%%%), %%5.1f submatrices/sec |'
+           ' loss %%f') % len(str(n_submatrices_to_train))
 
     def TrainingFn():
       for _ in range(int(n_steps_per_thread)):
-        _, global_step = sess.run([model.train_op, model.global_step])
-        n_steps_between_status_updates = 100
-        if (global_step % n_steps_between_status_updates) == 0:
+        _, global_step, loss = sess.run((
+            model.train_op, model.global_step, model.loss))
+
+        show_status = False
+        with status_lock:
+          new_i = global_step // n_steps_between_status_updates
+          if new_i > status_i[0]:
+            status_i[0] = new_i
+            show_status = True
+        if show_status:
           elapsed = float(time.time() - t0[0])
-          print('%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
-              global_step, n_submatrices_to_train,
+          log(msg, global_step, n_submatrices_to_train,
               100.0 * global_step / n_submatrices_to_train,
-              n_steps_between_status_updates / elapsed))
-          sys.stdout.flush()
+              n_steps_between_status_updates / elapsed, loss)
           t0[0] = time.time()
 
     # Start training threads
@@ -343,8 +405,9 @@ def main(_):
     # Write out vectors
     write_embeddings_to_disk(FLAGS, model, sess)
 
-    #Shutdown
+    # Shutdown
     sess.close()
+    log("Elapsed: %s", time.time() - start_time)
 
 
 if __name__ == '__main__':