| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- r"""Downloads and converts cifar10 data to TFRecords of TF-Example protos.
- This module downloads the cifar10 data, uncompresses it, reads the files
- that make up the cifar10 data and creates two TFRecord datasets: one for train
- and one for test. Each TFRecord dataset is comprised of a set of TF-Example
- protocol buffers, each of which contain a single image and label.
- The script should take several minutes to run.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import cPickle
- import os
- import sys
- import tarfile
- import numpy as np
- from six.moves import urllib
- import tensorflow as tf
- from datasets import dataset_utils
- # The URL where the CIFAR data can be downloaded.
- _DATA_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
- # The number of training files.
- _NUM_TRAIN_FILES = 5
- # The height and width of each image.
- _IMAGE_SIZE = 32
- # The names of the classes.
- _CLASS_NAMES = [
- 'airplane',
- 'automobile',
- 'bird',
- 'cat',
- 'deer',
- 'dog',
- 'frog',
- 'horse',
- 'ship',
- 'truck',
- ]
- def _add_to_tfrecord(filename, tfrecord_writer, offset=0):
- """Loads data from the cifar10 pickle files and writes files to a TFRecord.
- Args:
- filename: The filename of the cifar10 pickle file.
- tfrecord_writer: The TFRecord writer to use for writing.
- offset: An offset into the absolute number of images previously written.
- Returns:
- The new offset.
- """
- with tf.gfile.Open(filename, 'r') as f:
- data = cPickle.load(f)
- images = data['data']
- num_images = images.shape[0]
- images = images.reshape((num_images, 3, 32, 32))
- labels = data['labels']
- with tf.Graph().as_default():
- image_placeholder = tf.placeholder(dtype=tf.uint8)
- encoded_image = tf.image.encode_png(image_placeholder)
- with tf.Session('') as sess:
- for j in range(num_images):
- sys.stdout.write('\r>> Reading file [%s] image %d/%d' % (
- filename, offset + j + 1, offset + num_images))
- sys.stdout.flush()
- image = np.squeeze(images[j]).transpose((1, 2, 0))
- label = labels[j]
- png_string = sess.run(encoded_image,
- feed_dict={image_placeholder: image})
- example = dataset_utils.image_to_tfexample(
- png_string, 'png', _IMAGE_SIZE, _IMAGE_SIZE, label)
- tfrecord_writer.write(example.SerializeToString())
- return offset + num_images
- def _get_output_filename(dataset_dir, split_name):
- """Creates the output filename.
- Args:
- dataset_dir: The dataset directory where the dataset is stored.
- split_name: The name of the train/test split.
- Returns:
- An absolute file path.
- """
- return '%s/cifar10_%s.tfrecord' % (dataset_dir, split_name)
- def _download_and_uncompress_dataset(dataset_dir):
- """Downloads cifar10 and uncompresses it locally.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- """
- filename = _DATA_URL.split('/')[-1]
- filepath = os.path.join(dataset_dir, filename)
- if not os.path.exists(filepath):
- def _progress(count, block_size, total_size):
- sys.stdout.write('\r>> Downloading %s %.1f%%' % (
- filename, float(count * block_size) / float(total_size) * 100.0))
- sys.stdout.flush()
- filepath, _ = urllib.request.urlretrieve(_DATA_URL, filepath, _progress)
- print()
- statinfo = os.stat(filepath)
- print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
- tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
- def _clean_up_temporary_files(dataset_dir):
- """Removes temporary files used to create the dataset.
- Args:
- dataset_dir: The directory where the temporary files are stored.
- """
- filename = _DATA_URL.split('/')[-1]
- filepath = os.path.join(dataset_dir, filename)
- tf.gfile.Remove(filepath)
- tmp_dir = os.path.join(dataset_dir, 'cifar-10-batches-py')
- tf.gfile.DeleteRecursively(tmp_dir)
- def run(dataset_dir):
- """Runs the download and conversion operation.
- Args:
- dataset_dir: The dataset directory where the dataset is stored.
- """
- if not tf.gfile.Exists(dataset_dir):
- tf.gfile.MakeDirs(dataset_dir)
- training_filename = _get_output_filename(dataset_dir, 'train')
- testing_filename = _get_output_filename(dataset_dir, 'test')
- if tf.gfile.Exists(training_filename) and tf.gfile.Exists(testing_filename):
- print('Dataset files already exist. Exiting without re-creating them.')
- return
- dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir)
- # First, process the training data:
- with tf.python_io.TFRecordWriter(training_filename) as tfrecord_writer:
- offset = 0
- for i in range(_NUM_TRAIN_FILES):
- filename = os.path.join(dataset_dir,
- 'cifar-10-batches-py',
- 'data_batch_%d' % (i + 1)) # 1-indexed.
- offset = _add_to_tfrecord(filename, tfrecord_writer, offset)
- # Next, process the testing data:
- with tf.python_io.TFRecordWriter(testing_filename) as tfrecord_writer:
- filename = os.path.join(dataset_dir,
- 'cifar-10-batches-py',
- 'test_batch')
- _add_to_tfrecord(filename, tfrecord_writer)
- # Finally, write the labels file:
- labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
- dataset_utils.write_label_file(labels_to_class_names, dataset_dir)
- _clean_up_temporary_files(dataset_dir)
- print('\nFinished converting the Cifar10 dataset!')
|