beam_search.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Beam search module.
  16. Beam search takes the top K results from the model, predicts the K results for
  17. each of the previous K result, getting K*K results. Pick the top K results from
  18. K*K results, and start over again until certain number of results are fully
  19. decoded.
  20. """
  21. from six.moves import xrange
  22. import tensorflow as tf
  23. FLAGS = tf.flags.FLAGS
  24. tf.flags.DEFINE_bool('normalize_by_length', True, 'Whether to normalize')
  25. class Hypothesis(object):
  26. """Defines a hypothesis during beam search."""
  27. def __init__(self, tokens, log_prob, state):
  28. """Hypothesis constructor.
  29. Args:
  30. tokens: start tokens for decoding.
  31. log_prob: log prob of the start tokens, usually 1.
  32. state: decoder initial states.
  33. """
  34. self.tokens = tokens
  35. self.log_prob = log_prob
  36. self.state = state
  37. def Extend(self, token, log_prob, new_state):
  38. """Extend the hypothesis with result from latest step.
  39. Args:
  40. token: latest token from decoding.
  41. log_prob: log prob of the latest decoded tokens.
  42. new_state: decoder output state. Fed to the decoder for next step.
  43. Returns:
  44. New Hypothesis with the results from latest step.
  45. """
  46. return Hypothesis(self.tokens + [token], self.log_prob + log_prob,
  47. new_state)
  48. @property
  49. def latest_token(self):
  50. return self.tokens[-1]
  51. def __str__(self):
  52. return ('Hypothesis(log prob = %.4f, tokens = %s)' % (self.log_prob,
  53. self.tokens))
  54. class BeamSearch(object):
  55. """Beam search."""
  56. def __init__(self, model, beam_size, start_token, end_token, max_steps):
  57. """Creates BeamSearch object.
  58. Args:
  59. model: Seq2SeqAttentionModel.
  60. beam_size: int.
  61. start_token: int, id of the token to start decoding with
  62. end_token: int, id of the token that completes an hypothesis
  63. max_steps: int, upper limit on the size of the hypothesis
  64. """
  65. self._model = model
  66. self._beam_size = beam_size
  67. self._start_token = start_token
  68. self._end_token = end_token
  69. self._max_steps = max_steps
  70. def BeamSearch(self, sess, enc_inputs, enc_seqlen):
  71. """Performs beam search for decoding.
  72. Args:
  73. sess: tf.Session, session
  74. enc_inputs: ndarray of shape (enc_length, 1), the document ids to encode
  75. enc_seqlen: ndarray of shape (1), the length of the sequnce
  76. Returns:
  77. hyps: list of Hypothesis, the best hypotheses found by beam search,
  78. ordered by score
  79. """
  80. # Run the encoder and extract the outputs and final state.
  81. enc_top_states, dec_in_state = self._model.encode_top_state(
  82. sess, enc_inputs, enc_seqlen)
  83. # Replicate the initial states K times for the first step.
  84. hyps = [Hypothesis([self._start_token], 0.0, dec_in_state)
  85. ] * self._beam_size
  86. results = []
  87. steps = 0
  88. while steps < self._max_steps and len(results) < self._beam_size:
  89. latest_tokens = [h.latest_token for h in hyps]
  90. states = [h.state for h in hyps]
  91. topk_ids, topk_log_probs, new_states = self._model.decode_topk(
  92. sess, latest_tokens, enc_top_states, states)
  93. # Extend each hypothesis.
  94. all_hyps = []
  95. # The first step takes the best K results from first hyps. Following
  96. # steps take the best K results from K*K hyps.
  97. num_beam_source = 1 if steps == 0 else len(hyps)
  98. for i in xrange(num_beam_source):
  99. h, ns = hyps[i], new_states[i]
  100. for j in xrange(self._beam_size*2):
  101. all_hyps.append(h.Extend(topk_ids[i, j], topk_log_probs[i, j], ns))
  102. # Filter and collect any hypotheses that have the end token.
  103. hyps = []
  104. for h in self._BestHyps(all_hyps):
  105. if h.latest_token == self._end_token:
  106. # Pull the hypothesis off the beam if the end token is reached.
  107. results.append(h)
  108. else:
  109. # Otherwise continue to the extend the hypothesis.
  110. hyps.append(h)
  111. if len(hyps) == self._beam_size or len(results) == self._beam_size:
  112. break
  113. steps += 1
  114. if steps == self._max_steps:
  115. results.extend(hyps)
  116. return self._BestHyps(results)
  117. def _BestHyps(self, hyps):
  118. """Sort the hyps based on log probs and length.
  119. Args:
  120. hyps: A list of hypothesis.
  121. Returns:
  122. hyps: A list of sorted hypothesis in reverse log_prob order.
  123. """
  124. # This length normalization is only effective for the final results.
  125. if FLAGS.normalize_by_length:
  126. return sorted(hyps, key=lambda h: h.log_prob/len(h.tokens), reverse=True)
  127. else:
  128. return sorted(hyps, key=lambda h: h.log_prob, reverse=True)