123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Downloads and converts MNIST data to TFRecords of TF-Example protos.
- This module downloads the MNIST data, uncompresses it, reads the files
- that make up the MNIST data and creates two TFRecord datasets: one for train
- and one for test. Each TFRecord dataset is comprised of a set of TF-Example
- protocol buffers, each of which contain a single image and label.
- The script should take about a minute to run.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import gzip
- import os
- import sys
- import numpy as np
- from six.moves import urllib
- import tensorflow as tf
- from datasets import dataset_utils
- # The URLs where the MNIST data can be downloaded.
- _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
- _TRAIN_DATA_FILENAME = 'train-images-idx3-ubyte.gz'
- _TRAIN_LABELS_FILENAME = 'train-labels-idx1-ubyte.gz'
- _TEST_DATA_FILENAME = 't10k-images-idx3-ubyte.gz'
- _TEST_LABELS_FILENAME = 't10k-labels-idx1-ubyte.gz'
- _IMAGE_SIZE = 28
- _NUM_CHANNELS = 1
- # The names of the classes.
- _CLASS_NAMES = [
- 'zero',
- 'one',
- 'two',
- 'three',
- 'four',
- 'five',
- 'size',
- 'seven',
- 'eight',
- 'nine',
- ]
- def _extract_images(filename, num_images):
- """Extract the images into a numpy array.
- Args:
- filename: The path to an MNIST images file.
- num_images: The number of images in the file.
- Returns:
- A numpy array of shape [number_of_images, height, width, channels].
- """
- print('Extracting images from: ', filename)
- with gzip.open(filename) as bytestream:
- bytestream.read(16)
- buf = bytestream.read(
- _IMAGE_SIZE * _IMAGE_SIZE * num_images * _NUM_CHANNELS)
- data = np.frombuffer(buf, dtype=np.uint8)
- data = data.reshape(num_images, _IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
- return data
- def _extract_labels(filename, num_labels):
- """Extract the labels into a vector of int64 label IDs.
- Args:
- filename: The path to an MNIST labels file.
- num_labels: The number of labels in the file.
- Returns:
- A numpy array of shape [number_of_labels]
- """
- print('Extracting labels from: ', filename)
- with gzip.open(filename) as bytestream:
- bytestream.read(8)
- buf = bytestream.read(1 * num_labels)
- labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int64)
- return labels
- def _add_to_tfrecord(data_filename, labels_filename, num_images,
- tfrecord_writer):
- """Loads data from the binary MNIST files and writes files to a TFRecord.
- Args:
- data_filename: The filename of the MNIST images.
- labels_filename: The filename of the MNIST labels.
- num_images: The number of images in the dataset.
- tfrecord_writer: The TFRecord writer to use for writing.
- """
- images = _extract_images(data_filename, num_images)
- labels = _extract_labels(labels_filename, num_images)
- shape = (_IMAGE_SIZE, _IMAGE_SIZE, _NUM_CHANNELS)
- with tf.Graph().as_default():
- image = tf.placeholder(dtype=tf.uint8, shape=shape)
- encoded_png = tf.image.encode_png(image)
- with tf.Session('') as sess:
- for j in range(num_images):
- sys.stdout.write('\r>> Converting image %d/%d' % (j + 1, num_images))
- sys.stdout.flush()
- png_string = sess.run(encoded_png, feed_dict={image: images[j]})
- example = dataset_utils.image_to_tfexample(
- png_string, 'png'.encode(), _IMAGE_SIZE, _IMAGE_SIZE, labels[j])
- tfrecord_writer.write(example.SerializeToString())
- def _get_output_filename(dataset_dir, split_name):
- """Creates the output filename.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- split_name: The name of the train/test split.
- Returns:
- An absolute file path.
- """
- return '%s/mnist_%s.tfrecord' % (dataset_dir, split_name)
- def _download_dataset(dataset_dir):
- """Downloads MNIST locally.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- """
- for filename in [_TRAIN_DATA_FILENAME,
- _TRAIN_LABELS_FILENAME,
- _TEST_DATA_FILENAME,
- _TEST_LABELS_FILENAME]:
- filepath = os.path.join(dataset_dir, filename)
- if not os.path.exists(filepath):
- print('Downloading file %s...' % filename)
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %.1f%%' % (
- float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(_DATA_URL + filename,
- filepath,
- _progress)
- print()
- with tf.gfile.GFile(filepath) as f:
- size = f.size()
- print('Successfully downloaded', filename, size, 'bytes.')
- def _clean_up_temporary_files(dataset_dir):
- """Removes temporary files used to create the dataset.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- """
- for filename in [_TRAIN_DATA_FILENAME,
- _TRAIN_LABELS_FILENAME,
- _TEST_DATA_FILENAME,
- _TEST_LABELS_FILENAME]:
- filepath = os.path.join(dataset_dir, filename)
- tf.gfile.Remove(filepath)
- def run(dataset_dir):
- """Runs the download and conversion operation.
- Args:
- dataset_dir: The dataset directory where the dataset is stored.
- """
- if not tf.gfile.Exists(dataset_dir):
- tf.gfile.MakeDirs(dataset_dir)
- training_filename = _get_output_filename(dataset_dir, 'train')
- testing_filename = _get_output_filename(dataset_dir, 'test')
- if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
- print('Dataset files already exist. Exiting without re-creating them.')
- return
- _download_dataset(dataset_dir)
- # First, process the training data:
- with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
- data_filename = os.path.join(dataset_dir, _TRAIN_DATA_FILENAME)
- labels_filename = os.path.join(dataset_dir, _TRAIN_LABELS_FILENAME)
- _add_to_tfrecord(data_filename, labels_filename, 60000, tfrecord_writer)
- # Next, process the testing data:
- with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
- data_filename = os.path.join(dataset_dir, _TEST_DATA_FILENAME)
- labels_filename = os.path.join(dataset_dir, _TEST_LABELS_FILENAME)
- _add_to_tfrecord(data_filename, labels_filename, 10000, tfrecord_writer)
- # Finally, write the labels file:
- labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
- dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
- _clean_up_temporary_files(dataset_dir)
- print('\nFinished converting the MNIST dataset!')
|