123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
- This module downloads the MNIST-M data, uncompresses it, reads the files
- that make up the MNIST-M data and creates two TFRecord datasets: one for train
- and one for test. Each TFRecord dataset is comprised of a set of TF-Example
- protocol buffers, each of which contain a single image and label.
- The script should take about a minute to run.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import random
- import sys
- import google3
- import numpy as np
- from six.moves import urllib
- import tensorflow as tf
- from google3.third_party.tensorflow_models.slim.datasets import dataset_utils
- tf.app.flags.DEFINE_string(
- 'dataset_dir', None,
- 'The directory where the output TFRecords and temporary files are saved.')
- FLAGS = tf.app.flags.FLAGS
- # The URLs where the MNIST-M data can be downloaded.
- _DATA_URL = 'http://yann.lecun.com/exdb/mnist/'
- _TRAIN_DATA_DIR = 'mnist_m_train'
- _TRAIN_LABELS_FILENAME = 'mnist_m_train_labels'
- _TEST_DATA_DIR = 'mnist_m_test'
- _TEST_LABELS_FILENAME = 'mnist_m_test_labels'
- _IMAGE_SIZE = 32
- _NUM_CHANNELS = 3
- # The number of images in the training set.
- _NUM_TRAIN_SAMPLES = 59001
- # The number of images to be kept from the training set for the validation set.
- _NUM_VALIDATION = 1000
- # The number of images in the test set.
- _NUM_TEST_SAMPLES = 9001
- # Seed for repeatability.
- _RANDOM_SEED = 0
- # The names of the classes.
- _CLASS_NAMES = [
- 'zero',
- 'one',
- 'two',
- 'three',
- 'four',
- 'five',
- 'size',
- 'seven',
- 'eight',
- 'nine',
- ]
- class ImageReader(object):
- """Helper class that provides TensorFlow image coding utilities."""
- def __init__(self):
- # Initializes function that decodes RGB PNG data.
- self._decode_png_data = tf.placeholder(dtype=tf.string)
- self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)
- def read_image_dims(self, sess, image_data):
- image = self.decode_png(sess, image_data)
- return image.shape[0], image.shape[1]
- def decode_png(self, sess, image_data):
- image = sess.run(
- self._decode_png, feed_dict={self._decode_png_data: image_data})
- assert len(image.shape) == 3
- assert image.shape[2] == 3
- return image
- def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
- """Converts the given filenames to a TFRecord dataset.
- Args:
- split_name: The name of the dataset, either 'train' or 'valid'.
- filenames: A list of absolute paths to png images.
- filename_to_class_id: A dictionary from filenames (strings) to class ids
- (integers).
- dataset_dir: The directory where the converted datasets are stored.
- """
- print('Converting the {} split.'.format(split_name))
- # Train and validation splits are both in the train directory.
- if split_name in ['train', 'valid']:
- png_directory = os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train')
- elif split_name == 'test':
- png_directory = os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test')
- with tf.Graph().as_default():
- image_reader = ImageReader()
- with tf.Session('') as sess:
- output_filename = _get_output_filename(dataset_dir, split_name)
- with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
- for filename in filenames:
- # Read the filename:
- image_data = tf.gfile.FastGFile(
- os.path.join(png_directory, filename), 'r').read()
- height, width = image_reader.read_image_dims(sess, image_data)
- class_id = filename_to_class_id[filename]
- example = dataset_utils.image_to_tfexample(image_data, 'png', height,
- width, class_id)
- tfrecord_writer.write(example.SerializeToString())
- sys.stdout.write('\n')
- sys.stdout.flush()
- def _extract_labels(label_filename):
- """Extract the labels into a dict of filenames to int labels.
- Args:
- labels_filename: The filename of the MNIST-M labels.
- Returns:
- A dictionary of filenames to int labels.
- """
- print('Extracting labels from: ', label_filename)
- label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
- label_lines = [line.rstrip('\n').split() for line in label_file]
- labels = {}
- for line in label_lines:
- assert len(line) == 2
- labels[line[0]] = int(line[1])
- return labels
- def _get_output_filename(dataset_dir, split_name):
- """Creates the output filename.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- split_name: The name of the train/test split.
- Returns:
- An absolute file path.
- """
- return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)
- def _get_filenames(dataset_dir):
- """Returns a list of filenames and inferred class names.
- Args:
- dataset_dir: A directory containing a set PNG encoded MNIST-M images.
- Returns:
- A list of image file paths, relative to `dataset_dir`.
- """
- photo_filenames = []
- for filename in os.listdir(dataset_dir):
- photo_filenames.append(filename)
- return photo_filenames
- def run(dataset_dir):
- """Runs the download and conversion operation.
- Args:
- dataset_dir: The dataset directory where the dataset is stored.
- """
- if not tf.gfile.Exists(dataset_dir):
- tf.gfile.MakeDirs(dataset_dir)
- train_filename = _get_output_filename(dataset_dir, 'train')
- testing_filename = _get_output_filename(dataset_dir, 'test')
- if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
- print('Dataset files already exist. Exiting without re-creating them.')
- return
- #TODO(konstantinos): Add download and cleanup functionality
- train_validation_filenames = _get_filenames(
- os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train'))
- test_filenames = _get_filenames(
- os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test'))
- # Divide into train and validation:
- random.seed(_RANDOM_SEED)
- random.shuffle(train_validation_filenames)
- train_filenames = train_validation_filenames[_NUM_VALIDATION:]
- validation_filenames = train_validation_filenames[:_NUM_VALIDATION]
- train_validation_filenames_to_class_ids = _extract_labels(
- os.path.join(dataset_dir, 'mnist-m', 'mnist_m_train_labels.txt'))
- test_filenames_to_class_ids = _extract_labels(
- os.path.join(dataset_dir, 'mnist-m', 'mnist_m_test_labels.txt'))
- # Convert the train, validation, and test sets.
- _convert_dataset('train', train_filenames,
- train_validation_filenames_to_class_ids, dataset_dir)
- _convert_dataset('valid', validation_filenames,
- train_validation_filenames_to_class_ids, dataset_dir)
- _convert_dataset('test', test_filenames, test_filenames_to_class_ids,
- dataset_dir)
- # Finally, write the labels file:
- labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
- dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
- print('\nFinished converting the MNIST-M dataset!')
- def main(_):
- run(FLAGS.dataset_dir)
- if __name__ == '__main__':
- tf.app.run()
|