inference_wrapper_base.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Base wrapper class for performing inference with an image-to-text model.
  16. Subclasses must implement the following methods:
  17. build_model():
  18. Builds the model for inference and returns the model object.
  19. feed_image():
  20. Takes an encoded image and returns the initial model state, where "state"
  21. is a numpy array whose specifics are defined by the subclass, e.g.
  22. concatenated LSTM state. It's assumed that feed_image() will be called
  23. precisely once at the start of inference for each image. Subclasses may
  24. compute and/or save per-image internal context in this method.
  25. inference_step():
  26. Takes a batch of inputs and states at a single time-step. Returns the
  27. softmax output corresponding to the inputs, and the new states of the batch.
  28. Optionally also returns metadata about the current inference step, e.g. a
  29. serialized numpy array containing activations from a particular model layer.
  30. Client usage:
  31. 1. Build the model inference graph via build_graph_from_config() or
  32. build_graph_from_proto().
  33. 2. Call the resulting restore_fn to load the model checkpoint.
  34. 3. For each image in a batch of images:
  35. a) Call feed_image() once to get the initial state.
  36. b) For each step of caption generation, call inference_step().
  37. """
  38. from __future__ import absolute_import
  39. from __future__ import division
  40. from __future__ import print_function
  41. import os.path
  42. import tensorflow as tf
  43. # pylint: disable=unused-argument
  44. class InferenceWrapperBase(object):
  45. """Base wrapper class for performing inference with an image-to-text model."""
  46. def __init__(self):
  47. pass
  48. def build_model(self, model_config):
  49. """Builds the model for inference.
  50. Args:
  51. model_config: Object containing configuration for building the model.
  52. Returns:
  53. model: The model object.
  54. """
  55. tf.logging.fatal("Please implement build_model in subclass")
  56. def _create_restore_fn(self, checkpoint_path, saver):
  57. """Creates a function that restores a model from checkpoint.
  58. Args:
  59. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  60. file.
  61. saver: Saver for restoring variables from the checkpoint file.
  62. Returns:
  63. restore_fn: A function such that restore_fn(sess) loads model variables
  64. from the checkpoint file.
  65. Raises:
  66. ValueError: If checkpoint_path does not refer to a checkpoint file or a
  67. directory containing a checkpoint file.
  68. """
  69. if tf.gfile.IsDirectory(checkpoint_path):
  70. checkpoint_path = tf.train.latest_checkpoint(checkpoint_path)
  71. if not checkpoint_path:
  72. raise ValueError("No checkpoint file found in: %s" % checkpoint_path)
  73. def _restore_fn(sess):
  74. tf.logging.info("Loading model from checkpoint: %s", checkpoint_path)
  75. saver.restore(sess, checkpoint_path)
  76. tf.logging.info("Successfully loaded checkpoint: %s",
  77. os.path.basename(checkpoint_path))
  78. return _restore_fn
  79. def build_graph_from_config(self, model_config, checkpoint_path):
  80. """Builds the inference graph from a configuration object.
  81. Args:
  82. model_config: Object containing configuration for building the model.
  83. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  84. file.
  85. Returns:
  86. restore_fn: A function such that restore_fn(sess) loads model variables
  87. from the checkpoint file.
  88. """
  89. tf.logging.info("Building model.")
  90. self.build_model(model_config)
  91. saver = tf.train.Saver()
  92. return self._create_restore_fn(checkpoint_path, saver)
  93. def build_graph_from_proto(self, graph_def_file, saver_def_file,
  94. checkpoint_path):
  95. """Builds the inference graph from serialized GraphDef and SaverDef protos.
  96. Args:
  97. graph_def_file: File containing a serialized GraphDef proto.
  98. saver_def_file: File containing a serialized SaverDef proto.
  99. checkpoint_path: Checkpoint file or a directory containing a checkpoint
  100. file.
  101. Returns:
  102. restore_fn: A function such that restore_fn(sess) loads model variables
  103. from the checkpoint file.
  104. """
  105. # Load the Graph.
  106. tf.logging.info("Loading GraphDef from file: %s", graph_def_file)
  107. graph_def = tf.GraphDef()
  108. with tf.gfile.FastGFile(graph_def_file, "rb") as f:
  109. graph_def.ParseFromString(f.read())
  110. tf.import_graph_def(graph_def, name="")
  111. # Load the Saver.
  112. tf.logging.info("Loading SaverDef from file: %s", saver_def_file)
  113. saver_def = tf.train.SaverDef()
  114. with tf.gfile.FastGFile(saver_def_file, "rb") as f:
  115. saver_def.ParseFromString(f.read())
  116. saver = tf.train.Saver(saver_def=saver_def)
  117. return self._create_restore_fn(checkpoint_path, saver)
  118. def feed_image(self, sess, encoded_image):
  119. """Feeds an image and returns the initial model state.
  120. See comments at the top of file.
  121. Args:
  122. sess: TensorFlow Session object.
  123. encoded_image: An encoded image string.
  124. Returns:
  125. state: A numpy array of shape [1, state_size].
  126. """
  127. tf.logging.fatal("Please implement feed_image in subclass")
  128. def inference_step(self, sess, input_feed, state_feed):
  129. """Runs one step of inference.
  130. Args:
  131. sess: TensorFlow Session object.
  132. input_feed: A numpy array of shape [batch_size].
  133. state_feed: A numpy array of shape [batch_size, state_size].
  134. Returns:
  135. softmax_output: A numpy array of shape [batch_size, vocab_size].
  136. new_state: A numpy array of shape [batch_size, state_size].
  137. metadata: Optional. If not None, a string containing metadata about the
  138. current inference step (e.g. serialized numpy array containing
  139. activations from a particular model layer.).
  140. """
  141. tf.logging.fatal("Please implement inference_step in subclass")
  142. # pylint: enable=unused-argument