123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Trains a seq2seq model.
- WORK IN PROGRESS.
- Implement "Abstractive Text Summarization using Sequence-to-sequence RNNS and
- Beyond."
- """
- import sys
- import time
- import tensorflow as tf
- import batch_reader
- import data
- import seq2seq_attention_decode
- import seq2seq_attention_model
- FLAGS = tf.app.flags.FLAGS
- tf.app.flags.DEFINE_string('data_path',
- '', 'Path expression to tf.Example.')
- tf.app.flags.DEFINE_string('vocab_path',
- '', 'Path expression to text vocabulary file.')
- tf.app.flags.DEFINE_string('article_key', 'article',
- 'tf.Example feature key for article.')
- tf.app.flags.DEFINE_string('abstract_key', 'headline',
- 'tf.Example feature key for abstract.')
- tf.app.flags.DEFINE_string('log_root', '', 'Directory for model root.')
- tf.app.flags.DEFINE_string('train_dir', '', 'Directory for train.')
- tf.app.flags.DEFINE_string('eval_dir', '', 'Directory for eval.')
- tf.app.flags.DEFINE_string('decode_dir', '', 'Directory for decode summaries.')
- tf.app.flags.DEFINE_string('mode', 'train', 'train/eval/decode mode')
- tf.app.flags.DEFINE_integer('max_run_steps', 10000000,
- 'Maximum number of run steps.')
- tf.app.flags.DEFINE_integer('max_article_sentences', 2,
- 'Max number of first sentences to use from the '
- 'article')
- tf.app.flags.DEFINE_integer('max_abstract_sentences', 100,
- 'Max number of first sentences to use from the '
- 'abstract')
- tf.app.flags.DEFINE_integer('beam_size', 4,
- 'beam size for beam search decoding.')
- tf.app.flags.DEFINE_integer('eval_interval_secs', 60, 'How often to run eval.')
- tf.app.flags.DEFINE_integer('checkpoint_secs', 60, 'How often to checkpoint.')
- tf.app.flags.DEFINE_bool('use_bucketing', False,
- 'Whether bucket articles of similar length.')
- tf.app.flags.DEFINE_bool('truncate_input', False,
- 'Truncate inputs that are too long. If False, '
- 'examples that are too long are discarded.')
- tf.app.flags.DEFINE_integer('num_gpus', 0, 'Number of gpus used.')
- tf.app.flags.DEFINE_integer('random_seed', 111, 'A seed value for randomness.')
- def _RunningAvgLoss(loss, running_avg_loss, summary_writer, step, decay=0.999):
- """Calculate the running average of losses."""
- if running_avg_loss == 0:
- running_avg_loss = loss
- else:
- running_avg_loss = running_avg_loss * decay + (1 - decay) * loss
- running_avg_loss = min(running_avg_loss, 12)
- loss_sum = tf.Summary()
- loss_sum.value.add(tag='running_avg_loss', simple_value=running_avg_loss)
- summary_writer.add_summary(loss_sum, step)
- sys.stdout.write('running_avg_loss: %f\n' % running_avg_loss)
- return running_avg_loss
- def _Train(model, data_batcher):
- """Runs model training."""
- with tf.device('/cpu:0'):
- model.build_graph()
- saver = tf.train.Saver()
- # Train dir is different from log_root to avoid summary directory
- # conflict with Supervisor.
- summary_writer = tf.summary.FileWriter(FLAGS.train_dir)
- sv = tf.train.Supervisor(logdir=FLAGS.log_root,
- is_chief=True,
- saver=saver,
- summary_op=None,
- save_summaries_secs=60,
- save_model_secs=FLAGS.checkpoint_secs,
- global_step=model.global_step)
- sess = sv.prepare_or_wait_for_session(config=tf.ConfigProto(
- allow_soft_placement=True))
- running_avg_loss = 0
- step = 0
- while not sv.should_stop() and step < FLAGS.max_run_steps:
- (article_batch, abstract_batch, targets, article_lens, abstract_lens,
- loss_weights, _, _) = data_batcher.NextBatch()
- (_, summaries, loss, train_step) = model.run_train_step(
- sess, article_batch, abstract_batch, targets, article_lens,
- abstract_lens, loss_weights)
- summary_writer.add_summary(summaries, train_step)
- running_avg_loss = _RunningAvgLoss(
- running_avg_loss, loss, summary_writer, train_step)
- step += 1
- if step % 100 == 0:
- summary_writer.flush()
- sv.Stop()
- return running_avg_loss
- def _Eval(model, data_batcher, vocab=None):
- """Runs model eval."""
- model.build_graph()
- saver = tf.train.Saver()
- summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
- sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
- running_avg_loss = 0
- step = 0
- while True:
- time.sleep(FLAGS.eval_interval_secs)
- try:
- ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
- except tf.errors.OutOfRangeError as e:
- tf.logging.error('Cannot restore checkpoint: %s', e)
- continue
- if not (ckpt_state and ckpt_state.model_checkpoint_path):
- tf.logging.info('No model to eval yet at %s', FLAGS.train_dir)
- continue
- tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
- saver.restore(sess, ckpt_state.model_checkpoint_path)
- (article_batch, abstract_batch, targets, article_lens, abstract_lens,
- loss_weights, _, _) = data_batcher.NextBatch()
- (summaries, loss, train_step) = model.run_eval_step(
- sess, article_batch, abstract_batch, targets, article_lens,
- abstract_lens, loss_weights)
- tf.logging.info(
- 'article: %s',
- ' '.join(data.Ids2Words(article_batch[0][:].tolist(), vocab)))
- tf.logging.info(
- 'abstract: %s',
- ' '.join(data.Ids2Words(abstract_batch[0][:].tolist(), vocab)))
- summary_writer.add_summary(summaries, train_step)
- running_avg_loss = _RunningAvgLoss(
- running_avg_loss, loss, summary_writer, train_step)
- if step % 100 == 0:
- summary_writer.flush()
- def main(unused_argv):
- vocab = data.Vocab(FLAGS.vocab_path, 1000000)
- # Check for presence of required special tokens.
- assert vocab.CheckVocab(data.PAD_TOKEN) > 0
- assert vocab.CheckVocab(data.UNKNOWN_TOKEN) >= 0
- assert vocab.CheckVocab(data.SENTENCE_START) > 0
- assert vocab.CheckVocab(data.SENTENCE_END) > 0
- batch_size = 4
- if FLAGS.mode == 'decode':
- batch_size = FLAGS.beam_size
- hps = seq2seq_attention_model.HParams(
- mode=FLAGS.mode, # train, eval, decode
- min_lr=0.01, # min learning rate.
- lr=0.15, # learning rate
- batch_size=batch_size,
- enc_layers=4,
- enc_timesteps=120,
- dec_timesteps=30,
- min_input_len=2, # discard articles/summaries < than this
- num_hidden=256, # for rnn cell
- emb_dim=128, # If 0, don't use embedding
- max_grad_norm=2,
- num_softmax_samples=4096) # If 0, no sampled softmax.
- batcher = batch_reader.Batcher(
- FLAGS.data_path, vocab, hps, FLAGS.article_key,
- FLAGS.abstract_key, FLAGS.max_article_sentences,
- FLAGS.max_abstract_sentences, bucketing=FLAGS.use_bucketing,
- truncate_input=FLAGS.truncate_input)
- tf.set_random_seed(FLAGS.random_seed)
- if hps.mode == 'train':
- model = seq2seq_attention_model.Seq2SeqAttentionModel(
- hps, vocab, num_gpus=FLAGS.num_gpus)
- _Train(model, batcher)
- elif hps.mode == 'eval':
- model = seq2seq_attention_model.Seq2SeqAttentionModel(
- hps, vocab, num_gpus=FLAGS.num_gpus)
- _Eval(model, batcher, vocab=vocab)
- elif hps.mode == 'decode':
- decode_mdl_hps = hps
- # Only need to restore the 1st step and reuse it since
- # we keep and feed in state for each step's output.
- decode_mdl_hps = hps._replace(dec_timesteps=1)
- model = seq2seq_attention_model.Seq2SeqAttentionModel(
- decode_mdl_hps, vocab, num_gpus=FLAGS.num_gpus)
- decoder = seq2seq_attention_decode.BSDecoder(model, batcher, hps, vocab)
- decoder.DecodeLoop()
- if __name__ == '__main__':
- tf.app.run()
|