decoder.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. #!/usr/bin/python
  2. #
  3. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. r"""Neural Network Image Compression Decoder.
  17. Decompress an image from the numpy's npz format generated by the encoder.
  18. Example usage:
  19. python decoder.py --input_codes=output_codes.pkl --iteration=15 \
  20. --output_directory=/tmp/compression_output/ --model=residual_gru.pb
  21. """
  22. import os
  23. import numpy as np
  24. import tensorflow as tf
  25. tf.flags.DEFINE_string('input_codes', None, 'Location of binary code file.')
  26. tf.flags.DEFINE_integer('iteration', -1, 'The max quality level of '
  27. 'the images to output. Use -1 to infer from loaded '
  28. ' codes.')
  29. tf.flags.DEFINE_string('output_directory', None, 'Directory to save decoded '
  30. 'images.')
  31. tf.flags.DEFINE_string('model', None, 'Location of compression model.')
  32. FLAGS = tf.flags.FLAGS
  33. def get_input_tensor_names():
  34. name_list = ['GruBinarizer/SignBinarizer/Sign:0']
  35. for i in xrange(1, 16):
  36. name_list.append('GruBinarizer/SignBinarizer/Sign_{}:0'.format(i))
  37. return name_list
  38. def get_output_tensor_names():
  39. return ['loop_{0:02d}/add:0'.format(i) for i in xrange(0, 16)]
  40. def main(_):
  41. if (FLAGS.input_codes is None or FLAGS.output_directory is None or
  42. FLAGS.model is None):
  43. print ('\nUsage: python decoder.py --input_codes=output_codes.pkl '
  44. '--iteration=15 --output_directory=/tmp/compression_output/ '
  45. '--model=residual_gru.pb\n\n')
  46. return
  47. if FLAGS.iteration < -1 or FLAGS.iteration > 15:
  48. print ('\n--iteration must be between 0 and 15 inclusive, or -1 to infer '
  49. 'from file.\n')
  50. return
  51. iteration = FLAGS.iteration
  52. if not tf.gfile.Exists(FLAGS.output_directory):
  53. tf.gfile.MkDir(FLAGS.output_directory)
  54. if not tf.gfile.Exists(FLAGS.input_codes):
  55. print '\nInput codes not found.\n'
  56. return
  57. with tf.gfile.FastGFile(FLAGS.input_codes, 'rb') as code_file:
  58. loaded_codes = np.load(code_file)
  59. assert ['codes', 'shape'] not in loaded_codes.files
  60. loaded_shape = loaded_codes['shape']
  61. loaded_array = loaded_codes['codes']
  62. # Unpack and recover code shapes.
  63. unpacked_codes = np.reshape(np.unpackbits(loaded_array)
  64. [:np.prod(loaded_shape)],
  65. loaded_shape)
  66. numpy_int_codes = np.split(unpacked_codes, len(unpacked_codes))
  67. if iteration == -1:
  68. iteration = len(unpacked_codes) - 1
  69. # Convert back to float and recover scale.
  70. numpy_codes = [np.squeeze(x.astype(np.float32), 0) * 2 - 1 for x in
  71. numpy_int_codes]
  72. with tf.Graph().as_default() as graph:
  73. # Load the inference model for decoding.
  74. with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
  75. graph_def = tf.GraphDef()
  76. graph_def.ParseFromString(model_file.read())
  77. _ = tf.import_graph_def(graph_def, name='')
  78. # For encoding the tensors into PNGs.
  79. input_image = tf.placeholder(tf.uint8)
  80. encoded_image = tf.image.encode_png(input_image)
  81. input_tensors = [graph.get_tensor_by_name(name) for name in
  82. get_input_tensor_names()][0:iteration+1]
  83. outputs = [graph.get_tensor_by_name(name) for name in
  84. get_output_tensor_names()][0:iteration+1]
  85. feed_dict = {key: value for (key, value) in zip(input_tensors,
  86. numpy_codes)}
  87. with tf.Session(graph=graph) as sess:
  88. results = sess.run(outputs, feed_dict=feed_dict)
  89. for index, result in enumerate(results):
  90. img = np.uint8(np.clip(result + 0.5, 0, 255))
  91. img = img.squeeze()
  92. png_img = sess.run(encoded_image, feed_dict={input_image: img})
  93. with tf.gfile.FastGFile(os.path.join(FLAGS.output_directory,
  94. 'image_{0:02d}.png'.format(index)),
  95. 'w') as output_image:
  96. output_image.write(png_img)
  97. if __name__ == '__main__':
  98. tf.app.run()