download_and_convert_flowers.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Downloads and converts Flowers data to TFRecords of TF-Example protos.
  16. This module downloads the Flowers data, uncompresses it, reads the files
  17. that make up the Flowers data and creates two TFRecord datasets: one for train
  18. and one for test. Each TFRecord dataset is comprised of a set of TF-Example
  19. protocol buffers, each of which contain a single image and label.
  20. The script should take about a minute to run.
  21. """
  22. from __future__ import absolute_import
  23. from __future__ import division
  24. from __future__ import print_function
  25. import math
  26. import os
  27. import random
  28. import sys
  29. import tensorflow as tf
  30. from datasets import dataset_utils
  31. # The URL where the Flowers data can be downloaded.
  32. _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz'
  33. # The number of images in the validation set.
  34. _NUM_VALIDATION = 350
  35. # Seed for repeatability.
  36. _RANDOM_SEED = 0
  37. # The number of shards per dataset split.
  38. _NUM_SHARDS = 5
  39. class ImageReader(object):
  40. """Helper class that provides TensorFlow image coding utilities."""
  41. def __init__(self):
  42. # Initializes function that decodes RGB JPEG data.
  43. self._decode_jpeg_data = tf.placeholder(dtype=tf.string)
  44. self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3)
  45. def read_image_dims(self, sess, image_data):
  46. image = self.decode_jpeg(sess, image_data)
  47. return image.shape[0], image.shape[1]
  48. def decode_jpeg(self, sess, image_data):
  49. image = sess.run(self._decode_jpeg,
  50. feed_dict={self._decode_jpeg_data: image_data})
  51. assert len(image.shape) == 3
  52. assert image.shape[2] == 3
  53. return image
  54. def _get_filenames_and_classes(dataset_dir):
  55. """Returns a list of filenames and inferred class names.
  56. Args:
  57. dataset_dir: A directory containing a set of subdirectories representing
  58. class names. Each subdirectory should contain PNG or JPG encoded images.
  59. Returns:
  60. A list of image file paths, relative to `dataset_dir` and the list of
  61. subdirectories, representing class names.
  62. """
  63. flower_root = os.path.join(dataset_dir, 'flower_photos')
  64. directories = []
  65. class_names = []
  66. for filename in os.listdir(flower_root):
  67. path = os.path.join(flower_root, filename)
  68. if os.path.isdir(path):
  69. directories.append(path)
  70. class_names.append(filename)
  71. photo_filenames = []
  72. for directory in directories:
  73. for filename in os.listdir(directory):
  74. path = os.path.join(directory, filename)
  75. photo_filenames.append(path)
  76. return photo_filenames, sorted(class_names)
  77. def _get_dataset_filename(dataset_dir, split_name, shard_id):
  78. output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % (
  79. split_name, shard_id, _NUM_SHARDS)
  80. return os.path.join(dataset_dir, output_filename)
  81. def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir):
  82. """Converts the given filenames to a TFRecord dataset.
  83. Args:
  84. split_name: The name of the dataset, either 'train' or 'validation'.
  85. filenames: A list of absolute paths to png or jpg images.
  86. class_names_to_ids: A dictionary from class names (strings) to ids
  87. (integers).
  88. dataset_dir: The directory where the converted datasets are stored.
  89. """
  90. assert split_name in ['train', 'validation']
  91. num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS)))
  92. with tf.Graph().as_default():
  93. image_reader = ImageReader()
  94. with tf.Session('') as sess:
  95. for shard_id in range(_NUM_SHARDS):
  96. output_filename = _get_dataset_filename(
  97. dataset_dir, split_name, shard_id)
  98. with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
  99. start_ndx = shard_id * num_per_shard
  100. end_ndx = min((shard_id+1) * num_per_shard, len(filenames))
  101. for i in range(start_ndx, end_ndx):
  102. sys.stdout.write('\r>> Converting image %d/%d shard %d' % (
  103. i+1, len(filenames), shard_id))
  104. sys.stdout.flush()
  105. # Read the filename:
  106. image_data = tf.gfile.FastGFile(filenames[i], 'r').read()
  107. height, width = image_reader.read_image_dims(sess, image_data)
  108. class_name = os.path.basename(os.path.dirname(filenames[i]))
  109. class_id = class_names_to_ids[class_name]
  110. example = dataset_utils.image_to_tfexample(
  111. image_data, 'jpg', height, width, class_id)
  112. tfrecord_writer.write(example.SerializeToString())
  113. sys.stdout.write('\n')
  114. sys.stdout.flush()
  115. def _clean_up_temporary_files(dataset_dir):
  116. """Removes temporary files used to create the dataset.
  117. Args:
  118. dataset_dir: The directory where the temporary files are stored.
  119. """
  120. filename = _DATA_URL.split('/')[-1]
  121. filepath = os.path.join(dataset_dir, filename)
  122. tf.gfile.Remove(filepath)
  123. tmp_dir = os.path.join(dataset_dir, 'flower_photos')
  124. tf.gfile.DeleteRecursively(tmp_dir)
  125. def _dataset_exists(dataset_dir):
  126. for split_name in ['train', 'validation']:
  127. for shard_id in range(_NUM_SHARDS):
  128. output_filename = _get_dataset_filename(
  129. dataset_dir, split_name, shard_id)
  130. if not tf.gfile.Exists(output_filename):
  131. return False
  132. return True
  133. def run(dataset_dir):
  134. """Runs the download and conversion operation.
  135. Args:
  136. dataset_dir: The dataset directory where the dataset is stored.
  137. """
  138. if not tf.gfile.Exists(dataset_dir):
  139. tf.gfile.MakeDirs(dataset_dir)
  140. if _dataset_exists(dataset_dir):
  141. print('Dataset files already exist. Exiting without re-creating them.')
  142. return
  143. dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
  144. photo_filenames, class_names = _get_filenames_and_classes(dataset_dir)
  145. class_names_to_ids = dict(zip(class_names, range(len(class_names))))
  146. # Divide into train and test:
  147. random.seed(_RANDOM_SEED)
  148. random.shuffle(photo_filenames)
  149. training_filenames = photo_filenames[_NUM_VALIDATION:]
  150. validation_filenames = photo_filenames[:_NUM_VALIDATION]
  151. # First, convert the training and validation sets.
  152. _convert_dataset('train', training_filenames, class_names_to_ids,
  153. dataset_dir)
  154. _convert_dataset('validation', validation_filenames, class_names_to_ids,
  155. dataset_dir)
  156. # Finally, write the labels file:
  157. labels_to_class_names = dict(zip(range(len(class_names)), class_names))
  158. dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
  159. _clean_up_temporary_files(dataset_dir)
  160. print('\nFinished converting the Flowers dataset!')