inference_wrapper.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Model wrapper class for performing inference with a ShowAndTellModel."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from im2txt import show_and_tell_model
  20. from im2txt.inference_utils import inference_wrapper_base
  21. class InferenceWrapper(inference_wrapper_base.InferenceWrapperBase):
  22. """Model wrapper class for performing inference with a ShowAndTellModel."""
  23. def __init__(self):
  24. super(InferenceWrapper, self).__init__()
  25. def build_model(self, model_config):
  26. model = show_and_tell_model.ShowAndTellModel(model_config, mode="inference")
  27. model.build()
  28. return model
  29. def feed_image(self, sess, encoded_image):
  30. initial_state = sess.run(fetches="lstm/initial_state:0",
  31. feed_dict={"image_feed:0": encoded_image})
  32. return initial_state
  33. def inference_step(self, sess, input_feed, state_feed):
  34. softmax_output, state_output = sess.run(
  35. fetches=["softmax:0", "lstm/state:0"],
  36. feed_dict={
  37. "input_feed:0": input_feed,
  38. "lstm/state_feed:0": state_feed,
  39. })
  40. return softmax_output, state_output, None