download_and_convert_mnist.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
  16. This module downloads the MNIST data, uncompresses it, reads the files
  17. that make up the MNIST data and creates two TFRecord datasets: one for train
  18. and one for test. Each TFRecord dataset is comprised of a set of TF-Example
  19. protocol buffers, each of which contain a single image and label.
  20. The script should take about a minute to run.
  21. """
  22. from __future__ import absolute_import
  23. from __future__ import division
  24. from __future__ import print_function
  25. import gzip
  26. import os
  27. import sys
  28. import numpy as np
  29. from six.moves import urllib
  30. import tensorflow as tf
  31. from datasets import dataset_utils
  32. # The URLs where the MNIST data can be downloaded.
  33. _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
  34. _TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
  35. _TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
  36. _TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
  37. _TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
  38. _IMAGE_SIZE = 28
  39. _NUM_CHANNELS = 1
  40. # The names of the classes.
  41. _CLASS_NAMES = [
  42. 'zero',
  43. 'one',
  44. 'two',
  45. 'three',
  46. 'four',
  47. 'five',
  48. 'size',
  49. 'seven',
  50. 'eight',
  51. 'nine',
  52. ]
  53. def _extract_images(filename, num_images):
  54. """Extract the images into a numpy array.
  55. Args:
  56. filename: The path to an MNIST images file.
  57. num_images: The number of images in the file.
  58. Returns:
  59. A numpy array of shape [number_of_images, height, width, channels].
  60. """
  61. print('Extracting images from: ', filename)
  62. with gzip.open(filename) as bytestream:
  63. bytestream.read(16)
  64. buf = bytestream.read(
  65. _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
  66. data = np.frombuffer(buf, dtype=np.uint8)
  67. data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  68. return data
  69. def _extract_labels(filename, num_labels):
  70. """Extract the labels into a vector of int64 label IDs.
  71. Args:
  72. filename: The path to an MNIST labels file.
  73. num_labels: The number of labels in the file.
  74. Returns:
  75. A numpy array of shape [number_of_labels]
  76. """
  77. print('Extracting labels from: ', filename)
  78. with gzip.open(filename) as bytestream:
  79. bytestream.read(8)
  80. buf = bytestream.read(1 * num_labels)
  81. labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
  82. return labels
  83. def _add_to_tfrecord(data_filename, labels_filename, num_images,
  84. tfrecord_writer):
  85. """Loads data from the binary MNIST files and writes files to a TFRecord.
  86. Args:
  87. data_filename: The filename of the MNIST images.
  88. labels_filename: The filename of the MNIST labels.
  89. num_images: The number of images in the dataset.
  90. tfrecord_writer: The TFRecord writer to use for writing.
  91. """
  92. images = _extract_images(data_filename, num_images)
  93. labels = _extract_labels(labels_filename, num_images)
  94. shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
  95. with tf.Graph().as_default():
  96. image = tf.placeholder(dtype=tf.uint8, shape=shape)
  97. encoded_png = tf.image.encode_png(image)
  98. with tf.Session('') as sess:
  99. for j in range(num_images):
  100. sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
  101. sys.stdout.flush()
  102. png_string = sess.run(encoded_png, feed_dict={image: images[j]})
  103. example = dataset_utils.image_to_tfexample(
  104. png_string, 'png'.encode(), _IMAGE_SIZE, _IMAGE_SIZE, labels[j])
  105. tfrecord_writer.write(example.SerializeToString())
  106. def _get_output_filename(dataset_dir, split_name):
  107. """Creates the output filename.
  108. Args:
  109. dataset_dir: The directory where the temporary files are stored.
  110. split_name: The name of the train/test split.
  111. Returns:
  112. An absolute file path.
  113. """
  114. return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)
  115. def _download_dataset(dataset_dir):
  116. """Downloads MNIST locally.
  117. Args:
  118. dataset_dir: The directory where the temporary files are stored.
  119. """
  120. for filename in [_TRAIN_DATA_FILENAME,
  121. _TRAIN_LABELS_FILENAME,
  122. _TEST_DATA_FILENAME,
  123. _TEST_LABELS_FILENAME]:
  124. filepath = os.path.join(dataset_dir, filename)
  125. if not os.path.exists(filepath):
  126. print('Downloading file %s...' % filename)
  127. def _progress(count, block_size, total_size):
  128. sys.stdout.write('\r>> Downloading %.1f%%' % (
  129. float(count * block_size) / float(total_size) * 100.0))
  130. sys.stdout.flush()
  131. filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
  132. filepath,
  133. _progress)
  134. print()
  135. with tf.gfile.GFile(filepath) as f:
  136. size = f.size()
  137. print('Successfully downloaded', filename, size, 'bytes.')
  138. def _clean_up_temporary_files(dataset_dir):
  139. """Removes temporary files used to create the dataset.
  140. Args:
  141. dataset_dir: The directory where the temporary files are stored.
  142. """
  143. for filename in [_TRAIN_DATA_FILENAME,
  144. _TRAIN_LABELS_FILENAME,
  145. _TEST_DATA_FILENAME,
  146. _TEST_LABELS_FILENAME]:
  147. filepath = os.path.join(dataset_dir, filename)
  148. tf.gfile.Remove(filepath)
  149. def run(dataset_dir):
  150. """Runs the download and conversion operation.
  151. Args:
  152. dataset_dir: The dataset directory where the dataset is stored.
  153. """
  154. if not tf.gfile.Exists(dataset_dir):
  155. tf.gfile.MakeDirs(dataset_dir)
  156. training_filename = _get_output_filename(dataset_dir, 'train')
  157. testing_filename = _get_output_filename(dataset_dir, 'test')
  158. if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
  159. print('Dataset files already exist. Exiting without re-creating them.')
  160. return
  161. _download_dataset(dataset_dir)
  162. # First, process the training data:
  163. with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
  164. data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
  165. labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
  166. _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
  167. # Next, process the testing data:
  168. with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
  169. data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
  170. labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
  171. _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
  172. # Finally, write the labels file:
  173. labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
  174. dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
  175. _clean_up_temporary_files(dataset_dir)
  176. print('\nFinished converting the MNIST dataset!')