skip_thoughts_encoder.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Class for encoding text using a trained SkipThoughtsModel.
  16. Example usage:
  17. g = tf.Graph()
  18. with g.as_default():
  19. encoder = SkipThoughtsEncoder(embeddings)
  20. restore_fn = encoder.build_graph_from_config(model_config, checkpoint_path)
  21. with tf.Session(graph=g) as sess:
  22. restore_fn(sess)
  23. skip_thought_vectors = encoder.encode(sess, data)
  24. """
  25. from __future__ import absolute_import
  26. from __future__ import division
  27. from __future__ import print_function
  28. import os.path
  29. import nltk
  30. import nltk.tokenize
  31. import numpy as np
  32. import tensorflow as tf
  33. from skip_thoughts import skip_thoughts_model
  34. from skip_thoughts.data import special_words
  35. def _pad(seq, target_len):
  36. """Pads a sequence of word embeddings up to the target length.
  37. Args:
  38. seq: Sequence of word embeddings.
  39. target_len: Desired padded sequence length.
  40. Returns:
  41. embeddings: Input sequence padded with zero embeddings up to the target
  42. length.
  43. mask: A 0/1 vector with zeros corresponding to padded embeddings.
  44. Raises:
  45. ValueError: If len(seq) is not in the interval (0, target_len].
  46. """
  47. seq_len = len(seq)
  48. if seq_len <= 0 or seq_len > target_len:
  49. raise ValueError("Expected 0 < len(seq) <= %d, got %d" % (target_len,
  50. seq_len))
  51. emb_dim = seq[0].shape[0]
  52. padded_seq = np.zeros(shape=(target_len, emb_dim), dtype=seq[0].dtype)
  53. mask = np.zeros(shape=(target_len,), dtype=np.int8)
  54. for i in range(seq_len):
  55. padded_seq[i] = seq[i]
  56. mask[i] = 1
  57. return padded_seq, mask
  58. def _batch_and_pad(sequences):
  59. """Batches and pads sequences of word embeddings into a 2D array.
  60. Args:
  61. sequences: A list of batch_size sequences of word embeddings.
  62. Returns:
  63. embeddings: A numpy array with shape [batch_size, padded_length, emb_dim].
  64. mask: A numpy 0/1 array with shape [batch_size, padded_length] with zeros
  65. corresponding to padded elements.
  66. """
  67. batch_embeddings = []
  68. batch_mask = []
  69. batch_len = max([len(seq) for seq in sequences])
  70. for seq in sequences:
  71. embeddings, mask = _pad(seq, batch_len)
  72. batch_embeddings.append(embeddings)
  73. batch_mask.append(mask)
  74. return np.array(batch_embeddings), np.array(batch_mask)
  75. class SkipThoughtsEncoder(object):
  76. """Skip-thoughts sentence encoder."""
  77. def __init__(self, embeddings):
  78. """Initializes the encoder.
  79. Args:
  80. embeddings: Dictionary of word to embedding vector (1D numpy array).
  81. """
  82. self._sentence_detector = nltk.data.load("tokenizers/punkt/english.pickle")
  83. self._embeddings = embeddings
  84. def _create_restore_fn(self, checkpoint_path, saver):
  85. """Creates a function that restores a model from checkpoint.
  86. Args:
  87. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  88. file.
  89. saver: Saver for restoring variables from the checkpoint file.
  90. Returns:
  91. restore_fn: A function such that restore_fn(sess) loads model variables
  92. from the checkpoint file.
  93. Raises:
  94. ValueError: If checkpoint_path does not refer to a checkpoint file or a
  95. directory containing a checkpoint file.
  96. """
  97. if tf.gfile.IsDirectory(checkpoint_path):
  98. latest_checkpoint = tf.train.latest_checkpoint(checkpoint_path)
  99. if not latest_checkpoint:
  100. raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
  101. checkpoint_path = latest_checkpoint
  102. def _restore_fn(sess):
  103. tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
  104. saver.restore(sess, checkpoint_path)
  105. tf.logging.info("Successfully loaded checkpoint: %s",
  106. os.path.basename(checkpoint_path))
  107. return _restore_fn
  108. def build_graph_from_config(self, model_config, checkpoint_path):
  109. """Builds the inference graph from a configuration object.
  110. Args:
  111. model_config: Object containing configuration for building the model.
  112. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  113. file.
  114. Returns:
  115. restore_fn: A function such that restore_fn(sess) loads model variables
  116. from the checkpoint file.
  117. """
  118. tf.logging.info("Building model.")
  119. model = skip_thoughts_model.SkipThoughtsModel(model_config, mode="encode")
  120. model.build()
  121. saver = tf.train.Saver()
  122. return self._create_restore_fn(checkpoint_path, saver)
  123. def build_graph_from_proto(self, graph_def_file, saver_def_file,
  124. checkpoint_path):
  125. """Builds the inference graph from serialized GraphDef and SaverDef protos.
  126. Args:
  127. graph_def_file: File containing a serialized GraphDef proto.
  128. saver_def_file: File containing a serialized SaverDef proto.
  129. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  130. file.
  131. Returns:
  132. restore_fn: A function such that restore_fn(sess) loads model variables
  133. from the checkpoint file.
  134. """
  135. # Load the Graph.
  136. tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
  137. graph_def = tf.GraphDef()
  138. with tf.gfile.FastGFile(graph_def_file, "rb") as f:
  139. graph_def.ParseFromString(f.read())
  140. tf.import_graph_def(graph_def, name="")
  141. # Load the Saver.
  142. tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
  143. saver_def = tf.train.SaverDef()
  144. with tf.gfile.FastGFile(saver_def_file, "rb") as f:
  145. saver_def.ParseFromString(f.read())
  146. saver = tf.train.Saver(saver_def=saver_def)
  147. return self._create_restore_fn(checkpoint_path, saver)
  148. def _tokenize(self, item):
  149. """Tokenizes an input string into a list of words."""
  150. tokenized = []
  151. for s in self._sentence_detector.tokenize(item):
  152. tokenized.extend(nltk.tokenize.word_tokenize(s))
  153. return tokenized
  154. def _word_to_embedding(self, w):
  155. """Returns the embedding of a word."""
  156. return self._embeddings.get(w, self._embeddings[special_words.UNK])
  157. def _preprocess(self, data, use_eos):
  158. """Preprocesses text for the encoder.
  159. Args:
  160. data: A list of input strings.
  161. use_eos: Whether to append the end-of-sentence word to each sentence.
  162. Returns:
  163. embeddings: A list of word embedding sequences corresponding to the input
  164. strings.
  165. """
  166. preprocessed_data = []
  167. for item in data:
  168. tokenized = self._tokenize(item)
  169. if use_eos:
  170. tokenized.append(special_words.EOS)
  171. preprocessed_data.append([self._word_to_embedding(w) for w in tokenized])
  172. return preprocessed_data
  173. def encode(self,
  174. sess,
  175. data,
  176. use_norm=True,
  177. verbose=True,
  178. batch_size=128,
  179. use_eos=False):
  180. """Encodes a sequence of sentences as skip-thought vectors.
  181. Args:
  182. sess: TensorFlow Session.
  183. data: A list of input strings.
  184. use_norm: Whether to normalize skip-thought vectors to unit L2 norm.
  185. verbose: Whether to log every batch.
  186. batch_size: Batch size for the encoder.
  187. use_eos: Whether to append the end-of-sentence word to each input
  188. sentence.
  189. Returns:
  190. thought_vectors: A list of numpy arrays corresponding to the skip-thought
  191. encodings of sentences in 'data'.
  192. """
  193. data = self._preprocess(data, use_eos)
  194. thought_vectors = []
  195. batch_indices = np.arange(0, len(data), batch_size)
  196. for batch, start_index in enumerate(batch_indices):
  197. if verbose:
  198. tf.logging.info("Batch %d / %d.", batch, len(batch_indices))
  199. embeddings, mask = _batch_and_pad(
  200. data[start_index:start_index + batch_size])
  201. feed_dict = {
  202. "encode_emb:0": embeddings,
  203. "encode_mask:0": mask,
  204. }
  205. thought_vectors.extend(
  206. sess.run("encoder/thought_vectors:0", feed_dict=feed_dict))
  207. if use_norm:
  208. thought_vectors = [v / np.linalg.norm(v) for v in thought_vectors]
  209. return thought_vectors