cifar10_input.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Routine for decoding the CIFAR-10 binary file format."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import os
  20. from six.moves import xrange # pylint: disable=redefined-builtin
  21. import tensorflow as tf
  22. # Process images of this size. Note that this differs from the original CIFAR
  23. # image size of 32 x 32. If one alters this number, then the entire model
  24. # architecture will change and any model would need to be retrained.
  25. IMAGE_SIZE = 24
  26. # Global constants describing the CIFAR-10 data set.
  27. NUM_CLASSES = 10
  28. NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000
  29. NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000
  30. def read_cifar10(filename_queue):
  31. """Reads and parses examples from CIFAR10 data files.
  32. Recommendation: if you want N-way read parallelism, call this function
  33. N times. This will give you N independent Readers reading different
  34. files & positions within those files, which will give better mixing of
  35. examples.
  36. Args:
  37. filename_queue: A queue of strings with the filenames to read from.
  38. Returns:
  39. An object representing a single example, with the following fields:
  40. height: number of rows in the result (32)
  41. width: number of columns in the result (32)
  42. depth: number of color channels in the result (3)
  43. key: a scalar string Tensor describing the filename & record number
  44. for this example.
  45. label: an int32 Tensor with the label in the range 0..9.
  46. uint8image: a [height, width, depth] uint8 Tensor with the image data
  47. """
  48. class CIFAR10Record(object):
  49. pass
  50. result = CIFAR10Record()
  51. # Dimensions of the images in the CIFAR-10 dataset.
  52. # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
  53. # input format.
  54. label_bytes = 1 # 2 for CIFAR-100
  55. result.height = 32
  56. result.width = 32
  57. result.depth = 3
  58. image_bytes = result.height * result.width * result.depth
  59. # Every record consists of a label followed by the image, with a
  60. # fixed number of bytes for each.
  61. record_bytes = label_bytes + image_bytes
  62. # Read a record, getting filenames from the filename_queue. No
  63. # header or footer in the CIFAR-10 format, so we leave header_bytes
  64. # and footer_bytes at their default of 0.
  65. reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
  66. result.key, value = reader.read(filename_queue)
  67. # Convert from a string to a vector of uint8 that is record_bytes long.
  68. record_bytes = tf.decode_raw(value, tf.uint8)
  69. # The first bytes represent the label, which we convert from uint8->int32.
  70. result.label = tf.cast(
  71. tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32)
  72. # The remaining bytes after the label represent the image, which we reshape
  73. # from [depth * height * width] to [depth, height, width].
  74. depth_major = tf.reshape(
  75. tf.strided_slice(record_bytes, [label_bytes],
  76. [label_bytes + image_bytes]),
  77. [result.depth, result.height, result.width])
  78. # Convert from [depth, height, width] to [height, width, depth].
  79. result.uint8image = tf.transpose(depth_major, [1, 2, 0])
  80. return result
  81. def _generate_image_and_label_batch(image, label, min_queue_examples,
  82. batch_size, shuffle):
  83. """Construct a queued batch of images and labels.
  84. Args:
  85. image: 3-D Tensor of [height, width, 3] of type.float32.
  86. label: 1-D Tensor of type.int32
  87. min_queue_examples: int32, minimum number of samples to retain
  88. in the queue that provides of batches of examples.
  89. batch_size: Number of images per batch.
  90. shuffle: boolean indicating whether to use a shuffling queue.
  91. Returns:
  92. images: Images. 4D tensor of [batch_size, height, width, 3] size.
  93. labels: Labels. 1D tensor of [batch_size] size.
  94. """
  95. # Create a queue that shuffles the examples, and then
  96. # read 'batch_size' images + labels from the example queue.
  97. num_preprocess_threads = 16
  98. if shuffle:
  99. images, label_batch = tf.train.shuffle_batch(
  100. [image, label],
  101. batch_size=batch_size,
  102. num_threads=num_preprocess_threads,
  103. capacity=min_queue_examples + 3 * batch_size,
  104. min_after_dequeue=min_queue_examples)
  105. else:
  106. images, label_batch = tf.train.batch(
  107. [image, label],
  108. batch_size=batch_size,
  109. num_threads=num_preprocess_threads,
  110. capacity=min_queue_examples + 3 * batch_size)
  111. # Display the training images in the visualizer.
  112. tf.summary.image('images', images)
  113. return images, tf.reshape(label_batch, [batch_size])
  114. def distorted_inputs(data_dir, batch_size):
  115. """Construct distorted input for CIFAR training using the Reader ops.
  116. Args:
  117. data_dir: Path to the CIFAR-10 data directory.
  118. batch_size: Number of images per batch.
  119. Returns:
  120. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
  121. labels: Labels. 1D tensor of [batch_size] size.
  122. """
  123. filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
  124. for i in xrange(1, 6)]
  125. for f in filenames:
  126. if not tf.gfile.Exists(f):
  127. raise ValueError('Failed to find file: ' + f)
  128. # Create a queue that produces the filenames to read.
  129. filename_queue = tf.train.string_input_producer(filenames)
  130. # Read examples from files in the filename queue.
  131. read_input = read_cifar10(filename_queue)
  132. reshaped_image = tf.cast(read_input.uint8image, tf.float32)
  133. height = IMAGE_SIZE
  134. width = IMAGE_SIZE
  135. # Image processing for training the network. Note the many random
  136. # distortions applied to the image.
  137. # Randomly crop a [height, width] section of the image.
  138. distorted_image = tf.random_crop(reshaped_image, [height, width, 3])
  139. # Randomly flip the image horizontally.
  140. distorted_image = tf.image.random_flip_left_right(distorted_image)
  141. # Because these operations are not commutative, consider randomizing
  142. # the order their operation.
  143. distorted_image = tf.image.random_brightness(distorted_image,
  144. max_delta=63)
  145. distorted_image = tf.image.random_contrast(distorted_image,
  146. lower=0.2, upper=1.8)
  147. # Subtract off the mean and divide by the variance of the pixels.
  148. float_image = tf.image.per_image_standardization(distorted_image)
  149. # Set the shapes of tensors.
  150. float_image.set_shape([height, width, 3])
  151. read_input.label.set_shape([1])
  152. # Ensure that the random shuffling has good mixing properties.
  153. min_fraction_of_examples_in_queue = 0.4
  154. min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
  155. min_fraction_of_examples_in_queue)
  156. print ('Filling queue with %d CIFAR images before starting to train. '
  157. 'This will take a few minutes.' % min_queue_examples)
  158. # Generate a batch of images and labels by building up a queue of examples.
  159. return _generate_image_and_label_batch(float_image, read_input.label,
  160. min_queue_examples, batch_size,
  161. shuffle=True)
  162. def inputs(eval_data, data_dir, batch_size):
  163. """Construct input for CIFAR evaluation using the Reader ops.
  164. Args:
  165. eval_data: bool, indicating if one should use the train or eval data set.
  166. data_dir: Path to the CIFAR-10 data directory.
  167. batch_size: Number of images per batch.
  168. Returns:
  169. images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
  170. labels: Labels. 1D tensor of [batch_size] size.
  171. """
  172. if not eval_data:
  173. filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i)
  174. for i in xrange(1, 6)]
  175. num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN
  176. else:
  177. filenames = [os.path.join(data_dir, 'test_batch.bin')]
  178. num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL
  179. for f in filenames:
  180. if not tf.gfile.Exists(f):
  181. raise ValueError('Failed to find file: ' + f)
  182. # Create a queue that produces the filenames to read.
  183. filename_queue = tf.train.string_input_producer(filenames)
  184. # Read examples from files in the filename queue.
  185. read_input = read_cifar10(filename_queue)
  186. reshaped_image = tf.cast(read_input.uint8image, tf.float32)
  187. height = IMAGE_SIZE
  188. width = IMAGE_SIZE
  189. # Image processing for evaluation.
  190. # Crop the central [height, width] of the image.
  191. resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image,
  192. height, width)
  193. # Subtract off the mean and divide by the variance of the pixels.
  194. float_image = tf.image.per_image_standardization(resized_image)
  195. # Set the shapes of tensors.
  196. float_image.set_shape([height, width, 3])
  197. read_input.label.set_shape([1])
  198. # Ensure that the random shuffling has good mixing properties.
  199. min_fraction_of_examples_in_queue = 0.4
  200. min_queue_examples = int(num_examples_per_epoch *
  201. min_fraction_of_examples_in_queue)
  202. # Generate a batch of images and labels by building up a queue of examples.
  203. return _generate_image_and_label_batch(float_image, read_input.label,
  204. min_queue_examples, batch_size,
  205. shuffle=False)