caption_generator.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Class for generating captions from an image-to-text model."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import heapq
  20. import math
  21. import numpy as np
  22. class Caption(object):
  23. """Represents a complete or partial caption."""
  24. def __init__(self, sentence, state, logprob, score, metadata=None):
  25. """Initializes the Caption.
  26. Args:
  27. sentence: List of word ids in the caption.
  28. state: Model state after generating the previous word.
  29. logprob: Log-probability of the caption.
  30. score: Score of the caption.
  31. metadata: Optional metadata associated with the partial sentence. If not
  32. None, a list of strings with the same length as 'sentence'.
  33. """
  34. self.sentence = sentence
  35. self.state = state
  36. self.logprob = logprob
  37. self.score = score
  38. self.metadata = metadata
  39. def __cmp__(self, other):
  40. """Compares Captions by score."""
  41. assert isinstance(other, Caption)
  42. if self.score == other.score:
  43. return 0
  44. elif self.score < other.score:
  45. return -1
  46. else:
  47. return 1
  48. # For Python 3 compatibility (__cmp__ is deprecated).
  49. def __lt__(self, other):
  50. assert isinstance(other, Caption)
  51. return self.score < other.score
  52. # Also for Python 3 compatibility.
  53. def __eq__(self, other):
  54. assert isinstance(other, Caption)
  55. return self.score == other.score
  56. class TopN(object):
  57. """Maintains the top n elements of an incrementally provided set."""
  58. def __init__(self, n):
  59. self._n = n
  60. self._data = []
  61. def size(self):
  62. assert self._data is not None
  63. return len(self._data)
  64. def push(self, x):
  65. """Pushes a new element."""
  66. assert self._data is not None
  67. if len(self._data) < self._n:
  68. heapq.heappush(self._data, x)
  69. else:
  70. heapq.heappushpop(self._data, x)
  71. def extract(self, sort=False):
  72. """Extracts all elements from the TopN. This is a destructive operation.
  73. The only method that can be called immediately after extract() is reset().
  74. Args:
  75. sort: Whether to return the elements in descending sorted order.
  76. Returns:
  77. A list of data; the top n elements provided to the set.
  78. """
  79. assert self._data is not None
  80. data = self._data
  81. self._data = None
  82. if sort:
  83. data.sort(reverse=True)
  84. return data
  85. def reset(self):
  86. """Returns the TopN to an empty state."""
  87. self._data = []
  88. class CaptionGenerator(object):
  89. """Class to generate captions from an image-to-text model."""
  90. def __init__(self,
  91. model,
  92. vocab,
  93. beam_size=3,
  94. max_caption_length=20,
  95. length_normalization_factor=0.0):
  96. """Initializes the generator.
  97. Args:
  98. model: Object encapsulating a trained image-to-text model. Must have
  99. methods feed_image() and inference_step(). For example, an instance of
  100. InferenceWrapperBase.
  101. vocab: A Vocabulary object.
  102. beam_size: Beam size to use when generating captions.
  103. max_caption_length: The maximum caption length before stopping the search.
  104. length_normalization_factor: If != 0, a number x such that captions are
  105. scored by logprob/length^x, rather than logprob. This changes the
  106. relative scores of captions depending on their lengths. For example, if
  107. x > 0 then longer captions will be favored.
  108. """
  109. self.vocab = vocab
  110. self.model = model
  111. self.beam_size = beam_size
  112. self.max_caption_length = max_caption_length
  113. self.length_normalization_factor = length_normalization_factor
  114. def beam_search(self, sess, encoded_image):
  115. """Runs beam search caption generation on a single image.
  116. Args:
  117. sess: TensorFlow Session object.
  118. encoded_image: An encoded image string.
  119. Returns:
  120. A list of Caption sorted by descending score.
  121. """
  122. # Feed in the image to get the initial state.
  123. initial_state = self.model.feed_image(sess, encoded_image)
  124. initial_beam = Caption(
  125. sentence=[self.vocab.start_id],
  126. state=initial_state[0],
  127. logprob=0.0,
  128. score=0.0,
  129. metadata=[""])
  130. partial_captions = TopN(self.beam_size)
  131. partial_captions.push(initial_beam)
  132. complete_captions = TopN(self.beam_size)
  133. # Run beam search.
  134. for _ in range(self.max_caption_length - 1):
  135. partial_captions_list = partial_captions.extract()
  136. partial_captions.reset()
  137. input_feed = np.array([c.sentence[-1] for c in partial_captions_list])
  138. state_feed = np.array([c.state for c in partial_captions_list])
  139. softmax, new_states, metadata = self.model.inference_step(sess,
  140. input_feed,
  141. state_feed)
  142. for i, partial_caption in enumerate(partial_captions_list):
  143. word_probabilities = softmax[i]
  144. state = new_states[i]
  145. # For this partial caption, get the beam_size most probable next words.
  146. words_and_probs = list(enumerate(word_probabilities))
  147. words_and_probs.sort(key=lambda x: -x[1])
  148. words_and_probs = words_and_probs[0:self.beam_size]
  149. # Each next word gives a new partial caption.
  150. for w, p in words_and_probs:
  151. if p < 1e-12:
  152. continue # Avoid log(0).
  153. sentence = partial_caption.sentence + [w]
  154. logprob = partial_caption.logprob + math.log(p)
  155. score = logprob
  156. if metadata:
  157. metadata_list = partial_caption.metadata + [metadata[i]]
  158. else:
  159. metadata_list = None
  160. if w == self.vocab.end_id:
  161. if self.length_normalization_factor > 0:
  162. score /= len(sentence)**self.length_normalization_factor
  163. beam = Caption(sentence, state, logprob, score, metadata_list)
  164. complete_captions.push(beam)
  165. else:
  166. beam = Caption(sentence, state, logprob, score, metadata_list)
  167. partial_captions.push(beam)
  168. if partial_captions.size() == 0:
  169. # We have run out of partial candidates; happens when beam_size = 1.
  170. break
  171. # If we have no complete captions then fall back to the partial captions.
  172. # But never output a mixture of complete and partial captions because a
  173. # partial caption could have a higher score than all the complete captions.
  174. if not complete_captions.size():
  175. complete_captions = partial_captions
  176. return complete_captions.extract(sort=True)