decoder.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. #!/usr/bin/python
  2. #
  3. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. # ==============================================================================
  16. r"""Neural Network Image Compression Decoder.
  17. Decompress an image from the numpy's npz format generated by the encoder.
  18. Example usage:
  19. python decoder.py --input_codes=output_codes.pkl --iteration=15 \
  20. --output_directory=/tmp/compression_output/ --model=residual_gru.pb
  21. """
  22. import io
  23. import os
  24. import numpy as np
  25. import tensorflow as tf
  26. tf.flags.DEFINE_string('input_codes', None, 'Location of binary code file.')
  27. tf.flags.DEFINE_integer('iteration', -1, 'The max quality level of '
  28. 'the images to output. Use -1 to infer from loaded '
  29. ' codes.')
  30. tf.flags.DEFINE_string('output_directory', None, 'Directory to save decoded '
  31. 'images.')
  32. tf.flags.DEFINE_string('model', None, 'Location of compression model.')
  33. FLAGS = tf.flags.FLAGS
  34. def get_input_tensor_names():
  35. name_list = ['GruBinarizer/SignBinarizer/Sign:0']
  36. for i in range(1, 16):
  37. name_list.append('GruBinarizer/SignBinarizer/Sign_{}:0'.format(i))
  38. return name_list
  39. def get_output_tensor_names():
  40. return ['loop_{0:02d}/add:0'.format(i) for i in range(0, 16)]
  41. def main(_):
  42. if (FLAGS.input_codes is None or FLAGS.output_directory is None or
  43. FLAGS.model is None):
  44. print('\nUsage: python decoder.py --input_codes=output_codes.pkl '
  45. '--iteration=15 --output_directory=/tmp/compression_output/ '
  46. '--model=residual_gru.pb\n\n')
  47. return
  48. if FLAGS.iteration < -1 or FLAGS.iteration > 15:
  49. print('\n--iteration must be between 0 and 15 inclusive, or -1 to infer '
  50. 'from file.\n')
  51. return
  52. iteration = FLAGS.iteration
  53. if not tf.gfile.Exists(FLAGS.output_directory):
  54. tf.gfile.MkDir(FLAGS.output_directory)
  55. if not tf.gfile.Exists(FLAGS.input_codes):
  56. print('\nInput codes not found.\n')
  57. return
  58. contents = ''
  59. with tf.gfile.FastGFile(FLAGS.input_codes, 'r') as code_file:
  60. contents = code_file.read()
  61. loaded_codes = np.load(io.BytesIO(contents))
  62. assert ['codes', 'shape'] not in loaded_codes.files
  63. loaded_shape = loaded_codes['shape']
  64. loaded_array = loaded_codes['codes']
  65. # Unpack and recover code shapes.
  66. unpacked_codes = np.reshape(np.unpackbits(loaded_array)
  67. [:np.prod(loaded_shape)],
  68. loaded_shape)
  69. numpy_int_codes = np.split(unpacked_codes, len(unpacked_codes))
  70. if iteration == -1:
  71. iteration = len(unpacked_codes) - 1
  72. # Convert back to float and recover scale.
  73. numpy_codes = [np.squeeze(x.astype(np.float32), 0) * 2 - 1 for x in
  74. numpy_int_codes]
  75. with tf.Graph().as_default() as graph:
  76. # Load the inference model for decoding.
  77. with tf.gfile.FastGFile(FLAGS.model, 'rb') as model_file:
  78. graph_def = tf.GraphDef()
  79. graph_def.ParseFromString(model_file.read())
  80. _ = tf.import_graph_def(graph_def, name='')
  81. # For encoding the tensors into PNGs.
  82. input_image = tf.placeholder(tf.uint8)
  83. encoded_image = tf.image.encode_png(input_image)
  84. input_tensors = [graph.get_tensor_by_name(name) for name in
  85. get_input_tensor_names()][0:iteration+1]
  86. outputs = [graph.get_tensor_by_name(name) for name in
  87. get_output_tensor_names()][0:iteration+1]
  88. feed_dict = {key: value for (key, value) in zip(input_tensors,
  89. numpy_codes)}
  90. with tf.Session(graph=graph) as sess:
  91. results = sess.run(outputs, feed_dict=feed_dict)
  92. for index, result in enumerate(results):
  93. img = np.uint8(np.clip(result + 0.5, 0, 255))
  94. img = img.squeeze()
  95. png_img = sess.run(encoded_image, feed_dict={input_image: img})
  96. with tf.gfile.FastGFile(os.path.join(FLAGS.output_directory,
  97. 'image_{0:02d}.png'.format(index)),
  98. 'w') as output_image:
  99. output_image.write(png_img)
  100. if __name__ == '__main__':
  101. tf.app.run()