seq2seq_attention_decode.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Module for decoding."""
  16. import os
  17. import time
  18. import beam_search
  19. import data
  20. from six.moves import xrange
  21. import tensorflow as tf
  22. FLAGS = tf.app.flags.FLAGS
  23. tf.app.flags.DEFINE_integer('max_decode_steps', 1000000,
  24. 'Number of decoding steps.')
  25. tf.app.flags.DEFINE_integer('decode_batches_per_ckpt', 8000,
  26. 'Number of batches to decode before restoring next '
  27. 'checkpoint')
  28. DECODE_LOOP_DELAY_SECS = 60
  29. DECODE_IO_FLUSH_INTERVAL = 100
  30. class DecodeIO(object):
  31. """Writes the decoded and references to RKV files for Rouge score.
  32. See nlp/common/utils/internal/rkv_parser.py for detail about rkv file.
  33. """
  34. def __init__(self, outdir):
  35. self._cnt = 0
  36. self._outdir = outdir
  37. if not os.path.exists(self._outdir):
  38. os.mkdir(self._outdir)
  39. self._ref_file = None
  40. self._decode_file = None
  41. def Write(self, reference, decode):
  42. """Writes the reference and decoded outputs to RKV files.
  43. Args:
  44. reference: The human (correct) result.
  45. decode: The machine-generated result
  46. """
  47. self._ref_file.write('output=%s\n' % reference)
  48. self._decode_file.write('output=%s\n' % decode)
  49. self._cnt += 1
  50. if self._cnt % DECODE_IO_FLUSH_INTERVAL == 0:
  51. self._ref_file.flush()
  52. self._decode_file.flush()
  53. def ResetFiles(self):
  54. """Resets the output files. Must be called once before Write()."""
  55. if self._ref_file: self._ref_file.close()
  56. if self._decode_file: self._decode_file.close()
  57. timestamp = int(time.time())
  58. self._ref_file = open(
  59. os.path.join(self._outdir, 'ref%d'%timestamp), 'w')
  60. self._decode_file = open(
  61. os.path.join(self._outdir, 'decode%d'%timestamp), 'w')
  62. class BSDecoder(object):
  63. """Beam search decoder."""
  64. def __init__(self, model, batch_reader, hps, vocab):
  65. """Beam search decoding.
  66. Args:
  67. model: The seq2seq attentional model.
  68. batch_reader: The batch data reader.
  69. hps: Hyperparamters.
  70. vocab: Vocabulary
  71. """
  72. self._model = model
  73. self._model.build_graph()
  74. self._batch_reader = batch_reader
  75. self._hps = hps
  76. self._vocab = vocab
  77. self._saver = tf.train.Saver()
  78. self._decode_io = DecodeIO(FLAGS.decode_dir)
  79. def DecodeLoop(self):
  80. """Decoding loop for long running process."""
  81. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  82. step = 0
  83. while step < FLAGS.max_decode_steps:
  84. time.sleep(DECODE_LOOP_DELAY_SECS)
  85. if not self._Decode(self._saver, sess):
  86. continue
  87. step += 1
  88. def _Decode(self, saver, sess):
  89. """Restore a checkpoint and decode it.
  90. Args:
  91. saver: Tensorflow checkpoint saver.
  92. sess: Tensorflow session.
  93. Returns:
  94. If success, returns true, otherwise, false.
  95. """
  96. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  97. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  98. tf.logging.info('No model to decode yet at %s', FLAGS.log_root)
  99. return False
  100. tf.logging.info('checkpoint path %s', ckpt_state.model_checkpoint_path)
  101. ckpt_path = os.path.join(
  102. FLAGS.log_root, os.path.basename(ckpt_state.model_checkpoint_path))
  103. tf.logging.info('renamed checkpoint path %s', ckpt_path)
  104. saver.restore(sess, ckpt_path)
  105. self._decode_io.ResetFiles()
  106. for _ in xrange(FLAGS.decode_batches_per_ckpt):
  107. (article_batch, _, _, article_lens, _, _, origin_articles,
  108. origin_abstracts) = self._batch_reader.NextBatch()
  109. for i in xrange(self._hps.batch_size):
  110. bs = beam_search.BeamSearch(
  111. self._model, self._hps.batch_size,
  112. self._vocab.WordToId(data.SENTENCE_START),
  113. self._vocab.WordToId(data.SENTENCE_END),
  114. self._hps.dec_timesteps)
  115. article_batch_cp = article_batch.copy()
  116. article_batch_cp[:] = article_batch[i:i+1]
  117. article_lens_cp = article_lens.copy()
  118. article_lens_cp[:] = article_lens[i:i+1]
  119. best_beam = bs.BeamSearch(sess, article_batch_cp, article_lens_cp)[0]
  120. decode_output = [int(t) for t in best_beam.tokens[1:]]
  121. self._DecodeBatch(
  122. origin_articles[i], origin_abstracts[i], decode_output)
  123. return True
  124. def _DecodeBatch(self, article, abstract, output_ids):
  125. """Convert id to words and writing results.
  126. Args:
  127. article: The original article string.
  128. abstract: The human (correct) abstract string.
  129. output_ids: The abstract word ids output by machine.
  130. """
  131. decoded_output = ' '.join(data.Ids2Words(output_ids, self._vocab))
  132. end_p = decoded_output.find(data.SENTENCE_END, 0)
  133. if end_p != -1:
  134. decoded_output = decoded_output[:end_p]
  135. tf.logging.info('article: %s', article)
  136. tf.logging.info('abstract: %s', abstract)
  137. tf.logging.info('decoded: %s', decoded_output)
  138. self._decode_io.Write(abstract, decoded_output.strip())