seq2seq_attention_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Sequence-to-Sequence with attention model for text summarization.
  16. """
  17. from collections import namedtuple
  18. import numpy as np
  19. import tensorflow as tf
  20. import seq2seq_lib
  21. HParams = namedtuple('HParams',
  22. 'mode, min_lr, lr, batch_size, '
  23. 'enc_layers, enc_timesteps, dec_timesteps, '
  24. 'min_input_len, num_hidden, emb_dim, max_grad_norm, '
  25. 'num_softmax_samples')
  26. def _extract_argmax_and_embed(embedding, output_projection=None,
  27. update_embedding=True):
  28. """Get a loop_function that extracts the previous symbol and embeds it.
  29. Args:
  30. embedding: embedding tensor for symbols.
  31. output_projection: None or a pair (W, B). If provided, each fed previous
  32. output will first be multiplied by W and added B.
  33. update_embedding: Boolean; if False, the gradients will not propagate
  34. through the embeddings.
  35. Returns:
  36. A loop function.
  37. """
  38. def loop_function(prev, _):
  39. """function that feed previous model output rather than ground truth."""
  40. if output_projection is not None:
  41. prev = tf.nn.xw_plus_b(
  42. prev, output_projection[0], output_projection[1])
  43. prev_symbol = tf.argmax(prev, 1)
  44. # Note that gradients will not propagate through the second parameter of
  45. # embedding_lookup.
  46. emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
  47. if not update_embedding:
  48. emb_prev = tf.stop_gradient(emb_prev)
  49. return emb_prev
  50. return loop_function
  51. class Seq2SeqAttentionModel(object):
  52. """Wrapper for Tensorflow model graph for text sum vectors."""
  53. def __init__(self, hps, vocab, num_gpus=0):
  54. self._hps = hps
  55. self._vocab = vocab
  56. self._num_gpus = num_gpus
  57. self._cur_gpu = 0
  58. def run_train_step(self, sess, article_batch, abstract_batch, targets,
  59. article_lens, abstract_lens, loss_weights):
  60. to_return = [self._train_op, self._summaries, self._loss, self.global_step]
  61. return sess.run(to_return,
  62. feed_dict={self._articles: article_batch,
  63. self._abstracts: abstract_batch,
  64. self._targets: targets,
  65. self._article_lens: article_lens,
  66. self._abstract_lens: abstract_lens,
  67. self._loss_weights: loss_weights})
  68. def run_eval_step(self, sess, article_batch, abstract_batch, targets,
  69. article_lens, abstract_lens, loss_weights):
  70. to_return = [self._summaries, self._loss, self.global_step]
  71. return sess.run(to_return,
  72. feed_dict={self._articles: article_batch,
  73. self._abstracts: abstract_batch,
  74. self._targets: targets,
  75. self._article_lens: article_lens,
  76. self._abstract_lens: abstract_lens,
  77. self._loss_weights: loss_weights})
  78. def run_decode_step(self, sess, article_batch, abstract_batch, targets,
  79. article_lens, abstract_lens, loss_weights):
  80. to_return = [self._outputs, self.global_step]
  81. return sess.run(to_return,
  82. feed_dict={self._articles: article_batch,
  83. self._abstracts: abstract_batch,
  84. self._targets: targets,
  85. self._article_lens: article_lens,
  86. self._abstract_lens: abstract_lens,
  87. self._loss_weights: loss_weights})
  88. def _next_device(self):
  89. """Round robin the gpu device. (Reserve last gpu for expensive op)."""
  90. if self._num_gpus == 0:
  91. return ''
  92. dev = '/gpu:%d' % self._cur_gpu
  93. if self._num_gpus > 1:
  94. self._cur_gpu = (self._cur_gpu + 1) % (self._num_gpus-1)
  95. return dev
  96. def _get_gpu(self, gpu_id):
  97. if self._num_gpus <= 0 or gpu_id >= self._num_gpus:
  98. return ''
  99. return '/gpu:%d' % gpu_id
  100. def _add_placeholders(self):
  101. """Inputs to be fed to the graph."""
  102. hps = self._hps
  103. self._articles = tf.placeholder(tf.int32,
  104. [hps.batch_size, hps.enc_timesteps],
  105. name='articles')
  106. self._abstracts = tf.placeholder(tf.int32,
  107. [hps.batch_size, hps.dec_timesteps],
  108. name='abstracts')
  109. self._targets = tf.placeholder(tf.int32,
  110. [hps.batch_size, hps.dec_timesteps],
  111. name='targets')
  112. self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
  113. name='article_lens')
  114. self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
  115. name='abstract_lens')
  116. self._loss_weights = tf.placeholder(tf.float32,
  117. [hps.batch_size, hps.dec_timesteps],
  118. name='loss_weights')
  119. def _add_seq2seq(self):
  120. hps = self._hps
  121. vsize = self._vocab.NumIds()
  122. with tf.variable_scope('seq2seq'):
  123. encoder_inputs = tf.unstack(tf.transpose(self._articles))
  124. decoder_inputs = tf.unstack(tf.transpose(self._abstracts))
  125. targets = tf.unstack(tf.transpose(self._targets))
  126. loss_weights = tf.unstack(tf.transpose(self._loss_weights))
  127. article_lens = self._article_lens
  128. # Embedding shared by the input and outputs.
  129. with tf.variable_scope('embedding'), tf.device('/cpu:0'):
  130. embedding = tf.get_variable(
  131. 'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
  132. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  133. emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  134. for x in encoder_inputs]
  135. emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  136. for x in decoder_inputs]
  137. for layer_i in xrange(hps.enc_layers):
  138. with tf.variable_scope('encoder%d'%layer_i), tf.device(
  139. self._next_device()):
  140. cell_fw = tf.nn.rnn_cell.LSTMCell(
  141. hps.num_hidden,
  142. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
  143. state_is_tuple=False)
  144. cell_bw = tf.nn.rnn_cell.LSTMCell(
  145. hps.num_hidden,
  146. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
  147. state_is_tuple=False)
  148. (emb_encoder_inputs, fw_state, _) = tf.nn.bidirectional_rnn(
  149. cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
  150. sequence_length=article_lens)
  151. encoder_outputs = emb_encoder_inputs
  152. with tf.variable_scope('output_projection'):
  153. w = tf.get_variable(
  154. 'w', [hps.num_hidden, vsize], dtype=tf.float32,
  155. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  156. w_t = tf.transpose(w)
  157. v = tf.get_variable(
  158. 'v', [vsize], dtype=tf.float32,
  159. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  160. with tf.variable_scope('decoder'), tf.device(self._next_device()):
  161. # When decoding, use model output from the previous step
  162. # for the next step.
  163. loop_function = None
  164. if hps.mode == 'decode':
  165. loop_function = _extract_argmax_and_embed(
  166. embedding, (w, v), update_embedding=False)
  167. cell = tf.nn.rnn_cell.LSTMCell(
  168. hps.num_hidden,
  169. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
  170. state_is_tuple=False)
  171. encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
  172. for x in encoder_outputs]
  173. self._enc_top_states = tf.concat(axis=1, values=encoder_outputs)
  174. self._dec_in_state = fw_state
  175. # During decoding, follow up _dec_in_state are fed from beam_search.
  176. # dec_out_state are stored by beam_search for next step feeding.
  177. initial_state_attention = (hps.mode == 'decode')
  178. decoder_outputs, self._dec_out_state = tf.nn.seq2seq.attention_decoder(
  179. emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
  180. cell, num_heads=1, loop_function=loop_function,
  181. initial_state_attention=initial_state_attention)
  182. with tf.variable_scope('output'), tf.device(self._next_device()):
  183. model_outputs = []
  184. for i in xrange(len(decoder_outputs)):
  185. if i > 0:
  186. tf.get_variable_scope().reuse_variables()
  187. model_outputs.append(
  188. tf.nn.xw_plus_b(decoder_outputs[i], w, v))
  189. if hps.mode == 'decode':
  190. with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
  191. best_outputs = [tf.argmax(x, 1) for x in model_outputs]
  192. tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
  193. self._outputs = tf.concat(
  194. axis=1, values=[tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
  195. self._topk_log_probs, self._topk_ids = tf.nn.top_k(
  196. tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
  197. with tf.variable_scope('loss'), tf.device(self._next_device()):
  198. def sampled_loss_func(inputs, labels):
  199. with tf.device('/cpu:0'): # Try gpu.
  200. labels = tf.reshape(labels, [-1, 1])
  201. return tf.nn.sampled_softmax_loss(w_t, v, inputs, labels,
  202. hps.num_softmax_samples, vsize)
  203. if hps.num_softmax_samples != 0 and hps.mode == 'train':
  204. self._loss = seq2seq_lib.sampled_sequence_loss(
  205. decoder_outputs, targets, loss_weights, sampled_loss_func)
  206. else:
  207. self._loss = tf.nn.seq2seq.sequence_loss(
  208. model_outputs, targets, loss_weights)
  209. tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
  210. def _add_train_op(self):
  211. """Sets self._train_op, op to run for training."""
  212. hps = self._hps
  213. self._lr_rate = tf.maximum(
  214. hps.min_lr, # min_lr_rate.
  215. tf.train.exponential_decay(hps.lr, self.global_step, 30000, 0.98))
  216. tvars = tf.trainable_variables()
  217. with tf.device(self._get_gpu(self._num_gpus-1)):
  218. grads, global_norm = tf.clip_by_global_norm(
  219. tf.gradients(self._loss, tvars), hps.max_grad_norm)
  220. tf.summary.scalar('global_norm', global_norm)
  221. optimizer = tf.train.GradientDescentOptimizer(self._lr_rate)
  222. tf.summary.scalar('learning rate', self._lr_rate)
  223. self._train_op = optimizer.apply_gradients(
  224. zip(grads, tvars), global_step=self.global_step, name='train_step')
  225. def encode_top_state(self, sess, enc_inputs, enc_len):
  226. """Return the top states from encoder for decoder.
  227. Args:
  228. sess: tensorflow session.
  229. enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
  230. enc_len: encoder input length of shape [batch_size]
  231. Returns:
  232. enc_top_states: The top level encoder states.
  233. dec_in_state: The decoder layer initial state.
  234. """
  235. results = sess.run([self._enc_top_states, self._dec_in_state],
  236. feed_dict={self._articles: enc_inputs,
  237. self._article_lens: enc_len})
  238. return results[0], results[1][0]
  239. def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
  240. """Return the topK results and new decoder states."""
  241. feed = {
  242. self._enc_top_states: enc_top_states,
  243. self._dec_in_state:
  244. np.squeeze(np.array(dec_init_states)),
  245. self._abstracts:
  246. np.transpose(np.array([latest_tokens])),
  247. self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
  248. results = sess.run(
  249. [self._topk_ids, self._topk_log_probs, self._dec_out_state],
  250. feed_dict=feed)
  251. ids, probs, states = results[0], results[1], results[2]
  252. new_states = [s for s in states]
  253. return ids, probs, new_states
  254. def build_graph(self):
  255. self._add_placeholders()
  256. self._add_seq2seq()
  257. self.global_step = tf.Variable(0, name='global_step', trainable=False)
  258. if self._hps.mode == 'train':
  259. self._add_train_op()
  260. self._summaries = tf.summary.merge_all()