train_teachers.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import tensorflow as tf
  19. from differential_privacy.multiple_teachers import deep_cnn
  20. from differential_privacy.multiple_teachers import input
  21. from differential_privacy.multiple_teachers import metrics
  22. tf.flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
  23. tf.flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
  24. tf.flags.DEFINE_string('data_dir','/tmp','Temporary storage')
  25. tf.flags.DEFINE_string('train_dir','/tmp/train_dir',
  26. 'Where model ckpt are saved')
  27. tf.flags.DEFINE_integer('max_steps', 3000, 'Number of training steps to run.')
  28. tf.flags.DEFINE_integer('nb_teachers', 50, 'Teachers in the ensemble.')
  29. tf.flags.DEFINE_integer('teacher_id', 0, 'ID of teacher being trained.')
  30. tf.flags.DEFINE_boolean('deeper', False, 'Activate deeper CNN model')
  31. FLAGS = tf.flags.FLAGS
  32. def train_teacher(dataset, nb_teachers, teacher_id):
  33. """
  34. This function trains a teacher (teacher id) among an ensemble of nb_teachers
  35. models for the dataset specified.
  36. :param dataset: string corresponding to dataset (svhn, cifar10)
  37. :param nb_teachers: total number of teachers in the ensemble
  38. :param teacher_id: id of the teacher being trained
  39. :return: True if everything went well
  40. """
  41. # If working directories do not exist, create them
  42. assert input.create_dir_if_needed(FLAGS.data_dir)
  43. assert input.create_dir_if_needed(FLAGS.train_dir)
  44. # Load the dataset
  45. if dataset == 'svhn':
  46. train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
  47. elif dataset == 'cifar10':
  48. train_data, train_labels, test_data, test_labels = input.ld_cifar10()
  49. elif dataset == 'mnist':
  50. train_data, train_labels, test_data, test_labels = input.ld_mnist()
  51. else:
  52. print("Check value of dataset flag")
  53. return False
  54. # Retrieve subset of data for this teacher
  55. data, labels = input.partition_dataset(train_data,
  56. train_labels,
  57. nb_teachers,
  58. teacher_id)
  59. print("Length of training data: " + str(len(labels)))
  60. # Define teacher checkpoint filename and full path
  61. if FLAGS.deeper:
  62. filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
  63. else:
  64. filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
  65. ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename
  66. # Perform teacher training
  67. assert deep_cnn.train(data, labels, ckpt_path)
  68. # Append final step value to checkpoint for evaluation
  69. ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
  70. # Retrieve teacher probability estimates on the test data
  71. teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
  72. # Compute teacher accuracy
  73. precision = metrics.accuracy(teacher_preds, test_labels)
  74. print('Precision of teacher after training: ' + str(precision))
  75. return True
  76. def main(argv=None): # pylint: disable=unused-argument
  77. # Make a call to train_teachers with values specified in flags
  78. assert train_teacher(FLAGS.dataset, FLAGS.nb_teachers, FLAGS.teacher_id)
  79. if __name__ == '__main__':
  80. tf.app.run()