skip_thoughts_model.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Skip-Thoughts model for learning sentence vectors.
  16. The model is based on the paper:
  17. "Skip-Thought Vectors"
  18. Ryan Kiros, Yukun Zhu, Ruslan Salakhutdinov, Richard S. Zemel,
  19. Antonio Torralba, Raquel Urtasun, Sanja Fidler.
  20. https://papers.nips.cc/paper/5950-skip-thought-vectors.pdf
  21. Layer normalization is applied based on the paper:
  22. "Layer Normalization"
  23. Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton
  24. https://arxiv.org/abs/1607.06450
  25. """
  26. from __future__ import absolute_import
  27. from __future__ import division
  28. from __future__ import print_function
  29. import tensorflow as tf
  30. from skip_thoughts.ops import gru_cell
  31. from skip_thoughts.ops import input_ops
  32. def random_orthonormal_initializer(shape, dtype=tf.float32,
  33. partition_info=None): # pylint: disable=unused-argument
  34. """Variable initializer that produces a random orthonormal matrix."""
  35. if len(shape) != 2 or shape[0] != shape[1]:
  36. raise ValueError("Expecting square shape, got %s" % shape)
  37. _, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
  38. return u
  39. class SkipThoughtsModel(object):
  40. """Skip-thoughts model."""
  41. def __init__(self, config, mode="train", input_reader=None):
  42. """Basic setup. The actual TensorFlow graph is constructed in build().
  43. Args:
  44. config: Object containing configuration parameters.
  45. mode: "train", "eval" or "encode".
  46. input_reader: Subclass of tf.ReaderBase for reading the input serialized
  47. tf.Example protocol buffers. Defaults to TFRecordReader.
  48. Raises:
  49. ValueError: If mode is invalid.
  50. """
  51. if mode not in ["train", "eval", "encode"]:
  52. raise ValueError("Unrecognized mode: %s" % mode)
  53. self.config = config
  54. self.mode = mode
  55. self.reader = input_reader if input_reader else tf.TFRecordReader()
  56. # Initializer used for non-recurrent weights.
  57. self.uniform_initializer = tf.random_uniform_initializer(
  58. minval=-self.config.uniform_init_scale,
  59. maxval=self.config.uniform_init_scale)
  60. # Input sentences represented as sequences of word ids. "encode" is the
  61. # source sentence, "decode_pre" is the previous sentence and "decode_post"
  62. # is the next sentence.
  63. # Each is an int64 Tensor with shape [batch_size, padded_length].
  64. self.encode_ids = None
  65. self.decode_pre_ids = None
  66. self.decode_post_ids = None
  67. # Boolean masks distinguishing real words (1) from padded words (0).
  68. # Each is an int32 Tensor with shape [batch_size, padded_length].
  69. self.encode_mask = None
  70. self.decode_pre_mask = None
  71. self.decode_post_mask = None
  72. # Input sentences represented as sequences of word embeddings.
  73. # Each is a float32 Tensor with shape [batch_size, padded_length, emb_dim].
  74. self.encode_emb = None
  75. self.decode_pre_emb = None
  76. self.decode_post_emb = None
  77. # The output from the sentence encoder.
  78. # A float32 Tensor with shape [batch_size, num_gru_units].
  79. self.thought_vectors = None
  80. # The cross entropy losses and corresponding weights of the decoders. Used
  81. # for evaluation.
  82. self.target_cross_entropy_losses = []
  83. self.target_cross_entropy_loss_weights = []
  84. # The total loss to optimize.
  85. self.total_loss = None
  86. def build_inputs(self):
  87. """Builds the ops for reading input data.
  88. Outputs:
  89. self.encode_ids
  90. self.decode_pre_ids
  91. self.decode_post_ids
  92. self.encode_mask
  93. self.decode_pre_mask
  94. self.decode_post_mask
  95. """
  96. if self.mode == "encode":
  97. # Word embeddings are fed from an external vocabulary which has possibly
  98. # been expanded (see vocabulary_expansion.py).
  99. encode_ids = None
  100. decode_pre_ids = None
  101. decode_post_ids = None
  102. encode_mask = tf.placeholder(tf.int8, (None, None), name="encode_mask")
  103. decode_pre_mask = None
  104. decode_post_mask = None
  105. else:
  106. # Prefetch serialized tf.Example protos.
  107. input_queue = input_ops.prefetch_input_data(
  108. self.reader,
  109. self.config.input_file_pattern,
  110. shuffle=self.config.shuffle_input_data,
  111. capacity=self.config.input_queue_capacity,
  112. num_reader_threads=self.config.num_input_reader_threads)
  113. # Deserialize a batch.
  114. serialized = input_queue.dequeue_many(self.config.batch_size)
  115. encode, decode_pre, decode_post = input_ops.parse_example_batch(
  116. serialized)
  117. encode_ids = encode.ids
  118. decode_pre_ids = decode_pre.ids
  119. decode_post_ids = decode_post.ids
  120. encode_mask = encode.mask
  121. decode_pre_mask = decode_pre.mask
  122. decode_post_mask = decode_post.mask
  123. self.encode_ids = encode_ids
  124. self.decode_pre_ids = decode_pre_ids
  125. self.decode_post_ids = decode_post_ids
  126. self.encode_mask = encode_mask
  127. self.decode_pre_mask = decode_pre_mask
  128. self.decode_post_mask = decode_post_mask
  129. def build_word_embeddings(self):
  130. """Builds the word embeddings.
  131. Inputs:
  132. self.encode_ids
  133. self.decode_pre_ids
  134. self.decode_post_ids
  135. Outputs:
  136. self.encode_emb
  137. self.decode_pre_emb
  138. self.decode_post_emb
  139. """
  140. if self.mode == "encode":
  141. # Word embeddings are fed from an external vocabulary which has possibly
  142. # been expanded (see vocabulary_expansion.py).
  143. encode_emb = tf.placeholder(tf.float32, (
  144. None, None, self.config.word_embedding_dim), "encode_emb")
  145. # No sequences to decode.
  146. decode_pre_emb = None
  147. decode_post_emb = None
  148. else:
  149. word_emb = tf.get_variable(
  150. name="word_embedding",
  151. shape=[self.config.vocab_size, self.config.word_embedding_dim],
  152. initializer=self.uniform_initializer)
  153. encode_emb = tf.nn.embedding_lookup(word_emb, self.encode_ids)
  154. decode_pre_emb = tf.nn.embedding_lookup(word_emb, self.decode_pre_ids)
  155. decode_post_emb = tf.nn.embedding_lookup(word_emb, self.decode_post_ids)
  156. self.encode_emb = encode_emb
  157. self.decode_pre_emb = decode_pre_emb
  158. self.decode_post_emb = decode_post_emb
  159. def _initialize_gru_cell(self, num_units):
  160. """Initializes a GRU cell.
  161. The Variables of the GRU cell are initialized in a way that exactly matches
  162. the skip-thoughts paper: recurrent weights are initialized from random
  163. orthonormal matrices and non-recurrent weights are initialized from random
  164. uniform matrices.
  165. Args:
  166. num_units: Number of output units.
  167. Returns:
  168. cell: An instance of RNNCell with variable initializers that match the
  169. skip-thoughts paper.
  170. """
  171. return gru_cell.LayerNormGRUCell(
  172. num_units,
  173. w_initializer=self.uniform_initializer,
  174. u_initializer=random_orthonormal_initializer,
  175. b_initializer=tf.constant_initializer(0.0))
  176. def build_encoder(self):
  177. """Builds the sentence encoder.
  178. Inputs:
  179. self.encode_emb
  180. self.encode_mask
  181. Outputs:
  182. self.thought_vectors
  183. Raises:
  184. ValueError: if config.bidirectional_encoder is True and config.encoder_dim
  185. is odd.
  186. """
  187. with tf.variable_scope("encoder") as scope:
  188. length = tf.to_int32(tf.reduce_sum(self.encode_mask, 1), name="length")
  189. if self.config.bidirectional_encoder:
  190. if self.config.encoder_dim % 2:
  191. raise ValueError(
  192. "encoder_dim must be even when using a bidirectional encoder.")
  193. num_units = self.config.encoder_dim // 2
  194. cell_fw = self._initialize_gru_cell(num_units) # Forward encoder
  195. cell_bw = self._initialize_gru_cell(num_units) # Backward encoder
  196. _, states = tf.nn.bidirectional_dynamic_rnn(
  197. cell_fw=cell_fw,
  198. cell_bw=cell_bw,
  199. inputs=self.encode_emb,
  200. sequence_length=length,
  201. dtype=tf.float32,
  202. scope=scope)
  203. thought_vectors = tf.concat(states, 1, name="thought_vectors")
  204. else:
  205. cell = self._initialize_gru_cell(self.config.encoder_dim)
  206. _, state = tf.nn.dynamic_rnn(
  207. cell=cell,
  208. inputs=self.encode_emb,
  209. sequence_length=length,
  210. dtype=tf.float32,
  211. scope=scope)
  212. # Use an identity operation to name the Tensor in the Graph.
  213. thought_vectors = tf.identity(state, name="thought_vectors")
  214. self.thought_vectors = thought_vectors
  215. def _build_decoder(self, name, embeddings, targets, mask, initial_state,
  216. reuse_logits):
  217. """Builds a sentence decoder.
  218. Args:
  219. name: Decoder name.
  220. embeddings: Batch of sentences to decode; a float32 Tensor with shape
  221. [batch_size, padded_length, emb_dim].
  222. targets: Batch of target word ids; an int64 Tensor with shape
  223. [batch_size, padded_length].
  224. mask: A 0/1 Tensor with shape [batch_size, padded_length].
  225. initial_state: Initial state of the GRU. A float32 Tensor with shape
  226. [batch_size, num_gru_cells].
  227. reuse_logits: Whether to reuse the logits weights.
  228. """
  229. # Decoder RNN.
  230. cell = self._initialize_gru_cell(self.config.encoder_dim)
  231. with tf.variable_scope(name) as scope:
  232. # Add a padding word at the start of each sentence (to correspond to the
  233. # prediction of the first word) and remove the last word.
  234. decoder_input = tf.pad(
  235. embeddings[:, :-1, :], [[0, 0], [1, 0], [0, 0]], name="input")
  236. length = tf.reduce_sum(mask, 1, name="length")
  237. decoder_output, _ = tf.nn.dynamic_rnn(
  238. cell=cell,
  239. inputs=decoder_input,
  240. sequence_length=length,
  241. initial_state=initial_state,
  242. scope=scope)
  243. # Stack batch vertically.
  244. decoder_output = tf.reshape(decoder_output, [-1, self.config.encoder_dim])
  245. targets = tf.reshape(targets, [-1])
  246. weights = tf.to_float(tf.reshape(mask, [-1]))
  247. # Logits.
  248. with tf.variable_scope("logits", reuse=reuse_logits) as scope:
  249. logits = tf.contrib.layers.fully_connected(
  250. inputs=decoder_output,
  251. num_outputs=self.config.vocab_size,
  252. activation_fn=None,
  253. weights_initializer=self.uniform_initializer,
  254. scope=scope)
  255. losses = tf.nn.sparse_softmax_cross_entropy_with_logits(
  256. labels=targets, logits=logits)
  257. batch_loss = tf.reduce_sum(losses * weights)
  258. tf.losses.add_loss(batch_loss)
  259. tf.summary.scalar("losses/" + name, batch_loss)
  260. self.target_cross_entropy_losses.append(losses)
  261. self.target_cross_entropy_loss_weights.append(weights)
  262. def build_decoders(self):
  263. """Builds the sentence decoders.
  264. Inputs:
  265. self.decode_pre_emb
  266. self.decode_post_emb
  267. self.decode_pre_ids
  268. self.decode_post_ids
  269. self.decode_pre_mask
  270. self.decode_post_mask
  271. self.thought_vectors
  272. Outputs:
  273. self.target_cross_entropy_losses
  274. self.target_cross_entropy_loss_weights
  275. """
  276. if self.mode != "encode":
  277. # Pre-sentence decoder.
  278. self._build_decoder("decoder_pre", self.decode_pre_emb,
  279. self.decode_pre_ids, self.decode_pre_mask,
  280. self.thought_vectors, False)
  281. # Post-sentence decoder. Logits weights are reused.
  282. self._build_decoder("decoder_post", self.decode_post_emb,
  283. self.decode_post_ids, self.decode_post_mask,
  284. self.thought_vectors, True)
  285. def build_loss(self):
  286. """Builds the loss Tensor.
  287. Outputs:
  288. self.total_loss
  289. """
  290. if self.mode != "encode":
  291. total_loss = tf.losses.get_total_loss()
  292. tf.summary.scalar("losses/total", total_loss)
  293. self.total_loss = total_loss
  294. def build_global_step(self):
  295. """Builds the global step Tensor.
  296. Outputs:
  297. self.global_step
  298. """
  299. self.global_step = tf.contrib.framework.create_global_step()
  300. def build(self):
  301. """Creates all ops for training, evaluation or encoding."""
  302. self.build_inputs()
  303. self.build_word_embeddings()
  304. self.build_encoder()
  305. self.build_decoders()
  306. self.build_loss()
  307. self.build_global_step()