swivel.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #!/usr/bin/env python
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """Submatrix-wise Vector Embedding Learner.
  17. Implementation of SwiVel algorithm described at:
  18. http://arxiv.org/abs/1602.02215
  19. This program expects an input directory that contains the following files.
  20. row_vocab.txt, col_vocab.txt
  21. The row an column vocabulary files. Each file should contain one token per
  22. line; these will be used to generate a tab-separate file containing the
  23. trained embeddings.
  24. row_sums.txt, col_sum.txt
  25. The matrix row and column marginal sums. Each file should contain one
  26. decimal floating point number per line which corresponds to the marginal
  27. count of the matrix for that row or column.
  28. shards.recs
  29. A file containing the sub-matrix shards, stored as TFRecords. Each shard is
  30. expected to be a serialzed tf.Example protocol buffer with the following
  31. properties:
  32. global_row: the global row indicies contained in the shard
  33. global_col: the global column indicies contained in the shard
  34. sparse_local_row, sparse_local_col, sparse_value: three parallel arrays
  35. that are a sparse representation of the submatrix counts.
  36. It will generate embeddings, training from the input directory for the specified
  37. number of epochs. When complete, it will output the trained vectors to a
  38. tab-separated file that contains one line per embedding. Row and column
  39. embeddings are stored in separate files.
  40. """
  41. import argparse
  42. import glob
  43. import math
  44. import os
  45. import sys
  46. import time
  47. import threading
  48. import numpy as np
  49. import tensorflow as tf
  50. flags = tf.app.flags
  51. flags.DEFINE_string('input_base_path', '/tmp/swivel_data',
  52. 'Directory containing input shards, vocabularies, '
  53. 'and marginals.')
  54. flags.DEFINE_string('output_base_path', '/tmp/swivel_data',
  55. 'Path where to write the trained embeddings.')
  56. flags.DEFINE_integer('embedding_size', 300, 'Size of the embeddings')
  57. flags.DEFINE_boolean('trainable_bias', False, 'Biases are trainable')
  58. flags.DEFINE_integer('submatrix_rows', 4096, 'Rows in each training submatrix. '
  59. 'This must match the training data.')
  60. flags.DEFINE_integer('submatrix_cols', 4096, 'Rows in each training submatrix. '
  61. 'This must match the training data.')
  62. flags.DEFINE_float('loss_multiplier', 1.0 / 4096,
  63. 'constant multiplier on loss.')
  64. flags.DEFINE_float('confidence_exponent', 0.5,
  65. 'Exponent for l2 confidence function')
  66. flags.DEFINE_float('confidence_scale', 0.25, 'Scale for l2 confidence function')
  67. flags.DEFINE_float('confidence_base', 0.1, 'Base for l2 confidence function')
  68. flags.DEFINE_float('learning_rate', 1.0, 'Initial learning rate')
  69. flags.DEFINE_integer('num_concurrent_steps', 2,
  70. 'Number of threads to train with')
  71. flags.DEFINE_float('num_epochs', 40, 'Number epochs to train for')
  72. flags.DEFINE_float('per_process_gpu_memory_fraction', 0.25,
  73. 'Fraction of GPU memory to use')
  74. FLAGS = flags.FLAGS
  75. def embeddings_with_init(vocab_size, embedding_dim, name):
  76. """Creates and initializes the embedding tensors."""
  77. return tf.get_variable(name=name,
  78. shape=[vocab_size, embedding_dim],
  79. initializer=tf.random_normal_initializer(
  80. stddev=math.sqrt(1.0 / embedding_dim)))
  81. def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
  82. """Reads submatrix shards from disk."""
  83. filename_queue = tf.train.string_input_producer(filenames)
  84. reader = tf.WholeFileReader()
  85. _, serialized_example = reader.read(filename_queue)
  86. features = tf.parse_single_example(
  87. serialized_example,
  88. features={
  89. 'global_row': tf.FixedLenFeature([submatrix_rows], dtype=tf.int64),
  90. 'global_col': tf.FixedLenFeature([submatrix_cols], dtype=tf.int64),
  91. 'sparse_local_row': tf.VarLenFeature(dtype=tf.int64),
  92. 'sparse_local_col': tf.VarLenFeature(dtype=tf.int64),
  93. 'sparse_value': tf.VarLenFeature(dtype=tf.float32)
  94. })
  95. global_row = features['global_row']
  96. global_col = features['global_col']
  97. sparse_local_row = features['sparse_local_row'].values
  98. sparse_local_col = features['sparse_local_col'].values
  99. sparse_count = features['sparse_value'].values
  100. sparse_indices = tf.concat(1, [tf.expand_dims(sparse_local_row, 1),
  101. tf.expand_dims(sparse_local_col, 1)])
  102. count = tf.sparse_to_dense(sparse_indices, [submatrix_rows, submatrix_cols],
  103. sparse_count)
  104. queued_global_row, queued_global_col, queued_count = tf.train.batch(
  105. [global_row, global_col, count],
  106. batch_size=1,
  107. num_threads=4,
  108. capacity=32)
  109. queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
  110. queued_global_col = tf.reshape(queued_global_col, [submatrix_cols])
  111. queued_count = tf.reshape(queued_count, [submatrix_rows, submatrix_cols])
  112. return queued_global_row, queued_global_col, queued_count
  113. def read_marginals_file(filename):
  114. """Reads text file with one number per line to an array."""
  115. with open(filename) as lines:
  116. return [float(line) for line in lines]
  117. def write_embedding_tensor_to_disk(vocab_path, output_path, sess, embedding):
  118. """Writes tensor to output_path as tsv"""
  119. # Fetch the embedding values from the model
  120. embeddings = sess.run(embedding)
  121. with open(output_path, 'w') as out_f:
  122. with open(vocab_path) as vocab_f:
  123. for index, word in enumerate(vocab_f):
  124. word = word.strip()
  125. embedding = embeddings[index]
  126. out_f.write(word + '\t' + '\t'.join([str(x) for x in embedding]) + '\n')
  127. def write_embeddings_to_disk(config, model, sess):
  128. """Writes row and column embeddings disk"""
  129. # Row Embedding
  130. row_vocab_path = config.input_base_path + '/row_vocab.txt'
  131. row_embedding_output_path = config.output_base_path + '/row_embedding.tsv'
  132. print 'Writing row embeddings to:', row_embedding_output_path
  133. write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
  134. sess, model.row_embedding)
  135. # Column Embedding
  136. col_vocab_path = config.input_base_path + '/col_vocab.txt'
  137. col_embedding_output_path = config.output_base_path + '/col_embedding.tsv'
  138. print 'Writing column embeddings to:', col_embedding_output_path
  139. write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
  140. sess, model.col_embedding)
  141. class SwivelModel(object):
  142. """Small class to gather needed pieces from a Graph being built."""
  143. def __init__(self, config):
  144. """Construct graph for dmc."""
  145. self._config = config
  146. # Create paths to input data files
  147. print 'Reading model from:', config.input_base_path
  148. count_matrix_files = glob.glob(config.input_base_path + '/shard-*.pb')
  149. row_sums_path = config.input_base_path + '/row_sums.txt'
  150. col_sums_path = config.input_base_path + '/col_sums.txt'
  151. # Read marginals
  152. row_sums = read_marginals_file(row_sums_path)
  153. col_sums = read_marginals_file(col_sums_path)
  154. self.n_rows = len(row_sums)
  155. self.n_cols = len(col_sums)
  156. print 'Matrix dim: (%d,%d) SubMatrix dim: (%d,%d) ' % (
  157. self.n_rows, self.n_cols, config.submatrix_rows, config.submatrix_cols)
  158. self.n_submatrices = (self.n_rows * self.n_cols /
  159. (config.submatrix_rows * config.submatrix_cols))
  160. print 'n_submatrices: %d' % (self.n_submatrices)
  161. # ===== CREATE VARIABLES ======
  162. with tf.device('/cpu:0'):
  163. # embeddings
  164. self.row_embedding = embeddings_with_init(
  165. embedding_dim=config.embedding_size,
  166. vocab_size=self.n_rows,
  167. name='row_embedding')
  168. self.col_embedding = embeddings_with_init(
  169. embedding_dim=config.embedding_size,
  170. vocab_size=self.n_cols,
  171. name='col_embedding')
  172. tf.histogram_summary('row_emb', self.row_embedding)
  173. tf.histogram_summary('col_emb', self.col_embedding)
  174. matrix_log_sum = math.log(np.sum(row_sums) + 1)
  175. row_bias_init = [math.log(x + 1) for x in row_sums]
  176. col_bias_init = [math.log(x + 1) for x in col_sums]
  177. self.row_bias = tf.Variable(row_bias_init,
  178. trainable=config.trainable_bias)
  179. self.col_bias = tf.Variable(col_bias_init,
  180. trainable=config.trainable_bias)
  181. tf.histogram_summary('row_bias', self.row_bias)
  182. tf.histogram_summary('col_bias', self.col_bias)
  183. # ===== CREATE GRAPH =====
  184. # Get input
  185. with tf.device('/cpu:0'):
  186. global_row, global_col, count = count_matrix_input(
  187. count_matrix_files, config.submatrix_rows, config.submatrix_cols)
  188. # Fetch embeddings.
  189. selected_row_embedding = tf.nn.embedding_lookup(self.row_embedding,
  190. global_row)
  191. selected_col_embedding = tf.nn.embedding_lookup(self.col_embedding,
  192. global_col)
  193. # Fetch biases.
  194. selected_row_bias = tf.nn.embedding_lookup([self.row_bias], global_row)
  195. selected_col_bias = tf.nn.embedding_lookup([self.col_bias], global_col)
  196. # Multiply the row and column embeddings to generate predictions.
  197. predictions = tf.matmul(
  198. selected_row_embedding, selected_col_embedding, transpose_b=True)
  199. # These binary masks separate zero from non-zero values.
  200. count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
  201. count_is_zero = 1 - tf.to_float(tf.cast(count, tf.bool))
  202. objectives = count_is_nonzero * tf.log(count + 1e-30)
  203. objectives -= tf.reshape(selected_row_bias, [config.submatrix_rows, 1])
  204. objectives -= selected_col_bias
  205. objectives += matrix_log_sum
  206. err = predictions - objectives
  207. # The confidence function scales the L2 loss based on the raw co-occurrence
  208. # count.
  209. l2_confidence = (config.confidence_base + config.confidence_scale * tf.pow(
  210. count, config.confidence_exponent))
  211. l2_loss = config.loss_multiplier * tf.reduce_sum(
  212. 0.5 * l2_confidence * err * err * count_is_nonzero)
  213. sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
  214. tf.nn.softplus(err) * count_is_zero)
  215. self.loss = l2_loss + sigmoid_loss
  216. tf.scalar_summary("l2_loss", l2_loss)
  217. tf.scalar_summary("sigmoid_loss", sigmoid_loss)
  218. tf.scalar_summary("loss", self.loss)
  219. # Add optimizer.
  220. self.global_step = tf.Variable(0, name='global_step')
  221. opt = tf.train.AdagradOptimizer(config.learning_rate)
  222. self.train_op = opt.minimize(self.loss, global_step=self.global_step)
  223. self.saver = tf.train.Saver(sharded=True)
  224. def main(_):
  225. # Create the output path. If this fails, it really ought to fail
  226. # now. :)
  227. if not os.path.isdir(FLAGS.output_base_path):
  228. os.makedirs(FLAGS.output_base_path)
  229. # Create and run model
  230. with tf.Graph().as_default():
  231. model = SwivelModel(FLAGS)
  232. # Create a session for running Ops on the Graph.
  233. gpu_options = tf.GPUOptions(
  234. per_process_gpu_memory_fraction=FLAGS.per_process_gpu_memory_fraction)
  235. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  236. # Run the Op to initialize the variables.
  237. sess.run(tf.initialize_all_variables())
  238. # Start feeding input
  239. coord = tf.train.Coordinator()
  240. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  241. # Calculate how many steps each thread should run
  242. n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
  243. FLAGS.submatrix_rows * FLAGS.submatrix_cols)
  244. n_steps_per_thread = n_total_steps / FLAGS.num_concurrent_steps
  245. n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
  246. t0 = [time.time()]
  247. def TrainingFn():
  248. for _ in range(n_steps_per_thread):
  249. _, global_step = sess.run([model.train_op, model.global_step])
  250. n_steps_between_status_updates = 100
  251. if (global_step % n_steps_between_status_updates) == 0:
  252. elapsed = float(time.time() - t0[0])
  253. print '%d/%d submatrices trained (%.1f%%), %.1f submatrices/sec' % (
  254. global_step, n_submatrices_to_train,
  255. 100.0 * global_step / n_submatrices_to_train,
  256. n_steps_between_status_updates / elapsed)
  257. t0[0] = time.time()
  258. # Start training threads
  259. train_threads = []
  260. for _ in range(FLAGS.num_concurrent_steps):
  261. t = threading.Thread(target=TrainingFn)
  262. train_threads.append(t)
  263. t.start()
  264. # Wait for threads to finish.
  265. for t in train_threads:
  266. t.join()
  267. coord.request_stop()
  268. coord.join(threads)
  269. # Write out vectors
  270. write_embeddings_to_disk(FLAGS, model, sess)
  271. #Shutdown
  272. sess.close()
  273. if __name__ == '__main__':
  274. tf.app.run()