seq2seq_attention_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Sequence-to-Sequence with attention model for text summarization.
  16. """
  17. from collections import namedtuple
  18. import numpy as np
  19. import tensorflow as tf
  20. import seq2seq_lib
  21. HParams = namedtuple('HParams',
  22. 'mode, min_lr, lr, batch_size, '
  23. 'enc_layers, enc_timesteps, dec_timesteps, '
  24. 'min_input_len, num_hidden, emb_dim, max_grad_norm, '
  25. 'num_softmax_samples')
  26. def _extract_argmax_and_embed(embedding, output_projection=None,
  27. update_embedding=True):
  28. """Get a loop_function that extracts the previous symbol and embeds it.
  29. Args:
  30. embedding: embedding tensor for symbols.
  31. output_projection: None or a pair (W, B). If provided, each fed previous
  32. output will first be multiplied by W and added B.
  33. update_embedding: Boolean; if False, the gradients will not propagate
  34. through the embeddings.
  35. Returns:
  36. A loop function.
  37. """
  38. def loop_function(prev, _):
  39. """function that feed previous model output rather than ground truth."""
  40. if output_projection is not None:
  41. prev = tf.nn.xw_plus_b(
  42. prev, output_projection[0], output_projection[1])
  43. prev_symbol = tf.argmax(prev, 1)
  44. # Note that gradients will not propagate through the second parameter of
  45. # embedding_lookup.
  46. emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
  47. if not update_embedding:
  48. emb_prev = tf.stop_gradient(emb_prev)
  49. return emb_prev
  50. return loop_function
  51. class Seq2SeqAttentionModel(object):
  52. """Wrapper for Tensorflow model graph for text sum vectors."""
  53. def __init__(self, hps, vocab, num_gpus=0):
  54. self._hps = hps
  55. self._vocab = vocab
  56. self._num_gpus = num_gpus
  57. self._cur_gpu = 0
  58. def run_train_step(self, sess, article_batch, abstract_batch, targets,
  59. article_lens, abstract_lens, loss_weights):
  60. to_return = [self._train_op, self._summaries, self._loss, self.global_step]
  61. return sess.run(to_return,
  62. feed_dict={self._articles: article_batch,
  63. self._abstracts: abstract_batch,
  64. self._targets: targets,
  65. self._article_lens: article_lens,
  66. self._abstract_lens: abstract_lens,
  67. self._loss_weights: loss_weights})
  68. def run_eval_step(self, sess, article_batch, abstract_batch, targets,
  69. article_lens, abstract_lens, loss_weights):
  70. to_return = [self._summaries, self._loss, self.global_step]
  71. return sess.run(to_return,
  72. feed_dict={self._articles: article_batch,
  73. self._abstracts: abstract_batch,
  74. self._targets: targets,
  75. self._article_lens: article_lens,
  76. self._abstract_lens: abstract_lens,
  77. self._loss_weights: loss_weights})
  78. def run_decode_step(self, sess, article_batch, abstract_batch, targets,
  79. article_lens, abstract_lens, loss_weights):
  80. to_return = [self._outputs, self.global_step]
  81. return sess.run(to_return,
  82. feed_dict={self._articles: article_batch,
  83. self._abstracts: abstract_batch,
  84. self._targets: targets,
  85. self._article_lens: article_lens,
  86. self._abstract_lens: abstract_lens,
  87. self._loss_weights: loss_weights})
  88. def _next_device(self):
  89. """Round robin the gpu device. (Reserve last gpu for expensive op)."""
  90. if self._num_gpus == 0:
  91. return ''
  92. dev = '/gpu:%d' % self._cur_gpu
  93. self._cur_gpu = (self._cur_gpu + 1) % (self._num_gpus-1)
  94. return dev
  95. def _get_gpu(self, gpu_id):
  96. if self._num_gpus <= 0 or gpu_id >= self._num_gpus:
  97. return ''
  98. return '/gpu:%d' % gpu_id
  99. def _add_placeholders(self):
  100. """Inputs to be fed to the graph."""
  101. hps = self._hps
  102. self._articles = tf.placeholder(tf.int32,
  103. [hps.batch_size, hps.enc_timesteps],
  104. name='articles')
  105. self._abstracts = tf.placeholder(tf.int32,
  106. [hps.batch_size, hps.dec_timesteps],
  107. name='abstracts')
  108. self._targets = tf.placeholder(tf.int32,
  109. [hps.batch_size, hps.dec_timesteps],
  110. name='targets')
  111. self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
  112. name='article_lens')
  113. self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
  114. name='abstract_lens')
  115. self._loss_weights = tf.placeholder(tf.float32,
  116. [hps.batch_size, hps.dec_timesteps],
  117. name='loss_weights')
  118. def _add_seq2seq(self):
  119. hps = self._hps
  120. vsize = self._vocab.NumIds()
  121. with tf.variable_scope('seq2seq'):
  122. encoder_inputs = tf.unpack(tf.transpose(self._articles))
  123. decoder_inputs = tf.unpack(tf.transpose(self._abstracts))
  124. targets = tf.unpack(tf.transpose(self._targets))
  125. loss_weights = tf.unpack(tf.transpose(self._loss_weights))
  126. article_lens = self._article_lens
  127. # Embedding shared by the input and outputs.
  128. with tf.variable_scope('embedding'), tf.device('/cpu:0'):
  129. embedding = tf.get_variable(
  130. 'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
  131. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  132. emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  133. for x in encoder_inputs]
  134. emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  135. for x in decoder_inputs]
  136. for layer_i in xrange(hps.enc_layers):
  137. with tf.variable_scope('encoder%d'%layer_i), tf.device(
  138. self._next_device()):
  139. cell_fw = tf.nn.rnn_cell.LSTMCell(
  140. hps.num_hidden,
  141. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123))
  142. cell_bw = tf.nn.rnn_cell.LSTMCell(
  143. hps.num_hidden,
  144. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
  145. (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
  146. cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
  147. sequence_length=article_lens)
  148. encoder_outputs = emb_encoder_inputs
  149. with tf.variable_scope('output_projection'):
  150. w = tf.get_variable(
  151. 'w', [hps.num_hidden, vsize], dtype=tf.float32,
  152. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  153. w_t = tf.transpose(w)
  154. v = tf.get_variable(
  155. 'v', [vsize], dtype=tf.float32,
  156. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  157. with tf.variable_scope('decoder'), tf.device(self._next_device()):
  158. # When decoding, use model output from the previous step
  159. # for the next step.
  160. loop_function = None
  161. if hps.mode == 'decode':
  162. loop_function = _extract_argmax_and_embed(
  163. embedding, (w, v), update_embedding=False)
  164. cell = tf.nn.rnn_cell.LSTMCell(
  165. hps.num_hidden,
  166. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113))
  167. encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
  168. for x in encoder_outputs]
  169. self._enc_top_states = tf.concat(1, encoder_outputs)
  170. self._dec_in_state = fw_state
  171. # During decoding, follow up _dec_in_state are fed from beam_search.
  172. # dec_out_state are stored by beam_search for next step feeding.
  173. initial_state_attention = (hps.mode == 'decode')
  174. decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
  175. emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
  176. cell, num_heads=1, loop_function=loop_function,
  177. initial_state_attention=initial_state_attention)
  178. with tf.variable_scope('output'), tf.device(self._next_device()):
  179. model_outputs = []
  180. for i in xrange(len(decoder_outputs)):
  181. if i > 0:
  182. tf.get_variable_scope().reuse_variables()
  183. model_outputs.append(
  184. tf.nn.xw_plus_b(decoder_outputs[i], w, v))
  185. if hps.mode == 'decode':
  186. with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
  187. best_outputs = [tf.argmax(x, 1) for x in model_outputs]
  188. tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
  189. self._outputs = tf.concat(
  190. 1, [tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
  191. self._topk_log_probs, self._topk_ids = tf.nn.top_k(
  192. tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
  193. with tf.variable_scope('loss'), tf.device(self._next_device()):
  194. def sampled_loss_func(inputs, labels):
  195. with tf.device('/cpu:0'): # Try gpu.
  196. labels = tf.reshape(labels, [-1, 1])
  197. return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels,
  198. hps.num_softmax_samples, vsize)
  199. if hps.num_softmax_samples != 0 and hps.mode == 'train':
  200. self._loss = seq2seq_lib.sampled_sequence_loss(
  201. decoder_outputs, targets, loss_weights, sampled_loss_func)
  202. else:
  203. self._loss = tf.nn.seq2seq.sequence_loss(
  204. model_outputs, targets, loss_weights)
  205. tf.scalar_summary('loss', tf.minimum(12.0, self._loss))
  206. def _add_train_op(self):
  207. """Sets self._train_op, op to run for training."""
  208. hps = self._hps
  209. self._lr_rate = tf.maximum(
  210. hps.min_lr, # min_lr_rate.
  211. tf.train.exponential_decay(hps.lr, self.global_step, 30000, 0.98))
  212. tvars = tf.trainable_variables()
  213. with tf.device(self._get_gpu(self._num_gpus-1)):
  214. grads, global_norm = tf.clip_by_global_norm(
  215. tf.gradients(self._loss, tvars), hps.max_grad_norm)
  216. tf.scalar_summary('global_norm', global_norm)
  217. optimizer = tf.train.GradientDescentOptimizer(self._lr_rate)
  218. tf.scalar_summary('learning rate', self._lr_rate)
  219. self._train_op = optimizer.apply_gradients(
  220. zip(grads, tvars), global_step=self.global_step, name='train_step')
  221. def encode_top_state(self, sess, enc_inputs, enc_len):
  222. """Return the top states from encoder for decoder.
  223. Args:
  224. sess: tensorflow session.
  225. enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
  226. enc_len: encoder input length of shape [batch_size]
  227. Returns:
  228. enc_top_states: The top level encoder states.
  229. dec_in_state: The decoder layer initial state.
  230. """
  231. results = sess.run([self._enc_top_states, self._dec_in_state],
  232. feed_dict={self._articles: enc_inputs,
  233. self._article_lens: enc_len})
  234. return results[0], results[1][0]
  235. def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
  236. """Return the topK results and new decoder states."""
  237. feed = {
  238. self._enc_top_states: enc_top_states,
  239. self._dec_in_state:
  240. np.squeeze(np.array(dec_init_states)),
  241. self._abstracts:
  242. np.transpose(np.array([latest_tokens])),
  243. self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
  244. results = sess.run(
  245. [self._topk_ids, self._topk_log_probs, self._dec_out_state],
  246. feed_dict=feed)
  247. ids, probs, states = results[0], results[1], results[2]
  248. new_states = [s for s in states]
  249. return ids, probs, new_states
  250. def build_graph(self):
  251. self._add_placeholders()
  252. self._add_seq2seq()
  253. self.global_step = tf.Variable(0, name='global_step', trainable=False)
  254. if self._hps.mode == 'train':
  255. self._add_train_op()
  256. self._summaries = tf.merge_all_summaries()