dp_mnist.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Example differentially private trainer and evaluator for MNIST.
  16. """
  17. from __future__ import division
  18. import json
  19. import os
  20. import sys
  21. import time
  22. import numpy as np
  23. import tensorflow as tf
  24. from differential_privacy.dp_sgd.dp_optimizer import dp_optimizer
  25. from differential_privacy.dp_sgd.dp_optimizer import dp_pca
  26. from differential_privacy.dp_sgd.dp_optimizer import sanitizer
  27. from differential_privacy.dp_sgd.dp_optimizer import utils
  28. from differential_privacy.privacy_accountant.tf import accountant
  29. # parameters for the training
  30. tf.flags.DEFINE_integer("batch_size", 600,
  31. "The training batch size.")
  32. tf.flags.DEFINE_integer("batches_per_lot", 1,
  33. "Number of batches per lot.")
  34. # Together, batch_size and batches_per_lot determine lot_size.
  35. tf.flags.DEFINE_integer("num_training_steps", 50000,
  36. "The number of training steps."
  37. "This counts number of lots.")
  38. tf.flags.DEFINE_bool("randomize", True,
  39. "If true, randomize the input data; otherwise use a fixed "
  40. "seed and non-randomized input.")
  41. tf.flags.DEFINE_bool("freeze_bottom_layers", False,
  42. "If true, only train on the logit layer.")
  43. tf.flags.DEFINE_bool("save_mistakes", False,
  44. "If true, save the mistakes made during testing.")
  45. tf.flags.DEFINE_float("lr", 0.05, "start learning rate")
  46. tf.flags.DEFINE_float("end_lr", 0.05, "end learning rate")
  47. tf.flags.DEFINE_float("lr_saturate_epochs", 0,
  48. "learning rate saturate epochs; set to 0 for a constant "
  49. "learning rate of --lr.")
  50. # For searching parameters
  51. tf.flags.DEFINE_integer("projection_dimensions", 60,
  52. "PCA projection dimensions, or 0 for no projection.")
  53. tf.flags.DEFINE_integer("num_hidden_layers", 1,
  54. "Number of hidden layers in the network")
  55. tf.flags.DEFINE_integer("hidden_layer_num_units", 1000,
  56. "Number of units per hidden layer")
  57. tf.flags.DEFINE_float("default_gradient_l2norm_bound", 4.0, "norm clipping")
  58. tf.flags.DEFINE_integer("num_conv_layers", 0,
  59. "Number of convolutional layers to use.")
  60. tf.flags.DEFINE_string("training_data_path",
  61. "/tmp/mnist/mnist_train.tfrecord",
  62. "Location of the training data.")
  63. tf.flags.DEFINE_string("eval_data_path",
  64. "/tmp/mnist/mnist_test.tfrecord",
  65. "Location of the eval data.")
  66. tf.flags.DEFINE_integer("eval_steps", 10,
  67. "Evaluate the model every eval_steps")
  68. # Parameters for privacy spending. We allow linearly varying eps during
  69. # training.
  70. tf.flags.DEFINE_string("accountant_type", "Moments", "Moments, Amortized.")
  71. # Flags that control privacy spending during training.
  72. tf.flags.DEFINE_float("eps", 1.0,
  73. "Start privacy spending for one epoch of training, "
  74. "used if accountant_type is Amortized.")
  75. tf.flags.DEFINE_float("end_eps", 1.0,
  76. "End privacy spending for one epoch of training, "
  77. "used if accountant_type is Amortized.")
  78. tf.flags.DEFINE_float("eps_saturate_epochs", 0,
  79. "Stop varying epsilon after eps_saturate_epochs. Set to "
  80. "0 for constant eps of --eps. "
  81. "Used if accountant_type is Amortized.")
  82. tf.flags.DEFINE_float("delta", 1e-5,
  83. "Privacy spending for training. Constant through "
  84. "training, used if accountant_type is Amortized.")
  85. tf.flags.DEFINE_float("sigma", 4.0,
  86. "Noise sigma, used only if accountant_type is Moments")
  87. # Flags that control privacy spending for the pca projection
  88. # (only used if --projection_dimensions > 0).
  89. tf.flags.DEFINE_float("pca_eps", 0.5,
  90. "Privacy spending for PCA, used if accountant_type is "
  91. "Amortized.")
  92. tf.flags.DEFINE_float("pca_delta", 0.005,
  93. "Privacy spending for PCA, used if accountant_type is "
  94. "Amortized.")
  95. tf.flags.DEFINE_float("pca_sigma", 7.0,
  96. "Noise sigma for PCA, used if accountant_type is Moments")
  97. tf.flags.DEFINE_string("target_eps", "0.125,0.25,0.5,1,2,4,8",
  98. "Log the privacy loss for the target epsilon's. Only "
  99. "used when accountant_type is Moments.")
  100. tf.flags.DEFINE_float("target_delta", 1e-5,
  101. "Maximum delta for --terminate_based_on_privacy.")
  102. tf.flags.DEFINE_bool("terminate_based_on_privacy", False,
  103. "Stop training if privacy spent exceeds "
  104. "(max(--target_eps), --target_delta), even "
  105. "if --num_training_steps have not yet been completed.")
  106. tf.flags.DEFINE_string("save_path", "/tmp/mnist_dir",
  107. "Directory for saving model outputs.")
  108. FLAGS = tf.flags.FLAGS
  109. NUM_TRAINING_IMAGES = 60000
  110. NUM_TESTING_IMAGES = 10000
  111. IMAGE_SIZE = 28
  112. def MnistInput(mnist_data_file, batch_size, randomize):
  113. """Create operations to read the MNIST input file.
  114. Args:
  115. mnist_data_file: Path of a file containing the MNIST images to process.
  116. batch_size: size of the mini batches to generate.
  117. randomize: If true, randomize the dataset.
  118. Returns:
  119. images: A tensor with the formatted image data. shape [batch_size, 28*28]
  120. labels: A tensor with the labels for each image. shape [batch_size]
  121. """
  122. file_queue = tf.train.string_input_producer([mnist_data_file])
  123. reader = tf.TFRecordReader()
  124. _, value = reader.read(file_queue)
  125. example = tf.parse_single_example(
  126. value,
  127. features={"image/encoded": tf.FixedLenFeature(shape=(), dtype=tf.string),
  128. "image/class/label": tf.FixedLenFeature([1], tf.int64)})
  129. image = tf.cast(tf.image.decode_png(example["image/encoded"], channels=1),
  130. tf.float32)
  131. image = tf.reshape(image, [IMAGE_SIZE * IMAGE_SIZE])
  132. image /= 255
  133. label = tf.cast(example["image/class/label"], dtype=tf.int32)
  134. label = tf.reshape(label, [])
  135. if randomize:
  136. images, labels = tf.train.shuffle_batch(
  137. [image, label], batch_size=batch_size,
  138. capacity=(batch_size * 100),
  139. min_after_dequeue=(batch_size * 10))
  140. else:
  141. images, labels = tf.train.batch([image, label], batch_size=batch_size)
  142. return images, labels
  143. def Eval(mnist_data_file, network_parameters, num_testing_images,
  144. randomize, load_path, save_mistakes=False):
  145. """Evaluate MNIST for a number of steps.
  146. Args:
  147. mnist_data_file: Path of a file containing the MNIST images to process.
  148. network_parameters: parameters for defining and training the network.
  149. num_testing_images: the number of images we will evaluate on.
  150. randomize: if false, randomize; otherwise, read the testing images
  151. sequentially.
  152. load_path: path where to load trained parameters from.
  153. save_mistakes: save the mistakes if True.
  154. Returns:
  155. The evaluation accuracy as a float.
  156. """
  157. batch_size = 100
  158. # Like for training, we need a session for executing the TensorFlow graph.
  159. with tf.Graph().as_default(), tf.Session() as sess:
  160. # Create the basic Mnist model.
  161. images, labels = MnistInput(mnist_data_file, batch_size, randomize)
  162. logits, _, _ = utils.BuildNetwork(images, network_parameters)
  163. softmax = tf.nn.softmax(logits)
  164. # Load the variables.
  165. ckpt_state = tf.train.get_checkpoint_state(load_path)
  166. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  167. raise ValueError("No model checkpoint to eval at %s\n" % load_path)
  168. saver = tf.train.Saver()
  169. saver.restore(sess, ckpt_state.model_checkpoint_path)
  170. coord = tf.train.Coordinator()
  171. _ = tf.train.start_queue_runners(sess=sess, coord=coord)
  172. total_examples = 0
  173. correct_predictions = 0
  174. image_index = 0
  175. mistakes = []
  176. for _ in xrange((num_testing_images + batch_size - 1) // batch_size):
  177. predictions, label_values = sess.run([softmax, labels])
  178. # Count how many were predicted correctly.
  179. for prediction, label_value in zip(predictions, label_values):
  180. total_examples += 1
  181. if np.argmax(prediction) == label_value:
  182. correct_predictions += 1
  183. elif save_mistakes:
  184. mistakes.append({"index": image_index,
  185. "label": label_value,
  186. "pred": np.argmax(prediction)})
  187. image_index += 1
  188. return (correct_predictions / total_examples,
  189. mistakes if save_mistakes else None)
  190. def Train(mnist_train_file, mnist_test_file, network_parameters, num_steps,
  191. save_path, eval_steps=0):
  192. """Train MNIST for a number of steps.
  193. Args:
  194. mnist_train_file: path of MNIST train data file.
  195. mnist_test_file: path of MNIST test data file.
  196. network_parameters: parameters for defining and training the network.
  197. num_steps: number of steps to run. Here steps = lots
  198. save_path: path where to save trained parameters.
  199. eval_steps: evaluate the model every eval_steps.
  200. Returns:
  201. the result after the final training step.
  202. Raises:
  203. ValueError: if the accountant_type is not supported.
  204. """
  205. batch_size = FLAGS.batch_size
  206. params = {"accountant_type": FLAGS.accountant_type,
  207. "task_id": 0,
  208. "batch_size": FLAGS.batch_size,
  209. "projection_dimensions": FLAGS.projection_dimensions,
  210. "default_gradient_l2norm_bound":
  211. network_parameters.default_gradient_l2norm_bound,
  212. "num_hidden_layers": FLAGS.num_hidden_layers,
  213. "hidden_layer_num_units": FLAGS.hidden_layer_num_units,
  214. "num_examples": NUM_TRAINING_IMAGES,
  215. "learning_rate": FLAGS.lr,
  216. "end_learning_rate": FLAGS.end_lr,
  217. "learning_rate_saturate_epochs": FLAGS.lr_saturate_epochs
  218. }
  219. # Log different privacy parameters dependent on the accountant type.
  220. if FLAGS.accountant_type == "Amortized":
  221. params.update({"flag_eps": FLAGS.eps,
  222. "flag_delta": FLAGS.delta,
  223. "flag_pca_eps": FLAGS.pca_eps,
  224. "flag_pca_delta": FLAGS.pca_delta,
  225. })
  226. elif FLAGS.accountant_type == "Moments":
  227. params.update({"sigma": FLAGS.sigma,
  228. "pca_sigma": FLAGS.pca_sigma,
  229. })
  230. with tf.Graph().as_default(), tf.Session() as sess, tf.device('/cpu:0'):
  231. # Create the basic Mnist model.
  232. images, labels = MnistInput(mnist_train_file, batch_size, FLAGS.randomize)
  233. logits, projection, training_params = utils.BuildNetwork(
  234. images, network_parameters)
  235. cost = tf.nn.softmax_cross_entropy_with_logits(
  236. logits=logits, labels=tf.one_hot(labels, 10))
  237. # The actual cost is the average across the examples.
  238. cost = tf.reduce_sum(cost, [0]) / batch_size
  239. if FLAGS.accountant_type == "Amortized":
  240. priv_accountant = accountant.AmortizedAccountant(NUM_TRAINING_IMAGES)
  241. sigma = None
  242. pca_sigma = None
  243. with_privacy = FLAGS.eps > 0
  244. elif FLAGS.accountant_type == "Moments":
  245. priv_accountant = accountant.GaussianMomentsAccountant(
  246. NUM_TRAINING_IMAGES)
  247. sigma = FLAGS.sigma
  248. pca_sigma = FLAGS.pca_sigma
  249. with_privacy = FLAGS.sigma > 0
  250. else:
  251. raise ValueError("Undefined accountant type, needs to be "
  252. "Amortized or Moments, but got %s" % FLAGS.accountant)
  253. # Note: Here and below, we scale down the l2norm_bound by
  254. # batch_size. This is because per_example_gradients computes the
  255. # gradient of the minibatch loss with respect to each individual
  256. # example, and the minibatch loss (for our model) is the *average*
  257. # loss over examples in the minibatch. Hence, the scale of the
  258. # per-example gradients goes like 1 / batch_size.
  259. gaussian_sanitizer = sanitizer.AmortizedGaussianSanitizer(
  260. priv_accountant,
  261. [network_parameters.default_gradient_l2norm_bound / batch_size, True])
  262. for var in training_params:
  263. if "gradient_l2norm_bound" in training_params[var]:
  264. l2bound = training_params[var]["gradient_l2norm_bound"] / batch_size
  265. gaussian_sanitizer.set_option(var,
  266. sanitizer.ClipOption(l2bound, True))
  267. lr = tf.placeholder(tf.float32)
  268. eps = tf.placeholder(tf.float32)
  269. delta = tf.placeholder(tf.float32)
  270. init_ops = []
  271. if network_parameters.projection_type == "PCA":
  272. with tf.variable_scope("pca"):
  273. # Compute differentially private PCA.
  274. all_data, _ = MnistInput(mnist_train_file, NUM_TRAINING_IMAGES, False)
  275. pca_projection = dp_pca.ComputeDPPrincipalProjection(
  276. all_data, network_parameters.projection_dimensions,
  277. gaussian_sanitizer, [FLAGS.pca_eps, FLAGS.pca_delta], pca_sigma)
  278. assign_pca_proj = tf.assign(projection, pca_projection)
  279. init_ops.append(assign_pca_proj)
  280. # Add global_step
  281. global_step = tf.Variable(0, dtype=tf.int32, trainable=False,
  282. name="global_step")
  283. if with_privacy:
  284. gd_op = dp_optimizer.DPGradientDescentOptimizer(
  285. lr,
  286. [eps, delta],
  287. gaussian_sanitizer,
  288. sigma=sigma,
  289. batches_per_lot=FLAGS.batches_per_lot).minimize(
  290. cost, global_step=global_step)
  291. else:
  292. gd_op = tf.train.GradientDescentOptimizer(lr).minimize(cost)
  293. saver = tf.train.Saver()
  294. coord = tf.train.Coordinator()
  295. _ = tf.train.start_queue_runners(sess=sess, coord=coord)
  296. # We need to maintain the intialization sequence.
  297. for v in tf.trainable_variables():
  298. sess.run(tf.variables_initializer([v]))
  299. sess.run(tf.global_variables_initializer())
  300. sess.run(init_ops)
  301. results = []
  302. start_time = time.time()
  303. prev_time = start_time
  304. filename = "results-0.json"
  305. log_path = os.path.join(save_path, filename)
  306. target_eps = [float(s) for s in FLAGS.target_eps.split(",")]
  307. if FLAGS.accountant_type == "Amortized":
  308. # Only matters if --terminate_based_on_privacy is true.
  309. target_eps = [max(target_eps)]
  310. max_target_eps = max(target_eps)
  311. lot_size = FLAGS.batches_per_lot * FLAGS.batch_size
  312. lots_per_epoch = NUM_TRAINING_IMAGES / lot_size
  313. for step in xrange(num_steps):
  314. epoch = step / lots_per_epoch
  315. curr_lr = utils.VaryRate(FLAGS.lr, FLAGS.end_lr,
  316. FLAGS.lr_saturate_epochs, epoch)
  317. curr_eps = utils.VaryRate(FLAGS.eps, FLAGS.end_eps,
  318. FLAGS.eps_saturate_epochs, epoch)
  319. for _ in xrange(FLAGS.batches_per_lot):
  320. _ = sess.run(
  321. [gd_op], feed_dict={lr: curr_lr, eps: curr_eps, delta: FLAGS.delta})
  322. sys.stderr.write("step: %d\n" % step)
  323. # See if we should stop training due to exceeded privacy budget:
  324. should_terminate = False
  325. terminate_spent_eps_delta = None
  326. if with_privacy and FLAGS.terminate_based_on_privacy:
  327. terminate_spent_eps_delta = priv_accountant.get_privacy_spent(
  328. sess, target_eps=[max_target_eps])[0]
  329. # For the Moments accountant, we should always have
  330. # spent_eps == max_target_eps.
  331. if (terminate_spent_eps_delta.spent_delta > FLAGS.target_delta or
  332. terminate_spent_eps_delta.spent_eps > max_target_eps):
  333. should_terminate = True
  334. if (eval_steps > 0 and (step + 1) % eval_steps == 0) or should_terminate:
  335. if with_privacy:
  336. spent_eps_deltas = priv_accountant.get_privacy_spent(
  337. sess, target_eps=target_eps)
  338. else:
  339. spent_eps_deltas = [accountant.EpsDelta(0, 0)]
  340. for spent_eps, spent_delta in spent_eps_deltas:
  341. sys.stderr.write("spent privacy: eps %.4f delta %.5g\n" % (
  342. spent_eps, spent_delta))
  343. saver.save(sess, save_path=save_path + "/ckpt")
  344. train_accuracy, _ = Eval(mnist_train_file, network_parameters,
  345. num_testing_images=NUM_TESTING_IMAGES,
  346. randomize=True, load_path=save_path)
  347. sys.stderr.write("train_accuracy: %.2f\n" % train_accuracy)
  348. test_accuracy, mistakes = Eval(mnist_test_file, network_parameters,
  349. num_testing_images=NUM_TESTING_IMAGES,
  350. randomize=False, load_path=save_path,
  351. save_mistakes=FLAGS.save_mistakes)
  352. sys.stderr.write("eval_accuracy: %.2f\n" % test_accuracy)
  353. curr_time = time.time()
  354. elapsed_time = curr_time - prev_time
  355. prev_time = curr_time
  356. results.append({"step": step+1, # Number of lots trained so far.
  357. "elapsed_secs": elapsed_time,
  358. "spent_eps_deltas": spent_eps_deltas,
  359. "train_accuracy": train_accuracy,
  360. "test_accuracy": test_accuracy,
  361. "mistakes": mistakes})
  362. loginfo = {"elapsed_secs": curr_time-start_time,
  363. "spent_eps_deltas": spent_eps_deltas,
  364. "train_accuracy": train_accuracy,
  365. "test_accuracy": test_accuracy,
  366. "num_training_steps": step+1, # Steps so far.
  367. "mistakes": mistakes,
  368. "result_series": results}
  369. loginfo.update(params)
  370. if log_path:
  371. with tf.gfile.Open(log_path, "w") as f:
  372. json.dump(loginfo, f, indent=2)
  373. f.write("\n")
  374. f.close()
  375. if should_terminate:
  376. break
  377. def main(_):
  378. network_parameters = utils.NetworkParameters()
  379. # If the ASCII proto isn't specified, then construct a config protobuf based
  380. # on 3 flags.
  381. network_parameters.input_size = IMAGE_SIZE ** 2
  382. network_parameters.default_gradient_l2norm_bound = (
  383. FLAGS.default_gradient_l2norm_bound)
  384. if FLAGS.projection_dimensions > 0 and FLAGS.num_conv_layers > 0:
  385. raise ValueError("Currently you can't do PCA and have convolutions"
  386. "at the same time. Pick one")
  387. # could add support for PCA after convolutions.
  388. # Currently BuildNetwork can build the network with conv followed by
  389. # projection, but the PCA training works on data, rather than data run
  390. # through a few layers. Will need to init the convs before running the
  391. # PCA, and need to change the PCA subroutine to take a network and perhaps
  392. # allow for batched inputs, to handle larger datasets.
  393. if FLAGS.num_conv_layers > 0:
  394. conv = utils.ConvParameters()
  395. conv.name = "conv1"
  396. conv.in_channels = 1
  397. conv.out_channels = 128
  398. conv.num_outputs = 128 * 14 * 14
  399. network_parameters.conv_parameters.append(conv)
  400. # defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
  401. # insize 28x28, bias, stddev 0.1, non-trainable.
  402. if FLAGS.num_conv_layers > 1:
  403. conv = network_parameters.ConvParameters()
  404. conv.name = "conv2"
  405. conv.in_channels = 128
  406. conv.out_channels = 128
  407. conv.num_outputs = 128 * 7 * 7
  408. conv.in_size = 14
  409. # defaults for the rest: 5x5,stride 1, relu, maxpool 2x2,stride 2.
  410. # bias, stddev 0.1, non-trainable.
  411. network_parameters.conv_parameters.append(conv)
  412. if FLAGS.num_conv_layers > 2:
  413. raise ValueError("Currently --num_conv_layers must be 0,1 or 2."
  414. "Manually create a network_parameters proto for more.")
  415. if FLAGS.projection_dimensions > 0:
  416. network_parameters.projection_type = "PCA"
  417. network_parameters.projection_dimensions = FLAGS.projection_dimensions
  418. for i in xrange(FLAGS.num_hidden_layers):
  419. hidden = utils.LayerParameters()
  420. hidden.name = "hidden%d" % i
  421. hidden.num_units = FLAGS.hidden_layer_num_units
  422. hidden.relu = True
  423. hidden.with_bias = False
  424. hidden.trainable = not FLAGS.freeze_bottom_layers
  425. network_parameters.layer_parameters.append(hidden)
  426. logits = utils.LayerParameters()
  427. logits.name = "logits"
  428. logits.num_units = 10
  429. logits.relu = False
  430. logits.with_bias = False
  431. network_parameters.layer_parameters.append(logits)
  432. Train(FLAGS.training_data_path,
  433. FLAGS.eval_data_path,
  434. network_parameters,
  435. FLAGS.num_training_steps,
  436. FLAGS.save_path,
  437. eval_steps=FLAGS.eval_steps)
  438. if __name__ == "__main__":
  439. tf.app.run()