swivel.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485
  1. #!/usr/bin/env python3
  2. #
  3. # Copyright 2016 Google Inc. All Rights Reserved.
  4. # Copyright 2017 Sourced Technologies S. L.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """Submatrix-wise Vector Embedding Learner.
  18. Implementation of SwiVel algorithm described at:
  19. http://arxiv.org/abs/1602.02215
  20. This program expects an input directory that contains the following files.
  21. row_vocab.txt, col_vocab.txt
  22. The row an column vocabulary files. Each file should contain one token per
  23. line; these will be used to generate a tab-separate file containing the
  24. trained embeddings.
  25. row_sums.txt, col_sum.txt
  26. The matrix row and column marginal sums. Each file should contain one
  27. decimal floating point number per line which corresponds to the marginal
  28. count of the matrix for that row or column.
  29. shards.recs
  30. A file containing the sub-matrix shards, stored as TFRecords. Each shard is
  31. expected to be a serialzed tf.Example protocol buffer with the following
  32. properties:
  33. global_row: the global row indicies contained in the shard
  34. global_col: the global column indicies contained in the shard
  35. sparse_local_row, sparse_local_col, sparse_value: three parallel arrays
  36. that are a sparse representation of the submatrix counts.
  37. It will generate embeddings, training from the input directory for
  38. the specified number of epochs. When complete, it will output the trained
  39. vectors to a tab-separated file that contains one line per embedding. Row and
  40. column embeddings are stored in separate files.
  41. """
  42. import glob
  43. import math
  44. import os
  45. import threading
  46. import time
  47. import numpy as np
  48. import tensorflow as tf
  49. from tensorflow.python.client import device_lib
  50. flags = tf.app.flags
  51. flags.DEFINE_string("input_base_path", None,
  52. "Directory containing input shards, vocabularies, "
  53. "and marginals.")
  54. flags.DEFINE_string("output_base_path", None,
  55. "Path where to write the trained embeddings.")
  56. flags.DEFINE_integer("embedding_size", 300, "Size of the embeddings")
  57. flags.DEFINE_boolean("trainable_bias", False, "Biases are trainable")
  58. flags.DEFINE_integer("submatrix_rows", 4096,
  59. "Rows in each training submatrix. This must match "
  60. "the training data.")
  61. flags.DEFINE_integer("submatrix_cols", 4096,
  62. "Rows in each training submatrix. This must match "
  63. "the training data.")
  64. flags.DEFINE_float("loss_multiplier", 1.0 / 4096,
  65. "constant multiplier on loss.")
  66. flags.DEFINE_float("confidence_exponent", 0.5,
  67. "Exponent for l2 confidence function")
  68. flags.DEFINE_float("confidence_scale", 0.25,
  69. "Scale for l2 confidence function")
  70. flags.DEFINE_float("confidence_base", 0.1, "Base for l2 confidence function")
  71. flags.DEFINE_float("learning_rate", 1.0, "Initial learning rate")
  72. flags.DEFINE_string("optimizer", "Adagrad",
  73. "SGD optimizer (tf.train.*Optimizer)")
  74. flags.DEFINE_integer("num_concurrent_steps", 2,
  75. "Number of threads to train with")
  76. flags.DEFINE_integer("num_readers", 4,
  77. "Number of threads to read the input data and feed it")
  78. flags.DEFINE_float("num_epochs", 40, "Number epochs to train for")
  79. flags.DEFINE_float("per_process_gpu_memory_fraction", 0,
  80. "Fraction of GPU memory to use, 0 means allow_growth")
  81. flags.DEFINE_integer("num_gpus", 0,
  82. "Number of GPUs to use, 0 means all available")
  83. flags.DEFINE_string("logs", "",
  84. "Path for TensorBoard logs (empty value disables them)")
  85. FLAGS = flags.FLAGS
  86. def log(message, *args, **kwargs):
  87. tf.logging.info(message, *args, **kwargs)
  88. def get_available_gpus():
  89. return [d.name for d in device_lib.list_local_devices()
  90. if d.device_type == "GPU"]
  91. def embeddings_with_init(vocab_size, embedding_dim, name):
  92. """Creates and initializes the embedding tensors."""
  93. return tf.get_variable(name=name,
  94. shape=[vocab_size, embedding_dim],
  95. initializer=tf.random_normal_initializer(
  96. stddev=math.sqrt(1.0 / embedding_dim)))
  97. def count_matrix_input(filenames, submatrix_rows, submatrix_cols):
  98. """Reads submatrix shards from disk."""
  99. filename_queue = tf.train.string_input_producer(filenames)
  100. reader = tf.WholeFileReader()
  101. _, serialized_example = reader.read(filename_queue)
  102. features = tf.parse_single_example(
  103. serialized_example,
  104. features={
  105. "global_row": tf.FixedLenFeature([submatrix_rows], dtype=tf.int64),
  106. "global_col": tf.FixedLenFeature([submatrix_cols], dtype=tf.int64),
  107. "sparse_local_row": tf.VarLenFeature(dtype=tf.int64),
  108. "sparse_local_col": tf.VarLenFeature(dtype=tf.int64),
  109. "sparse_value": tf.VarLenFeature(dtype=tf.float32)
  110. })
  111. global_row = features["global_row"]
  112. global_col = features["global_col"]
  113. sparse_local_row = features["sparse_local_row"].values
  114. sparse_local_col = features["sparse_local_col"].values
  115. sparse_count = features["sparse_value"].values
  116. sparse_indices = tf.concat(axis=1, values=[tf.expand_dims(sparse_local_row, 1),
  117. tf.expand_dims(sparse_local_col, 1)])
  118. count = tf.sparse_to_dense(sparse_indices, [submatrix_rows, submatrix_cols],
  119. sparse_count, validate_indices=False)
  120. queued_global_row, queued_global_col, queued_count = tf.train.batch(
  121. [global_row, global_col, count],
  122. batch_size=1,
  123. num_threads=FLAGS.num_readers,
  124. capacity=32)
  125. queued_global_row = tf.reshape(queued_global_row, [submatrix_rows])
  126. queued_global_col = tf.reshape(queued_global_col, [submatrix_cols])
  127. queued_count = tf.reshape(queued_count, [submatrix_rows, submatrix_cols])
  128. return queued_global_row, queued_global_col, queued_count
  129. def read_marginals_file(filename):
  130. """Reads text file with one number per line to an array."""
  131. with open(filename) as lines:
  132. return [float(line) for line in lines]
  133. def write_embedding_tensor_to_disk(vocab_path, output_path, sess, embedding):
  134. """Writes tensor to output_path as tsv"""
  135. # Fetch the embedding values from the model
  136. embeddings = sess.run(embedding)
  137. with open(output_path, "w") as out_f:
  138. with open(vocab_path) as vocab_f:
  139. for index, word in enumerate(vocab_f):
  140. word = word.strip()
  141. embedding = embeddings[index]
  142. out_f.write(word + "\t" + "\t".join(
  143. [str(x) for x in embedding]) + "\n")
  144. def write_embeddings_to_disk(config, model, sess):
  145. """Writes row and column embeddings disk"""
  146. # Row Embedding
  147. row_vocab_path = config.input_base_path + "/row_vocab.txt"
  148. row_embedding_output_path = config.output_base_path + "/row_embedding.tsv"
  149. log("Writing row embeddings to: %s", row_embedding_output_path)
  150. write_embedding_tensor_to_disk(row_vocab_path, row_embedding_output_path,
  151. sess, model.row_embedding)
  152. # Column Embedding
  153. col_vocab_path = config.input_base_path + "/col_vocab.txt"
  154. col_embedding_output_path = config.output_base_path + "/col_embedding.tsv"
  155. log("Writing column embeddings to: %s", col_embedding_output_path)
  156. write_embedding_tensor_to_disk(col_vocab_path, col_embedding_output_path,
  157. sess, model.col_embedding)
  158. class SwivelModel:
  159. """Small class to gather needed pieces from a Graph being built."""
  160. def __init__(self, config):
  161. """Construct graph for dmc."""
  162. self._config = config
  163. # Create paths to input data files
  164. log("Reading model from: %s", config.input_base_path)
  165. count_matrix_files = glob.glob(os.path.join(config.input_base_path, "shard-*.pb"))
  166. row_sums_path = os.path.join(config.input_base_path, "row_sums.txt")
  167. col_sums_path = os.path.join(config.input_base_path, "col_sums.txt")
  168. # Read marginals
  169. row_sums = read_marginals_file(row_sums_path)
  170. col_sums = read_marginals_file(col_sums_path)
  171. self.n_rows = len(row_sums)
  172. self.n_cols = len(col_sums)
  173. log("Matrix dim: (%d,%d) SubMatrix dim: (%d,%d)",
  174. self.n_rows, self.n_cols, config.submatrix_rows,
  175. config.submatrix_cols)
  176. if self.n_cols < config.submatrix_cols:
  177. raise ValueError(
  178. "submatrix_cols={0} can not be bigger than columns number={1} "
  179. "(specify submatrix_cols={1})".format(config.submatrix_cols, self.n_cols))
  180. if self.n_rows < config.submatrix_rows:
  181. raise ValueError(
  182. "submatrix_rows={0} can not be bigger than rows number={1} "
  183. "(specify submatrix_rows={1})".format(config.submatrix_rows, self.n_cols))
  184. self.n_submatrices = (
  185. self.n_rows * self.n_cols / (config.submatrix_rows * config.submatrix_cols))
  186. log("n_submatrices: %d", self.n_submatrices)
  187. with tf.device("/cpu:0"):
  188. # ===== CREATE VARIABLES ======
  189. # Get input
  190. global_row, global_col, count = count_matrix_input(
  191. count_matrix_files, config.submatrix_rows,
  192. config.submatrix_cols)
  193. # Embeddings
  194. self.row_embedding = embeddings_with_init(
  195. embedding_dim=config.embedding_size,
  196. vocab_size=self.n_rows,
  197. name="row_embedding")
  198. self.col_embedding = embeddings_with_init(
  199. embedding_dim=config.embedding_size,
  200. vocab_size=self.n_cols,
  201. name="col_embedding")
  202. tf.summary.histogram("row_emb", self.row_embedding)
  203. tf.summary.histogram("col_emb", self.col_embedding)
  204. matrix_log_sum = math.log(np.sum(row_sums) + 1)
  205. row_bias_init = [math.log(x + 1) for x in row_sums]
  206. col_bias_init = [math.log(x + 1) for x in col_sums]
  207. self.row_bias = tf.Variable(
  208. row_bias_init, trainable=config.trainable_bias)
  209. self.col_bias = tf.Variable(
  210. col_bias_init, trainable=config.trainable_bias)
  211. tf.summary.histogram("row_bias", self.row_bias)
  212. tf.summary.histogram("col_bias", self.col_bias)
  213. # Add optimizer
  214. l2_losses = []
  215. sigmoid_losses = []
  216. self.global_step = tf.Variable(0, name="global_step")
  217. learning_rate = tf.Variable(config.learning_rate,
  218. name="learning_rate")
  219. opt = getattr(tf.train, FLAGS.optimizer + "Optimizer")(
  220. learning_rate)
  221. tf.summary.scalar("learning_rate", learning_rate)
  222. all_grads = []
  223. devices = ["/gpu:%d" % i for i in range(FLAGS.num_gpus)] \
  224. if FLAGS.num_gpus > 0 else get_available_gpus()
  225. self.devices_number = len(devices)
  226. if not self.devices_number:
  227. devices = ["/cpu:0"]
  228. self.devices_number = 1
  229. for dev in devices:
  230. with tf.device(dev):
  231. with tf.name_scope(dev[1:].replace(":", "_")):
  232. # ===== CREATE GRAPH =====
  233. # Fetch embeddings.
  234. selected_row_embedding = tf.nn.embedding_lookup(
  235. self.row_embedding, global_row)
  236. selected_col_embedding = tf.nn.embedding_lookup(
  237. self.col_embedding, global_col)
  238. # Fetch biases.
  239. selected_row_bias = tf.nn.embedding_lookup(
  240. [self.row_bias], global_row)
  241. selected_col_bias = tf.nn.embedding_lookup(
  242. [self.col_bias], global_col)
  243. # Multiply the row and column embeddings to generate
  244. # predictions.
  245. predictions = tf.matmul(
  246. selected_row_embedding, selected_col_embedding,
  247. transpose_b=True)
  248. # These binary masks separate zero from non-zero values.
  249. count_is_nonzero = tf.to_float(tf.cast(count, tf.bool))
  250. count_is_zero = 1 - count_is_nonzero
  251. objectives = count_is_nonzero * tf.log(count + 1e-30)
  252. objectives -= tf.reshape(
  253. selected_row_bias, [config.submatrix_rows, 1])
  254. objectives -= selected_col_bias
  255. objectives += matrix_log_sum
  256. err = predictions - objectives
  257. # The confidence function scales the L2 loss based on
  258. # the raw co-occurrence count.
  259. l2_confidence = config.confidence_base + config.confidence_scale * tf.pow(
  260. count, config.confidence_exponent)
  261. l2_loss = config.loss_multiplier * tf.reduce_sum(
  262. 0.5 * l2_confidence * err * err * count_is_nonzero)
  263. l2_losses.append(tf.expand_dims(l2_loss, 0))
  264. sigmoid_loss = config.loss_multiplier * tf.reduce_sum(
  265. tf.nn.softplus(err) * count_is_zero)
  266. sigmoid_losses.append(tf.expand_dims(sigmoid_loss, 0))
  267. loss = l2_loss + sigmoid_loss
  268. grads = opt.compute_gradients(loss)
  269. all_grads.append(grads)
  270. with tf.device("/cpu:0"):
  271. # ===== MERGE LOSSES =====
  272. l2_loss = tf.reduce_mean(tf.concat(axis=0, values=l2_losses), 0,
  273. name="l2_loss")
  274. sigmoid_loss = tf.reduce_mean(
  275. tf.concat(axis=0, values=sigmoid_losses), 0,
  276. name="sigmoid_loss")
  277. overall_loss = l2_loss + sigmoid_loss
  278. average = tf.train.ExponentialMovingAverage(0.999)
  279. loss_average_op = average.apply(
  280. (overall_loss, l2_loss, sigmoid_loss))
  281. self.loss = average.average(overall_loss)
  282. tf.summary.scalar("overall_loss", self.loss)
  283. tf.summary.scalar("l2_loss", average.average(l2_loss))
  284. tf.summary.scalar("sigmoid_loss", average.average(sigmoid_loss))
  285. # Apply the gradients to adjust the shared variables.
  286. apply_gradient_ops = []
  287. for grads in all_grads:
  288. apply_gradient_ops.append(opt.apply_gradients(
  289. grads, global_step=self.global_step))
  290. self.train_op = tf.group(loss_average_op, *apply_gradient_ops)
  291. self.saver = tf.train.Saver(sharded=True)
  292. def initialize_summary(self, sess):
  293. log("creating TensorBoard stuff...")
  294. self.summary = tf.summary.merge_all()
  295. self.writer = tf.summary.FileWriter(FLAGS.logs, sess.graph)
  296. projector_config = \
  297. tf.contrib.tensorboard.plugins.projector.ProjectorConfig()
  298. embedding_config = projector_config.embeddings.add()
  299. length = min(10000, self.n_rows, self.n_cols)
  300. self.embedding10k = tf.Variable(
  301. tf.zeros((length, self._config.embedding_size)),
  302. name="top10k_embedding")
  303. embedding_config.tensor_name = self.embedding10k.name
  304. embedding_config.metadata_path = os.path.join(
  305. self._config.input_base_path, "row_vocab.txt")
  306. tf.contrib.tensorboard.plugins.projector.visualize_embeddings(
  307. self.writer, projector_config)
  308. self.saver = tf.train.Saver((self.embedding10k,), max_to_keep=1)
  309. def write_summary(self, sess):
  310. log("writing the summary...")
  311. length = min(10000, self.n_rows, self.n_cols)
  312. assignment = self.embedding10k.assign(
  313. (self.row_embedding[:length] + self.col_embedding[:length]) / 2)
  314. summary, _, global_step = sess.run(
  315. (self.summary, assignment, self.global_step))
  316. self.writer.add_summary(summary, global_step)
  317. self.saver.save(
  318. sess, os.path.join(FLAGS.logs, "embeddings10k.checkpoint"),
  319. global_step)
  320. def main(_):
  321. tf.logging.set_verbosity(tf.logging.INFO)
  322. start_time = time.time()
  323. # Create the output path. If this fails, it really ought to fail now. :)
  324. if not os.path.isdir(FLAGS.output_base_path):
  325. os.makedirs(FLAGS.output_base_path)
  326. # Create and run model
  327. with tf.Graph().as_default():
  328. log("creating the model...")
  329. model = SwivelModel(FLAGS)
  330. # Create a session for running Ops on the Graph.
  331. gpu_opts = {}
  332. if FLAGS.per_process_gpu_memory_fraction > 0:
  333. gpu_opts["per_process_gpu_memory_fraction"] = \
  334. FLAGS.per_process_gpu_memory_fraction
  335. else:
  336. gpu_opts["allow_growth"] = True
  337. gpu_options = tf.GPUOptions(**gpu_opts)
  338. sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
  339. if FLAGS.logs:
  340. model.initialize_summary(sess)
  341. # Run the Op to initialize the variables.
  342. log("initializing the variables...")
  343. sess.run(tf.global_variables_initializer())
  344. # Start feeding input
  345. log("starting the input threads...")
  346. coord = tf.train.Coordinator()
  347. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  348. # Calculate how many steps each thread should run
  349. n_total_steps = int(FLAGS.num_epochs * model.n_rows * model.n_cols) / (
  350. FLAGS.submatrix_rows * FLAGS.submatrix_cols)
  351. n_steps_per_thread = n_total_steps / (
  352. FLAGS.num_concurrent_steps * model.devices_number)
  353. n_submatrices_to_train = model.n_submatrices * FLAGS.num_epochs
  354. t0 = [time.time()]
  355. n_steps_between_status_updates = 100
  356. n_steps_between_summary_updates = 10000
  357. status_i = [0, 0]
  358. status_lock = threading.Lock()
  359. msg = ("%%%dd/%%d submatrices trained (%%.1f%%%%), "
  360. "%%5.1f submatrices/sec | loss %%f") % \
  361. len(str(n_submatrices_to_train))
  362. def TrainingFn():
  363. for _ in range(int(n_steps_per_thread)):
  364. _, global_step, loss = sess.run((
  365. model.train_op, model.global_step, model.loss))
  366. show_status = False
  367. update_summary = False
  368. with status_lock:
  369. new_i = global_step // n_steps_between_status_updates
  370. if new_i > status_i[0]:
  371. status_i[0] = new_i
  372. show_status = True
  373. new_i = global_step // n_steps_between_summary_updates
  374. if new_i > status_i[1]:
  375. status_i[1] = new_i
  376. update_summary = True
  377. if show_status:
  378. elapsed = float(time.time() - t0[0])
  379. log(msg, global_step, n_submatrices_to_train,
  380. 100.0 * global_step / n_submatrices_to_train,
  381. n_steps_between_status_updates / elapsed, loss)
  382. t0[0] = time.time()
  383. if update_summary and FLAGS.logs:
  384. model.write_summary(sess)
  385. # Start training threads
  386. train_threads = []
  387. for _ in range(FLAGS.num_concurrent_steps):
  388. t = threading.Thread(target=TrainingFn)
  389. train_threads.append(t)
  390. t.start()
  391. # Wait for threads to finish.
  392. for t in train_threads:
  393. t.join()
  394. coord.request_stop()
  395. coord.join(threads)
  396. # Write out vectors
  397. write_embeddings_to_disk(FLAGS, model, sess)
  398. # Shutdown
  399. sess.close()
  400. log("Elapsed: %s", time.time() - start_time)
  401. if __name__ == "__main__":
  402. tf.app.run()