vocabulary_expansion.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Compute an expanded vocabulary of embeddings using a word2vec model.
  16. This script loads the word embeddings from a trained skip-thoughts model and
  17. from a trained word2vec model (typically with a larger vocabulary). It trains a
  18. linear regression model without regularization to learn a linear mapping from
  19. the word2vec embedding space to the skip-thoughts embedding space. The model is
  20. then applied to all words in the word2vec vocabulary, yielding vectors in the
  21. skip-thoughts word embedding space for the union of the two vocabularies.
  22. The linear regression task is to learn a parameter matrix W to minimize
  23. || X - Y * W ||^2,
  24. where X is a matrix of skip-thoughts embeddings of shape [num_words, dim1],
  25. Y is a matrix of word2vec embeddings of shape [num_words, dim2], and W is a
  26. matrix of shape [dim2, dim1].
  27. This is based on the "Translation Matrix" method from the paper:
  28. "Exploiting Similarities among Languages for Machine Translation"
  29. Tomas Mikolov, Quoc V. Le, Ilya Sutskever
  30. https://arxiv.org/abs/1309.4168
  31. """
  32. from __future__ import absolute_import
  33. from __future__ import division
  34. from __future__ import print_function
  35. import collections
  36. import os.path
  37. import gensim.models
  38. import numpy as np
  39. import sklearn.linear_model
  40. import tensorflow as tf
  41. FLAGS = tf.flags.FLAGS
  42. tf.flags.DEFINE_string("skip_thoughts_model", None,
  43. "Checkpoint file or directory containing a checkpoint "
  44. "file.")
  45. tf.flags.DEFINE_string("skip_thoughts_vocab", None,
  46. "Path to vocabulary file containing a list of newline-"
  47. "separated words where the word id is the "
  48. "corresponding 0-based index in the file.")
  49. tf.flags.DEFINE_string("word2vec_model", None,
  50. "File containing a word2vec model in binary format.")
  51. tf.flags.DEFINE_string("output_dir", None, "Output directory.")
  52. tf.logging.set_verbosity(tf.logging.INFO)
  53. def _load_skip_thoughts_embeddings(checkpoint_path):
  54. """Loads the embedding matrix from a skip-thoughts model checkpoint.
  55. Args:
  56. checkpoint_path: Model checkpoint file or directory containing a checkpoint
  57. file.
  58. Returns:
  59. word_embedding: A numpy array of shape [vocab_size, embedding_dim].
  60. Raises:
  61. ValueError: If no checkpoint file matches checkpoint_path.
  62. """
  63. if tf.gfile.IsDirectory(checkpoint_path):
  64. checkpoint_file = tf.train.latest_checkpoint(checkpoint_path)
  65. if not checkpoint_file:
  66. raise ValueError("No checkpoint file found in %s" % checkpoint_path)
  67. else:
  68. checkpoint_file = checkpoint_path
  69. tf.logging.info("Loading skip-thoughts embedding matrix from %s",
  70. checkpoint_file)
  71. reader = tf.train.NewCheckpointReader(checkpoint_file)
  72. word_embedding = reader.get_tensor("word_embedding")
  73. tf.logging.info("Loaded skip-thoughts embedding matrix of shape %s",
  74. word_embedding.shape)
  75. return word_embedding
  76. def _load_vocabulary(filename):
  77. """Loads a vocabulary file.
  78. Args:
  79. filename: Path to text file containing newline-separated words.
  80. Returns:
  81. vocab: A dictionary mapping word to word id.
  82. """
  83. tf.logging.info("Reading vocabulary from %s", filename)
  84. vocab = collections.OrderedDict()
  85. with tf.gfile.GFile(filename, mode="r") as f:
  86. for i, line in enumerate(f):
  87. word = line.decode("utf-8").strip()
  88. assert word not in vocab, "Attempting to add word twice: %s" % word
  89. vocab[word] = i
  90. tf.logging.info("Read vocabulary of size %d", len(vocab))
  91. return vocab
  92. def _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab, word2vec):
  93. """Runs vocabulary expansion on a skip-thoughts model using a word2vec model.
  94. Args:
  95. skip_thoughts_emb: A numpy array of shape [skip_thoughts_vocab_size,
  96. skip_thoughts_embedding_dim].
  97. skip_thoughts_vocab: A dictionary of word to id.
  98. word2vec: An instance of gensim.models.Word2Vec.
  99. Returns:
  100. combined_emb: A dictionary mapping words to embedding vectors.
  101. """
  102. # Find words shared between the two vocabularies.
  103. tf.logging.info("Finding shared words")
  104. shared_words = [w for w in word2vec.vocab if w in skip_thoughts_vocab]
  105. # Select embedding vectors for shared words.
  106. tf.logging.info("Selecting embeddings for %d shared words", len(shared_words))
  107. shared_st_emb = skip_thoughts_emb[[
  108. skip_thoughts_vocab[w] for w in shared_words
  109. ]]
  110. shared_w2v_emb = word2vec[shared_words]
  111. # Train a linear regression model on the shared embedding vectors.
  112. tf.logging.info("Training linear regression model")
  113. model = sklearn.linear_model.LinearRegression()
  114. model.fit(shared_w2v_emb, shared_st_emb)
  115. # Create the expanded vocabulary.
  116. tf.logging.info("Creating embeddings for expanded vocabuary")
  117. combined_emb = collections.OrderedDict()
  118. for w in word2vec.vocab:
  119. # Ignore words with underscores (spaces).
  120. if "_" not in w:
  121. w_emb = model.predict(word2vec[w].reshape(1, -1))
  122. combined_emb[w] = w_emb.reshape(-1)
  123. for w in skip_thoughts_vocab:
  124. combined_emb[w] = skip_thoughts_emb[skip_thoughts_vocab[w]]
  125. tf.logging.info("Created expanded vocabulary of %d words", len(combined_emb))
  126. return combined_emb
  127. def main(unused_argv):
  128. if not FLAGS.skip_thoughts_model:
  129. raise ValueError("--skip_thoughts_model is required.")
  130. if not FLAGS.skip_thoughts_vocab:
  131. raise ValueError("--skip_thoughts_vocab is required.")
  132. if not FLAGS.word2vec_model:
  133. raise ValueError("--word2vec_model is required.")
  134. if not FLAGS.output_dir:
  135. raise ValueError("--output_dir is required.")
  136. if not tf.gfile.IsDirectory(FLAGS.output_dir):
  137. tf.gfile.MakeDirs(FLAGS.output_dir)
  138. # Load the skip-thoughts embeddings and vocabulary.
  139. skip_thoughts_emb = _load_skip_thoughts_embeddings(FLAGS.skip_thoughts_model)
  140. skip_thoughts_vocab = _load_vocabulary(FLAGS.skip_thoughts_vocab)
  141. # Load the Word2Vec model.
  142. word2vec = gensim.models.Word2Vec.load_word2vec_format(
  143. FLAGS.word2vec_model, binary=True)
  144. # Run vocabulary expansion.
  145. embedding_map = _expand_vocabulary(skip_thoughts_emb, skip_thoughts_vocab,
  146. word2vec)
  147. # Save the output.
  148. vocab = embedding_map.keys()
  149. vocab_file = os.path.join(FLAGS.output_dir, "vocab.txt")
  150. with tf.gfile.GFile(vocab_file, "w") as f:
  151. f.write("\n".join(vocab))
  152. tf.logging.info("Wrote vocabulary file to %s", vocab_file)
  153. embeddings = np.array(embedding_map.values())
  154. embeddings_file = os.path.join(FLAGS.output_dir, "embeddings.npy")
  155. np.save(embeddings_file, embeddings)
  156. tf.logging.info("Wrote embeddings file to %s", embeddings_file)
  157. if __name__ == "__main__":
  158. tf.app.run()