123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from datetime import datetime
- import math
- import numpy as np
- import tensorflow as tf
- import time
- from differential_privacy.multiple_teachers import utils
- FLAGS = tf.app.flags.FLAGS
- # Basic model parameters.
- tf.app.flags.DEFINE_integer('dropout_seed', 123, """seed for dropout.""")
- tf.app.flags.DEFINE_integer('batch_size', 128, """Nb of images in a batch.""")
- tf.app.flags.DEFINE_integer('epochs_per_decay', 350, """Nb epochs per decay""")
- tf.app.flags.DEFINE_integer('learning_rate', 5, """100 * learning rate""")
- tf.app.flags.DEFINE_boolean('log_device_placement', False, """see TF doc""")
- # Constants describing the training process.
- MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
- LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
- def _variable_on_cpu(name, shape, initializer):
- """Helper to create a Variable stored on CPU memory.
- Args:
- name: name of the variable
- shape: list of ints
- initializer: initializer for Variable
- Returns:
- Variable Tensor
- """
- with tf.device('/cpu:0'):
- var = tf.get_variable(name, shape, initializer=initializer)
- return var
- def _variable_with_weight_decay(name, shape, stddev, wd):
- """Helper to create an initialized Variable with weight decay.
- Note that the Variable is initialized with a truncated normal distribution.
- A weight decay is added only if one is specified.
- Args:
- name: name of the variable
- shape: list of ints
- stddev: standard deviation of a truncated Gaussian
- wd: add L2Loss weight decay multiplied by this float. If None, weight
- decay is not added for this Variable.
- Returns:
- Variable Tensor
- """
- var = _variable_on_cpu(name, shape,
- tf.truncated_normal_initializer(stddev=stddev))
- if wd is not None:
- weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
- tf.add_to_collection('losses', weight_decay)
- return var
- def inference(images, dropout=False):
- """Build the CNN model.
- Args:
- images: Images returned from distorted_inputs() or inputs().
- dropout: Boolean controling whether to use dropout or not
- Returns:
- Logits
- """
- if FLAGS.dataset == 'mnist':
- first_conv_shape = [5, 5, 1, 64]
- else:
- first_conv_shape = [5, 5, 3, 64]
- # conv1
- with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=first_conv_shape,
- stddev=1e-4,
- wd=0.0)
- conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv1 = tf.nn.relu(bias, name=scope.name)
- if dropout:
- conv1 = tf.nn.dropout(conv1, 0.3, seed=FLAGS.dropout_seed)
- # pool1
- pool1 = tf.nn.max_pool(conv1,
- ksize=[1, 3, 3, 1],
- strides=[1, 2, 2, 1],
- padding='SAME',
- name='pool1')
-
- # norm1
- norm1 = tf.nn.lrn(pool1,
- 4,
- bias=1.0,
- alpha=0.001 / 9.0,
- beta=0.75,
- name='norm1')
- # conv2
- with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 64, 128],
- stddev=1e-4,
- wd=0.0)
- conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [128], tf.constant_initializer(0.1))
- bias = tf.nn.bias_add(conv, biases)
- conv2 = tf.nn.relu(bias, name=scope.name)
- if dropout:
- conv2 = tf.nn.dropout(conv2, 0.3, seed=FLAGS.dropout_seed)
- # norm2
- norm2 = tf.nn.lrn(conv2,
- 4,
- bias=1.0,
- alpha=0.001 / 9.0,
- beta=0.75,
- name='norm2')
-
- # pool2
- pool2 = tf.nn.max_pool(norm2,
- ksize=[1, 3, 3, 1],
- strides=[1, 2, 2, 1],
- padding='SAME',
- name='pool2')
- # local3
- with tf.variable_scope('local3') as scope:
- # Move everything into depth so we can perform a single matrix multiply.
- reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
- dim = reshape.get_shape()[1].value
- weights = _variable_with_weight_decay('weights',
- shape=[dim, 384],
- stddev=0.04,
- wd=0.004)
- biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
- local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
- if dropout:
- local3 = tf.nn.dropout(local3, 0.5, seed=FLAGS.dropout_seed)
- # local4
- with tf.variable_scope('local4') as scope:
- weights = _variable_with_weight_decay('weights',
- shape=[384, 192],
- stddev=0.04,
- wd=0.004)
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
- local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
- if dropout:
- local4 = tf.nn.dropout(local4, 0.5, seed=FLAGS.dropout_seed)
- # compute logits
- with tf.variable_scope('softmax_linear') as scope:
- weights = _variable_with_weight_decay('weights',
- [192, FLAGS.nb_labels],
- stddev=1/192.0,
- wd=0.0)
- biases = _variable_on_cpu('biases',
- [FLAGS.nb_labels],
- tf.constant_initializer(0.0))
- logits = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
- return logits
- def inference_deeper(images, dropout=False):
- """Build a deeper CNN model.
- Args:
- images: Images returned from distorted_inputs() or inputs().
- dropout: Boolean controling whether to use dropout or not
- Returns:
- Logits
- """
- if FLAGS.dataset == 'mnist':
- first_conv_shape = [3, 3, 1, 96]
- else:
- first_conv_shape = [3, 3, 3, 96]
- # conv1
- with tf.variable_scope('conv1') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=first_conv_shape,
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv1 = tf.nn.relu(bias, name=scope.name)
- # conv2
- with tf.variable_scope('conv2') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[3, 3, 96, 96],
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(conv1, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv2 = tf.nn.relu(bias, name=scope.name)
- # conv3
- with tf.variable_scope('conv3') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[3, 3, 96, 96],
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(conv2, kernel, [1, 2, 2, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv3 = tf.nn.relu(bias, name=scope.name)
- if dropout:
- conv3 = tf.nn.dropout(conv3, 0.5, seed=FLAGS.dropout_seed)
- # conv4
- with tf.variable_scope('conv4') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[3, 3, 96, 192],
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv4 = tf.nn.relu(bias, name=scope.name)
- # conv5
- with tf.variable_scope('conv5') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[3, 3, 192, 192],
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv5 = tf.nn.relu(bias, name=scope.name)
- # conv6
- with tf.variable_scope('conv6') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[3, 3, 192, 192],
- stddev=0.05,
- wd=0.0)
- conv = tf.nn.conv2d(conv5, kernel, [1, 2, 2, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
- bias = tf.nn.bias_add(conv, biases)
- conv6 = tf.nn.relu(bias, name=scope.name)
- if dropout:
- conv6 = tf.nn.dropout(conv6, 0.5, seed=FLAGS.dropout_seed)
- # conv7
- with tf.variable_scope('conv7') as scope:
- kernel = _variable_with_weight_decay('weights',
- shape=[5, 5, 192, 192],
- stddev=1e-4,
- wd=0.0)
- conv = tf.nn.conv2d(conv6, kernel, [1, 1, 1, 1], padding='SAME')
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
- bias = tf.nn.bias_add(conv, biases)
- conv7 = tf.nn.relu(bias, name=scope.name)
- # local1
- with tf.variable_scope('local1') as scope:
- # Move everything into depth so we can perform a single matrix multiply.
- reshape = tf.reshape(conv7, [FLAGS.batch_size, -1])
- dim = reshape.get_shape()[1].value
- weights = _variable_with_weight_decay('weights',
- shape=[dim, 192],
- stddev=0.05,
- wd=0)
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
- local1 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
- # local2
- with tf.variable_scope('local2') as scope:
- weights = _variable_with_weight_decay('weights',
- shape=[192, 192],
- stddev=0.05,
- wd=0)
- biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
- local2 = tf.nn.relu(tf.matmul(local1, weights) + biases, name=scope.name)
- if dropout:
- local2 = tf.nn.dropout(local2, 0.5, seed=FLAGS.dropout_seed)
- # compute logits
- with tf.variable_scope('softmax_linear') as scope:
- weights = _variable_with_weight_decay('weights',
- [192, FLAGS.nb_labels],
- stddev=0.05,
- wd=0.0)
- biases = _variable_on_cpu('biases',
- [FLAGS.nb_labels],
- tf.constant_initializer(0.0))
- logits = tf.add(tf.matmul(local2, weights), biases, name=scope.name)
- return logits
- def loss_fun(logits, labels):
- """Add L2Loss to all the trainable variables.
- Add summary for "Loss" and "Loss/avg".
- Args:
- logits: Logits from inference().
- labels: Labels from distorted_inputs or inputs(). 1-D tensor
- of shape [batch_size]
- distillation: if set to True, use probabilities and not class labels to
- compute softmax loss
- Returns:
- Loss tensor of type float.
- """
- # Calculate the cross entropy between labels and predictions
- labels = tf.cast(labels, tf.int64)
- cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
- logits=logits, labels=labels, name='cross_entropy_per_example')
- # Calculate the average cross entropy loss across the batch.
- cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
- # Add to TF collection for losses
- tf.add_to_collection('losses', cross_entropy_mean)
- # The total loss is defined as the cross entropy loss plus all of the weight
- # decay terms (L2 loss).
- return tf.add_n(tf.get_collection('losses'), name='total_loss')
- def moving_av(total_loss):
- """
- Generates moving average for all losses
- Args:
- total_loss: Total loss from loss().
- Returns:
- loss_averages_op: op for generating moving averages of losses.
- """
- # Compute the moving average of all individual losses and the total loss.
- loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
- losses = tf.get_collection('losses')
- loss_averages_op = loss_averages.apply(losses + [total_loss])
- return loss_averages_op
- def train_op_fun(total_loss, global_step):
- """Train model.
- Create an optimizer and apply to all trainable variables. Add moving
- average for all trainable variables.
- Args:
- total_loss: Total loss from loss().
- global_step: Integer Variable counting the number of training steps
- processed.
- Returns:
- train_op: op for training.
- """
- # Variables that affect learning rate.
- nb_ex_per_train_epoch = int(60000 / FLAGS.nb_teachers)
-
- num_batches_per_epoch = nb_ex_per_train_epoch / FLAGS.batch_size
- decay_steps = int(num_batches_per_epoch * FLAGS.epochs_per_decay)
- initial_learning_rate = float(FLAGS.learning_rate) / 100.0
- # Decay the learning rate exponentially based on the number of steps.
- lr = tf.train.exponential_decay(initial_learning_rate,
- global_step,
- decay_steps,
- LEARNING_RATE_DECAY_FACTOR,
- staircase=True)
- tf.summary.scalar('learning_rate', lr)
- # Generate moving averages of all losses and associated summaries.
- loss_averages_op = moving_av(total_loss)
- # Compute gradients.
- with tf.control_dependencies([loss_averages_op]):
- opt = tf.train.GradientDescentOptimizer(lr)
- grads = opt.compute_gradients(total_loss)
- # Apply gradients.
- apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
- # Add histograms for trainable variables.
- for var in tf.trainable_variables():
- tf.summary.histogram(var.op.name, var)
- # Track the moving averages of all trainable variables.
- variable_averages = tf.train.ExponentialMovingAverage(
- MOVING_AVERAGE_DECAY, global_step)
- variables_averages_op = variable_averages.apply(tf.trainable_variables())
- with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
- train_op = tf.no_op(name='train')
- return train_op
- def _input_placeholder():
- """
- This helper function declares a TF placeholder for the graph input data
- :return: TF placeholder for the graph input data
- """
- if FLAGS.dataset == 'mnist':
- image_size = 28
- num_channels = 1
- else:
- image_size = 32
- num_channels = 3
- # Declare data placeholder
- train_node_shape = (FLAGS.batch_size, image_size, image_size, num_channels)
- return tf.placeholder(tf.float32, shape=train_node_shape)
- def train(images, labels, ckpt_path, dropout=False):
- """
- This function contains the loop that actually trains the model.
- :param images: a numpy array with the input data
- :param labels: a numpy array with the output labels
- :param ckpt_path: a path (including name) where model checkpoints are saved
- :param dropout: Boolean, whether to use dropout or not
- :return: True if everything went well
- """
- # Check training data
- assert len(images) == len(labels)
- assert images.dtype == np.float32
- assert labels.dtype == np.int32
- # Set default TF graph
- with tf.Graph().as_default():
- global_step = tf.Variable(0, trainable=False)
- # Declare data placeholder
- train_data_node = _input_placeholder()
- # Create a placeholder to hold labels
- train_labels_shape = (FLAGS.batch_size,)
- train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape)
- print("Done Initializing Training Placeholders")
- # Build a Graph that computes the logits predictions from the placeholder
- if FLAGS.deeper:
- logits = inference_deeper(train_data_node, dropout=dropout)
- else:
- logits = inference(train_data_node, dropout=dropout)
- # Calculate loss
- loss = loss_fun(logits, train_labels_node)
- # Build a Graph that trains the model with one batch of examples and
- # updates the model parameters.
- train_op = train_op_fun(loss, global_step)
- # Create a saver.
- saver = tf.train.Saver(tf.global_variables())
- print("Graph constructed and saver created")
- # Build an initialization operation to run below.
- init = tf.global_variables_initializer()
- # Create and init sessions
- sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) #NOLINT(long-line)
- sess.run(init)
- print("Session ready, beginning training loop")
- # Initialize the number of batches
- data_length = len(images)
- nb_batches = math.ceil(data_length / FLAGS.batch_size)
- for step in xrange(FLAGS.max_steps):
- # for debug, save start time
- start_time = time.time()
- # Current batch number
- batch_nb = step % nb_batches
- # Current batch start and end indices
- start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
- # Prepare dictionnary to feed the session with
- feed_dict = {train_data_node: images[start:end],
- train_labels_node: labels[start:end]}
- # Run training step
- _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
- # Compute duration of training step
- duration = time.time() - start_time
- # Sanity check
- assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
- # Echo loss once in a while
- if step % 100 == 0:
- num_examples_per_step = FLAGS.batch_size
- examples_per_sec = num_examples_per_step / duration
- sec_per_batch = float(duration)
- format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
- 'sec/batch)')
- print (format_str % (datetime.now(), step, loss_value,
- examples_per_sec, sec_per_batch))
- # Save the model checkpoint periodically.
- if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
- saver.save(sess, ckpt_path, global_step=step)
- return True
- def softmax_preds(images, ckpt_path, return_logits=False):
- """
- Compute softmax activations (probabilities) with the model saved in the path
- specified as an argument
- :param images: a np array of images
- :param ckpt_path: a TF model checkpoint
- :param logits: if set to True, return logits instead of probabilities
- :return: probabilities (or logits if logits is set to True)
- """
- # Compute nb samples and deduce nb of batches
- data_length = len(images)
- nb_batches = math.ceil(len(images) / FLAGS.batch_size)
- # Declare data placeholder
- train_data_node = _input_placeholder()
- # Build a Graph that computes the logits predictions from the placeholder
- if FLAGS.deeper:
- logits = inference_deeper(train_data_node)
- else:
- logits = inference(train_data_node)
- if return_logits:
- # We are returning the logits directly (no need to apply softmax)
- output = logits
- else:
- # Add softmax predictions to graph: will return probabilities
- output = tf.nn.softmax(logits)
- # Restore the moving average version of the learned variables for eval.
- variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
- variables_to_restore = variable_averages.variables_to_restore()
- saver = tf.train.Saver(variables_to_restore)
- # Will hold the result
- preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)
- # Create TF session
- with tf.Session() as sess:
- # Restore TF session from checkpoint file
- saver.restore(sess, ckpt_path)
- # Parse data by batch
- for batch_nb in xrange(0, int(nb_batches+1)):
- # Compute batch start and end indices
- start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
- # Prepare feed dictionary
- feed_dict = {train_data_node: images[start:end]}
- # Run session ([0] because run returns a batch with len 1st dim == 1)
- preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]
- # Reset graph to allow multiple calls
- tf.reset_default_graph()
- return preds
|