seq2seq_attention.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Trains a seq2seq model.
  16. WORK IN PROGRESS.
  17. Implement "Abstractive Text Summarization using Sequence-to-sequence RNNS and
  18. Beyond."
  19. """
  20. import sys
  21. import time
  22. import tensorflow as tf
  23. import batch_reader
  24. import data
  25. import seq2seq_attention_decode
  26. import seq2seq_attention_model
  27. FLAGS = tf.app.flags.FLAGS
  28. tf.app.flags.DEFINE_string('data_path',
  29. '', 'Path expression to tf.Example.')
  30. tf.app.flags.DEFINE_string('vocab_path',
  31. '', 'Path expression to text vocabulary file.')
  32. tf.app.flags.DEFINE_string('article_key', 'article',
  33. 'tf.Example feature key for article.')
  34. tf.app.flags.DEFINE_string('abstract_key', 'headline',
  35. 'tf.Example feature key for abstract.')
  36. tf.app.flags.DEFINE_string('log_root', '', 'Directory for model root.')
  37. tf.app.flags.DEFINE_string('train_dir', '', 'Directory for train.')
  38. tf.app.flags.DEFINE_string('eval_dir', '', 'Directory for eval.')
  39. tf.app.flags.DEFINE_string('decode_dir', '', 'Directory for decode summaries.')
  40. tf.app.flags.DEFINE_string('mode', 'train', 'train/eval/decode mode')
  41. tf.app.flags.DEFINE_integer('max_run_steps', 10000000,
  42. 'Maximum number of run steps.')
  43. tf.app.flags.DEFINE_integer('max_article_sentences', 2,
  44. 'Max number of first sentences to use from the '
  45. 'article')
  46. tf.app.flags.DEFINE_integer('max_abstract_sentences', 100,
  47. 'Max number of first sentences to use from the '
  48. 'abstract')
  49. tf.app.flags.DEFINE_integer('beam_size', 4,
  50. 'beam size for beam search decoding.')
  51. tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run eval.')
  52. tf.app.flags.DEFINE_integer('checkpoint_secs', 60, 'How often to checkpoint.')
  53. tf.app.flags.DEFINE_bool('use_bucketing', False,
  54. 'Whether bucket articles of similar length.')
  55. tf.app.flags.DEFINE_bool('truncate_input', False,
  56. 'Truncate inputs that are too long. If False, '
  57. 'examples that are too long are discarded.')
  58. tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used.')
  59. tf.app.flags.DEFINE_integer('random_seed', 111, 'A seed value for randomness.')
  60. def _RunningAvgLoss(loss, running_avg_loss, summary_writer, step, decay=0.999):
  61. """Calculate the running average of losses."""
  62. if running_avg_loss == 0:
  63. running_avg_loss = loss
  64. else:
  65. running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
  66. running_avg_loss = min(running_avg_loss, 12)
  67. loss_sum = tf.Summary()
  68. loss_sum.value.add(tag='running_avg_loss', simple_value=running_avg_loss)
  69. summary_writer.add_summary(loss_sum, step)
  70. sys.stdout.write('running_avg_loss: %f\n' % running_avg_loss)
  71. return running_avg_loss
  72. def _Train(model, data_batcher):
  73. """Runs model training."""
  74. with tf.device('/cpu:0'):
  75. model.build_graph()
  76. saver = tf.train.Saver()
  77. # Train dir is different from log_root to avoid summary directory
  78. # conflict with Supervisor.
  79. summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
  80. sv = tf.train.Supervisor(logdir=FLAGS.log_root,
  81. is_chief=True,
  82. saver=saver,
  83. summary_op=None,
  84. save_summaries_secs=60,
  85. save_model_secs=FLAGS.checkpoint_secs,
  86. global_step=model.global_step)
  87. sess = sv.prepare_or_wait_for_session()
  88. running_avg_loss = 0
  89. step = 0
  90. while not sv.should_stop() and step < FLAGS.max_run_steps:
  91. (article_batch, abstract_batch, targets, article_lens, abstract_lens,
  92. loss_weights, _, _) = data_batcher.NextBatch()
  93. (_, summaries, loss, train_step) = model.run_train_step(
  94. sess, article_batch, abstract_batch, targets, article_lens,
  95. abstract_lens, loss_weights)
  96. summary_writer.add_summary(summaries, train_step)
  97. running_avg_loss = _RunningAvgLoss(
  98. running_avg_loss, loss, summary_writer, train_step)
  99. step += 1
  100. if step % 100 == 0:
  101. summary_writer.flush()
  102. sv.Stop()
  103. return running_avg_loss
  104. def _Eval(model, data_batcher, vocab=None):
  105. """Runs model eval."""
  106. model.build_graph()
  107. saver = tf.train.Saver()
  108. summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
  109. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  110. running_avg_loss = 0
  111. step = 0
  112. while True:
  113. time.sleep(FLAGS.eval_interval_secs)
  114. try:
  115. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  116. except tf.errors.OutOfRangeError as e:
  117. tf.logging.error('Cannot restore checkpoint: %s', e)
  118. continue
  119. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  120. tf.logging.info('No model to eval yet at %s', FLAGS.train_dir)
  121. continue
  122. tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
  123. saver.restore(sess, ckpt_state.model_checkpoint_path)
  124. (article_batch, abstract_batch, targets, article_lens, abstract_lens,
  125. loss_weights, _, _) = data_batcher.NextBatch()
  126. (summaries, loss, train_step) = model.run_eval_step(
  127. sess, article_batch, abstract_batch, targets, article_lens,
  128. abstract_lens, loss_weights)
  129. tf.logging.info(
  130. 'article: %s',
  131. ' '.join(data.Ids2Words(article_batch[0][:].tolist(), vocab)))
  132. tf.logging.info(
  133. 'abstract: %s',
  134. ' '.join(data.Ids2Words(abstract_batch[0][:].tolist(), vocab)))
  135. summary_writer.add_summary(summaries, train_step)
  136. running_avg_loss = _RunningAvgLoss(
  137. running_avg_loss, loss, summary_writer, train_step)
  138. if step % 100 == 0:
  139. summary_writer.flush()
  140. def main(unused_argv):
  141. vocab = data.Vocab(FLAGS.vocab_path, 1000000)
  142. # Check for presence of required special tokens.
  143. assert vocab.WordToId(data.PAD_TOKEN) > 0
  144. assert vocab.WordToId(data.UNKNOWN_TOKEN) >= 0
  145. assert vocab.WordToId(data.SENTENCE_START) > 0
  146. assert vocab.WordToId(data.SENTENCE_END) > 0
  147. batch_size = 4
  148. if FLAGS.mode == 'decode':
  149. batch_size = FLAGS.beam_size
  150. hps = seq2seq_attention_model.HParams(
  151. mode=FLAGS.mode, # train, eval, decode
  152. min_lr=0.01, # min learning rate.
  153. lr=0.15, # learning rate
  154. batch_size=batch_size,
  155. enc_layers=4,
  156. enc_timesteps=120,
  157. dec_timesteps=30,
  158. min_input_len=2, # discard articles/summaries < than this
  159. num_hidden=256, # for rnn cell
  160. emb_dim=128, # If 0, don't use embedding
  161. max_grad_norm=2,
  162. num_softmax_samples=4096) # If 0, no sampled softmax.
  163. batcher = batch_reader.Batcher(
  164. FLAGS.data_path, vocab, hps, FLAGS.article_key,
  165. FLAGS.abstract_key, FLAGS.max_article_sentences,
  166. FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
  167. truncate_input=FLAGS.truncate_input)
  168. tf.set_random_seed(FLAGS.random_seed)
  169. if hps.mode == 'train':
  170. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  171. hps, vocab, num_gpus=FLAGS.num_gpus)
  172. _Train(model, batcher)
  173. elif hps.mode == 'eval':
  174. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  175. hps, vocab, num_gpus=FLAGS.num_gpus)
  176. _Eval(model, batcher, vocab=vocab)
  177. elif hps.mode == 'decode':
  178. decode_mdl_hps = hps
  179. # Only need to restore the 1st step and reuse it since
  180. # we keep and feed in state for each step's output.
  181. decode_mdl_hps = hps._replace(dec_timesteps=1)
  182. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  183. decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
  184. decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
  185. decoder.DecodeLoop()
  186. if __name__ == '__main__':
  187. tf.app.run()