input.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import cPickle
  19. import gzip
  20. import math
  21. import numpy as np
  22. import os
  23. from scipy.io import loadmat as loadmat
  24. from six.moves import urllib
  25. import sys
  26. import tarfile
  27. from tensorflow.python.platform import gfile
  28. from tensorflow.python.platform import flags
  29. FLAGS = flags.FLAGS
  30. def create_dir_if_needed(dest_directory):
  31. """
  32. Create directory if doesn't exist
  33. :param dest_directory:
  34. :return: True if everything went well
  35. """
  36. if not gfile.IsDirectory(dest_directory):
  37. gfile.MakeDirs(dest_directory)
  38. return True
  39. def maybe_download(file_urls, directory):
  40. """
  41. Download a set of files in temporary local folder
  42. :param directory: the directory where to download
  43. :return: a tuple of filepaths corresponding to the files given as input
  44. """
  45. # Create directory if doesn't exist
  46. assert create_dir_if_needed(directory)
  47. # This list will include all URLS of the local copy of downloaded files
  48. result = []
  49. # For each file of the dataset
  50. for file_url in file_urls:
  51. # Extract filename
  52. filename = file_url.split('/')[-1]
  53. # If downloading from GitHub, remove suffix ?raw=True from local filename
  54. if filename.endswith("?raw=true"):
  55. filename = filename[:-9]
  56. # Deduce local file url
  57. #filepath = os.path.join(directory, filename)
  58. filepath = directory + '/' + filename
  59. # Add to result list
  60. result.append(filepath)
  61. # Test if file already exists
  62. if not gfile.Exists(filepath):
  63. def _progress(count, block_size, total_size):
  64. sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
  65. float(count * block_size) / float(total_size) * 100.0))
  66. sys.stdout.flush()
  67. filepath, _ = urllib.request.urlretrieve(file_url, filepath, _progress)
  68. print()
  69. statinfo = os.stat(filepath)
  70. print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
  71. return result
  72. def image_whitening(data):
  73. """
  74. Subtracts mean of image and divides by adjusted standard variance (for
  75. stability). Operations are per image but performed for the entire array.
  76. :param image: 4D array (ID, Height, Weight, Channel)
  77. :return: 4D array (ID, Height, Weight, Channel)
  78. """
  79. assert len(np.shape(data)) == 4
  80. # Compute number of pixels in image
  81. nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
  82. # Subtract mean
  83. mean = np.mean(data, axis=(1,2,3))
  84. ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
  85. for i in xrange(len(data)):
  86. data[i, :, :, :] -= mean[i] * ones
  87. # Compute adjusted standard variance
  88. adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
  89. # Divide image
  90. for i in xrange(len(data)):
  91. data[i, :, :, :] = data[i, :, :, :] / adj_std_var[i]
  92. print(np.shape(data))
  93. return data
  94. def extract_svhn(local_url):
  95. """
  96. Extract a MATLAB matrix into two numpy arrays with data and labels
  97. :param local_url:
  98. :return:
  99. """
  100. with gfile.Open(local_url, mode='r') as file_obj:
  101. # Load MATLAB matrix using scipy IO
  102. dict = loadmat(file_obj)
  103. # Extract each dictionary (one for data, one for labels)
  104. data, labels = dict["X"], dict["y"]
  105. # Set np type
  106. data = np.asarray(data, dtype=np.float32)
  107. labels = np.asarray(labels, dtype=np.int32)
  108. # Transpose data to match TF model input format
  109. data = data.transpose(3, 0, 1, 2)
  110. # Fix the SVHN labels which label 0s as 10s
  111. labels[labels == 10] = 0
  112. # Fix label dimensions
  113. labels = labels.reshape(len(labels))
  114. return data, labels
  115. def unpickle_cifar_dic(file):
  116. """
  117. Helper function: unpickles a dictionary (used for loading CIFAR)
  118. :param file: filename of the pickle
  119. :return: tuple of (images, labels)
  120. """
  121. fo = open(file, 'rb')
  122. dict = cPickle.load(fo)
  123. fo.close()
  124. return dict['data'], dict['labels']
  125. def extract_cifar10(local_url, data_dir):
  126. """
  127. Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
  128. :param local_url: where the tar.gz archive is located locally
  129. :param data_dir: where to extract the archive's file
  130. :return: a tuple (train data, train labels, test data, test labels)
  131. """
  132. # These numpy dumps can be reloaded to avoid performing the pre-processing
  133. # if they exist in the working directory.
  134. # Changing the order of this list will ruin the indices below.
  135. preprocessed_files = ['/cifar10_train.npy',
  136. '/cifar10_train_labels.npy',
  137. '/cifar10_test.npy',
  138. '/cifar10_test_labels.npy']
  139. all_preprocessed = True
  140. for file in preprocessed_files:
  141. if not gfile.Exists(data_dir + file):
  142. all_preprocessed = False
  143. break
  144. if all_preprocessed:
  145. # Reload pre-processed training data from numpy dumps
  146. with gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
  147. train_data = np.load(file_obj)
  148. with gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
  149. train_labels = np.load(file_obj)
  150. # Reload pre-processed testing data from numpy dumps
  151. with gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
  152. test_data = np.load(file_obj)
  153. with gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
  154. test_labels = np.load(file_obj)
  155. else:
  156. # Do everything from scratch
  157. # Define lists of all files we should extract
  158. train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
  159. test_file = ["test_batch"]
  160. cifar10_files = train_files + test_file
  161. # Check if all files have already been extracted
  162. need_to_unpack = False
  163. for file in cifar10_files:
  164. if not gfile.Exists(file):
  165. need_to_unpack = True
  166. break
  167. # We have to unpack the archive
  168. if need_to_unpack:
  169. tarfile.open(local_url, 'r:gz').extractall(data_dir)
  170. # Load training images and labels
  171. images = []
  172. labels = []
  173. for file in train_files:
  174. # Construct filename
  175. filename = data_dir + "/cifar-10-batches-py/" + file
  176. # Unpickle dictionary and extract images and labels
  177. images_tmp, labels_tmp = unpickle_cifar_dic(filename)
  178. # Append to lists
  179. images.append(images_tmp)
  180. labels.append(labels_tmp)
  181. # Convert to numpy arrays and reshape in the expected format
  182. train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
  183. train_data = np.swapaxes(train_data, 1, 3)
  184. train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
  185. # Save so we don't have to do this again
  186. np.save(data_dir + preprocessed_files[0], train_data)
  187. np.save(data_dir + preprocessed_files[1], train_labels)
  188. # Construct filename for test file
  189. filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
  190. # Load test images and labels
  191. test_data, test_images = unpickle_cifar_dic(filename)
  192. # Convert to numpy arrays and reshape in the expected format
  193. test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
  194. test_data = np.swapaxes(test_data, 1, 3)
  195. test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
  196. # Save so we don't have to do this again
  197. np.save(data_dir + preprocessed_files[2], test_data)
  198. np.save(data_dir + preprocessed_files[3], test_labels)
  199. return train_data, train_labels, test_data, test_labels
  200. def extract_mnist_data(filename, num_images, image_size, pixel_depth):
  201. """
  202. Extract the images into a 4D tensor [image index, y, x, channels].
  203. Values are rescaled from [0, 255] down to [-0.5, 0.5].
  204. """
  205. # if not os.path.exists(file):
  206. if not gfile.Exists(filename+".npy"):
  207. with gzip.open(filename) as bytestream:
  208. bytestream.read(16)
  209. buf = bytestream.read(image_size * image_size * num_images)
  210. data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
  211. data = (data - (pixel_depth / 2.0)) / pixel_depth
  212. data = data.reshape(num_images, image_size, image_size, 1)
  213. np.save(filename, data)
  214. return data
  215. else:
  216. with gfile.Open(filename+".npy", mode='r') as file_obj:
  217. return np.load(file_obj)
  218. def extract_mnist_labels(filename, num_images):
  219. """
  220. Extract the labels into a vector of int64 label IDs.
  221. """
  222. # if not os.path.exists(file):
  223. if not gfile.Exists(filename+".npy"):
  224. with gzip.open(filename) as bytestream:
  225. bytestream.read(8)
  226. buf = bytestream.read(1 * num_images)
  227. labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int32)
  228. np.save(filename, labels)
  229. return labels
  230. else:
  231. with gfile.Open(filename+".npy", mode='r') as file_obj:
  232. return np.load(file_obj)
  233. def ld_svhn(extended=False, test_only=False):
  234. """
  235. Load the original SVHN data
  236. :param extended: include extended training data in the returned array
  237. :param test_only: disables loading of both train and extra -> large speed up
  238. :return: tuple of arrays which depend on the parameters
  239. """
  240. # Define files to be downloaded
  241. # WARNING: changing the order of this list will break indices (cf. below)
  242. file_urls = ['http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
  243. 'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
  244. 'http://ufldl.stanford.edu/housenumbers/extra_32x32.mat']
  245. # Maybe download data and retrieve local storage urls
  246. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  247. # Extra Train, Test, and Extended Train data
  248. if not test_only:
  249. # Load and applying whitening to train data
  250. train_data, train_labels = extract_svhn(local_urls[0])
  251. train_data = image_whitening(train_data)
  252. # Load and applying whitening to extended train data
  253. ext_data, ext_labels = extract_svhn(local_urls[2])
  254. ext_data = image_whitening(ext_data)
  255. # Load and applying whitening to test data
  256. test_data, test_labels = extract_svhn(local_urls[1])
  257. test_data = image_whitening(test_data)
  258. if test_only:
  259. return test_data, test_labels
  260. else:
  261. if extended:
  262. # Stack train data with the extended training data
  263. train_data = np.vstack((train_data, ext_data))
  264. train_labels = np.hstack((train_labels, ext_labels))
  265. return train_data, train_labels, test_data, test_labels
  266. else:
  267. # Return training and extended training data separately
  268. return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
  269. def ld_cifar10(test_only=False):
  270. """
  271. Load the original CIFAR10 data
  272. :param extended: include extended training data in the returned array
  273. :param test_only: disables loading of both train and extra -> large speed up
  274. :return: tuple of arrays which depend on the parameters
  275. """
  276. # Define files to be downloaded
  277. file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
  278. # Maybe download data and retrieve local storage urls
  279. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  280. # Extract archives and return different sets
  281. dataset = extract_cifar10(local_urls[0], FLAGS.data_dir)
  282. # Unpack tuple
  283. train_data, train_labels, test_data, test_labels = dataset
  284. # Apply whitening to input data
  285. train_data = image_whitening(train_data)
  286. test_data = image_whitening(test_data)
  287. if test_only:
  288. return test_data, test_labels
  289. else:
  290. return train_data, train_labels, test_data, test_labels
  291. def ld_mnist(test_only=False):
  292. """
  293. Load the MNIST dataset
  294. :param extended: include extended training data in the returned array
  295. :param test_only: disables loading of both train and extra -> large speed up
  296. :return: tuple of arrays which depend on the parameters
  297. """
  298. # Define files to be downloaded
  299. # WARNING: changing the order of this list will break indices (cf. below)
  300. file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
  301. 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
  302. 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
  303. 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
  304. ]
  305. # Maybe download data and retrieve local storage urls
  306. local_urls = maybe_download(file_urls, FLAGS.data_dir)
  307. # Extract it into np arrays.
  308. train_data = extract_mnist_data(local_urls[0], 60000, 28, 1)
  309. train_labels = extract_mnist_labels(local_urls[1], 60000)
  310. test_data = extract_mnist_data(local_urls[2], 10000, 28, 1)
  311. test_labels = extract_mnist_labels(local_urls[3], 10000)
  312. if test_only:
  313. return test_data, test_labels
  314. else:
  315. return train_data, train_labels, test_data, test_labels
  316. def partition_dataset(data, labels, nb_teachers, teacher_id):
  317. """
  318. Simple partitioning algorithm that returns the right portion of the data
  319. needed by a given teacher out of a certain nb of teachers
  320. :param data: input data to be partitioned
  321. :param labels: output data to be partitioned
  322. :param nb_teachers: number of teachers in the ensemble (affects size of each
  323. partition)
  324. :param teacher_id: id of partition to retrieve
  325. :return:
  326. """
  327. # Sanity check
  328. assert len(data) == len(labels)
  329. assert int(teacher_id) < int(nb_teachers)
  330. # This will floor the possible number of batches
  331. batch_len = int(len(data) / nb_teachers)
  332. # Compute start, end indices of partition
  333. start = teacher_id * batch_len
  334. end = (teacher_id+1) * batch_len
  335. # Slice partition off
  336. partition_data = data[start:end]
  337. partition_labels = labels[start:end]
  338. return partition_data, partition_labels