download_and_convert_mnist.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
  16. This script downloads the MNIST data, uncompresses it, reads the files
  17. that make up the MNIST data and creates two TFRecord datasets: one for train
  18. and one for test. Each TFRecord dataset is comprised of a set of TF-Example
  19. protocol buffers, each of which contain a single image and label.
  20. The script should take about a minute to run.
  21. Usage:
  22. $ bazel build slim:download_and_convert_mnist
  23. $ .bazel-bin/slim/download_and_convert_mnist --dataset_dir=[DIRECTORY]
  24. """
  25. from __future__ import absolute_import
  26. from __future__ import division
  27. from __future__ import print_function
  28. import gzip
  29. import os
  30. import sys
  31. import numpy as np
  32. from six.moves import urllib
  33. import tensorflow as tf
  34. from slim.datasets import dataset_utils
  35. tf.app.flags.DEFINE_string(
  36. 'dataset_dir',
  37. None,
  38. 'The directory where the output TFRecords and temporary files are saved.')
  39. FLAGS = tf.app.flags.FLAGS
  40. # The URLs where the MNIST data can be downloaded.
  41. _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
  42. _TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
  43. _TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
  44. _TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
  45. _TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
  46. _IMAGE_SIZE = 28
  47. _NUM_CHANNELS = 1
  48. # The names of the classes.
  49. _CLASS_NAMES = [
  50. 'zero',
  51. 'one',
  52. 'two',
  53. 'three',
  54. 'four',
  55. 'five',
  56. 'size',
  57. 'seven',
  58. 'eight',
  59. 'nine',
  60. ]
  61. def _extract_images(filename, num_images):
  62. """Extract the images into a numpy array.
  63. Args:
  64. filename: The path to an MNIST images file.
  65. num_images: The number of images in the file.
  66. Returns:
  67. A numpy array of shape [number_of_images, height, width, channels].
  68. """
  69. print('Extracting images from: ', filename)
  70. with gzip.open(filename) as bytestream:
  71. bytestream.read(16)
  72. buf = bytestream.read(
  73. _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
  74. data = np.frombuffer(buf, dtype=np.uint8)
  75. data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  76. return data
  77. def _extract_labels(filename, num_labels):
  78. """Extract the labels into a vector of int64 label IDs.
  79. Args:
  80. filename: The path to an MNIST labels file.
  81. num_labels: The number of labels in the file.
  82. Returns:
  83. A numpy array of shape [number_of_labels]
  84. """
  85. print('Extracting labels from: ', filename)
  86. with gzip.open(filename) as bytestream:
  87. bytestream.read(8)
  88. buf = bytestream.read(1 * num_labels)
  89. labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  90. return labels
  91. def _add_to_tfrecord(data_filename, labels_filename, num_images,
  92. tfrecord_writer):
  93. """Loads data from the binary MNIST files and writes files to a TFRecord.
  94. Args:
  95. data_filename: The filename of the MNIST images.
  96. labels_filename: The filename of the MNIST labels.
  97. num_images: The number of images in the dataset.
  98. tfrecord_writer: The TFRecord writer to use for writing.
  99. """
  100. images = _extract_images(data_filename, num_images)
  101. labels = _extract_labels(labels_filename, num_images)
  102. shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  103. with tf.Graph().as_default():
  104. image = tf.placeholder(dtype=tf.uint8, shape=shape)
  105. encoded_png = tf.image.encode_png(image)
  106. with tf.Session('') as sess:
  107. for j in range(num_images):
  108. sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
  109. sys.stdout.flush()
  110. png_string = sess.run(encoded_png, feed_dict={image: images[j]})
  111. example = dataset_utils.image_to_tfexample(
  112. png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, labels[j])
  113. tfrecord_writer.write(example.SerializeToString())
  114. def _get_output_filename(split_name):
  115. """Creates the output filename.
  116. Args:
  117. split_name: The name of the train/test split.
  118. Returns:
  119. An absolute file path.
  120. """
  121. return '%s/mnist_%s.tfrecord' % (FLAGS.dataset_dir, split_name)
  122. def _download_dataset(dataset_dir):
  123. """Downloads MNIST locally.
  124. Args:
  125. dataset_dir: The directory where the temporary files are stored.
  126. """
  127. for filename in [_TRAIN_DATA_FILENAME,
  128. _TRAIN_LABELS_FILENAME,
  129. _TEST_DATA_FILENAME,
  130. _TEST_LABELS_FILENAME]:
  131. filepath = os.path.join(dataset_dir, filename)
  132. if not os.path.exists(filepath):
  133. print('Downloading file %s...' % filename)
  134. def _progress(count, block_size, total_size):
  135. sys.stdout.write('\r>> Downloading %.1f%%' % (
  136. float(count * block_size) / float(total_size) * 100.0))
  137. sys.stdout.flush()
  138. filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
  139. filepath,
  140. _progress)
  141. print()
  142. with tf.gfile.GFile(filepath) as f:
  143. size = f.Size()
  144. print('Successfully downloaded', filename, size, 'bytes.')
  145. def _clean_up_temporary_files(dataset_dir):
  146. """Removes temporary files used to create the dataset.
  147. Args:
  148. dataset_dir: The directory where the temporary files are stored.
  149. """
  150. for filename in [_TRAIN_DATA_FILENAME,
  151. _TRAIN_LABELS_FILENAME,
  152. _TEST_DATA_FILENAME,
  153. _TEST_LABELS_FILENAME]:
  154. filepath = os.path.join(dataset_dir, filename)
  155. tf.gfile.Remove(filepath)
  156. def main(_):
  157. if not FLAGS.dataset_dir:
  158. raise ValueError('You must supply the dataset directory with --dataset_dir')
  159. if not tf.gfile.Exists(FLAGS.dataset_dir):
  160. tf.gfile.MakeDirs(FLAGS.dataset_dir)
  161. _download_dataset(FLAGS.dataset_dir)
  162. # First, process the training data:
  163. output_file = _get_output_filename('train')
  164. with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
  165. data_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_DATA_FILENAME)
  166. labels_filename = os.path.join(FLAGS.dataset_dir, _TRAIN_LABELS_FILENAME)
  167. _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
  168. # Next, process the testing data:
  169. output_file = _get_output_filename('test')
  170. with tf.python_io.TFRecordWriter(output_file) as tfrecord_writer:
  171. data_filename = os.path.join(FLAGS.dataset_dir, _TEST_DATA_FILENAME)
  172. labels_filename = os.path.join(FLAGS.dataset_dir, _TEST_LABELS_FILENAME)
  173. _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
  174. # Finally, write the labels file:
  175. labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
  176. dataset_utils.write_label_file(labels_to_class_names, FLAGS.dataset_dir)
  177. _clean_up_temporary_files(FLAGS.dataset_dir)
  178. print('\nFinished converting the MNIST dataset!')
  179. if __name__ == '__main__':
  180. tf.app.run()