فهرست منبع

Merge pull request #510 from nickj-google/master

Fix IO handling in compression models.
Rohan Jain 9 سال پیش
والد
کامیت
199db00e33
2فایلهای تغییر یافته به همراه10 افزوده شده و 4 حذف شده
  1. 5 2
      compression/decoder.py
  2. 5 2
      compression/encoder.py

+ 5 - 2
compression/decoder.py

@@ -21,6 +21,7 @@ Example usage:
 python decoder.py --input_codes=output_codes.pkl --iteration=15 \
 --output_directory=/tmp/compression_output/ --model=residual_gru.pb
 """
+import io
 import os
 
 import numpy as np
@@ -69,8 +70,10 @@ def main(_):
     print '\nInput codes not found.\n'
     return
 
-  with tf.gfile.FastGFile(FLAGS.input_codes, 'rb') as code_file:
-    loaded_codes = np.load(code_file)
+  contents = ''
+  with tf.gfile.FastGFile(FLAGS.input_codes, 'r') as code_file:
+    contents = code_file.read()
+    loaded_codes = np.load(io.BytesIO(contents))
     assert ['codes', 'shape'] not in loaded_codes.files
     loaded_shape = loaded_codes['shape']
     loaded_array = loaded_codes['codes']

+ 5 - 2
compression/encoder.py

@@ -23,6 +23,7 @@ Example usage:
 python encoder.py --input_image=/your/image/here.png \
 --output_codes=output_codes.pkl --iteration=15 --model=residual_gru.pb
 """
+import io
 import os
 
 import numpy as np
@@ -94,8 +95,10 @@ def main(_):
   int_codes = (int_codes + 1)/2
   export = np.packbits(int_codes.reshape(-1))
 
-  with tf.gfile.FastGFile(FLAGS.output_codes, 'wb') as code_file:
-    np.savez_compressed(code_file, shape=int_codes.shape, codes=export)
+  output = io.BytesIO()
+  np.savez_compressed(output, shape=int_codes.shape, codes=export)
+  with tf.gfile.FastGFile(FLAGS.output_codes, 'w') as code_file:
+    code_file.write(output.getvalue())
 
 
 if __name__ == '__main__':