run_inference.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Generate captions for images using default beam search parameters."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. import os
  21. import tensorflow as tf
  22. from im2txt import configuration
  23. from im2txt import inference_wrapper
  24. from im2txt.inference_utils import caption_generator
  25. from im2txt.inference_utils import vocabulary
  26. FLAGS = tf.flags.FLAGS
  27. tf.flags.DEFINE_string("checkpoint_path", "",
  28. "Model checkpoint file or directory containing a "
  29. "model checkpoint file.")
  30. tf.flags.DEFINE_string("vocab_file", "", "Text file containing the vocabulary.")
  31. tf.flags.DEFINE_string("input_files", "",
  32. "File pattern or comma-separated list of file patterns "
  33. "of image files.")
  34. def main(_):
  35. # Build the inference graph.
  36. g = tf.Graph()
  37. with g.as_default():
  38. model = inference_wrapper.InferenceWrapper()
  39. restore_fn = model.build_graph_from_config(configuration.ModelConfig(),
  40. FLAGS.checkpoint_path)
  41. g.finalize()
  42. # Create the vocabulary.
  43. vocab = vocabulary.Vocabulary(FLAGS.vocab_file)
  44. filenames = []
  45. for file_pattern in FLAGS.input_files.split(","):
  46. filenames.extend(tf.gfile.Glob(file_pattern))
  47. tf.logging.info("Running caption generation on %d files matching %s",
  48. len(filenames), FLAGS.input_files)
  49. with tf.Session(graph=g) as sess:
  50. # Load the model from checkpoint.
  51. restore_fn(sess)
  52. # Prepare the caption generator. Here we are implicitly using the default
  53. # beam search parameters. See caption_generator.py for a description of the
  54. # available beam search parameters.
  55. generator = caption_generator.CaptionGenerator(model, vocab)
  56. for filename in filenames:
  57. with tf.gfile.GFile(filename, "r") as f:
  58. image = f.read()
  59. captions = generator.beam_search(sess, image)
  60. print("Captions for image %s:" % os.path.basename(filename))
  61. for i, caption in enumerate(captions):
  62. # Ignore begin and end words.
  63. sentence = [vocab.id_to_word(w) for w in caption.sentence[1:-1]]
  64. sentence = " ".join(sentence)
  65. print(" %d) %s (p=%f)" % (i, sentence, math.exp(caption.logprob)))
  66. if __name__ == "__main__":
  67. tf.app.run()