seq2seq_attention.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Trains a seq2seq model.
  16. WORK IN PROGRESS.
  17. Implement "Abstractive Text Summarization using Sequence-to-sequence RNNS and
  18. Beyond."
  19. """
  20. import sys
  21. import time
  22. import tensorflow as tf
  23. import batch_reader
  24. import data
  25. import seq2seq_attention_decode
  26. import seq2seq_attention_model
  27. FLAGS = tf.app.flags.FLAGS
  28. tf.app.flags.DEFINE_string('data_path',
  29. '', 'Path expression to tf.Example.')
  30. tf.app.flags.DEFINE_string('vocab_path',
  31. '', 'Path expression to text vocabulary file.')
  32. tf.app.flags.DEFINE_string('article_key', 'article',
  33. 'tf.Example feature key for article.')
  34. tf.app.flags.DEFINE_string('abstract_key', 'headline',
  35. 'tf.Example feature key for abstract.')
  36. tf.app.flags.DEFINE_string('log_root', '', 'Directory for model root.')
  37. tf.app.flags.DEFINE_string('train_dir', '', 'Directory for train.')
  38. tf.app.flags.DEFINE_string('eval_dir', '', 'Directory for eval.')
  39. tf.app.flags.DEFINE_string('decode_dir', '', 'Directory for decode summaries.')
  40. tf.app.flags.DEFINE_string('mode', 'train', 'train/eval/decode mode')
  41. tf.app.flags.DEFINE_integer('max_run_steps', 10000000,
  42. 'Maximum number of run steps.')
  43. tf.app.flags.DEFINE_integer('max_article_sentences', 2,
  44. 'Max number of first sentences to use from the '
  45. 'article')
  46. tf.app.flags.DEFINE_integer('max_abstract_sentences', 100,
  47. 'Max number of first sentences to use from the '
  48. 'abstract')
  49. tf.app.flags.DEFINE_integer('beam_size', 4,
  50. 'beam size for beam search decoding.')
  51. tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run eval.')
  52. tf.app.flags.DEFINE_integer('checkpoint_secs', 60, 'How often to checkpoint.')
  53. tf.app.flags.DEFINE_bool('use_bucketing', False,
  54. 'Whether bucket articles of similar length.')
  55. tf.app.flags.DEFINE_bool('truncate_input', False,
  56. 'Truncate inputs that are too long. If False, '
  57. 'examples that are too long are discarded.')
  58. tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used.')
  59. tf.app.flags.DEFINE_integer('random_seed', 111, 'A seed value for randomness.')
  60. def _RunningAvgLoss(loss, running_avg_loss, summary_writer, step, decay=0.999):
  61. """Calculate the running average of losses."""
  62. if running_avg_loss == 0:
  63. running_avg_loss = loss
  64. else:
  65. running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
  66. running_avg_loss = min(running_avg_loss, 12)
  67. loss_sum = tf.Summary()
  68. loss_sum.value.add(tag='running_avg_loss', simple_value=running_avg_loss)
  69. summary_writer.add_summary(loss_sum, step)
  70. sys.stdout.write('running_avg_loss: %f\n' % running_avg_loss)
  71. return running_avg_loss
  72. def _Train(model, data_batcher):
  73. """Runs model training."""
  74. with tf.device('/cpu:0'):
  75. model.build_graph()
  76. saver = tf.train.Saver()
  77. # Train dir is different from log_root to avoid summary directory
  78. # conflict with Supervisor.
  79. summary_writer = tf.summary.FileWriter(FLAGS.train_dir)
  80. sv = tf.train.Supervisor(logdir=FLAGS.log_root,
  81. is_chief=True,
  82. saver=saver,
  83. summary_op=None,
  84. save_summaries_secs=60,
  85. save_model_secs=FLAGS.checkpoint_secs,
  86. global_step=model.global_step)
  87. sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
  88. allow_soft_placement=True))
  89. running_avg_loss = 0
  90. step = 0
  91. while not sv.should_stop() and step < FLAGS.max_run_steps:
  92. (article_batch, abstract_batch, targets, article_lens, abstract_lens,
  93. loss_weights, _, _) = data_batcher.NextBatch()
  94. (_, summaries, loss, train_step) = model.run_train_step(
  95. sess, article_batch, abstract_batch, targets, article_lens,
  96. abstract_lens, loss_weights)
  97. summary_writer.add_summary(summaries, train_step)
  98. running_avg_loss = _RunningAvgLoss(
  99. running_avg_loss, loss, summary_writer, train_step)
  100. step += 1
  101. if step % 100 == 0:
  102. summary_writer.flush()
  103. sv.Stop()
  104. return running_avg_loss
  105. def _Eval(model, data_batcher, vocab=None):
  106. """Runs model eval."""
  107. model.build_graph()
  108. saver = tf.train.Saver()
  109. summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
  110. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  111. running_avg_loss = 0
  112. step = 0
  113. while True:
  114. time.sleep(FLAGS.eval_interval_secs)
  115. try:
  116. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  117. except tf.errors.OutOfRangeError as e:
  118. tf.logging.error('Cannot restore checkpoint: %s', e)
  119. continue
  120. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  121. tf.logging.info('No model to eval yet at %s', FLAGS.train_dir)
  122. continue
  123. tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
  124. saver.restore(sess, ckpt_state.model_checkpoint_path)
  125. (article_batch, abstract_batch, targets, article_lens, abstract_lens,
  126. loss_weights, _, _) = data_batcher.NextBatch()
  127. (summaries, loss, train_step) = model.run_eval_step(
  128. sess, article_batch, abstract_batch, targets, article_lens,
  129. abstract_lens, loss_weights)
  130. tf.logging.info(
  131. 'article: %s',
  132. ' '.join(data.Ids2Words(article_batch[0][:].tolist(), vocab)))
  133. tf.logging.info(
  134. 'abstract: %s',
  135. ' '.join(data.Ids2Words(abstract_batch[0][:].tolist(), vocab)))
  136. summary_writer.add_summary(summaries, train_step)
  137. running_avg_loss = _RunningAvgLoss(
  138. running_avg_loss, loss, summary_writer, train_step)
  139. if step % 100 == 0:
  140. summary_writer.flush()
  141. def main(unused_argv):
  142. vocab = data.Vocab(FLAGS.vocab_path, 1000000)
  143. # Check for presence of required special tokens.
  144. assert vocab.CheckVocab(data.PAD_TOKEN) > 0
  145. assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
  146. assert vocab.CheckVocab(data.SENTENCE_START) > 0
  147. assert vocab.CheckVocab(data.SENTENCE_END) > 0
  148. batch_size = 4
  149. if FLAGS.mode == 'decode':
  150. batch_size = FLAGS.beam_size
  151. hps = seq2seq_attention_model.HParams(
  152. mode=FLAGS.mode, # train, eval, decode
  153. min_lr=0.01, # min learning rate.
  154. lr=0.15, # learning rate
  155. batch_size=batch_size,
  156. enc_layers=4,
  157. enc_timesteps=120,
  158. dec_timesteps=30,
  159. min_input_len=2, # discard articles/summaries < than this
  160. num_hidden=256, # for rnn cell
  161. emb_dim=128, # If 0, don't use embedding
  162. max_grad_norm=2,
  163. num_softmax_samples=4096) # If 0, no sampled softmax.
  164. batcher = batch_reader.Batcher(
  165. FLAGS.data_path, vocab, hps, FLAGS.article_key,
  166. FLAGS.abstract_key, FLAGS.max_article_sentences,
  167. FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
  168. truncate_input=FLAGS.truncate_input)
  169. tf.set_random_seed(FLAGS.random_seed)
  170. if hps.mode == 'train':
  171. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  172. hps, vocab, num_gpus=FLAGS.num_gpus)
  173. _Train(model, batcher)
  174. elif hps.mode == 'eval':
  175. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  176. hps, vocab, num_gpus=FLAGS.num_gpus)
  177. _Eval(model, batcher, vocab=vocab)
  178. elif hps.mode == 'decode':
  179. decode_mdl_hps = hps
  180. # Only need to restore the 1st step and reuse it since
  181. # we keep and feed in state for each step's output.
  182. decode_mdl_hps = hps._replace(dec_timesteps=1)
  183. model = seq2seq_attention_model.Seq2SeqAttentionModel(
  184. decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
  185. decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
  186. decoder.DecodeLoop()
  187. if __name__ == '__main__':
  188. tf.app.run()