|
|
@@ -52,7 +52,6 @@ embeddings are stored in separate files.
|
|
|
"""
|
|
|
|
|
|
from __future__ import print_function
|
|
|
-import argparse
|
|
|
import glob
|
|
|
import math
|
|
|
import os
|
|
|
@@ -62,6 +61,7 @@ import threading
|
|
|
|
|
|
import numpy as np
|
|
|
import tensorflow as tf
|
|
|
+from tensorflow.python.client import device_lib
|
|
|
|
|
|
flags = tf.app.flags
|
|
|
|
|
|
@@ -85,13 +85,26 @@ flags.DEFINE_float('confidence_base', 0.1, 'Base for l2 confidence function')
|
|
|
flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate')
|
|
|
flags.DEFINE_integer('num_concurrent_steps', 2,
|
|
|
'Number of threads to train with')
|
|
|
+flags.DEFINE_integer('num_readers', 4,
|
|
|
+ 'Number of threads to read the input data and feed it')
|
|
|
flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for')
|
|
|
-flags.DEFINE_float('per_process_gpu_memory_fraction', 0.25,
|
|
|
- 'Fraction of GPU memory to use')
|
|
|
+flags.DEFINE_float('per_process_gpu_memory_fraction', 0,
|
|
|
+ 'Fraction of GPU memory to use, 0 means allow_growth')
|
|
|
+flags.DEFINE_integer('num_gpus', 0,
|
|
|
+ 'Number of GPUs to use, 0 means all available')
|
|
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
|
+def log(message, *args, **kwargs):
|
|
|
+ tf.logging.info(message, *args, **kwargs)
|
|
|
+
|
|
|
+
|
|
|
+def get_available_gpus():
|
|
|
+ return [d.name for d in device_lib.list_local_devices()
|
|
|
+ if d.device_type == 'GPU']
|
|
|
+
|
|
|
+
|
|
|
def embeddings_with_init(vocab_size, embedding_dim, name):
|
|
|
"""Creates and initializes the embedding tensors."""
|
|
|
return tf.get_variable(name=name,
|
|
|
@@ -130,7 +143,7 @@ def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
|
|
|
queued_global_row, queued_global_col, queued_count = tf.train.batch(
|
|
|
[global_row, global_col, count],
|
|
|
batch_size=1,
|
|
|
- num_threads=4,
|
|
|
+ num_threads=FLAGS.num_readers,
|
|
|
capacity=32)
|
|
|
|
|
|
queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
|
|
|
@@ -164,16 +177,14 @@ def write_embeddings_to_disk(config, model, sess):
|
|
|
# Row Embedding
|
|
|
row_vocab_path = config.input_base_path + '/row_vocab.txt'
|
|
|
row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
|
|
|
- print('Writing row embeddings to:', row_embedding_output_path)
|
|
|
- sys.stdout.flush()
|
|
|
+ log('Writing row embeddings to: %s', row_embedding_output_path)
|
|
|
write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
|
|
|
sess, model.row_embedding)
|
|
|
|
|
|
# Column Embedding
|
|
|
col_vocab_path = config.input_base_path + '/col_vocab.txt'
|
|
|
col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
|
|
|
- print('Writing column embeddings to:', col_embedding_output_path)
|
|
|
- sys.stdout.flush()
|
|
|
+ log('Writing column embeddings to: %s', col_embedding_output_path)
|
|
|
write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
|
|
|
sess, model.col_embedding)
|
|
|
|
|
|
@@ -186,8 +197,7 @@ class SwivelModel(object):
|
|
|
self._config = config
|
|
|
|
|
|
# Create paths to input data files
|
|
|
- print('Reading model from:', config.input_base_path)
|
|
|
- sys.stdout.flush()
|
|
|
+ log('Reading model from: %s', config.input_base_path)
|
|
|
count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
|
|
|
row_sums_path = config.input_base_path + '/row_sums.txt'
|
|
|
col_sums_path = config.input_base_path + '/col_sums.txt'
|
|
|
@@ -198,93 +208,129 @@ class SwivelModel(object):
|
|
|
|
|
|
self.n_rows = len(row_sums)
|
|
|
self.n_cols = len(col_sums)
|
|
|
- print('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
|
|
|
- self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols))
|
|
|
- sys.stdout.flush()
|
|
|
+ log('Matrix dim: (%d,%d) SubMatrix dim: (%d,%d)',
|
|
|
+ self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
|
|
|
self.n_submatrices = (self.n_rows * self.n_cols /
|
|
|
(config.submatrix_rows * config.submatrix_cols))
|
|
|
- print('n_submatrices: %d' % (self.n_submatrices))
|
|
|
- sys.stdout.flush()
|
|
|
-
|
|
|
- # ===== CREATE VARIABLES ======
|
|
|
- # embeddings
|
|
|
- self.row_embedding = embeddings_with_init(
|
|
|
- embedding_dim=config.embedding_size,
|
|
|
- vocab_size=self.n_rows,
|
|
|
- name='row_embedding')
|
|
|
- self.col_embedding = embeddings_with_init(
|
|
|
- embedding_dim=config.embedding_size,
|
|
|
- vocab_size=self.n_cols,
|
|
|
- name='col_embedding')
|
|
|
- tf.summary.histogram('row_emb', self.row_embedding)
|
|
|
- tf.summary.histogram('col_emb', self.col_embedding)
|
|
|
-
|
|
|
- matrix_log_sum = math.log(np.sum(row_sums) + 1)
|
|
|
- row_bias_init = [math.log(x + 1) for x in row_sums]
|
|
|
- col_bias_init = [math.log(x + 1) for x in col_sums]
|
|
|
- self.row_bias = tf.Variable(
|
|
|
- row_bias_init, trainable=config.trainable_bias)
|
|
|
- self.col_bias = tf.Variable(
|
|
|
- col_bias_init, trainable=config.trainable_bias)
|
|
|
- tf.summary.histogram('row_bias', self.row_bias)
|
|
|
- tf.summary.histogram('col_bias', self.col_bias)
|
|
|
-
|
|
|
- # ===== CREATE GRAPH =====
|
|
|
-
|
|
|
- # Get input
|
|
|
- global_row, global_col, count = count_matrix_input(
|
|
|
- count_matrix_files, config.submatrix_rows, config.submatrix_cols)
|
|
|
-
|
|
|
- # Fetch embeddings.
|
|
|
- selected_row_embedding = tf.nn.embedding_lookup(
|
|
|
- self.row_embedding, global_row)
|
|
|
- selected_col_embedding = tf.nn.embedding_lookup(
|
|
|
- self.col_embedding, global_col)
|
|
|
-
|
|
|
- # Fetch biases.
|
|
|
- selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
|
|
|
- selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
|
|
|
-
|
|
|
- # Multiply the row and column embeddings to generate predictions.
|
|
|
- predictions = tf.matmul(
|
|
|
- selected_row_embedding, selected_col_embedding, transpose_b=True)
|
|
|
-
|
|
|
- # These binary masks separate zero from non-zero values.
|
|
|
- count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
|
|
|
- count_is_zero = 1 - tf.to_float(tf.cast(count, tf.bool))
|
|
|
-
|
|
|
- objectives = count_is_nonzero * tf.log(count + 1e-30)
|
|
|
- objectives -= tf.reshape(selected_row_bias, [config.submatrix_rows, 1])
|
|
|
- objectives -= selected_col_bias
|
|
|
- objectives += matrix_log_sum
|
|
|
-
|
|
|
- err = predictions - objectives
|
|
|
-
|
|
|
- # The confidence function scales the L2 loss based on the raw co-occurrence
|
|
|
- # count.
|
|
|
- l2_confidence = (config.confidence_base + config.confidence_scale * tf.pow(
|
|
|
- count, config.confidence_exponent))
|
|
|
-
|
|
|
- l2_loss = config.loss_multiplier * tf.reduce_sum(
|
|
|
- 0.5 * l2_confidence * err * err * count_is_nonzero)
|
|
|
-
|
|
|
- sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
|
|
|
- tf.nn.softplus(err) * count_is_zero)
|
|
|
-
|
|
|
- self.loss = l2_loss + sigmoid_loss
|
|
|
-
|
|
|
- tf.summary.scalar("l2_loss", l2_loss)
|
|
|
- tf.summary.scalar("sigmoid_loss", sigmoid_loss)
|
|
|
- tf.summary.scalar("loss", self.loss)
|
|
|
-
|
|
|
- # Add optimizer.
|
|
|
- self.global_step = tf.Variable(0, name='global_step')
|
|
|
- opt = tf.train.AdagradOptimizer(config.learning_rate)
|
|
|
- self.train_op = opt.minimize(self.loss, global_step=self.global_step)
|
|
|
- self.saver = tf.train.Saver(sharded=True)
|
|
|
+ log('n_submatrices: %d', self.n_submatrices)
|
|
|
+
|
|
|
+ with tf.device('/cpu:0'):
|
|
|
+ # ===== CREATE VARIABLES ======
|
|
|
+ # Get input
|
|
|
+ global_row, global_col, count = count_matrix_input(
|
|
|
+ count_matrix_files, config.submatrix_rows, config.submatrix_cols)
|
|
|
+
|
|
|
+ # Embeddings
|
|
|
+ self.row_embedding = embeddings_with_init(
|
|
|
+ embedding_dim=config.embedding_size,
|
|
|
+ vocab_size=self.n_rows,
|
|
|
+ name='row_embedding')
|
|
|
+ self.col_embedding = embeddings_with_init(
|
|
|
+ embedding_dim=config.embedding_size,
|
|
|
+ vocab_size=self.n_cols,
|
|
|
+ name='col_embedding')
|
|
|
+ tf.summary.histogram('row_emb', self.row_embedding)
|
|
|
+ tf.summary.histogram('col_emb', self.col_embedding)
|
|
|
+
|
|
|
+ matrix_log_sum = math.log(np.sum(row_sums) + 1)
|
|
|
+ row_bias_init = [math.log(x + 1) for x in row_sums]
|
|
|
+ col_bias_init = [math.log(x + 1) for x in col_sums]
|
|
|
+ self.row_bias = tf.Variable(
|
|
|
+ row_bias_init, trainable=config.trainable_bias)
|
|
|
+ self.col_bias = tf.Variable(
|
|
|
+ col_bias_init, trainable=config.trainable_bias)
|
|
|
+ tf.summary.histogram('row_bias', self.row_bias)
|
|
|
+ tf.summary.histogram('col_bias', self.col_bias)
|
|
|
+
|
|
|
+ # Add optimizer
|
|
|
+ l2_losses = []
|
|
|
+ sigmoid_losses = []
|
|
|
+ self.global_step = tf.Variable(0, name='global_step')
|
|
|
+ opt = tf.train.AdagradOptimizer(config.learning_rate)
|
|
|
+
|
|
|
+ all_grads = []
|
|
|
+
|
|
|
+ devices = ['/gpu:%d' % i for i in range(FLAGS.num_gpus)] \
|
|
|
+ if FLAGS.num_gpus > 0 else get_available_gpus()
|
|
|
+ self.devices_number = len(devices)
|
|
|
+ with tf.variable_scope(tf.get_variable_scope()):
|
|
|
+ for dev in devices:
|
|
|
+ with tf.device(dev):
|
|
|
+ with tf.name_scope(dev[1:].replace(':', '_')):
|
|
|
+ # ===== CREATE GRAPH =====
|
|
|
+ # Fetch embeddings.
|
|
|
+ selected_row_embedding = tf.nn.embedding_lookup(
|
|
|
+ self.row_embedding, global_row)
|
|
|
+ selected_col_embedding = tf.nn.embedding_lookup(
|
|
|
+ self.col_embedding, global_col)
|
|
|
+
|
|
|
+ # Fetch biases.
|
|
|
+ selected_row_bias = tf.nn.embedding_lookup(
|
|
|
+ [self.row_bias], global_row)
|
|
|
+ selected_col_bias = tf.nn.embedding_lookup(
|
|
|
+ [self.col_bias], global_col)
|
|
|
+
|
|
|
+ # Multiply the row and column embeddings to generate predictions.
|
|
|
+ predictions = tf.matmul(
|
|
|
+ selected_row_embedding, selected_col_embedding,
|
|
|
+ transpose_b=True)
|
|
|
+
|
|
|
+ # These binary masks separate zero from non-zero values.
|
|
|
+ count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
|
|
|
+ count_is_zero = 1 - count_is_nonzero
|
|
|
+
|
|
|
+ objectives = count_is_nonzero * tf.log(count + 1e-30)
|
|
|
+ objectives -= tf.reshape(
|
|
|
+ selected_row_bias, [config.submatrix_rows, 1])
|
|
|
+ objectives -= selected_col_bias
|
|
|
+ objectives += matrix_log_sum
|
|
|
+
|
|
|
+ err = predictions - objectives
|
|
|
+
|
|
|
+ # The confidence function scales the L2 loss based on the raw
|
|
|
+ # co-occurrence count.
|
|
|
+ l2_confidence = (config.confidence_base +
|
|
|
+ config.confidence_scale * tf.pow(
|
|
|
+ count, config.confidence_exponent))
|
|
|
+
|
|
|
+ l2_loss = config.loss_multiplier * tf.reduce_sum(
|
|
|
+ 0.5 * l2_confidence * err * err * count_is_nonzero)
|
|
|
+ l2_losses.append(tf.expand_dims(l2_loss, 0))
|
|
|
+
|
|
|
+ sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
|
|
|
+ tf.nn.softplus(err) * count_is_zero)
|
|
|
+ sigmoid_losses.append(tf.expand_dims(sigmoid_loss, 0))
|
|
|
+
|
|
|
+ loss = l2_loss + sigmoid_loss
|
|
|
+ grads = opt.compute_gradients(loss)
|
|
|
+ all_grads.append(grads)
|
|
|
+
|
|
|
+ with tf.device('/cpu:0'):
|
|
|
+ # ===== MERGE LOSSES =====
|
|
|
+ l2_loss = tf.reduce_mean(tf.concat(l2_losses, 0), 0, name="l2_loss")
|
|
|
+ sigmoid_loss = tf.reduce_mean(tf.concat(sigmoid_losses, 0), 0,
|
|
|
+ name="sigmoid_loss")
|
|
|
+ self.loss = l2_loss + sigmoid_loss
|
|
|
+ average = tf.train.ExponentialMovingAverage(0.8, self.global_step)
|
|
|
+ loss_average_op = average.apply((self.loss,))
|
|
|
+ tf.summary.scalar("l2_loss", l2_loss)
|
|
|
+ tf.summary.scalar("sigmoid_loss", sigmoid_loss)
|
|
|
+ tf.summary.scalar("loss", self.loss)
|
|
|
+
|
|
|
+ # Apply the gradients to adjust the shared variables.
|
|
|
+ apply_gradient_ops = []
|
|
|
+ for grads in all_grads:
|
|
|
+ apply_gradient_ops.append(opt.apply_gradients(
|
|
|
+ grads, global_step=self.global_step))
|
|
|
+
|
|
|
+ self.train_op = tf.group(loss_average_op, *apply_gradient_ops)
|
|
|
+ self.saver = tf.train.Saver(sharded=True)
|
|
|
|
|
|
|
|
|
def main(_):
|
|
|
+ tf.logging.set_verbosity(tf.logging.INFO)
|
|
|
+ start_time = time.time()
|
|
|
+
|
|
|
# Create the output path. If this fails, it really ought to fail
|
|
|
# now. :)
|
|
|
if not os.path.isdir(FLAGS.output_base_path):
|
|
|
@@ -295,8 +341,13 @@ def main(_):
|
|
|
model = SwivelModel(FLAGS)
|
|
|
|
|
|
# Create a session for running Ops on the Graph.
|
|
|
- gpu_options = tf.GPUOptions(
|
|
|
- per_process_gpu_memory_fraction=FLAGS.per_process_gpu_memory_fraction)
|
|
|
+ gpu_opts = {}
|
|
|
+ if FLAGS.per_process_gpu_memory_fraction > 0:
|
|
|
+ gpu_opts["per_process_gpu_memory_fraction"] = \
|
|
|
+ FLAGS.per_process_gpu_memory_fraction
|
|
|
+ else:
|
|
|
+ gpu_opts["allow_growth"] = True
|
|
|
+ gpu_options = tf.GPUOptions(**gpu_opts)
|
|
|
sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
|
|
|
|
|
|
# Run the Op to initialize the variables.
|
|
|
@@ -309,21 +360,32 @@ def main(_):
|
|
|
# Calculate how many steps each thread should run
|
|
|
n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
|
|
|
FLAGS.submatrix_rows * FLAGS.submatrix_cols)
|
|
|
- n_steps_per_thread = n_total_steps / FLAGS.num_concurrent_steps
|
|
|
+ n_steps_per_thread = n_total_steps / (
|
|
|
+ FLAGS.num_concurrent_steps * model.devices_number)
|
|
|
n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
|
|
|
t0 = [time.time()]
|
|
|
+ n_steps_between_status_updates = 100
|
|
|
+ status_i = [0]
|
|
|
+ status_lock = threading.Lock()
|
|
|
+ msg = ('%%%dd/%%d submatrices trained (%%.1f%%%%), %%5.1f submatrices/sec |'
|
|
|
+ ' loss %%f') % len(str(n_submatrices_to_train))
|
|
|
|
|
|
def TrainingFn():
|
|
|
for _ in range(int(n_steps_per_thread)):
|
|
|
- _, global_step = sess.run([model.train_op, model.global_step])
|
|
|
- n_steps_between_status_updates = 100
|
|
|
- if (global_step % n_steps_between_status_updates) == 0:
|
|
|
+ _, global_step, loss = sess.run((
|
|
|
+ model.train_op, model.global_step, model.loss))
|
|
|
+
|
|
|
+ show_status = False
|
|
|
+ with status_lock:
|
|
|
+ new_i = global_step // n_steps_between_status_updates
|
|
|
+ if new_i > status_i[0]:
|
|
|
+ status_i[0] = new_i
|
|
|
+ show_status = True
|
|
|
+ if show_status:
|
|
|
elapsed = float(time.time() - t0[0])
|
|
|
- print('%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
|
|
|
- global_step, n_submatrices_to_train,
|
|
|
+ log(msg, global_step, n_submatrices_to_train,
|
|
|
100.0 * global_step / n_submatrices_to_train,
|
|
|
- n_steps_between_status_updates / elapsed))
|
|
|
- sys.stdout.flush()
|
|
|
+ n_steps_between_status_updates / elapsed, loss)
|
|
|
t0[0] = time.time()
|
|
|
|
|
|
# Start training threads
|
|
|
@@ -343,8 +405,9 @@ def main(_):
|
|
|
# Write out vectors
|
|
|
write_embeddings_to_disk(FLAGS, model, sess)
|
|
|
|
|
|
- #Shutdown
|
|
|
+ # Shutdown
|
|
|
sess.close()
|
|
|
+ log("Elapsed: %s", time.time() - start_time)
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|