浏览代码

added semi-supervised training of the student using improved-gan (#655)

Nicolas Papernot 8 年之前
父节点
当前提交
a66b9e13c9

+ 14 - 0
differential_privacy/multiple_teachers/README.md

@@ -63,6 +63,20 @@ training points (they will be labeled using the teacher predictions). The
 remaining samples are used for evaluation of the student's accuracy, which
 is displayed upon completion of training.
 
+## Using semi-supervised GANs to train the student
+
+In the paper, we describe how to train the student in a semi-supervised 
+fashion using Generative Adversarial Networks. This can be reproduced for MNIST 
+by cloning the [improved-gan](https://github.com/openai/improved-gan)
+repository and adding to your `PATH` variable before running the shell
+script `train_student_mnist_250_lap_20_count_50_epochs_600.sh`.
+
+```
+export PATH="/path/to/improved-gan/mnist_svhn_cifar10":$PATH
+sh train_student_mnist_250_lap_20_count_50_epochs_600.sh
+```
+
+
 ## Alternative deeper convolutional architecture
 
 Note that a deeper convolutional model is available. Both the default and

+ 9 - 0
differential_privacy/multiple_teachers/train_student_mnist_250_lap_20_count_50_epochs_600.sh

@@ -0,0 +1,9 @@
+# Be sure to clone https://github.com/openai/improved-gan
+# and add improved-gan/mnist_svhn_cifar10 to your PATH variable
+
+# Download labels used to train the student
+wget https://github.com/npapernot/multiple-teachers-for-privacy/blob/master/mnist_250_student_labels_lap_20.npy
+
+# Train the student using improved-gan 
+THEANO_FLAGS='floatX=float32,device=gpu,lib.cnmem=1' train_mnist_fm_custom_labels.py --labels mnist_250_student_labels_lap_20.npy --count 50 --epochs 600
+