Sfoglia il codice sorgente

make large files downloadable

Nicolas Papernot 9 anni fa
parent
commit
86c1405c0b
2 ha cambiato i file con 25 aggiunte e 0 eliminazioni
  1. 21 0
      privacy/analysis.py
  2. 4 0
      privacy/input.py

+ 21 - 0
privacy/analysis.py

@@ -23,10 +23,13 @@ python analysis.py
   --max_examples=1000
   --delta=1e-6
 """
+import os
 import math
 import numpy as np
 import tensorflow as tf
 
+from input import maybe_download
+
 # These parameters can be changed to compute bounds for different failure rates
 # or different model predictions.
 
@@ -194,6 +197,24 @@ def smoothed_sens(counts, noise_eps, l, beta):
 
 
 def main(unused_argv):
+  ##################################################################
+  # If we are reproducing results from paper https://arxiv.org/abs/1610.05755,
+  # download the required binaries with label information.
+  ##################################################################
+  
+  # Binaries for MNIST results
+  paper_binaries_mnist = \
+    ["https://github.com/npapernot/multiple-teachers-for-privacy/blob/master/mnist_250_teachers_labels.npy?raw=true", 
+    "https://github.com/npapernot/multiple-teachers-for-privacy/blob/master/mnist_250_teachers_100_indices_used_by_student.npy?raw=true"]
+  if FLAGS.counts_file == "mnist_250_teachers_labels.npy" \
+    or FLAGS.indices_file == "mnist_250_teachers_100_indices_used_by_student.npy":
+    maybe_download(paper_binaries_mnist, os.getcwd())
+
+  # Binaries for SVHN results
+  paper_binaries_svhn = ["https://github.com/npapernot/multiple-teachers-for-privacy/blob/master/svhn_250_teachers_labels.npy?raw=true"]
+  if FLAGS.counts_file == "svhn_250_teachers_labels.npy":
+    maybe_download(paper_binaries_svhn, os.getcwd())
+
   input_mat = np.load(FLAGS.counts_file)
   if FLAGS.input_is_counts:
     counts_mat = input_mat

+ 4 - 0
privacy/input.py

@@ -62,6 +62,10 @@ def maybe_download(file_urls, directory):
     # Extract filename
     filename = file_url.split('/')[-1]
 
+    # If downloading from GitHub, remove suffix ?raw=True from local filename
+    if filename.endswith("?raw=true"):
+      filename = filename[:-9]
+
     # Deduce local file url
     #filepath = os.path.join(directory, filename)
     filepath = directory + '/' + filename