lsun_formatting.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""LSUN dataset formatting.
  16. Download and format the LSUN dataset as follow:
  17. git clone https://github.com/fyu/lsun.git
  18. cd lsun
  19. python2.7 download.py -c [CATEGORY]
  20. Then unzip the downloaded .zip files before executing:
  21. python2.7 data.py export [IMAGE_DB_PATH] --out_dir [LSUN_FOLDER] --flat
  22. Then use the script as follow:
  23. python lsun_formatting.py \
  24. --file_out [OUTPUT_FILE_PATH_PREFIX] \
  25. --fn_root [LSUN_FOLDER]
  26. """
  27. import os
  28. import os.path
  29. import numpy
  30. import skimage.transform
  31. from PIL import Image
  32. import tensorflow as tf
  33. tf.flags.DEFINE_string("file_out", "",
  34. "Filename of the output .tfrecords file.")
  35. tf.flags.DEFINE_string("fn_root", "", "Name of root file path.")
  36. FLAGS = tf.flags.FLAGS
  37. def _int64_feature(value):
  38. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  39. def _bytes_feature(value):
  40. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  41. def main():
  42. """Main converter function."""
  43. fn_root = FLAGS.fn_root
  44. img_fn_list = os.listdir(fn_root)
  45. img_fn_list = [img_fn for img_fn in img_fn_list
  46. if img_fn.endswith('.webp')]
  47. num_examples = len(img_fn_list)
  48. n_examples_per_file = 10000
  49. for example_idx, img_fn in enumerate(img_fn_list):
  50. if example_idx % n_examples_per_file == 0:
  51. file_out = "%s_%05d.tfrecords"
  52. file_out = file_out % (FLAGS.file_out,
  53. example_idx // n_examples_per_file)
  54. print "Writing on:", file_out
  55. writer = tf.python_io.TFRecordWriter(file_out)
  56. if example_idx % 1000 == 0:
  57. print example_idx, "/", num_examples
  58. image_raw = numpy.array(Image.open(os.path.join(fn_root, img_fn)))
  59. rows = image_raw.shape[0]
  60. cols = image_raw.shape[1]
  61. depth = image_raw.shape[2]
  62. downscale = min(rows / 96., cols / 96.)
  63. image_raw = skimage.transform.pyramid_reduce(image_raw, downscale)
  64. image_raw *= 255.
  65. image_raw = image_raw.astype("uint8")
  66. rows = image_raw.shape[0]
  67. cols = image_raw.shape[1]
  68. depth = image_raw.shape[2]
  69. image_raw = image_raw.tostring()
  70. example = tf.train.Example(
  71. features=tf.train.Features(
  72. feature={
  73. "height": _int64_feature(rows),
  74. "width": _int64_feature(cols),
  75. "depth": _int64_feature(depth),
  76. "image_raw": _bytes_feature(image_raw)
  77. }
  78. )
  79. )
  80. writer.write(example.SerializeToString())
  81. if example_idx % n_examples_per_file == (n_examples_per_file - 1):
  82. writer.close()
  83. writer.close()
  84. if __name__ == "__main__":
  85. main()