input.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import cPickle
  19. import gzip
  20. import math
  21. import numpy as np
  22. import os
  23. from scipy.io import loadmat as loadmat
  24. from six.moves import urllib
  25. import sys
  26. import tarfile
  27. from tensorflow.python.platform import gfile
  28. from tensorflow.python.platform import flags
  29. FLAGS = flags.FLAGS
  30. def create_dir_if_needed(dest_directory):
  31. """
  32. Create directory if doesn't exist
  33. :param dest_directory:
  34. :return: True if everything went well
  35. """
  36. if not gfile.IsDirectory(dest_directory):
  37. gfile.MakeDirs(dest_directory)
  38. return True
  39. def maybe_download(file_urls, directory):
  40. """
  41. Download a set of files in temporary local folder
  42. :param directory: the directory where to download
  43. :return: a tuple of filepaths corresponding to the files given as input
  44. """
  45. # Create directory if doesn't exist
  46. assert create_dir_if_needed(directory)
  47. # This list will include all URLS of the local copy of downloaded files
  48. result = []
  49. # For each file of the dataset
  50. for file_url in file_urls:
  51. # Extract filename
  52. filename = file_url.split('/')[-1]
  53. # Deduce local file url
  54. #filepath = os.path.join(directory, filename)
  55. filepath = directory + '/' + filename
  56. # Add to result list
  57. result.append(filepath)
  58. # Test if file already exists
  59. if not gfile.Exists(filepath):
  60. def _progress(count, block_size, total_size):
  61. sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
  62. float(count * block_size) / float(total_size) * 100.0))
  63. sys.stdout.flush()
  64. filepath, _ = urllib.request.urlretrieve(file_url, filepath, _progress)
  65. print()
  66. statinfo = os.stat(filepath)
  67. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  68. return result
  69. def image_whitening(data):
  70. """
  71. Subtracts mean of image and divides by adjusted standard variance (for
  72. stability). Operations are per image but performed for the entire array.
  73. :param image: 4D array (ID, Height, Weight, Channel)
  74. :return: 4D array (ID, Height, Weight, Channel)
  75. """
  76. assert len(np.shape(data)) == 4
  77. # Compute number of pixels in image
  78. nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
  79. # Subtract mean
  80. mean = np.mean(data, axis=(1,2,3))
  81. ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
  82. for i in xrange(len(data)):
  83. data[i, :, :, :] -= mean[i] * ones
  84. # Compute adjusted standard variance
  85. adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
  86. # Divide image
  87. for i in xrange(len(data)):
  88. data[i, :, :, :] = data[i, :, :, :] / adj_std_var[i]
  89. print(np.shape(data))
  90. return data
  91. def extract_svhn(local_url):
  92. """
  93. Extract a MATLAB matrix into two numpy arrays with data and labels
  94. :param local_url:
  95. :return:
  96. """
  97. with gfile.Open(local_url, mode='r') as file_obj:
  98. # Load MATLAB matrix using scipy IO
  99. dict = loadmat(file_obj)
  100. # Extract each dictionary (one for data, one for labels)
  101. data, labels = dict["X"], dict["y"]
  102. # Set np type
  103. data = np.asarray(data, dtype=np.float32)
  104. labels = np.asarray(labels, dtype=np.int32)
  105. # Transpose data to match TF model input format
  106. data = data.transpose(3, 0, 1, 2)
  107. # Fix the SVHN labels which label 0s as 10s
  108. labels[labels == 10] = 0
  109. # Fix label dimensions
  110. labels = labels.reshape(len(labels))
  111. return data, labels
  112. def unpickle_cifar_dic(file):
  113. """
  114. Helper function: unpickles a dictionary (used for loading CIFAR)
  115. :param file: filename of the pickle
  116. :return: tuple of (images, labels)
  117. """
  118. fo = open(file, 'rb')
  119. dict = cPickle.load(fo)
  120. fo.close()
  121. return dict['data'], dict['labels']
  122. def extract_cifar10(local_url, data_dir):
  123. """
  124. Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
  125. :param local_url: where the tar.gz archive is located locally
  126. :param data_dir: where to extract the archive's file
  127. :return: a tuple (train data, train labels, test data, test labels)
  128. """
  129. # These numpy dumps can be reloaded to avoid performing the pre-processing
  130. # if they exist in the working directory.
  131. # Changing the order of this list will ruin the indices below.
  132. preprocessed_files = ['/cifar10_train.npy',
  133. '/cifar10_train_labels.npy',
  134. '/cifar10_test.npy',
  135. '/cifar10_test_labels.npy']
  136. all_preprocessed = True
  137. for file in preprocessed_files:
  138. if not gfile.Exists(data_dir + file):
  139. all_preprocessed = False
  140. break
  141. if all_preprocessed:
  142. # Reload pre-processed training data from numpy dumps
  143. with gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
  144. train_data = np.load(file_obj)
  145. with gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
  146. train_labels = np.load(file_obj)
  147. # Reload pre-processed testing data from numpy dumps
  148. with gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
  149. test_data = np.load(file_obj)
  150. with gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
  151. test_labels = np.load(file_obj)
  152. else:
  153. # Do everything from scratch
  154. # Define lists of all files we should extract
  155. train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
  156. test_file = ["test_batch"]
  157. cifar10_files = train_files + test_file
  158. # Check if all files have already been extracted
  159. need_to_unpack = False
  160. for file in cifar10_files:
  161. if not gfile.Exists(file):
  162. need_to_unpack = True
  163. break
  164. # We have to unpack the archive
  165. if need_to_unpack:
  166. tarfile.open(local_url, 'r:gz').extractall(data_dir)
  167. # Load training images and labels
  168. images = []
  169. labels = []
  170. for file in train_files:
  171. # Construct filename
  172. filename = data_dir + "/cifar-10-batches-py/" + file
  173. # Unpickle dictionary and extract images and labels
  174. images_tmp, labels_tmp = unpickle_cifar_dic(filename)
  175. # Append to lists
  176. images.append(images_tmp)
  177. labels.append(labels_tmp)
  178. # Convert to numpy arrays and reshape in the expected format
  179. train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
  180. train_data = np.swapaxes(train_data, 1, 3)
  181. train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
  182. # Save so we don't have to do this again
  183. np.save(data_dir + preprocessed_files[0], train_data)
  184. np.save(data_dir + preprocessed_files[1], train_labels)
  185. # Construct filename for test file
  186. filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
  187. # Load test images and labels
  188. test_data, test_images = unpickle_cifar_dic(filename)
  189. # Convert to numpy arrays and reshape in the expected format
  190. test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
  191. test_data = np.swapaxes(test_data, 1, 3)
  192. test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
  193. # Save so we don't have to do this again
  194. np.save(data_dir + preprocessed_files[2], test_data)
  195. np.save(data_dir + preprocessed_files[3], test_labels)
  196. return train_data, train_labels, test_data, test_labels
  197. def extract_mnist_data(filename, num_images, image_size, pixel_depth):
  198. """
  199. Extract the images into a 4D tensor [image index, y, x, channels].
  200. Values are rescaled from [0, 255] down to [-0.5, 0.5].
  201. """
  202. # if not os.path.exists(file):
  203. if not gfile.Exists(filename+".npy"):
  204. with gzip.open(filename) as bytestream:
  205. bytestream.read(16)
  206. buf = bytestream.read(image_size * image_size * num_images)
  207. data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
  208. data = (data - (pixel_depth / 2.0)) / pixel_depth
  209. data = data.reshape(num_images, image_size, image_size, 1)
  210. np.save(filename, data)
  211. return data
  212. else:
  213. with gfile.Open(filename+".npy", mode='r') as file_obj:
  214. return np.load(file_obj)
  215. def extract_mnist_labels(filename, num_images):
  216. """
  217. Extract the labels into a vector of int64 label IDs.
  218. """
  219. # if not os.path.exists(file):
  220. if not gfile.Exists(filename+".npy"):
  221. with gzip.open(filename) as bytestream:
  222. bytestream.read(8)
  223. buf = bytestream.read(1 * num_images)
  224. labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int32)
  225. np.save(filename, labels)
  226. return labels
  227. else:
  228. with gfile.Open(filename+".npy", mode='r') as file_obj:
  229. return np.load(file_obj)
  230. def ld_svhn(extended=False, test_only=False):
  231. """
  232. Load the original SVHN data
  233. :param extended: include extended training data in the returned array
  234. :param test_only: disables loading of both train and extra -> large speed up
  235. :return: tuple of arrays which depend on the parameters
  236. """
  237. # Define files to be downloaded
  238. # WARNING: changing the order of this list will break indices (cf. below)
  239. file_urls = ['http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
  240. 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
  241. 'http://ufldl.stanford.edu/housenumbers/extra_32x32.mat']
  242. # Maybe download data and retrieve local storage urls
  243. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  244. # Extra Train, Test, and Extended Train data
  245. if not test_only:
  246. # Load and applying whitening to train data
  247. train_data, train_labels = extract_svhn(local_urls[0])
  248. train_data = image_whitening(train_data)
  249. # Load and applying whitening to extended train data
  250. ext_data, ext_labels = extract_svhn(local_urls[2])
  251. ext_data = image_whitening(ext_data)
  252. # Load and applying whitening to test data
  253. test_data, test_labels = extract_svhn(local_urls[1])
  254. test_data = image_whitening(test_data)
  255. if test_only:
  256. return test_data, test_labels
  257. else:
  258. if extended:
  259. # Stack train data with the extended training data
  260. train_data = np.vstack((train_data, ext_data))
  261. train_labels = np.hstack((train_labels, ext_labels))
  262. return train_data, train_labels, test_data, test_labels
  263. else:
  264. # Return training and extended training data separately
  265. return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
  266. def ld_cifar10(test_only=False):
  267. """
  268. Load the original CIFAR10 data
  269. :param extended: include extended training data in the returned array
  270. :param test_only: disables loading of both train and extra -> large speed up
  271. :return: tuple of arrays which depend on the parameters
  272. """
  273. # Define files to be downloaded
  274. file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
  275. # Maybe download data and retrieve local storage urls
  276. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  277. # Extract archives and return different sets
  278. dataset = extract_cifar10(local_urls[0], FLAGS.data_dir)
  279. # Unpack tuple
  280. train_data, train_labels, test_data, test_labels = dataset
  281. # Apply whitening to input data
  282. train_data = image_whitening(train_data)
  283. test_data = image_whitening(test_data)
  284. if test_only:
  285. return test_data, test_labels
  286. else:
  287. return train_data, train_labels, test_data, test_labels
  288. def ld_mnist(test_only=False):
  289. """
  290. Load the MNIST dataset
  291. :param extended: include extended training data in the returned array
  292. :param test_only: disables loading of both train and extra -> large speed up
  293. :return: tuple of arrays which depend on the parameters
  294. """
  295. # Define files to be downloaded
  296. # WARNING: changing the order of this list will break indices (cf. below)
  297. file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
  298. 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
  299. 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
  300. 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
  301. ]
  302. # Maybe download data and retrieve local storage urls
  303. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  304. # Extract it into np arrays.
  305. train_data = extract_mnist_data(local_urls[0], 60000, 28, 1)
  306. train_labels = extract_mnist_labels(local_urls[1], 60000)
  307. test_data = extract_mnist_data(local_urls[2], 10000, 28, 1)
  308. test_labels = extract_mnist_labels(local_urls[3], 10000)
  309. if test_only:
  310. return test_data, test_labels
  311. else:
  312. return train_data, train_labels, test_data, test_labels
  313. def partition_dataset(data, labels, nb_teachers, teacher_id):
  314. """
  315. Simple partitioning algorithm that returns the right portion of the data
  316. needed by a given teacher out of a certain nb of teachers
  317. :param data: input data to be partitioned
  318. :param labels: output data to be partitioned
  319. :param nb_teachers: number of teachers in the ensemble (affects size of each
  320. partition)
  321. :param teacher_id: id of partition to retrieve
  322. :return:
  323. """
  324. # Sanity check
  325. assert len(data) == len(labels)
  326. assert int(teacher_id) < int(nb_teachers)
  327. # This will floor the possible number of batches
  328. batch_len = int(len(data) / nb_teachers)
  329. # Compute start, end indices of partition
  330. start = teacher_id * batch_len
  331. end = (teacher_id+1) * batch_len
  332. # Slice partition off
  333. partition_data = data[start:end]
  334. partition_labels = labels[start:end]
  335. return partition_data, partition_labels