Quellcode durchsuchen

added private learning with multiple teachers (#331)

Nicolas Papernot vor 9 Jahren
Ursprung
Commit
c711dc707e
8 geänderte Dateien mit 1636 neuen und 0 gelöschten Zeilen
  1. 79 0
      privacy/README.md
  2. 131 0
      privacy/aggregation.py
  3. 604 0
      privacy/deep_cnn.py
  4. 420 0
      privacy/input.py
  5. 49 0
      privacy/metrics.py
  6. 212 0
      privacy/train_student.py
  7. 106 0
      privacy/train_teachers.py
  8. 35 0
      privacy/utils.py

+ 79 - 0
privacy/README.md

@@ -0,0 +1,79 @@
+# Learning private models with multiple teachers
+
+This repository contains code to create a setup for learning privacy-preserving 
+student models by transferring knowledge from an ensemble of teachers trained 
+on disjoint subsets of the data for which privacy guarantees are to be provided.
+
+Knowledge acquired by teachers is transferred to the student in a differentially
+private manner by noisily aggregating the teacher decisions before feeding them
+to the student during training.
+
+A paper describing the approach is in preparation. A link will be added to this 
+README when available.
+
+## Dependencies
+
+This model uses `TensorFlow` to perform numerical computations associated with 
+machine learning models, as well as common Python libraries like: `numpy`, 
+`scipy`, and `six`. Instructions to install these can be found in their 
+respective documentations. 
+
+## How to run
+
+This repository supports the MNIST, CIFAR10, and SVHN datasets. The following
+instructions are given for MNIST but can easily be adapted by replacing the 
+flag `--dataset=mnist` by `--dataset=cifar10` or `--dataset=svhn`.
+There are 2 steps: teacher training and student training. Data will be 
+automatically downloaded when you start the teacher training. 
+
+The following is a two-step process: first we train an ensemble of teacher
+models and second we train a student using predictions made by this ensemble.
+
+**Training the teachers:** first run the `train_teachers.py` file with at least
+three flags specifying (1) the number of teachers, (2) the ID of the teacher
+you are training among these teachers, and (3) the dataset on which to train. 
+For instance, to train teacher number 10 among an ensemble of 100 teachers for 
+MNIST, you use the following command:
+
+```
+python train_teachers.py --nb_teachers=100 --teacher_id=10 --dataset=mnist
+```
+
+Other flags like `train_dir` and `data_dir` should optionally be set to
+respectively point to the directory where model checkpoints and temporary data
+(like the dataset) should be saved. The flag `max_steps` (default at 3000) 
+controls the length of training. See `train_teachers.py` and `deep_cnn.py` 
+to find available flags and their descriptions.
+
+**Training the student:** once the teachers are all trained, e.g., teachers 
+with IDs `0` to `99` are trained for `nb_teachers=100`, we are ready to train
+the student. The student is trained by labeling some of the test data with 
+predictions from the teachers. The predictions are aggregated by counting the
+votes assigned to each class among the ensemble of teachers, adding Laplacian 
+noise to these votes, and assigning the label with the maximum noisy vote count
+to the sample. This is detailed in function `noisy_max` in the file 
+`aggregation.py`. To learn the student, use the following command:
+
+```
+python train_student.py --nb_teachers=100 --dataset=mnist --stdnt_share=5000
+```
+
+The flag `--stdnt_share=5000` indicates that the student should be able to
+use the first `5000` samples of the dataset's test subset as unlabeled
+training points (they will be labeled using the teacher predictions). The 
+remaining samples are used for evaluation of the student's accuracy, which
+is displayed upon completion of training.
+
+## Alternative deeper convolutional architecture
+
+Note that a deeper convolutional model is available. Both the default and 
+deeper models graphs are defined in `deep_cnn.py`, respectively by 
+functions `inference` and `inference_deeper`. Use the flag `--deeper=true` 
+to switch to that model when launching `train_teachers.py` and 
+`train_student.py`. 
+
+## Contact
+
+To ask questions, please email `nicolas@papernot.fr` or open an issue on 
+the `tensorflow/models` issues tracker. Please assign issues to 
+[(@npapernot)](https://github.com/npapernot).

+ 131 - 0
privacy/aggregation.py

@@ -0,0 +1,131 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+def labels_from_probs(probs):
+  """
+  Helper function: computes argmax along last dimension of array to obtain
+  labels (max prob or max logit value)
+  :param probs: numpy array where probabilities or logits are on last dimension
+  :return: array with same shape as input besides last dimension with shape 1
+          now containing the labels
+  """
+  # Compute last axis index
+  last_axis = len(np.shape(probs)) - 1
+
+  # Label is argmax over last dimension
+  labels = np.argmax(probs, axis=last_axis)
+
+  # Return as np.int32
+  return np.asarray(labels, dtype=np.int32)
+
+
+def noisy_max(logits, lap_scale, return_clean_votes=False):
+  """
+  This aggregation mechanism takes the softmax/logit output of several models
+  resulting from inference on identical inputs and computes the noisy-max of
+  the votes for candidate classes to select a label for each sample: it
+  adds Laplacian noise to label counts and returns the most frequent label.
+  :param logits: logits or probabilities for each sample
+  :param lap_scale: scale of the Laplacian noise to be added to counts
+  :param return_clean_votes: if set to True, also returns clean votes (without
+                      Laplacian noise). This can be used to perform the
+                      privacy analysis of this aggregation mechanism.
+  :return: pair of result and (if clean_votes is set to True) the clean counts
+           for each class per sample and the the original labels produced by
+           the teachers.
+  """
+
+  # Compute labels from logits/probs and reshape array properly
+  labels = labels_from_probs(logits)
+  labels_shape = np.shape(labels)
+  labels = labels.reshape((labels_shape[0], labels_shape[1]))
+
+  # Initialize array to hold final labels
+  result = np.zeros(int(labels_shape[1]))
+
+  if return_clean_votes:
+    # Initialize array to hold clean votes for each sample
+    clean_votes = np.zeros((int(labels_shape[1]), 10))
+
+  # Parse each sample
+  for i in xrange(int(labels_shape[1])):
+    # Count number of votes assigned to each class
+    label_counts = np.bincount(labels[:,i], minlength=10)
+
+    if return_clean_votes:
+      # Store vote counts for export
+      clean_votes[i] = label_counts
+
+    # Cast in float32 to prepare before addition of Laplacian noise
+    label_counts = np.asarray(label_counts, dtype=np.float32)
+
+    # Sample independent Laplacian noise for each class
+    for item in xrange(10):
+      label_counts[item] += np.random.laplace(loc=0.0, scale=float(lap_scale))
+
+    # Result is the most frequent label
+    result[i] = np.argmax(label_counts)
+
+  # Cast labels to np.int32 for compatibility with deep_cnn.py feed dictionaries
+  result = np.asarray(result, dtype=np.int32)
+
+  if return_clean_votes:
+    # Returns several array, which are later saved:
+    # result: labels obtained from the noisy aggregation
+    # clean_votes: the number of teacher votes assigned to each sample and class
+    # labels: the labels assigned by teachers (before the noisy aggregation)
+    return result, clean_votes, labels
+  else:
+    # Only return labels resulting from noisy aggregation
+    return result
+
+
+def aggregation_most_frequent(logits):
+  """
+  This aggregation mechanism takes the softmax/logit output of several models
+  resulting from inference on identical inputs and computes the most frequent
+  label. It is deterministic (no noise injection like noisy_max() above.
+  :param logits: logits or probabilities for each sample
+  :return:
+  """
+  # Compute labels from logits/probs and reshape array properly
+  labels = labels_from_probs(logits)
+  labels_shape = np.shape(labels)
+  labels = labels.reshape((labels_shape[0], labels_shape[1]))
+
+  # Initialize array to hold final labels
+  result = np.zeros(int(labels_shape[1]))
+
+  # Parse each sample
+  for i in xrange(int(labels_shape[1])):
+    # Count number of votes assigned to each class
+    label_counts = np.bincount(labels[:,i], minlength=10)
+
+    label_counts = np.asarray(label_counts, dtype=np.int32)
+
+    # Result is the most frequent label
+    result[i] = np.argmax(label_counts)
+
+  return np.asarray(result, dtype=np.int32)
+
+

+ 604 - 0
privacy/deep_cnn.py

@@ -0,0 +1,604 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from datetime import datetime
+import math
+import numpy as np
+import tensorflow as tf
+import time
+
+import utils
+
+FLAGS = tf.app.flags.FLAGS
+
+# Basic model parameters.
+tf.app.flags.DEFINE_integer('dropout_seed', 123, """seed for dropout.""")
+tf.app.flags.DEFINE_integer('batch_size', 128, """Nb of images in a batch.""")
+tf.app.flags.DEFINE_integer('epochs_per_decay', 350, """Nb epochs per decay""")
+tf.app.flags.DEFINE_integer('learning_rate', 5, """100 * learning rate""")
+tf.app.flags.DEFINE_boolean('log_device_placement', False, """see TF doc""")
+
+
+# Constants describing the training process.
+MOVING_AVERAGE_DECAY = 0.9999     # The decay to use for the moving average.
+LEARNING_RATE_DECAY_FACTOR = 0.1  # Learning rate decay factor.
+
+
+def _variable_on_cpu(name, shape, initializer):
+  """Helper to create a Variable stored on CPU memory.
+
+  Args:
+    name: name of the variable
+    shape: list of ints
+    initializer: initializer for Variable
+
+  Returns:
+    Variable Tensor
+  """
+  with tf.device('/cpu:0'):
+    var = tf.get_variable(name, shape, initializer=initializer)
+  return var
+
+
+def _variable_with_weight_decay(name, shape, stddev, wd):
+  """Helper to create an initialized Variable with weight decay.
+
+  Note that the Variable is initialized with a truncated normal distribution.
+  A weight decay is added only if one is specified.
+
+  Args:
+    name: name of the variable
+    shape: list of ints
+    stddev: standard deviation of a truncated Gaussian
+    wd: add L2Loss weight decay multiplied by this float. If None, weight
+        decay is not added for this Variable.
+
+  Returns:
+    Variable Tensor
+  """
+  var = _variable_on_cpu(name, shape,
+                         tf.truncated_normal_initializer(stddev=stddev))
+  if wd is not None:
+    weight_decay = tf.mul(tf.nn.l2_loss(var), wd, name='weight_loss')
+    tf.add_to_collection('losses', weight_decay)
+  return var
+
+
+def inference(images, dropout=False):
+  """Build the CNN model.
+  Args:
+    images: Images returned from distorted_inputs() or inputs().
+    dropout: Boolean controling whether to use dropout or not
+  Returns:
+    Logits
+  """
+  if FLAGS.dataset == 'mnist':
+    first_conv_shape = [5, 5, 1, 64]
+  else:
+    first_conv_shape = [5, 5, 3, 64]
+
+  # conv1
+  with tf.variable_scope('conv1') as scope:
+    kernel = _variable_with_weight_decay('weights', 
+                                         shape=first_conv_shape,
+                                         stddev=1e-4, 
+                                         wd=0.0)
+    conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv1 = tf.nn.relu(bias, name=scope.name)
+    if dropout:
+      conv1 = tf.nn.dropout(conv1, 0.3, seed=FLAGS.dropout_seed)
+
+
+  # pool1
+  pool1 = tf.nn.max_pool(conv1, 
+                         ksize=[1, 3, 3, 1], 
+                         strides=[1, 2, 2, 1],
+                         padding='SAME', 
+                         name='pool1')
+  
+  # norm1
+  norm1 = tf.nn.lrn(pool1, 
+                    4, 
+                    bias=1.0, 
+                    alpha=0.001 / 9.0, 
+                    beta=0.75,
+                    name='norm1')
+
+  # conv2
+  with tf.variable_scope('conv2') as scope:
+    kernel = _variable_with_weight_decay('weights', 
+                                         shape=[5, 5, 64, 128],
+                                         stddev=1e-4, 
+                                         wd=0.0)
+    conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [128], tf.constant_initializer(0.1))
+    bias = tf.nn.bias_add(conv, biases)
+    conv2 = tf.nn.relu(bias, name=scope.name)
+    if dropout:
+      conv2 = tf.nn.dropout(conv2, 0.3, seed=FLAGS.dropout_seed)
+
+
+  # norm2
+  norm2 = tf.nn.lrn(conv2, 
+                    4, 
+                    bias=1.0, 
+                    alpha=0.001 / 9.0, 
+                    beta=0.75,
+                    name='norm2')
+  
+  # pool2
+  pool2 = tf.nn.max_pool(norm2, 
+                         ksize=[1, 3, 3, 1],
+                         strides=[1, 2, 2, 1], 
+                         padding='SAME', 
+                         name='pool2')
+
+  # local3
+  with tf.variable_scope('local3') as scope:
+    # Move everything into depth so we can perform a single matrix multiply.
+    reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
+    dim = reshape.get_shape()[1].value
+    weights = _variable_with_weight_decay('weights', 
+                                          shape=[dim, 384],
+                                          stddev=0.04, 
+                                          wd=0.004)
+    biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
+    local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
+    if dropout:
+      local3 = tf.nn.dropout(local3, 0.5, seed=FLAGS.dropout_seed)
+
+  # local4
+  with tf.variable_scope('local4') as scope:
+    weights = _variable_with_weight_decay('weights', 
+                                          shape=[384, 192],
+                                          stddev=0.04, 
+                                          wd=0.004)
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+    local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
+    if dropout:
+      local4 = tf.nn.dropout(local4, 0.5, seed=FLAGS.dropout_seed)
+
+  # compute logits
+  with tf.variable_scope('softmax_linear') as scope:
+    weights = _variable_with_weight_decay('weights', 
+                                          [192, FLAGS.nb_labels],
+                                          stddev=1/192.0, 
+                                          wd=0.0)
+    biases = _variable_on_cpu('biases', 
+                              [FLAGS.nb_labels],
+                              tf.constant_initializer(0.0))
+    logits = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
+
+  return logits
+
+
+def inference_deeper(images, dropout=False):
+  """Build a deeper CNN model.
+  Args:
+    images: Images returned from distorted_inputs() or inputs().
+    dropout: Boolean controling whether to use dropout or not
+  Returns:
+    Logits
+  """
+  if FLAGS.dataset == 'mnist':
+    first_conv_shape = [3, 3, 1, 96]
+  else:
+    first_conv_shape = [3, 3, 3, 96]
+
+  # conv1
+  with tf.variable_scope('conv1') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=first_conv_shape,
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv1 = tf.nn.relu(bias, name=scope.name)
+
+  # conv2
+  with tf.variable_scope('conv2') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[3, 3, 96, 96],
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv1, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv2 = tf.nn.relu(bias, name=scope.name)
+
+  # conv3
+  with tf.variable_scope('conv3') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[3, 3, 96, 96],
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv2, kernel, [1, 2, 2, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv3 = tf.nn.relu(bias, name=scope.name)
+    if dropout:
+      conv3 = tf.nn.dropout(conv3, 0.5, seed=FLAGS.dropout_seed)
+
+  # conv4
+  with tf.variable_scope('conv4') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[3, 3, 96, 192],
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv4 = tf.nn.relu(bias, name=scope.name)
+
+  # conv5
+  with tf.variable_scope('conv5') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[3, 3, 192, 192],
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv5 = tf.nn.relu(bias, name=scope.name)
+
+  # conv6
+  with tf.variable_scope('conv6') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[3, 3, 192, 192],
+                                         stddev=0.05,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv5, kernel, [1, 2, 2, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
+    bias = tf.nn.bias_add(conv, biases)
+    conv6 = tf.nn.relu(bias, name=scope.name)
+    if dropout:
+      conv6 = tf.nn.dropout(conv6, 0.5, seed=FLAGS.dropout_seed)
+
+
+  # conv7
+  with tf.variable_scope('conv7') as scope:
+    kernel = _variable_with_weight_decay('weights',
+                                         shape=[5, 5, 192, 192],
+                                         stddev=1e-4,
+                                         wd=0.0)
+    conv = tf.nn.conv2d(conv6, kernel, [1, 1, 1, 1], padding='SAME')
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+    bias = tf.nn.bias_add(conv, biases)
+    conv7 = tf.nn.relu(bias, name=scope.name)
+
+
+  # local1
+  with tf.variable_scope('local1') as scope:
+    # Move everything into depth so we can perform a single matrix multiply.
+    reshape = tf.reshape(conv7, [FLAGS.batch_size, -1])
+    dim = reshape.get_shape()[1].value
+    weights = _variable_with_weight_decay('weights',
+                                          shape=[dim, 192],
+                                          stddev=0.05,
+                                          wd=0)
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+    local1 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
+
+  # local2
+  with tf.variable_scope('local2') as scope:
+    weights = _variable_with_weight_decay('weights',
+                                          shape=[192, 192],
+                                          stddev=0.05,
+                                          wd=0)
+    biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
+    local2 = tf.nn.relu(tf.matmul(local1, weights) + biases, name=scope.name)
+    if dropout:
+      local2 = tf.nn.dropout(local2, 0.5, seed=FLAGS.dropout_seed)
+
+  # compute logits
+  with tf.variable_scope('softmax_linear') as scope:
+    weights = _variable_with_weight_decay('weights',
+                                          [192, FLAGS.nb_labels],
+                                          stddev=0.05,
+                                          wd=0.0)
+    biases = _variable_on_cpu('biases',
+                              [FLAGS.nb_labels],
+                              tf.constant_initializer(0.0))
+    logits = tf.add(tf.matmul(local2, weights), biases, name=scope.name)
+
+  return logits
+
+
+def loss_fun(logits, labels):
+  """Add L2Loss to all the trainable variables.
+
+  Add summary for "Loss" and "Loss/avg".
+  Args:
+    logits: Logits from inference().
+    labels: Labels from distorted_inputs or inputs(). 1-D tensor
+            of shape [batch_size]
+    distillation: if set to True, use probabilities and not class labels to
+                  compute softmax loss
+
+  Returns:
+    Loss tensor of type float.
+  """
+
+  # Calculate the cross entropy between labels and predictions
+  labels = tf.cast(labels, tf.int64)
+  cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
+      logits, labels, name='cross_entropy_per_example')
+
+  # Calculate the average cross entropy loss across the batch.
+  cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
+
+  # Add to TF collection for losses
+  tf.add_to_collection('losses', cross_entropy_mean)
+
+  # The total loss is defined as the cross entropy loss plus all of the weight
+  # decay terms (L2 loss).
+  return tf.add_n(tf.get_collection('losses'), name='total_loss')
+
+
+def moving_av(total_loss):
+  """
+  Generates moving average for all losses
+
+  Args:
+    total_loss: Total loss from loss().
+  Returns:
+    loss_averages_op: op for generating moving averages of losses.
+  """
+  # Compute the moving average of all individual losses and the total loss.
+  loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
+  losses = tf.get_collection('losses')
+  loss_averages_op = loss_averages.apply(losses + [total_loss])
+
+  return loss_averages_op
+
+
+def train_op_fun(total_loss, global_step):
+  """Train model.
+
+  Create an optimizer and apply to all trainable variables. Add moving
+  average for all trainable variables.
+
+  Args:
+    total_loss: Total loss from loss().
+    global_step: Integer Variable counting the number of training steps
+      processed.
+  Returns:
+    train_op: op for training.
+  """
+  # Variables that affect learning rate.
+  nb_ex_per_train_epoch = int(60000 / FLAGS.nb_teachers)
+  
+  num_batches_per_epoch = nb_ex_per_train_epoch / FLAGS.batch_size
+  decay_steps = int(num_batches_per_epoch * FLAGS.epochs_per_decay)
+
+  initial_learning_rate = float(FLAGS.learning_rate) / 100.0
+
+  # Decay the learning rate exponentially based on the number of steps.
+  lr = tf.train.exponential_decay(initial_learning_rate,
+                                  global_step,
+                                  decay_steps,
+                                  LEARNING_RATE_DECAY_FACTOR,
+                                  staircase=True)
+  tf.scalar_summary('learning_rate', lr)
+
+  # Generate moving averages of all losses and associated summaries.
+  loss_averages_op = moving_av(total_loss)
+
+  # Compute gradients.
+  with tf.control_dependencies([loss_averages_op]):
+    opt = tf.train.GradientDescentOptimizer(lr)
+    grads = opt.compute_gradients(total_loss)
+
+  # Apply gradients.
+  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
+
+  # Add histograms for trainable variables.
+  for var in tf.trainable_variables():
+    tf.histogram_summary(var.op.name, var)
+
+  # Track the moving averages of all trainable variables.
+  variable_averages = tf.train.ExponentialMovingAverage(
+      MOVING_AVERAGE_DECAY, global_step)
+  variables_averages_op = variable_averages.apply(tf.trainable_variables())
+
+  with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
+    train_op = tf.no_op(name='train')
+
+  return train_op
+
+
+def _input_placeholder():
+  """
+  This helper function declares a TF placeholder for the graph input data
+  :return: TF placeholder for the graph input data
+  """
+  if FLAGS.dataset == 'mnist':
+    image_size = 28
+    num_channels = 1
+  else:
+    image_size = 32
+    num_channels = 3
+
+  # Declare data placeholder
+  train_node_shape = (FLAGS.batch_size, image_size, image_size, num_channels)
+  return tf.placeholder(tf.float32, shape=train_node_shape)
+
+
+def train(images, labels, ckpt_path, dropout=False):
+  """
+  This function contains the loop that actually trains the model.
+  :param images: a numpy array with the input data
+  :param labels: a numpy array with the output labels
+  :param ckpt_path: a path (including name) where model checkpoints are saved
+  :param dropout: Boolean, whether to use dropout or not
+  :return: True if everything went well
+  """
+
+  # Check training data
+  assert len(images) == len(labels)
+  assert images.dtype == np.float32
+  assert labels.dtype == np.int32
+
+  # Set default TF graph
+  with tf.Graph().as_default():
+    global_step = tf.Variable(0, trainable=False)
+
+    # Declare data placeholder
+    train_data_node = _input_placeholder()
+
+    # Create a placeholder to hold labels
+    train_labels_shape = (FLAGS.batch_size,)
+    train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape)
+
+    print("Done Initializing Training Placeholders")
+
+    # Build a Graph that computes the logits predictions from the placeholder
+    if FLAGS.deeper:
+      logits = inference_deeper(train_data_node, dropout=dropout)
+    else:
+      logits = inference(train_data_node, dropout=dropout)
+
+    # Calculate loss
+    loss = loss_fun(logits, train_labels_node)
+
+    # Build a Graph that trains the model with one batch of examples and
+    # updates the model parameters.
+    train_op = train_op_fun(loss, global_step)
+
+    # Create a saver.
+    saver = tf.train.Saver(tf.all_variables())
+
+    print("Graph constructed and saver created")
+
+    # Build an initialization operation to run below.
+    init = tf.initialize_all_variables()
+
+    # Create and init sessions
+    sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) #NOLINT(long-line)
+    sess.run(init)
+
+    print("Session ready, beginning training loop")
+
+    # Initialize the number of batches
+    data_length = len(images)
+    nb_batches = math.ceil(data_length / FLAGS.batch_size)
+
+    for step in xrange(FLAGS.max_steps):
+      # for debug, save start time
+      start_time = time.time()
+
+      # Current batch number
+      batch_nb = step % nb_batches
+
+      # Current batch start and end indices
+      start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
+
+      # Prepare dictionnary to feed the session with
+      feed_dict = {train_data_node: images[start:end],
+                   train_labels_node: labels[start:end]}
+
+      # Run training step
+      _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
+
+      # Compute duration of training step
+      duration = time.time() - start_time
+
+      # Sanity check
+      assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
+
+      # Echo loss once in a while
+      if step % 100 == 0:
+        num_examples_per_step = FLAGS.batch_size
+        examples_per_sec = num_examples_per_step / duration
+        sec_per_batch = float(duration)
+
+        format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
+                      'sec/batch)')
+        print (format_str % (datetime.now(), step, loss_value,
+                             examples_per_sec, sec_per_batch))
+
+      # Save the model checkpoint periodically.
+      if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
+        saver.save(sess, ckpt_path, global_step=step)
+
+  return True
+
+
+def softmax_preds(images, ckpt_path, return_logits=False):
+  """
+  Compute softmax activations (probabilities) with the model saved in the path
+  specified as an argument
+  :param images: a np array of images
+  :param ckpt_path: a TF model checkpoint
+  :param logits: if set to True, return logits instead of probabilities
+  :return: probabilities (or logits if logits is set to True)
+  """
+  # Compute nb samples and deduce nb of batches
+  data_length = len(images)
+  nb_batches = math.ceil(len(images) / FLAGS.batch_size)
+
+  # Declare data placeholder
+  train_data_node = _input_placeholder()
+
+  # Build a Graph that computes the logits predictions from the placeholder
+  if FLAGS.deeper:
+    logits = inference_deeper(train_data_node)
+  else:
+    logits = inference(train_data_node)
+
+  if return_logits:
+    # We are returning the logits directly (no need to apply softmax)
+    output = logits
+  else:
+    # Add softmax predictions to graph: will return probabilities
+    output = tf.nn.softmax(logits)
+
+  # Restore the moving average version of the learned variables for eval.
+  variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
+  variables_to_restore = variable_averages.variables_to_restore()
+  saver = tf.train.Saver(variables_to_restore)
+
+  # Will hold the result
+  preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)
+
+  # Create TF session
+  with tf.Session() as sess:
+    # Restore TF session from checkpoint file
+    saver.restore(sess, ckpt_path)
+
+    # Parse data by batch
+    for batch_nb in xrange(0, int(nb_batches+1)):
+      # Compute batch start and end indices
+      start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
+
+      # Prepare feed dictionary
+      feed_dict = {train_data_node: images[start:end]}
+
+      # Run session ([0] because run returns a batch with len 1st dim == 1)
+      preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]
+
+  # Reset graph to allow multiple calls
+  tf.reset_default_graph()
+
+  return preds
+
+

+ 420 - 0
privacy/input.py

@@ -0,0 +1,420 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import cPickle
+import gzip
+import math
+import numpy as np
+import os
+from scipy.io import loadmat as loadmat
+from six.moves import urllib
+import sys
+import tarfile
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import flags
+
+FLAGS = flags.FLAGS
+
+
+def create_dir_if_needed(dest_directory):
+  """
+  Create directory if doesn't exist
+  :param dest_directory:
+  :return: True if everything went well
+  """
+  if not gfile.IsDirectory(dest_directory):
+    gfile.MakeDirs(dest_directory)
+
+  return True
+
+
+def maybe_download(file_urls, directory):
+  """
+  Download a set of files in temporary local folder
+  :param directory: the directory where to download 
+  :return: a tuple of filepaths corresponding to the files given as input
+  """
+  # Create directory if doesn't exist
+  assert create_dir_if_needed(directory)
+
+  # This list will include all URLS of the local copy of downloaded files
+  result = []
+
+  # For each file of the dataset
+  for file_url in file_urls:
+    # Extract filename
+    filename = file_url.split('/')[-1]
+
+    # Deduce local file url
+    #filepath = os.path.join(directory, filename)
+    filepath = directory + '/' + filename
+
+    # Add to result list
+    result.append(filepath)
+
+    # Test if file already exists
+    if not gfile.Exists(filepath):
+      def _progress(count, block_size, total_size):
+        sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename,
+            float(count * block_size) / float(total_size) * 100.0))
+        sys.stdout.flush()
+      filepath, _ = urllib.request.urlretrieve(file_url, filepath, _progress)
+      print()
+      statinfo = os.stat(filepath)
+      print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
+
+  return result
+
+
+def image_whitening(data):
+  """
+  Subtracts mean of image and divides by adjusted standard variance (for
+  stability). Operations are per image but performed for the entire array.
+  :param image: 4D array (ID, Height, Weight, Channel)
+  :return: 4D array (ID, Height, Weight, Channel)
+  """
+  assert len(np.shape(data)) == 4
+
+  # Compute number of pixels in image
+  nb_pixels = np.shape(data)[1] * np.shape(data)[2] * np.shape(data)[3]
+
+  # Subtract mean
+  mean = np.mean(data, axis=(1,2,3))
+
+  ones = np.ones(np.shape(data)[1:4], dtype=np.float32)
+  for i in xrange(len(data)):
+    data[i, :, :, :] -= mean[i] * ones
+
+  # Compute adjusted standard variance
+  adj_std_var = np.maximum(np.ones(len(data), dtype=np.float32) / math.sqrt(nb_pixels), np.std(data, axis=(1,2,3))) #NOLINT(long-line)
+
+  # Divide image
+  for i in xrange(len(data)):
+    data[i, :, :, :] = data[i, :, :, :] / adj_std_var[i]
+
+  print(np.shape(data))
+
+  return data
+
+
+def extract_svhn(local_url):
+  """
+  Extract a MATLAB matrix into two numpy arrays with data and labels
+  :param local_url:
+  :return:
+  """
+
+  with gfile.Open(local_url, mode='r') as file_obj:
+    # Load MATLAB matrix using scipy IO
+    dict = loadmat(file_obj)
+
+    # Extract each dictionary (one for data, one for labels)
+    data, labels = dict["X"], dict["y"]
+
+    # Set np type
+    data = np.asarray(data, dtype=np.float32)
+    labels = np.asarray(labels, dtype=np.int32)
+
+    # Transpose data to match TF model input format
+    data = data.transpose(3, 0, 1, 2)
+
+    # Fix the SVHN labels which label 0s as 10s
+    labels[labels == 10] = 0
+
+    # Fix label dimensions
+    labels = labels.reshape(len(labels))
+
+    return data, labels
+
+
+def unpickle_cifar_dic(file):
+  """
+  Helper function: unpickles a dictionary (used for loading CIFAR)
+  :param file: filename of the pickle
+  :return: tuple of (images, labels)
+  """
+  fo = open(file, 'rb')
+  dict = cPickle.load(fo)
+  fo.close()
+  return dict['data'], dict['labels']
+
+
+def extract_cifar10(local_url, data_dir):
+  """
+  Extracts the CIFAR-10 dataset and return numpy arrays with the different sets
+  :param local_url: where the tar.gz archive is located locally
+  :param data_dir: where to extract the archive's file
+  :return: a tuple (train data, train labels, test data, test labels)
+  """
+  # These numpy dumps can be reloaded to avoid performing the pre-processing
+  # if they exist in the working directory.
+  # Changing the order of this list will ruin the indices below.
+  preprocessed_files = ['/cifar10_train.npy',
+                        '/cifar10_train_labels.npy',
+                        '/cifar10_test.npy',
+                        '/cifar10_test_labels.npy']
+
+  all_preprocessed = True
+  for file in preprocessed_files:
+    if not gfile.Exists(data_dir + file):
+      all_preprocessed = False
+      break
+
+  if all_preprocessed:
+    # Reload pre-processed training data from numpy dumps
+    with gfile.Open(data_dir + preprocessed_files[0], mode='r') as file_obj:
+      train_data = np.load(file_obj)
+    with gfile.Open(data_dir + preprocessed_files[1], mode='r') as file_obj:
+      train_labels = np.load(file_obj)
+
+    # Reload pre-processed testing data from numpy dumps
+    with gfile.Open(data_dir + preprocessed_files[2], mode='r') as file_obj:
+      test_data = np.load(file_obj)
+    with gfile.Open(data_dir + preprocessed_files[3], mode='r') as file_obj:
+      test_labels = np.load(file_obj)
+
+  else:
+    # Do everything from scratch
+    # Define lists of all files we should extract
+    train_files = ["data_batch_" + str(i) for i in xrange(1,6)]
+    test_file = ["test_batch"]
+    cifar10_files = train_files + test_file
+
+    # Check if all files have already been extracted
+    need_to_unpack = False
+    for file in cifar10_files:
+      if not gfile.Exists(file):
+        need_to_unpack = True
+        break
+
+    # We have to unpack the archive
+    if need_to_unpack:
+      tarfile.open(local_url, 'r:gz').extractall(data_dir)
+
+    # Load training images and labels
+    images = []
+    labels = []
+    for file in train_files:
+      # Construct filename
+      filename = data_dir + "/cifar-10-batches-py/" + file
+
+      # Unpickle dictionary and extract images and labels
+      images_tmp, labels_tmp = unpickle_cifar_dic(filename)
+
+      # Append to lists
+      images.append(images_tmp)
+      labels.append(labels_tmp)
+
+    # Convert to numpy arrays and reshape in the expected format
+    train_data = np.asarray(images, dtype=np.float32).reshape((50000,3,32,32))
+    train_data = np.swapaxes(train_data, 1, 3)
+    train_labels = np.asarray(labels, dtype=np.int32).reshape(50000)
+
+    # Save so we don't have to do this again
+    np.save(data_dir + preprocessed_files[0], train_data)
+    np.save(data_dir + preprocessed_files[1], train_labels)
+
+    # Construct filename for test file
+    filename = data_dir + "/cifar-10-batches-py/" + test_file[0]
+
+    # Load test images and labels
+    test_data, test_images = unpickle_cifar_dic(filename)
+
+    # Convert to numpy arrays and reshape in the expected format
+    test_data = np.asarray(test_data,dtype=np.float32).reshape((10000,3,32,32))
+    test_data = np.swapaxes(test_data, 1, 3)
+    test_labels = np.asarray(test_images, dtype=np.int32).reshape(10000)
+
+    # Save so we don't have to do this again
+    np.save(data_dir + preprocessed_files[2], test_data)
+    np.save(data_dir + preprocessed_files[3], test_labels)
+
+  return train_data, train_labels, test_data, test_labels
+
+
+def extract_mnist_data(filename, num_images, image_size, pixel_depth):
+  """
+  Extract the images into a 4D tensor [image index, y, x, channels].
+
+  Values are rescaled from [0, 255] down to [-0.5, 0.5].
+  """
+  # if not os.path.exists(file):
+  if not gfile.Exists(filename+".npy"):
+    with gzip.open(filename) as bytestream:
+      bytestream.read(16)
+      buf = bytestream.read(image_size * image_size * num_images)
+      data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
+      data = (data - (pixel_depth / 2.0)) / pixel_depth
+      data = data.reshape(num_images, image_size, image_size, 1)
+      np.save(filename, data)
+      return data
+  else:
+    with gfile.Open(filename+".npy", mode='r') as file_obj:
+      return np.load(file_obj)
+
+
+def extract_mnist_labels(filename, num_images):
+  """
+  Extract the labels into a vector of int64 label IDs.
+  """
+  # if not os.path.exists(file):
+  if not gfile.Exists(filename+".npy"):
+    with gzip.open(filename) as bytestream:
+      bytestream.read(8)
+      buf = bytestream.read(1 * num_images)
+      labels = np.frombuffer(buf, dtype=np.uint8).astype(np.int32)
+      np.save(filename, labels)
+    return labels
+  else:
+    with gfile.Open(filename+".npy", mode='r') as file_obj:
+      return np.load(file_obj)
+
+
+def ld_svhn(extended=False, test_only=False):
+  """
+  Load the original SVHN data
+  :param extended: include extended training data in the returned array
+  :param test_only: disables loading of both train and extra -> large speed up
+  :return: tuple of arrays which depend on the parameters
+  """
+  # Define files to be downloaded
+  # WARNING: changing the order of this list will break indices (cf. below)
+  file_urls = ['http://ufldl.stanford.edu/housenumbers/train_32x32.mat',
+               'http://ufldl.stanford.edu/housenumbers/test_32x32.mat',
+               'http://ufldl.stanford.edu/housenumbers/extra_32x32.mat']
+
+  # Maybe download data and retrieve local storage urls
+  local_urls = maybe_download(file_urls, FLAGS.data_dir)
+
+  # Extra Train, Test, and Extended Train data
+  if not test_only:
+    # Load and applying whitening to train data
+    train_data, train_labels = extract_svhn(local_urls[0])
+    train_data = image_whitening(train_data)
+
+    # Load and applying whitening to extended train data
+    ext_data, ext_labels = extract_svhn(local_urls[2])
+    ext_data = image_whitening(ext_data)
+
+  # Load and applying whitening to test data
+  test_data, test_labels = extract_svhn(local_urls[1])
+  test_data = image_whitening(test_data)
+
+  if test_only:
+    return test_data, test_labels
+  else:
+    if extended:
+      # Stack train data with the extended training data
+      train_data = np.vstack((train_data, ext_data))
+      train_labels = np.hstack((train_labels, ext_labels))
+
+      return train_data, train_labels, test_data, test_labels
+    else:
+      # Return training and extended training data separately
+      return train_data,train_labels, test_data,test_labels, ext_data,ext_labels
+
+
+def ld_cifar10(test_only=False):
+  """
+  Load the original CIFAR10 data
+  :param extended: include extended training data in the returned array
+  :param test_only: disables loading of both train and extra -> large speed up
+  :return: tuple of arrays which depend on the parameters
+  """
+  # Define files to be downloaded
+  file_urls = ['https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz']
+
+  # Maybe download data and retrieve local storage urls
+  local_urls = maybe_download(file_urls, FLAGS.data_dir)
+
+  # Extract archives and return different sets
+  dataset = extract_cifar10(local_urls[0], FLAGS.data_dir)
+
+  # Unpack tuple
+  train_data, train_labels, test_data, test_labels = dataset
+
+  # Apply whitening to input data
+  train_data = image_whitening(train_data)
+  test_data = image_whitening(test_data)
+
+  if test_only:
+    return test_data, test_labels
+  else:
+    return train_data, train_labels, test_data, test_labels
+
+
+def ld_mnist(test_only=False):
+  """
+  Load the MNIST dataset
+  :param extended: include extended training data in the returned array
+  :param test_only: disables loading of both train and extra -> large speed up
+  :return: tuple of arrays which depend on the parameters
+  """
+  # Define files to be downloaded
+  # WARNING: changing the order of this list will break indices (cf. below)
+  file_urls = ['http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz',
+               'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz',
+               'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz',
+               'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz',
+               ]
+
+  # Maybe download data and retrieve local storage urls
+  local_urls = maybe_download(file_urls, FLAGS.data_dir)
+
+  # Extract it into np arrays.
+  train_data = extract_mnist_data(local_urls[0], 60000, 28, 1)
+  train_labels = extract_mnist_labels(local_urls[1], 60000)
+  test_data = extract_mnist_data(local_urls[2], 10000, 28, 1)
+  test_labels = extract_mnist_labels(local_urls[3], 10000)
+
+  if test_only:
+    return test_data, test_labels
+  else:
+    return train_data, train_labels, test_data, test_labels
+
+
+def partition_dataset(data, labels, nb_teachers, teacher_id):
+  """
+  Simple partitioning algorithm that returns the right portion of the data
+  needed by a given teacher out of a certain nb of teachers
+  :param data: input data to be partitioned
+  :param labels: output data to be partitioned
+  :param nb_teachers: number of teachers in the ensemble (affects size of each
+                      partition)
+  :param teacher_id: id of partition to retrieve
+  :return:
+  """
+
+  # Sanity check
+  assert len(data) == len(labels)
+  assert int(teacher_id) < int(nb_teachers)
+
+  # This will floor the possible number of batches
+  batch_len = int(len(data) / nb_teachers)
+
+  # Compute start, end indices of partition
+  start = teacher_id * batch_len
+  end = (teacher_id+1) * batch_len
+
+  # Slice partition off
+  partition_data = data[start:end]
+  partition_labels = labels[start:end]
+
+  return partition_data, partition_labels

+ 49 - 0
privacy/metrics.py

@@ -0,0 +1,49 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+
+def accuracy(logits, labels):
+  """
+  Return accuracy of the array of logits (or label predictions) wrt the labels
+  :param logits: this can either be logits, probabilities, or a single label
+  :param labels: the correct labels to match against
+  :return: the accuracy as a float
+  """
+  assert len(logits) == len(labels)
+
+  if len(np.shape(logits)) > 1:
+    # Predicted labels are the argmax over axis 1
+    predicted_labels = np.argmax(logits, axis=1)
+  else:
+    # Input was already labels
+    assert len(np.shape(logits)) == 1
+    predicted_labels = logits
+
+  # Check against correct labels to compute correct guesses
+  correct = np.sum(predicted_labels == labels.reshape(len(labels)))
+
+  # Divide by number of labels to obtain accuracy
+  accuracy = float(correct) / len(labels)
+
+  # Return float value
+  return accuracy
+
+

+ 212 - 0
privacy/train_student.py

@@ -0,0 +1,212 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import app
+
+import aggregation
+import deep_cnn
+import input
+import metrics
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
+flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
+
+flags.DEFINE_string('data_dir','/tmp','Temporary storage')
+flags.DEFINE_string('train_dir','/tmp/train_dir','Where model chkpt are saved')
+flags.DEFINE_string('teachers_dir','/tmp/train_dir',
+                    'Directory where teachers checkpoints are stored.')
+
+flags.DEFINE_integer('teachers_max_steps', 3000,
+                     """Number of steps teachers were ran.""")
+flags.DEFINE_integer('max_steps', 3000, """Number of steps to run student.""")
+flags.DEFINE_integer('nb_teachers', 10, """Teachers in the ensemble.""")
+tf.app.flags.DEFINE_integer('stdnt_share', 1000,
+                            """Student share (last index) of the test data""")
+flags.DEFINE_integer('lap_scale', 10,
+                     """Scale of the Laplacian noise added for privacy""")
+flags.DEFINE_boolean('save_labels', False,
+                     """Dump numpy arrays of labels and clean teacher votes""")
+
+flags.DEFINE_boolean('deeper', False, """Activate deeper CNN model""")
+
+
+def ensemble_preds(dataset, nb_teachers, stdnt_data):
+  """
+  Given a dataset, a number of teachers, and some input data, this helper
+  function queries each teacher for predictions on the data and returns
+  all predictions in a single array. (That can then be aggregated into
+  one single prediction per input using aggregation.py (cf. function
+  prepare_student_data() below)
+  :param dataset: string corresponding to mnist, cifar10, or svhn
+  :param nb_teachers: number of teachers (in the ensemble) to learn from
+  :param stdnt_data: unlabeled student training data
+  :return: 3d array (teacher id, sample id, probability per class)
+  """
+
+  # Compute shape of array that will hold probabilities produced by each
+  # teacher, for each training point, and each output class
+  result_shape = (nb_teachers, len(stdnt_data), FLAGS.nb_labels)
+
+  # Create array that will hold result
+  result = np.zeros(result_shape, dtype=np.float32)
+
+  # Get predictions from each teacher
+  for teacher_id in xrange(nb_teachers):
+    # Compute path of checkpoint file for teacher model with ID teacher_id
+    if FLAGS.deeper:
+      ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt-' + str(FLAGS.teachers_max_steps - 1) #NOLINT(long-line)
+    else:
+      ckpt_path = FLAGS.teachers_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt-' + str(FLAGS.teachers_max_steps - 1)  # NOLINT(long-line)
+
+    # Get predictions on our training data and store in result array
+    result[teacher_id] = deep_cnn.softmax_preds(stdnt_data, ckpt_path)
+
+    # This can take a while when there are a lot of teachers so output status
+    print("Computed Teacher " + str(teacher_id) + " softmax predictions")
+
+  return result
+
+
+def prepare_student_data(dataset, nb_teachers, save=False):
+  """
+  Takes a dataset name and the size of the teacher ensemble and prepares
+  training data for the student model, according to parameters indicated
+  in flags above.
+  :param dataset: string corresponding to mnist, cifar10, or svhn
+  :param nb_teachers: number of teachers (in the ensemble) to learn from
+  :param save: if set to True, will dump student training labels predicted by
+               the ensemble of teachers (with Laplacian noise) as npy files.
+               It also dumps the clean votes for each class (without noise) and
+               the labels assigned by teachers
+  :return: pairs of (data, labels) to be used for student training and testing
+  """
+  assert input.create_dir_if_needed(FLAGS.train_dir)
+
+  # Load the dataset
+  if dataset == 'svhn':
+    test_data, test_labels = input.ld_svhn(test_only=True)
+  elif dataset == 'cifar10':
+    test_data, test_labels = input.ld_cifar10(test_only=True)
+  elif dataset == 'mnist':
+    test_data, test_labels = input.ld_mnist(test_only=True)
+  else:
+    print("Check value of dataset flag")
+    return False
+
+  # Make sure there is data leftover to be used as a test set
+  assert FLAGS.stdnt_share < len(test_data)
+
+  # Prepare [unlabeled] student training data (subset of test set)
+  stdnt_data = test_data[:FLAGS.stdnt_share]
+
+  # Compute teacher predictions for student training data
+  teachers_preds = ensemble_preds(dataset, nb_teachers, stdnt_data)
+
+  # Aggregate teacher predictions to get student training labels
+  if not save:
+    stdnt_labels = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale)
+  else:
+    # Request clean votes and clean labels as well
+    stdnt_labels, clean_votes, labels_for_dump = aggregation.noisy_max(teachers_preds, FLAGS.lap_scale, return_clean_votes=True) #NOLINT(long-line)
+
+    # Prepare filepath for numpy dump of clean votes
+    filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_clean_votes_lap_' + str(FLAGS.lap_scale) + '.npy'  # NOLINT(long-line)
+
+    # Prepare filepath for numpy dump of clean labels
+    filepath_labels = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_teachers_labels_lap_' + str(FLAGS.lap_scale) + '.npy'  # NOLINT(long-line)
+
+    # Dump clean_votes array
+    with gfile.Open(filepath, mode='w') as file_obj:
+      np.save(file_obj, clean_votes)
+
+    # Dump labels_for_dump array
+    with gfile.Open(filepath_labels, mode='w') as file_obj:
+      np.save(file_obj, labels_for_dump)
+
+  # Print accuracy of aggregated labels
+  ac_ag_labels = metrics.accuracy(stdnt_labels, test_labels[:FLAGS.stdnt_share])
+  print("Accuracy of the aggregated labels: " + str(ac_ag_labels))
+
+  # Store unused part of test set for use as a test set after student training
+  stdnt_test_data = test_data[FLAGS.stdnt_share:]
+  stdnt_test_labels = test_labels[FLAGS.stdnt_share:]
+
+  if save:
+    # Prepare filepath for numpy dump of labels produced by noisy aggregation
+    filepath = FLAGS.data_dir + "/" + str(dataset) + '_' + str(nb_teachers) + '_student_labels_lap_' + str(FLAGS.lap_scale) + '.npy' #NOLINT(long-line)
+
+    # Dump student noisy labels array
+    with gfile.Open(filepath, mode='w') as file_obj:
+      np.save(file_obj, stdnt_labels)
+
+  return stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels
+
+
+def train_student(dataset, nb_teachers):
+  """
+  This function trains a student using predictions made by an ensemble of
+  teachers. The student and teacher models are trained using the same
+  neural network architecture.
+  :param dataset: string corresponding to mnist, cifar10, or svhn
+  :param nb_teachers: number of teachers (in the ensemble) to learn from
+  :return: True if student training went well
+  """
+  assert input.create_dir_if_needed(FLAGS.train_dir)
+
+  # Call helper function to prepare student data using teacher predictions
+  stdnt_dataset = prepare_student_data(dataset, nb_teachers, save=True)
+
+  # Unpack the student dataset
+  stdnt_data, stdnt_labels, stdnt_test_data, stdnt_test_labels = stdnt_dataset
+
+  # Prepare checkpoint filename and path
+  if FLAGS.deeper:
+    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student_deeper.ckpt' #NOLINT(long-line)
+  else:
+    ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + str(nb_teachers) + '_student.ckpt'  # NOLINT(long-line)
+
+  # Start student training
+  assert deep_cnn.train(stdnt_data, stdnt_labels, ckpt_path)
+
+  # Compute final checkpoint name for student (with max number of steps)
+  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
+
+  # Compute student label predictions on remaining chunk of test set
+  student_preds = deep_cnn.softmax_preds(stdnt_test_data, ckpt_path_final)
+
+  # Compute teacher accuracy
+  precision = metrics.accuracy(student_preds, stdnt_test_labels)
+  print('Precision of student after training: ' + str(precision))
+
+  return True
+
+def main(argv=None): # pylint: disable=unused-argument
+  # Run student training according to values specified in flags
+  assert train_student(FLAGS.dataset, FLAGS.nb_teachers)
+
+if __name__ == '__main__':
+  app.run()

+ 106 - 0
privacy/train_teachers.py

@@ -0,0 +1,106 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import deep_cnn
+import input
+import metrics
+
+from tensorflow.python.platform import app
+from tensorflow.python.platform import flags
+
+
+flags.DEFINE_string('dataset', 'svhn', 'The name of the dataset to use')
+flags.DEFINE_integer('nb_labels', 10, 'Number of output classes')
+
+flags.DEFINE_string('data_dir','/tmp','Temporary storage')
+flags.DEFINE_string('train_dir','/tmp/train_dir','Where model ckpt are saved')
+
+flags.DEFINE_integer('max_steps', 3000, """Number of training steps to run.""")
+flags.DEFINE_integer('nb_teachers', 50, """Teachers in the ensemble.""")
+flags.DEFINE_integer('teacher_id', 0, """ID of teacher being trained.""")
+
+flags.DEFINE_boolean('deeper', False, """Activate deeper CNN model""")
+
+FLAGS = flags.FLAGS
+
+
+def train_teacher(dataset, nb_teachers, teacher_id):
+  """
+  This function trains a teacher (teacher id) among an ensemble of nb_teachers
+  models for the dataset specified.
+  :param dataset: string corresponding to dataset (svhn, cifar10)
+  :param nb_teachers: total number of teachers in the ensemble
+  :param teacher_id: id of the teacher being trained
+  :return: True if everything went well
+  """
+  # If working directories do not exist, create them
+  assert input.create_dir_if_needed(FLAGS.data_dir)
+  assert input.create_dir_if_needed(FLAGS.train_dir)
+
+  # Load the dataset
+  if dataset == 'svhn':
+    train_data,train_labels,test_data,test_labels = input.ld_svhn(extended=True)
+  elif dataset == 'cifar10':
+    train_data, train_labels, test_data, test_labels = input.ld_cifar10()
+  elif dataset == 'mnist':
+    train_data, train_labels, test_data, test_labels = input.ld_mnist()
+  else:
+    print("Check value of dataset flag")
+    return False
+    
+  # Retrieve subset of data for this teacher
+  data, labels = input.partition_dataset(train_data, 
+                                         train_labels, 
+                                         nb_teachers, 
+                                         teacher_id)
+
+  print("Length of training data: " + str(len(labels)))
+
+  # Define teacher checkpoint filename and full path
+  if FLAGS.deeper:
+    filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '_deep.ckpt'
+  else:
+    filename = str(nb_teachers) + '_teachers_' + str(teacher_id) + '.ckpt'
+  ckpt_path = FLAGS.train_dir + '/' + str(dataset) + '_' + filename
+
+  # Perform teacher training
+  assert deep_cnn.train(data, labels, ckpt_path)
+
+  # Append final step value to checkpoint for evaluation
+  ckpt_path_final = ckpt_path + '-' + str(FLAGS.max_steps - 1)
+
+  # Retrieve teacher probability estimates on the test data
+  teacher_preds = deep_cnn.softmax_preds(test_data, ckpt_path_final)
+
+  # Compute teacher accuracy
+  precision = metrics.accuracy(teacher_preds, test_labels)
+  print('Precision of teacher after training: ' + str(precision))
+
+  return True
+
+
+def main(argv=None):  # pylint: disable=unused-argument
+  # Make a call to train_teachers with values specified in flags
+  assert train_teacher(FLAGS.dataset, FLAGS.nb_teachers, FLAGS.teacher_id)
+
+if __name__ == '__main__':
+  app.run()
+
+

+ 35 - 0
privacy/utils.py

@@ -0,0 +1,35 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+
+def batch_indices(batch_nb, data_length, batch_size):
+  """
+  This helper function computes a batch start and end index
+  :param batch_nb: the batch number
+  :param data_length: the total length of the data being parsed by batches
+  :param batch_size: the number of inputs in each batch
+  :return: pair of (start, end) indices
+  """
+  # Batch start and end index
+  start = int(batch_nb * batch_size)
+  end = int((batch_nb + 1) * batch_size)
+
+  # When there are not enough inputs left, we reuse some to complete the batch
+  if end > data_length:
+    shift = end - data_length
+    start -= shift
+    end -= shift
+
+  return start, end