123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import tensorflow as tf
- from differential_privacy.multiple_teachers import deep_cnn
- from differential_privacy.multiple_teachers import input
- from differential_privacy.multiple_teachers import metrics
- tf.flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
- tf.flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
- tf.flags.DEFINE_string('data_dir','/tmp','Temporary storage')
- tf.flags.DEFINE_string('train_dir','/tmp/train_dir',
- 'Where model ckpt are saved')
- tf.flags.DEFINE_integer('max_steps', 3000, 'Number of training steps to run.')
- tf.flags.DEFINE_integer('nb_teachers', 50, 'Teachers in the ensemble.')
- tf.flags.DEFINE_integer('teacher_id', 0, 'ID of teacher being trained.')
- tf.flags.DEFINE_boolean('deeper', False, 'Activate deeper CNN model')
- FLAGS = tf.flags.FLAGS
- def train_teacher(dataset, nb_teachers, teacher_id):
- """
- This function trains a teacher (teacher id) among an ensemble of nb_teachers
- models for the dataset specified.
- :param dataset: string corresponding to dataset (svhn, cifar10)
- :param nb_teachers: total number of teachers in the ensemble
- :param teacher_id: id of the teacher being trained
- :return: True if everything went well
- """
- # If working directories do not exist, create them
- assert input.create_dir_if_needed(FLAGS.data_dir)
- assert input.create_dir_if_needed(FLAGS.train_dir)
- # Load the dataset
- if dataset == 'svhn':
- train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
- elif dataset == 'cifar10':
- train_data, train_labels, test_data, test_labels = input.ld_cifar10()
- elif dataset == 'mnist':
- train_data, train_labels, test_data, test_labels = input.ld_mnist()
- else:
- print("Check value of dataset flag")
- return False
-
- # Retrieve subset of data for this teacher
- data, labels = input.partition_dataset(train_data,
- train_labels,
- nb_teachers,
- teacher_id)
- print("Length of training data: " + str(len(labels)))
- # Define teacher checkpoint filename and full path
- if FLAGS.deeper:
- filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
- else:
- filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
- ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename
- # Perform teacher training
- assert deep_cnn.train(data, labels, ckpt_path)
- # Append final step value to checkpoint for evaluation
- ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
- # Retrieve teacher probability estimates on the test data
- teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
- # Compute teacher accuracy
- precision = metrics.accuracy(teacher_preds, test_labels)
- print('Precision of teacher after training: ' + str(precision))
- return True
- def main(argv=None): # pylint: disable=unused-argument
- # Make a call to train_teachers with values specified in flags
- assert train_teacher(FLAGS.dataset, FLAGS.nb_teachers, FLAGS.teacher_id)
- if __name__ == '__main__':
- tf.app.run()
|