train_student.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import numpy as np
  19. import tensorflow as tf
  20. from differential_privacy.multiple_teachers import aggregation
  21. from differential_privacy.multiple_teachers import deep_cnn
  22. from differential_privacy.multiple_teachers import input
  23. from differential_privacy.multiple_teachers import metrics
  24. FLAGS = tf.flags.FLAGS
  25. tf.flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
  26. tf.flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
  27. tf.flags.DEFINE_string('data_dir','/tmp','Temporary storage')
  28. tf.flags.DEFINE_string('train_dir','/tmp/train_dir','Where model chkpt are saved')
  29. tf.flags.DEFINE_string('teachers_dir','/tmp/train_dir',
  30. 'Directory where teachers checkpoints are stored.')
  31. tf.flags.DEFINE_integer('teachers_max_steps', 3000,
  32. 'Number of steps teachers were ran.')
  33. tf.flags.DEFINE_integer('max_steps', 3000, 'Number of steps to run student.')
  34. tf.flags.DEFINE_integer('nb_teachers', 10, 'Teachers in the ensemble.')
  35. tf.flags.DEFINE_integer('stdnt_share', 1000,
  36. 'Student share (last index) of the test data')
  37. tf.flags.DEFINE_integer('lap_scale', 10,
  38. 'Scale of the Laplacian noise added for privacy')
  39. tf.flags.DEFINE_boolean('save_labels', False,
  40. 'Dump numpy arrays of labels and clean teacher votes')
  41. tf.flags.DEFINE_boolean('deeper', False, 'Activate deeper CNN model')
  42. def ensemble_preds(dataset, nb_teachers, stdnt_data):
  43. """
  44. Given a dataset, a number of teachers, and some input data, this helper
  45. function queries each teacher for predictions on the data and returns
  46. all predictions in a single array. (That can then be aggregated into
  47. one single prediction per input using aggregation.py (cf. function
  48. prepare_student_data() below)
  49. :param dataset: string corresponding to mnist, cifar10, or svhn
  50. :param nb_teachers: number of teachers (in the ensemble) to learn from
  51. :param stdnt_data: unlabeled student training data
  52. :return: 3d array (teacher id, sample id, probability per class)
  53. """
  54. # Compute shape of array that will hold probabilities produced by each
  55. # teacher, for each training point, and each output class
  56. result_shape = (nb_teachers, len(stdnt_data), FLAGS.nb_labels)
  57. # Create array that will hold result
  58. result = np.zeros(result_shape, dtype=np.float32)
  59. # Get predictions from each teacher
  60. for teacher_id in xrange(nb_teachers):
  61. # Compute path of checkpoint file for teacher model with ID teacher_id
  62. if FLAGS.deeper:
  63. ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt-' + str(FLAGS.teachers_max_steps - 1) #NOLINT(long-line)
  64. else:
  65. ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt-' + str(FLAGS.teachers_max_steps - 1) # NOLINT(long-line)
  66. # Get predictions on our training data and store in result array
  67. result[teacher_id] = deep_cnn.softmax_preds(stdnt_data, ckpt_path)
  68. # This can take a while when there are a lot of teachers so output status
  69. print("Computed Teacher " + str(teacher_id) + " softmax predictions")
  70. return result
  71. def prepare_student_data(dataset, nb_teachers, save=False):
  72. """
  73. Takes a dataset name and the size of the teacher ensemble and prepares
  74. training data for the student model, according to parameters indicated
  75. in flags above.
  76. :param dataset: string corresponding to mnist, cifar10, or svhn
  77. :param nb_teachers: number of teachers (in the ensemble) to learn from
  78. :param save: if set to True, will dump student training labels predicted by
  79. the ensemble of teachers (with Laplacian noise) as npy files.
  80. It also dumps the clean votes for each class (without noise) and
  81. the labels assigned by teachers
  82. :return: pairs of (data, labels) to be used for student training and testing
  83. """
  84. assert input.create_dir_if_needed(FLAGS.train_dir)
  85. # Load the dataset
  86. if dataset == 'svhn':
  87. test_data, test_labels = input.ld_svhn(test_only=True)
  88. elif dataset == 'cifar10':
  89. test_data, test_labels = input.ld_cifar10(test_only=True)
  90. elif dataset == 'mnist':
  91. test_data, test_labels = input.ld_mnist(test_only=True)
  92. else:
  93. print("Check value of dataset flag")
  94. return False
  95. # Make sure there is data leftover to be used as a test set
  96. assert FLAGS.stdnt_share < len(test_data)
  97. # Prepare [unlabeled] student training data (subset of test set)
  98. stdnt_data = test_data[:FLAGS.stdnt_share]
  99. # Compute teacher predictions for student training data
  100. teachers_preds = ensemble_preds(dataset, nb_teachers, stdnt_data)
  101. # Aggregate teacher predictions to get student training labels
  102. if not save:
  103. stdnt_labels = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale)
  104. else:
  105. # Request clean votes and clean labels as well
  106. stdnt_labels, clean_votes, labels_for_dump = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale, return_clean_votes=True) #NOLINT(long-line)
  107. # Prepare filepath for numpy dump of clean votes
  108. filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_clean_votes_lap_' + str(FLAGS.lap_scale) + '.npy' # NOLINT(long-line)
  109. # Prepare filepath for numpy dump of clean labels
  110. filepath_labels = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_teachers_labels_lap_' + str(FLAGS.lap_scale) + '.npy' # NOLINT(long-line)
  111. # Dump clean_votes array
  112. with tf.gfile.Open(filepath, mode='w') as file_obj:
  113. np.save(file_obj, clean_votes)
  114. # Dump labels_for_dump array
  115. with tf.gfile.Open(filepath_labels, mode='w') as file_obj:
  116. np.save(file_obj, labels_for_dump)
  117. # Print accuracy of aggregated labels
  118. ac_ag_labels = metrics.accuracy(stdnt_labels, test_labels[:FLAGS.stdnt_share])
  119. print("Accuracy of the aggregated labels: " + str(ac_ag_labels))
  120. # Store unused part of test set for use as a test set after student training
  121. stdnt_test_data = test_data[FLAGS.stdnt_share:]
  122. stdnt_test_labels = test_labels[FLAGS.stdnt_share:]
  123. if save:
  124. # Prepare filepath for numpy dump of labels produced by noisy aggregation
  125. filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_labels_lap_' + str(FLAGS.lap_scale) + '.npy' #NOLINT(long-line)
  126. # Dump student noisy labels array
  127. with tf.gfile.Open(filepath, mode='w') as file_obj:
  128. np.save(file_obj, stdnt_labels)
  129. return stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels
  130. def train_student(dataset, nb_teachers):
  131. """
  132. This function trains a student using predictions made by an ensemble of
  133. teachers. The student and teacher models are trained using the same
  134. neural network architecture.
  135. :param dataset: string corresponding to mnist, cifar10, or svhn
  136. :param nb_teachers: number of teachers (in the ensemble) to learn from
  137. :return: True if student training went well
  138. """
  139. assert input.create_dir_if_needed(FLAGS.train_dir)
  140. # Call helper function to prepare student data using teacher predictions
  141. stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=True)
  142. # Unpack the student dataset
  143. stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
  144. # Prepare checkpoint filename and path
  145. if FLAGS.deeper:
  146. ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student_deeper.ckpt' #NOLINT(long-line)
  147. else:
  148. ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt' # NOLINT(long-line)
  149. # Start student training
  150. assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
  151. # Compute final checkpoint name for student (with max number of steps)
  152. ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
  153. # Compute student label predictions on remaining chunk of test set
  154. student_preds = deep_cnn.softmax_preds(stdnt_test_data, ckpt_path_final)
  155. # Compute teacher accuracy
  156. precision = metrics.accuracy(student_preds, stdnt_test_labels)
  157. print('Precision of student after training: ' + str(precision))
  158. return True
  159. def main(argv=None): # pylint: disable=unused-argument
  160. # Run student training according to values specified in flags
  161. assert train_student(FLAGS.dataset, FLAGS.nb_teachers)
  162. if __name__ == '__main__':
  163. tf.app.run()