data_utils.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # ==============================================================================
  16. """Data loading and other utilities.
  17. Use this file to first copy over and pre-process the Omniglot dataset.
  18. Simply call
  19. python data_utils.py
  20. """
  21. import cPickle as pickle
  22. import logging
  23. import os
  24. import subprocess
  25. import numpy as np
  26. from scipy.misc import imresize
  27. from scipy.misc import imrotate
  28. from scipy.ndimage import imread
  29. import tensorflow as tf
  30. MAIN_DIR = ''
  31. REPO_LOCATION = 'https://github.com/brendenlake/omniglot.git'
  32. REPO_DIR = os.path.join(MAIN_DIR, 'omniglot')
  33. DATA_DIR = os.path.join(REPO_DIR, 'python')
  34. TRAIN_DIR = os.path.join(DATA_DIR, 'images_background')
  35. TEST_DIR = os.path.join(DATA_DIR, 'images_evaluation')
  36. DATA_FILE_FORMAT = os.path.join(MAIN_DIR, '%s_omni.pkl')
  37. TRAIN_ROTATIONS = True # augment training data with rotations
  38. TEST_ROTATIONS = False # augment testing data with rotations
  39. IMAGE_ORIGINAL_SIZE = 105
  40. IMAGE_NEW_SIZE = 28
  41. def get_data():
  42. """Get data in form suitable for episodic training.
  43. Returns:
  44. Train and test data as dictionaries mapping
  45. label to list of examples.
  46. """
  47. with tf.gfile.GFile(DATA_FILE_FORMAT % 'train') as f:
  48. processed_train_data = pickle.load(f)
  49. with tf.gfile.GFile(DATA_FILE_FORMAT % 'test') as f:
  50. processed_test_data = pickle.load(f)
  51. train_data = {}
  52. test_data = {}
  53. for data, processed_data in zip([train_data, test_data],
  54. [processed_train_data, processed_test_data]):
  55. for image, label in zip(processed_data['images'],
  56. processed_data['labels']):
  57. if label not in data:
  58. data[label] = []
  59. data[label].append(image.reshape([-1]).astype('float32'))
  60. intersection = set(train_data.keys()) & set(test_data.keys())
  61. assert not intersection, 'Train and test data intersect.'
  62. ok_num_examples = [len(ll) == 20 for _, ll in train_data.iteritems()]
  63. assert all(ok_num_examples), 'Bad number of examples in train data.'
  64. ok_num_examples = [len(ll) == 20 for _, ll in test_data.iteritems()]
  65. assert all(ok_num_examples), 'Bad number of examples in test data.'
  66. logging.info('Number of labels in train data: %d.', len(train_data))
  67. logging.info('Number of labels in test data: %d.', len(test_data))
  68. return train_data, test_data
  69. def crawl_directory(directory, augment_with_rotations=False,
  70. first_label=0):
  71. """Crawls data directory and returns stuff."""
  72. label_idx = first_label
  73. images = []
  74. labels = []
  75. info = []
  76. # traverse root directory
  77. for root, _, files in os.walk(directory):
  78. logging.info('Reading files from %s', root)
  79. fileflag = 0
  80. for file_name in files:
  81. full_file_name = os.path.join(root, file_name)
  82. img = imread(full_file_name, flatten=True)
  83. for i, angle in enumerate([0, 90, 180, 270]):
  84. if not augment_with_rotations and i > 0:
  85. break
  86. images.append(imrotate(img, angle))
  87. labels.append(label_idx + i)
  88. info.append(full_file_name)
  89. fileflag = 1
  90. if fileflag:
  91. label_idx += 4 if augment_with_rotations else 1
  92. return images, labels, info
  93. def resize_images(images, new_width, new_height):
  94. """Resize images to new dimensions."""
  95. resized_images = np.zeros([images.shape[0], new_width, new_height],
  96. dtype=np.float32)
  97. for i in range(images.shape[0]):
  98. resized_images[i, :, :] = imresize(images[i, :, :],
  99. [new_width, new_height],
  100. interp='bilinear',
  101. mode=None)
  102. return resized_images
  103. def write_datafiles(directory, write_file,
  104. resize=True, rotate=False,
  105. new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
  106. first_label=0):
  107. """Load and preprocess images from a directory and write them to a file.
  108. Args:
  109. directory: Directory of alphabet sub-directories.
  110. write_file: Filename to write to.
  111. resize: Whether to resize the images.
  112. rotate: Whether to augment the dataset with rotations.
  113. new_width: New resize width.
  114. new_height: New resize height.
  115. first_label: Label to start with.
  116. Returns:
  117. Number of new labels created.
  118. """
  119. # these are the default sizes for Omniglot:
  120. imgwidth = IMAGE_ORIGINAL_SIZE
  121. imgheight = IMAGE_ORIGINAL_SIZE
  122. logging.info('Reading the data.')
  123. images, labels, info = crawl_directory(directory,
  124. augment_with_rotations=rotate,
  125. first_label=first_label)
  126. images_np = np.zeros([len(images), imgwidth, imgheight], dtype=np.bool)
  127. labels_np = np.zeros([len(labels)], dtype=np.uint32)
  128. for i in xrange(len(images)):
  129. images_np[i, :, :] = images[i]
  130. labels_np[i] = labels[i]
  131. if resize:
  132. logging.info('Resizing images.')
  133. resized_images = resize_images(images_np, new_width, new_height)
  134. logging.info('Writing resized data in float32 format.')
  135. data = {'images': resized_images,
  136. 'labels': labels_np,
  137. 'info': info}
  138. with tf.gfile.GFile(write_file, 'w') as f:
  139. pickle.dump(data, f)
  140. else:
  141. logging.info('Writing original sized data in boolean format.')
  142. data = {'images': images_np,
  143. 'labels': labels_np,
  144. 'info': info}
  145. with tf.gfile.GFile(write_file, 'w') as f:
  146. pickle.dump(data, f)
  147. return len(np.unique(labels_np))
  148. def maybe_download_data():
  149. """Download Omniglot repo if it does not exist."""
  150. if os.path.exists(REPO_DIR):
  151. logging.info('It appears that Git repo already exists.')
  152. else:
  153. logging.info('It appears that Git repo does not exist.')
  154. logging.info('Cloning now.')
  155. subprocess.check_output('git clone %s' % REPO_LOCATION, shell=True)
  156. if os.path.exists(TRAIN_DIR):
  157. logging.info('It appears that train data has already been unzipped.')
  158. else:
  159. logging.info('It appears that train data has not been unzipped.')
  160. logging.info('Unzipping now.')
  161. subprocess.check_output('unzip %s.zip -d %s' % (TRAIN_DIR, DATA_DIR),
  162. shell=True)
  163. if os.path.exists(TEST_DIR):
  164. logging.info('It appears that test data has already been unzipped.')
  165. else:
  166. logging.info('It appears that test data has not been unzipped.')
  167. logging.info('Unzipping now.')
  168. subprocess.check_output('unzip %s.zip -d %s' % (TEST_DIR, DATA_DIR),
  169. shell=True)
  170. def preprocess_omniglot():
  171. """Download and prepare raw Omniglot data.
  172. Downloads the data from GitHub if it does not exist.
  173. Then load the images, augment with rotations if desired.
  174. Resize the images and write them to a pickle file.
  175. """
  176. maybe_download_data()
  177. directory = TRAIN_DIR
  178. write_file = DATA_FILE_FORMAT % 'train'
  179. num_labels = write_datafiles(
  180. directory, write_file, resize=True, rotate=TRAIN_ROTATIONS,
  181. new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE)
  182. directory = TEST_DIR
  183. write_file = DATA_FILE_FORMAT % 'test'
  184. write_datafiles(directory, write_file, resize=True, rotate=TEST_ROTATIONS,
  185. new_width=IMAGE_NEW_SIZE, new_height=IMAGE_NEW_SIZE,
  186. first_label=num_labels)
  187. def main(unused_argv):
  188. logging.basicConfig(level=logging.INFO)
  189. preprocess_omniglot()
  190. if __name__ == '__main__':
  191. tf.app.run()