input.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import cPickle
  19. import gzip
  20. import math
  21. import numpy as np
  22. import os
  23. from scipy.io import loadmat as loadmat
  24. from six.moves import urllib
  25. import sys
  26. import tarfile
  27. import tensorflow as tf
  28. FLAGS = tf.flags.FLAGS
  29. def create_dir_if_needed(dest_directory):
  30. """
  31. Create directory if doesn't exist
  32. :param dest_directory:
  33. :return: True if everything went well
  34. """
  35. if not tf.gfile.IsDirectory(dest_directory):
  36. tf.gfile.MakeDirs(dest_directory)
  37. return True
  38. def maybe_download(file_urls, directory):
  39. """
  40. Download a set of files in temporary local folder
  41. :param directory: the directory where to download
  42. :return: a tuple of filepaths corresponding to the files given as input
  43. """
  44. # Create directory if doesn't exist
  45. assert create_dir_if_needed(directory)
  46. # This list will include all URLS of the local copy of downloaded files
  47. result = []
  48. # For each file of the dataset
  49. for file_url in file_urls:
  50. # Extract filename
  51. filename = file_url.split('/')[-1]
  52. # If downloading from GitHub, remove suffix ?raw=True from local filename
  53. if filename.endswith("?raw=true"):
  54. filename = filename[:-9]
  55. # Deduce local file url
  56. #filepath = os.path.join(directory, filename)
  57. filepath = directory + '/' + filename
  58. # Add to result list
  59. result.append(filepath)
  60. # Test if file already exists
  61. if not gfile.Exists(filepath):
  62. def _progress(count, block_size, total_size):
  63. sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
  64. float(count * block_size) / float(total_size) * 100.0))
  65. sys.stdout.flush()
  66. filepath, _ = urllib.request.urlretrieve(file_url, filepath, _progress)
  67. print()
  68. statinfo = os.stat(filepath)
  69. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  70. return result
  71. def image_whitening(data):
  72. """
  73. Subtracts mean of image and divides by adjusted standard variance (for
  74. stability). Operations are per image but performed for the entire array.
  75. :param image: 4D array (ID, Height, Weight, Channel)
  76. :return: 4D array (ID, Height, Weight, Channel)
  77. """
  78. assert len(np.shape(data)) == 4
  79. # Compute number of pixels in image
  80. nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
  81. # Subtract mean
  82. mean = np.mean(data, axis=(1,2,3))
  83. ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
  84. for i in xrange(len(data)):
  85. data[i, :, :, :] -= mean[i] * ones
  86. # Compute adjusted standard variance
  87. adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
  88. # Divide image
  89. for i in xrange(len(data)):
  90. data[i, :, :, :] = data[i, :, :, :] / adj_std_var[i]
  91. print(np.shape(data))
  92. return data
  93. def extract_svhn(local_url):
  94. """
  95. Extract a MATLAB matrix into two numpy arrays with data and labels
  96. :param local_url:
  97. :return:
  98. """
  99. with gfile.Open(local_url, mode='r') as file_obj:
  100. # Load MATLAB matrix using scipy IO
  101. dict = loadmat(file_obj)
  102. # Extract each dictionary (one for data, one for labels)
  103. data, labels = dict["X"], dict["y"]
  104. # Set np type
  105. data = np.asarray(data, dtype=np.float32)
  106. labels = np.asarray(labels, dtype=np.int32)
  107. # Transpose data to match TF model input format
  108. data = data.transpose(3, 0, 1, 2)
  109. # Fix the SVHN labels which label 0s as 10s
  110. labels[labels == 10] = 0
  111. # Fix label dimensions
  112. labels = labels.reshape(len(labels))
  113. return data, labels
  114. def unpickle_cifar_dic(file):
  115. """
  116. Helper function: unpickles a dictionary (used for loading CIFAR)
  117. :param file: filename of the pickle
  118. :return: tuple of (images, labels)
  119. """
  120. fo = open(file, 'rb')
  121. dict = cPickle.load(fo)
  122. fo.close()
  123. return dict['data'], dict['labels']
  124. def extract_cifar10(local_url, data_dir):
  125. """
  126. Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
  127. :param local_url: where the tar.gz archive is located locally
  128. :param data_dir: where to extract the archive's file
  129. :return: a tuple (train data, train labels, test data, test labels)
  130. """
  131. # These numpy dumps can be reloaded to avoid performing the pre-processing
  132. # if they exist in the working directory.
  133. # Changing the order of this list will ruin the indices below.
  134. preprocessed_files = ['/cifar10_train.npy',
  135. '/cifar10_train_labels.npy',
  136. '/cifar10_test.npy',
  137. '/cifar10_test_labels.npy']
  138. all_preprocessed = True
  139. for file in preprocessed_files:
  140. if not tf.gfile.Exists(data_dir + file):
  141. all_preprocessed = False
  142. break
  143. if all_preprocessed:
  144. # Reload pre-processed training data from numpy dumps
  145. with tf.gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
  146. train_data = np.load(file_obj)
  147. with tf.gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
  148. train_labels = np.load(file_obj)
  149. # Reload pre-processed testing data from numpy dumps
  150. with tf.gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
  151. test_data = np.load(file_obj)
  152. with tf.gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
  153. test_labels = np.load(file_obj)
  154. else:
  155. # Do everything from scratch
  156. # Define lists of all files we should extract
  157. train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
  158. test_file = ["test_batch"]
  159. cifar10_files = train_files + test_file
  160. # Check if all files have already been extracted
  161. need_to_unpack = False
  162. for file in cifar10_files:
  163. if not tf.gfile.Exists(file):
  164. need_to_unpack = True
  165. break
  166. # We have to unpack the archive
  167. if need_to_unpack:
  168. tarfile.open(local_url, 'r:gz').extractall(data_dir)
  169. # Load training images and labels
  170. images = []
  171. labels = []
  172. for file in train_files:
  173. # Construct filename
  174. filename = data_dir + "/cifar-10-batches-py/" + file
  175. # Unpickle dictionary and extract images and labels
  176. images_tmp, labels_tmp = unpickle_cifar_dic(filename)
  177. # Append to lists
  178. images.append(images_tmp)
  179. labels.append(labels_tmp)
  180. # Convert to numpy arrays and reshape in the expected format
  181. train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
  182. train_data = np.swapaxes(train_data, 1, 3)
  183. train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
  184. # Save so we don't have to do this again
  185. np.save(data_dir + preprocessed_files[0], train_data)
  186. np.save(data_dir + preprocessed_files[1], train_labels)
  187. # Construct filename for test file
  188. filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
  189. # Load test images and labels
  190. test_data, test_images = unpickle_cifar_dic(filename)
  191. # Convert to numpy arrays and reshape in the expected format
  192. test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
  193. test_data = np.swapaxes(test_data, 1, 3)
  194. test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
  195. # Save so we don't have to do this again
  196. np.save(data_dir + preprocessed_files[2], test_data)
  197. np.save(data_dir + preprocessed_files[3], test_labels)
  198. return train_data, train_labels, test_data, test_labels
  199. def extract_mnist_data(filename, num_images, image_size, pixel_depth):
  200. """
  201. Extract the images into a 4D tensor [image index, y, x, channels].
  202. Values are rescaled from [0, 255] down to [-0.5, 0.5].
  203. """
  204. # if not os.path.exists(file):
  205. if not tf.gfile.Exists(filename+".npy"):
  206. with gzip.open(filename) as bytestream:
  207. bytestream.read(16)
  208. buf = bytestream.read(image_size * image_size * num_images)
  209. data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
  210. data = (data - (pixel_depth / 2.0)) / pixel_depth
  211. data = data.reshape(num_images, image_size, image_size, 1)
  212. np.save(filename, data)
  213. return data
  214. else:
  215. with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
  216. return np.load(file_obj)
  217. def extract_mnist_labels(filename, num_images):
  218. """
  219. Extract the labels into a vector of int64 label IDs.
  220. """
  221. # if not os.path.exists(file):
  222. if not tf.gfile.Exists(filename+".npy"):
  223. with gzip.open(filename) as bytestream:
  224. bytestream.read(8)
  225. buf = bytestream.read(1 * num_images)
  226. labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int32)
  227. np.save(filename, labels)
  228. return labels
  229. else:
  230. with tf.gfile.Open(filename+".npy", mode='r') as file_obj:
  231. return np.load(file_obj)
  232. def ld_svhn(extended=False, test_only=False):
  233. """
  234. Load the original SVHN data
  235. :param extended: include extended training data in the returned array
  236. :param test_only: disables loading of both train and extra -> large speed up
  237. :return: tuple of arrays which depend on the parameters
  238. """
  239. # Define files to be downloaded
  240. # WARNING: changing the order of this list will break indices (cf. below)
  241. file_urls = ['http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
  242. 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
  243. 'http://ufldl.stanford.edu/housenumbers/extra_32x32.mat']
  244. # Maybe download data and retrieve local storage urls
  245. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  246. # Extra Train, Test, and Extended Train data
  247. if not test_only:
  248. # Load and applying whitening to train data
  249. train_data, train_labels = extract_svhn(local_urls[0])
  250. train_data = image_whitening(train_data)
  251. # Load and applying whitening to extended train data
  252. ext_data, ext_labels = extract_svhn(local_urls[2])
  253. ext_data = image_whitening(ext_data)
  254. # Load and applying whitening to test data
  255. test_data, test_labels = extract_svhn(local_urls[1])
  256. test_data = image_whitening(test_data)
  257. if test_only:
  258. return test_data, test_labels
  259. else:
  260. if extended:
  261. # Stack train data with the extended training data
  262. train_data = np.vstack((train_data, ext_data))
  263. train_labels = np.hstack((train_labels, ext_labels))
  264. return train_data, train_labels, test_data, test_labels
  265. else:
  266. # Return training and extended training data separately
  267. return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
  268. def ld_cifar10(test_only=False):
  269. """
  270. Load the original CIFAR10 data
  271. :param extended: include extended training data in the returned array
  272. :param test_only: disables loading of both train and extra -> large speed up
  273. :return: tuple of arrays which depend on the parameters
  274. """
  275. # Define files to be downloaded
  276. file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
  277. # Maybe download data and retrieve local storage urls
  278. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  279. # Extract archives and return different sets
  280. dataset = extract_cifar10(local_urls[0], FLAGS.data_dir)
  281. # Unpack tuple
  282. train_data, train_labels, test_data, test_labels = dataset
  283. # Apply whitening to input data
  284. train_data = image_whitening(train_data)
  285. test_data = image_whitening(test_data)
  286. if test_only:
  287. return test_data, test_labels
  288. else:
  289. return train_data, train_labels, test_data, test_labels
  290. def ld_mnist(test_only=False):
  291. """
  292. Load the MNIST dataset
  293. :param extended: include extended training data in the returned array
  294. :param test_only: disables loading of both train and extra -> large speed up
  295. :return: tuple of arrays which depend on the parameters
  296. """
  297. # Define files to be downloaded
  298. # WARNING: changing the order of this list will break indices (cf. below)
  299. file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
  300. 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
  301. 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
  302. 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
  303. ]
  304. # Maybe download data and retrieve local storage urls
  305. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  306. # Extract it into np arrays.
  307. train_data = extract_mnist_data(local_urls[0], 60000, 28, 1)
  308. train_labels = extract_mnist_labels(local_urls[1], 60000)
  309. test_data = extract_mnist_data(local_urls[2], 10000, 28, 1)
  310. test_labels = extract_mnist_labels(local_urls[3], 10000)
  311. if test_only:
  312. return test_data, test_labels
  313. else:
  314. return train_data, train_labels, test_data, test_labels
  315. def partition_dataset(data, labels, nb_teachers, teacher_id):
  316. """
  317. Simple partitioning algorithm that returns the right portion of the data
  318. needed by a given teacher out of a certain nb of teachers
  319. :param data: input data to be partitioned
  320. :param labels: output data to be partitioned
  321. :param nb_teachers: number of teachers in the ensemble (affects size of each
  322. partition)
  323. :param teacher_id: id of partition to retrieve
  324. :return:
  325. """
  326. # Sanity check
  327. assert len(data) == len(labels)
  328. assert int(teacher_id) < int(nb_teachers)
  329. # This will floor the possible number of batches
  330. batch_len = int(len(data) / nb_teachers)
  331. # Compute start, end indices of partition
  332. start = teacher_id * batch_len
  333. end = (teacher_id+1) * batch_len
  334. # Slice partition off
  335. partition_data = data[start:end]
  336. partition_labels = labels[start:end]
  337. return partition_data, partition_labels