seq2seq_attention_decode.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Module for decoding."""
  16. import os
  17. import time
  18. import tensorflow as tf
  19. import beam_search
  20. import data
  21. FLAGS = tf.app.flags.FLAGS
  22. tf.app.flags.DEFINE_integer('max_decode_steps', 1000000,
  23. 'Number of decoding steps.')
  24. tf.app.flags.DEFINE_integer('decode_batches_per_ckpt', 8000,
  25. 'Number of batches to decode before restoring next '
  26. 'checkpoint')
  27. DECODE_LOOP_DELAY_SECS = 60
  28. DECODE_IO_FLUSH_INTERVAL = 100
  29. class DecodeIO(object):
  30. """Writes the decoded and references to RKV files for Rouge score.
  31. See nlp/common/utils/internal/rkv_parser.py for detail about rkv file.
  32. """
  33. def __init__(self, outdir):
  34. self._cnt = 0
  35. self._outdir = outdir
  36. if not os.path.exists(self._outdir):
  37. os.mkdir(self._outdir)
  38. self._ref_file = None
  39. self._decode_file = None
  40. def Write(self, reference, decode):
  41. """Writes the reference and decoded outputs to RKV files.
  42. Args:
  43. reference: The human (correct) result.
  44. decode: The machine-generated result
  45. """
  46. self._ref_file.write('output=%s\n' % reference)
  47. self._decode_file.write('output=%s\n' % decode)
  48. self._cnt += 1
  49. if self._cnt % DECODE_IO_FLUSH_INTERVAL == 0:
  50. self._ref_file.flush()
  51. self._decode_file.flush()
  52. def ResetFiles(self):
  53. """Resets the output files. Must be called once before Write()."""
  54. if self._ref_file: self._ref_file.close()
  55. if self._decode_file: self._decode_file.close()
  56. timestamp = int(time.time())
  57. self._ref_file = open(
  58. os.path.join(self._outdir, 'ref%d'%timestamp), 'w')
  59. self._decode_file = open(
  60. os.path.join(self._outdir, 'decode%d'%timestamp), 'w')
  61. class BSDecoder(object):
  62. """Beam search decoder."""
  63. def __init__(self, model, batch_reader, hps, vocab):
  64. """Beam search decoding.
  65. Args:
  66. model: The seq2seq attentional model.
  67. batch_reader: The batch data reader.
  68. hps: Hyperparamters.
  69. vocab: Vocabulary
  70. """
  71. self._model = model
  72. self._model.build_graph()
  73. self._batch_reader = batch_reader
  74. self._hps = hps
  75. self._vocab = vocab
  76. self._saver = tf.train.Saver()
  77. self._decode_io = DecodeIO(FLAGS.decode_dir)
  78. def DecodeLoop(self):
  79. """Decoding loop for long running process."""
  80. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  81. step = 0
  82. while step < FLAGS.max_decode_steps:
  83. time.sleep(DECODE_LOOP_DELAY_SECS)
  84. if not self._Decode(self._saver, sess):
  85. continue
  86. step += 1
  87. def _Decode(self, saver, sess):
  88. """Restore a checkpoint and decode it.
  89. Args:
  90. saver: Tensorflow checkpoint saver.
  91. sess: Tensorflow session.
  92. Returns:
  93. If success, returns true, otherwise, false.
  94. """
  95. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  96. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  97. tf.logging.info('No model to decode yet at %s', FLAGS.log_root)
  98. return False
  99. tf.logging.info('checkpoint path %s', ckpt_state.model_checkpoint_path)
  100. ckpt_path = os.path.join(
  101. FLAGS.log_root, os.path.basename(ckpt_state.model_checkpoint_path))
  102. tf.logging.info('renamed checkpoint path %s', ckpt_path)
  103. saver.restore(sess, ckpt_path)
  104. self._decode_io.ResetFiles()
  105. for _ in xrange(FLAGS.decode_batches_per_ckpt):
  106. (article_batch, _, _, article_lens, _, _, origin_articles,
  107. origin_abstracts) = self._batch_reader.NextBatch()
  108. for i in xrange(self._hps.batch_size):
  109. bs = beam_search.BeamSearch(
  110. self._model, self._hps.batch_size,
  111. self._vocab.WordToId(data.SENTENCE_START),
  112. self._vocab.WordToId(data.SENTENCE_END),
  113. self._hps.dec_timesteps)
  114. article_batch_cp = article_batch.copy()
  115. article_batch_cp[:] = article_batch[i:i+1]
  116. article_lens_cp = article_lens.copy()
  117. article_lens_cp[:] = article_lens[i:i+1]
  118. best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
  119. decode_output = [int(t) for t in best_beam.tokens[1:]]
  120. self._DecodeBatch(
  121. origin_articles[i], origin_abstracts[i], decode_output)
  122. return True
  123. def _DecodeBatch(self, article, abstract, output_ids):
  124. """Convert id to words and writing results.
  125. Args:
  126. article: The original article string.
  127. abstract: The human (correct) abstract string.
  128. output_ids: The abstract word ids output by machine.
  129. """
  130. decoded_output = ' '.join(data.Ids2Words(output_ids, self._vocab))
  131. end_p = decoded_output.find(data.SENTENCE_END, 0)
  132. if end_p != -1:
  133. decoded_output = decoded_output[:end_p]
  134. tf.logging.info('article: %s', article)
  135. tf.logging.info('abstract: %s', abstract)
  136. tf.logging.info('decoded: %s', decoded_output)
  137. self._decode_io.Write(abstract, decoded_output.strip())