123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- #
- # ==============================================================================
- """Data loading and other utilities.
- Use this file to first copy over and pre-process the Omniglot dataset.
- Simply call
- python data_utils.py
- """
- import cPickle as pickle
- import logging
- import os
- import subprocess
- import numpy as np
- from scipy.misc import imresize
- from scipy.misc import imrotate
- from scipy.ndimage import imread
- import tensorflow as tf
- MAIN_DIR = ''
- REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
- REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
- DATA_DIR = os.path.join(REPO_DIR, 'python')
- TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
- TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
- DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')
- TRAIN_ROTATIONS = True # augment training data with rotations
- TEST_ROTATIONS = False # augment testing data with rotations
- IMAGE_ORIGINAL_SIZE = 105
- IMAGE_NEW_SIZE = 28
- def get_data():
- """Get data in form suitable for episodic training.
- Returns:
- Train and test data as dictionaries mapping
- label to list of examples.
- """
- with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
- processed_train_data = pickle.load(f)
- with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
- processed_test_data = pickle.load(f)
- train_data = {}
- test_data = {}
- for data, processed_data in zip([train_data, test_data],
- [processed_train_data, processed_test_data]):
- for image, label in zip(processed_data['images'],
- processed_data['labels']):
- if label not in data:
- data[label] = []
- data[label].append(image.reshape([-1]).astype('float32'))
- intersection = set(train_data.keys()) & set(test_data.keys())
- assert not intersection, 'Train and test data intersect.'
- ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
- assert all(ok_num_examples), 'Bad number of examples in train data.'
- ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
- assert all(ok_num_examples), 'Bad number of examples in test data.'
- logging.info('Number of labels in train data: %d.', len(train_data))
- logging.info('Number of labels in test data: %d.', len(test_data))
- return train_data, test_data
- def crawl_directory(directory, augment_with_rotations=False,
- first_label=0):
- """Crawls data directory and returns stuff."""
- label_idx = first_label
- images = []
- labels = []
- info = []
- # traverse root directory
- for root, _, files in os.walk(directory):
- logging.info('Reading files from %s', root)
- fileflag = 0
- for file_name in files:
- full_file_name = os.path.join(root, file_name)
- img = imread(full_file_name, flatten=True)
- for i, angle in enumerate([0, 90, 180, 270]):
- if not augment_with_rotations and i > 0:
- break
- images.append(imrotate(img, angle))
- labels.append(label_idx + i)
- info.append(full_file_name)
- fileflag = 1
- if fileflag:
- label_idx += 4 if augment_with_rotations else 1
- return images, labels, info
- def resize_images(images, new_width, new_height):
- """Resize images to new dimensions."""
- resized_images = np.zeros([images.shape[0], new_width, new_height],
- dtype=np.float32)
- for i in range(images.shape[0]):
- resized_images[i, :, :] = imresize(images[i, :, :],
- [new_width, new_height],
- interp='bilinear',
- mode=None)
- return resized_images
- def write_datafiles(directory, write_file,
- resize=True, rotate=False,
- new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
- first_label=0):
- """Load and preprocess images from a directory and write them to a file.
- Args:
- directory: Directory of alphabet sub-directories.
- write_file: Filename to write to.
- resize: Whether to resize the images.
- rotate: Whether to augment the dataset with rotations.
- new_width: New resize width.
- new_height: New resize height.
- first_label: Label to start with.
- Returns:
- Number of new labels created.
- """
- # these are the default sizes for Omniglot:
- imgwidth = IMAGE_ORIGINAL_SIZE
- imgheight = IMAGE_ORIGINAL_SIZE
- logging.info('Reading the data.')
- images, labels, info = crawl_directory(directory,
- augment_with_rotations=rotate,
- first_label=first_label)
- images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
- labels_np = np.zeros([len(labels)], dtype=np.uint32)
- for i in xrange(len(images)):
- images_np[i, :, :] = images[i]
- labels_np[i] = labels[i]
- if resize:
- logging.info('Resizing images.')
- resized_images = resize_images(images_np, new_width, new_height)
- logging.info('Writing resized data in float32 format.')
- data = {'images': resized_images,
- 'labels': labels_np,
- 'info': info}
- with tf.gfile.GFile(write_file, 'w') as f:
- pickle.dump(data, f)
- else:
- logging.info('Writing original sized data in boolean format.')
- data = {'images': images_np,
- 'labels': labels_np,
- 'info': info}
- with tf.gfile.GFile(write_file, 'w') as f:
- pickle.dump(data, f)
- return len(np.unique(labels_np))
- def maybe_download_data():
- """Download Omniglot repo if it does not exist."""
- if os.path.exists(REPO_DIR):
- logging.info('It appears that Git repo already exists.')
- else:
- logging.info('It appears that Git repo does not exist.')
- logging.info('Cloning now.')
- subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)
- if os.path.exists(TRAIN_DIR):
- logging.info('It appears that train data has already been unzipped.')
- else:
- logging.info('It appears that train data has not been unzipped.')
- logging.info('Unzipping now.')
- subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
- shell=True)
- if os.path.exists(TEST_DIR):
- logging.info('It appears that test data has already been unzipped.')
- else:
- logging.info('It appears that test data has not been unzipped.')
- logging.info('Unzipping now.')
- subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
- shell=True)
- def preprocess_omniglot():
- """Download and prepare raw Omniglot data.
- Downloads the data from GitHub if it does not exist.
- Then load the images, augment with rotations if desired.
- Resize the images and write them to a pickle file.
- """
- maybe_download_data()
- directory = TRAIN_DIR
- write_file = DATA_FILE_FORMAT % 'train'
- num_labels = write_datafiles(
- directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
- new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)
- directory = TEST_DIR
- write_file = DATA_FILE_FORMAT % 'test'
- write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
- new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
- first_label=num_labels)
- def main(unused_argv):
- logging.basicConfig(level=logging.INFO)
- preprocess_omniglot()
- if __name__ == '__main__':
- tf.app.run()
|