imnet_formatting.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""LSUN dataset formatting.
  16. Download and format the Imagenet dataset as follow:
  17. mkdir [IMAGENET_PATH]
  18. cd [IMAGENET_PATH]
  19. for FILENAME in train_32x32.tar valid_32x32.tar train_64x64.tar valid_64x64.tar
  20. do
  21. curl -O http://image-net.org/small/$FILENAME
  22. tar -xvf $FILENAME
  23. done
  24. Then use the script as follow:
  25. for DIRNAME in train_32x32 valid_32x32 train_64x64 valid_64x64
  26. do
  27. python imnet_formatting.py \
  28. --file_out $DIRNAME \
  29. --fn_root $DIRNAME
  30. done
  31. """
  32. import os
  33. import os.path
  34. import scipy.io
  35. import scipy.io.wavfile
  36. import scipy.ndimage
  37. import tensorflow as tf
  38. tf.flags.DEFINE_string("file_out", "",
  39. "Filename of the output .tfrecords file.")
  40. tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
  41. FLAGS = tf.flags.FLAGS
  42. def _int64_feature(value):
  43. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  44. def _bytes_feature(value):
  45. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  46. def main():
  47. """Main converter function."""
  48. # LSUN
  49. fn_root = FLAGS.fn_root
  50. img_fn_list = os.listdir(fn_root)
  51. img_fn_list = [img_fn for img_fn in img_fn_list
  52. if img_fn.endswith('.png')]
  53. num_examples = len(img_fn_list)
  54. n_examples_per_file = 10000
  55. for example_idx, img_fn in enumerate(img_fn_list):
  56. if example_idx % n_examples_per_file == 0:
  57. file_out = "%s_%05d.tfrecords"
  58. file_out = file_out % (FLAGS.file_out,
  59. example_idx // n_examples_per_file)
  60. print "Writing on:", file_out
  61. writer = tf.python_io.TFRecordWriter(file_out)
  62. if example_idx % 1000 == 0:
  63. print example_idx, "/", num_examples
  64. image_raw = scipy.ndimage.imread(os.path.join(fn_root, img_fn))
  65. rows = image_raw.shape[0]
  66. cols = image_raw.shape[1]
  67. depth = image_raw.shape[2]
  68. image_raw = image_raw.astype("uint8")
  69. image_raw = image_raw.tostring()
  70. example = tf.train.Example(
  71. features=tf.train.Features(
  72. feature={
  73. "height": _int64_feature(rows),
  74. "width": _int64_feature(cols),
  75. "depth": _int64_feature(depth),
  76. "image_raw": _bytes_feature(image_raw)
  77. }
  78. )
  79. )
  80. writer.write(example.SerializeToString())
  81. if example_idx % n_examples_per_file == (n_examples_per_file - 1):
  82. writer.close()
  83. writer.close()
  84. if __name__ == "__main__":
  85. main()