beam_search.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Beam search module.
  16. Beam search takes the top K results from the model, predicts the K results for
  17. each of the previous K result, getting K*K results. Pick the top K results from
  18. K*K results, and start over again until certain number of results are fully
  19. decoded.
  20. """
  21. import tensorflow as tf
  22. FLAGS = tf.flags.FLAGS
  23. tf.flags.DEFINE_bool('normalize_by_length', True, 'Whether normalize')
  24. class Hypothesis(object):
  25. """Defines a hypothesis during beam search."""
  26. def __init__(self, tokens, log_prob, state):
  27. """Hypothesis constructor.
  28. Args:
  29. tokens: start tokens for decoding.
  30. log_prob: log prob of the start tokens, usually 1.
  31. state: decoder initial states.
  32. """
  33. self.tokens = tokens
  34. self.log_prob = log_prob
  35. self.state = state
  36. def Extend(self, token, log_prob, new_state):
  37. """Extend the hypothesis with result from latest step.
  38. Args:
  39. token: latest token from decoding.
  40. log_prob: log prob of the latest decoded tokens.
  41. new_state: decoder output state. Fed to the decoder for next step.
  42. Returns:
  43. New Hypothesis with the results from latest step.
  44. """
  45. return Hypothesis(self.tokens + [token], self.log_prob + log_prob,
  46. new_state)
  47. @property
  48. def latest_token(self):
  49. return self.tokens[-1]
  50. def __str__(self):
  51. return ('Hypothesis(log prob = %.4f, tokens = %s)' % (self.log_prob,
  52. self.tokens))
  53. class BeamSearch(object):
  54. """Beam search."""
  55. def __init__(self, model, beam_size, start_token, end_token, max_steps):
  56. """Creates BeamSearch object.
  57. Args:
  58. model: Seq2SeqAttentionModel.
  59. beam_size: int.
  60. start_token: int, id of the token to start decoding with
  61. end_token: int, id of the token that completes an hypothesis
  62. max_steps: int, upper limit on the size of the hypothesis
  63. """
  64. self._model = model
  65. self._beam_size = beam_size
  66. self._start_token = start_token
  67. self._end_token = end_token
  68. self._max_steps = max_steps
  69. def BeamSearch(self, sess, enc_inputs, enc_seqlen):
  70. """Performs beam search for decoding.
  71. Args:
  72. sess: tf.Session, session
  73. enc_inputs: ndarray of shape (enc_length, 1), the document ids to encode
  74. enc_seqlen: ndarray of shape (1), the length of the sequnce
  75. Returns:
  76. hyps: list of Hypothesis, the best hypotheses found by beam search,
  77. ordered by score
  78. """
  79. # Run the encoder and extract the outputs and final state.
  80. enc_top_states, dec_in_state = self._model.encode_top_state(
  81. sess, enc_inputs, enc_seqlen)
  82. # Replicate the initial states K times for the first step.
  83. hyps = [Hypothesis([self._start_token], 0.0, dec_in_state)
  84. ] * self._beam_size
  85. results = []
  86. steps = 0
  87. while steps < self._max_steps and len(results) < self._beam_size:
  88. latest_tokens = [h.latest_token for h in hyps]
  89. states = [h.state for h in hyps]
  90. topk_ids, topk_log_probs, new_states = self._model.decode_topk(
  91. sess, latest_tokens, enc_top_states, states)
  92. # Extend each hypothesis.
  93. all_hyps = []
  94. # The first step takes the best K results from first hyps. Following
  95. # steps take the best K results from K*K hyps.
  96. num_beam_source = 1 if steps == 0 else len(hyps)
  97. for i in xrange(num_beam_source):
  98. h, ns = hyps[i], new_states[i]
  99. for j in xrange(self._beam_size*2):
  100. all_hyps.append(h.Extend(topk_ids[i, j], topk_log_probs[i, j], ns))
  101. # Filter and collect any hypotheses that have the end token.
  102. hyps = []
  103. for h in self._BestHyps(all_hyps):
  104. if h.latest_token == self._end_token:
  105. # Pull the hypothesis off the beam if the end token is reached.
  106. results.append(h)
  107. else:
  108. # Otherwise continue to the extend the hypothesis.
  109. hyps.append(h)
  110. if len(hyps) == self._beam_size or len(results) == self._beam_size:
  111. break
  112. steps += 1
  113. if steps == self._max_steps:
  114. results.extend(hyps)
  115. return self._BestHyps(results)
  116. def _BestHyps(self, hyps):
  117. """Sort the hyps based on log probs and length.
  118. Args:
  119. hyps: A list of hypothesis.
  120. Returns:
  121. hyps: A list of sorted hypothesis in reverse log_prob order.
  122. """
  123. # This length normalization is only effective for the final results.
  124. if FLAGS.normalize_by_length:
  125. return sorted(hyps, key=lambda h: h.log_prob/len(h.tokens), reverse=True)
  126. else:
  127. return sorted(hyps, key=lambda h: h.log_prob, reverse=True)