seq2seq_attention_model.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Sequence-to-Sequence with attention model for text summarization.
  16. """
  17. from collections import namedtuple
  18. import numpy as np
  19. import seq2seq_lib
  20. from six.moves import xrange
  21. import tensorflow as tf
  22. HParams = namedtuple('HParams',
  23. 'mode, min_lr, lr, batch_size, '
  24. 'enc_layers, enc_timesteps, dec_timesteps, '
  25. 'min_input_len, num_hidden, emb_dim, max_grad_norm, '
  26. 'num_softmax_samples')
  27. def _extract_argmax_and_embed(embedding, output_projection=None,
  28. update_embedding=True):
  29. """Get a loop_function that extracts the previous symbol and embeds it.
  30. Args:
  31. embedding: embedding tensor for symbols.
  32. output_projection: None or a pair (W, B). If provided, each fed previous
  33. output will first be multiplied by W and added B.
  34. update_embedding: Boolean; if False, the gradients will not propagate
  35. through the embeddings.
  36. Returns:
  37. A loop function.
  38. """
  39. def loop_function(prev, _):
  40. """function that feed previous model output rather than ground truth."""
  41. if output_projection is not None:
  42. prev = tf.nn.xw_plus_b(
  43. prev, output_projection[0], output_projection[1])
  44. prev_symbol = tf.argmax(prev, 1)
  45. # Note that gradients will not propagate through the second parameter of
  46. # embedding_lookup.
  47. emb_prev = tf.nn.embedding_lookup(embedding, prev_symbol)
  48. if not update_embedding:
  49. emb_prev = tf.stop_gradient(emb_prev)
  50. return emb_prev
  51. return loop_function
  52. class Seq2SeqAttentionModel(object):
  53. """Wrapper for Tensorflow model graph for text sum vectors."""
  54. def __init__(self, hps, vocab, num_gpus=0):
  55. self._hps = hps
  56. self._vocab = vocab
  57. self._num_gpus = num_gpus
  58. self._cur_gpu = 0
  59. def run_train_step(self, sess, article_batch, abstract_batch, targets,
  60. article_lens, abstract_lens, loss_weights):
  61. to_return = [self._train_op, self._summaries, self._loss, self.global_step]
  62. return sess.run(to_return,
  63. feed_dict={self._articles: article_batch,
  64. self._abstracts: abstract_batch,
  65. self._targets: targets,
  66. self._article_lens: article_lens,
  67. self._abstract_lens: abstract_lens,
  68. self._loss_weights: loss_weights})
  69. def run_eval_step(self, sess, article_batch, abstract_batch, targets,
  70. article_lens, abstract_lens, loss_weights):
  71. to_return = [self._summaries, self._loss, self.global_step]
  72. return sess.run(to_return,
  73. feed_dict={self._articles: article_batch,
  74. self._abstracts: abstract_batch,
  75. self._targets: targets,
  76. self._article_lens: article_lens,
  77. self._abstract_lens: abstract_lens,
  78. self._loss_weights: loss_weights})
  79. def run_decode_step(self, sess, article_batch, abstract_batch, targets,
  80. article_lens, abstract_lens, loss_weights):
  81. to_return = [self._outputs, self.global_step]
  82. return sess.run(to_return,
  83. feed_dict={self._articles: article_batch,
  84. self._abstracts: abstract_batch,
  85. self._targets: targets,
  86. self._article_lens: article_lens,
  87. self._abstract_lens: abstract_lens,
  88. self._loss_weights: loss_weights})
  89. def _next_device(self):
  90. """Round robin the gpu device. (Reserve last gpu for expensive op)."""
  91. if self._num_gpus == 0:
  92. return ''
  93. dev = '/gpu:%d' % self._cur_gpu
  94. if self._num_gpus > 1:
  95. self._cur_gpu = (self._cur_gpu + 1) % (self._num_gpus-1)
  96. return dev
  97. def _get_gpu(self, gpu_id):
  98. if self._num_gpus <= 0 or gpu_id >= self._num_gpus:
  99. return ''
  100. return '/gpu:%d' % gpu_id
  101. def _add_placeholders(self):
  102. """Inputs to be fed to the graph."""
  103. hps = self._hps
  104. self._articles = tf.placeholder(tf.int32,
  105. [hps.batch_size, hps.enc_timesteps],
  106. name='articles')
  107. self._abstracts = tf.placeholder(tf.int32,
  108. [hps.batch_size, hps.dec_timesteps],
  109. name='abstracts')
  110. self._targets = tf.placeholder(tf.int32,
  111. [hps.batch_size, hps.dec_timesteps],
  112. name='targets')
  113. self._article_lens = tf.placeholder(tf.int32, [hps.batch_size],
  114. name='article_lens')
  115. self._abstract_lens = tf.placeholder(tf.int32, [hps.batch_size],
  116. name='abstract_lens')
  117. self._loss_weights = tf.placeholder(tf.float32,
  118. [hps.batch_size, hps.dec_timesteps],
  119. name='loss_weights')
  120. def _add_seq2seq(self):
  121. hps = self._hps
  122. vsize = self._vocab.NumIds()
  123. with tf.variable_scope('seq2seq'):
  124. encoder_inputs = tf.unstack(tf.transpose(self._articles))
  125. decoder_inputs = tf.unstack(tf.transpose(self._abstracts))
  126. targets = tf.unstack(tf.transpose(self._targets))
  127. loss_weights = tf.unstack(tf.transpose(self._loss_weights))
  128. article_lens = self._article_lens
  129. # Embedding shared by the input and outputs.
  130. with tf.variable_scope('embedding'), tf.device('/cpu:0'):
  131. embedding = tf.get_variable(
  132. 'embedding', [vsize, hps.emb_dim], dtype=tf.float32,
  133. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  134. emb_encoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  135. for x in encoder_inputs]
  136. emb_decoder_inputs = [tf.nn.embedding_lookup(embedding, x)
  137. for x in decoder_inputs]
  138. for layer_i in xrange(hps.enc_layers):
  139. with tf.variable_scope('encoder%d'%layer_i), tf.device(
  140. self._next_device()):
  141. cell_fw = tf.contrib.rnn.LSTMCell(
  142. hps.num_hidden,
  143. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=123),
  144. state_is_tuple=False)
  145. cell_bw = tf.contrib.rnn.LSTMCell(
  146. hps.num_hidden,
  147. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
  148. state_is_tuple=False)
  149. (emb_encoder_inputs, fw_state, _) = tf.contrib.rnn.static_bidirectional_rnn(
  150. cell_fw, cell_bw, emb_encoder_inputs, dtype=tf.float32,
  151. sequence_length=article_lens)
  152. encoder_outputs = emb_encoder_inputs
  153. with tf.variable_scope('output_projection'):
  154. w = tf.get_variable(
  155. 'w', [hps.num_hidden, vsize], dtype=tf.float32,
  156. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  157. w_t = tf.transpose(w)
  158. v = tf.get_variable(
  159. 'v', [vsize], dtype=tf.float32,
  160. initializer=tf.truncated_normal_initializer(stddev=1e-4))
  161. with tf.variable_scope('decoder'), tf.device(self._next_device()):
  162. # When decoding, use model output from the previous step
  163. # for the next step.
  164. loop_function = None
  165. if hps.mode == 'decode':
  166. loop_function = _extract_argmax_and_embed(
  167. embedding, (w, v), update_embedding=False)
  168. cell = tf.contrib.rnn.LSTMCell(
  169. hps.num_hidden,
  170. initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=113),
  171. state_is_tuple=False)
  172. encoder_outputs = [tf.reshape(x, [hps.batch_size, 1, 2*hps.num_hidden])
  173. for x in encoder_outputs]
  174. self._enc_top_states = tf.concat(axis=1, values=encoder_outputs)
  175. self._dec_in_state = fw_state
  176. # During decoding, follow up _dec_in_state are fed from beam_search.
  177. # dec_out_state are stored by beam_search for next step feeding.
  178. initial_state_attention = (hps.mode == 'decode')
  179. decoder_outputs, self._dec_out_state = tf.contrib.legacy_seq2seq.attention_decoder(
  180. emb_decoder_inputs, self._dec_in_state, self._enc_top_states,
  181. cell, num_heads=1, loop_function=loop_function,
  182. initial_state_attention=initial_state_attention)
  183. with tf.variable_scope('output'), tf.device(self._next_device()):
  184. model_outputs = []
  185. for i in xrange(len(decoder_outputs)):
  186. if i > 0:
  187. tf.get_variable_scope().reuse_variables()
  188. model_outputs.append(
  189. tf.nn.xw_plus_b(decoder_outputs[i], w, v))
  190. if hps.mode == 'decode':
  191. with tf.variable_scope('decode_output'), tf.device('/cpu:0'):
  192. best_outputs = [tf.argmax(x, 1) for x in model_outputs]
  193. tf.logging.info('best_outputs%s', best_outputs[0].get_shape())
  194. self._outputs = tf.concat(
  195. axis=1, values=[tf.reshape(x, [hps.batch_size, 1]) for x in best_outputs])
  196. self._topk_log_probs, self._topk_ids = tf.nn.top_k(
  197. tf.log(tf.nn.softmax(model_outputs[-1])), hps.batch_size*2)
  198. with tf.variable_scope('loss'), tf.device(self._next_device()):
  199. def sampled_loss_func(inputs, labels):
  200. with tf.device('/cpu:0'): # Try gpu.
  201. labels = tf.reshape(labels, [-1, 1])
  202. return tf.nn.sampled_softmax_loss(
  203. weights=w_t, biases=v, labels=labels, inputs=inputs,
  204. num_sampled=hps.num_softmax_samples, num_classes=vsize)
  205. if hps.num_softmax_samples != 0 and hps.mode == 'train':
  206. self._loss = seq2seq_lib.sampled_sequence_loss(
  207. decoder_outputs, targets, loss_weights, sampled_loss_func)
  208. else:
  209. self._loss = tf.contrib.legacy_seq2seq.sequence_loss(
  210. model_outputs, targets, loss_weights)
  211. tf.summary.scalar('loss', tf.minimum(12.0, self._loss))
  212. def _add_train_op(self):
  213. """Sets self._train_op, op to run for training."""
  214. hps = self._hps
  215. self._lr_rate = tf.maximum(
  216. hps.min_lr, # min_lr_rate.
  217. tf.train.exponential_decay(hps.lr, self.global_step, 30000, 0.98))
  218. tvars = tf.trainable_variables()
  219. with tf.device(self._get_gpu(self._num_gpus-1)):
  220. grads, global_norm = tf.clip_by_global_norm(
  221. tf.gradients(self._loss, tvars), hps.max_grad_norm)
  222. tf.summary.scalar('global_norm', global_norm)
  223. optimizer = tf.train.GradientDescentOptimizer(self._lr_rate)
  224. tf.summary.scalar('learning rate', self._lr_rate)
  225. self._train_op = optimizer.apply_gradients(
  226. zip(grads, tvars), global_step=self.global_step, name='train_step')
  227. def encode_top_state(self, sess, enc_inputs, enc_len):
  228. """Return the top states from encoder for decoder.
  229. Args:
  230. sess: tensorflow session.
  231. enc_inputs: encoder inputs of shape [batch_size, enc_timesteps].
  232. enc_len: encoder input length of shape [batch_size]
  233. Returns:
  234. enc_top_states: The top level encoder states.
  235. dec_in_state: The decoder layer initial state.
  236. """
  237. results = sess.run([self._enc_top_states, self._dec_in_state],
  238. feed_dict={self._articles: enc_inputs,
  239. self._article_lens: enc_len})
  240. return results[0], results[1][0]
  241. def decode_topk(self, sess, latest_tokens, enc_top_states, dec_init_states):
  242. """Return the topK results and new decoder states."""
  243. feed = {
  244. self._enc_top_states: enc_top_states,
  245. self._dec_in_state:
  246. np.squeeze(np.array(dec_init_states)),
  247. self._abstracts:
  248. np.transpose(np.array([latest_tokens])),
  249. self._abstract_lens: np.ones([len(dec_init_states)], np.int32)}
  250. results = sess.run(
  251. [self._topk_ids, self._topk_log_probs, self._dec_out_state],
  252. feed_dict=feed)
  253. ids, probs, states = results[0], results[1], results[2]
  254. new_states = [s for s in states]
  255. return ids, probs, new_states
  256. def build_graph(self):
  257. self._add_placeholders()
  258. self._add_seq2seq()
  259. self.global_step = tf.Variable(0, name='global_step', trainable=False)
  260. if self._hps.mode == 'train':
  261. self._add_train_op()
  262. self._summaries = tf.summary.merge_all()