deep_cnn.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. from datetime import datetime
  19. import math
  20. import numpy as np
  21. import tensorflow as tf
  22. import time
  23. from differential_privacy.multiple_teachers import utils
  24. FLAGS = tf.app.flags.FLAGS
  25. # Basic model parameters.
  26. tf.app.flags.DEFINE_integer('dropout_seed', 123, """seed for dropout.""")
  27. tf.app.flags.DEFINE_integer('batch_size', 128, """Nb of images in a batch.""")
  28. tf.app.flags.DEFINE_integer('epochs_per_decay', 350, """Nb epochs per decay""")
  29. tf.app.flags.DEFINE_integer('learning_rate', 5, """100 * learning rate""")
  30. tf.app.flags.DEFINE_boolean('log_device_placement', False, """see TF doc""")
  31. # Constants describing the training process.
  32. MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average.
  33. LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor.
  34. def _variable_on_cpu(name, shape, initializer):
  35. """Helper to create a Variable stored on CPU memory.
  36. Args:
  37. name: name of the variable
  38. shape: list of ints
  39. initializer: initializer for Variable
  40. Returns:
  41. Variable Tensor
  42. """
  43. with tf.device('/cpu:0'):
  44. var = tf.get_variable(name, shape, initializer=initializer)
  45. return var
  46. def _variable_with_weight_decay(name, shape, stddev, wd):
  47. """Helper to create an initialized Variable with weight decay.
  48. Note that the Variable is initialized with a truncated normal distribution.
  49. A weight decay is added only if one is specified.
  50. Args:
  51. name: name of the variable
  52. shape: list of ints
  53. stddev: standard deviation of a truncated Gaussian
  54. wd: add L2Loss weight decay multiplied by this float. If None, weight
  55. decay is not added for this Variable.
  56. Returns:
  57. Variable Tensor
  58. """
  59. var = _variable_on_cpu(name, shape,
  60. tf.truncated_normal_initializer(stddev=stddev))
  61. if wd is not None:
  62. weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss')
  63. tf.add_to_collection('losses', weight_decay)
  64. return var
  65. def inference(images, dropout=False):
  66. """Build the CNN model.
  67. Args:
  68. images: Images returned from distorted_inputs() or inputs().
  69. dropout: Boolean controlling whether to use dropout or not
  70. Returns:
  71. Logits
  72. """
  73. if FLAGS.dataset == 'mnist':
  74. first_conv_shape = [5, 5, 1, 64]
  75. else:
  76. first_conv_shape = [5, 5, 3, 64]
  77. # conv1
  78. with tf.variable_scope('conv1') as scope:
  79. kernel = _variable_with_weight_decay('weights',
  80. shape=first_conv_shape,
  81. stddev=1e-4,
  82. wd=0.0)
  83. conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
  84. biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0))
  85. bias = tf.nn.bias_add(conv, biases)
  86. conv1 = tf.nn.relu(bias, name=scope.name)
  87. if dropout:
  88. conv1 = tf.nn.dropout(conv1, 0.3, seed=FLAGS.dropout_seed)
  89. # pool1
  90. pool1 = tf.nn.max_pool(conv1,
  91. ksize=[1, 3, 3, 1],
  92. strides=[1, 2, 2, 1],
  93. padding='SAME',
  94. name='pool1')
  95. # norm1
  96. norm1 = tf.nn.lrn(pool1,
  97. 4,
  98. bias=1.0,
  99. alpha=0.001 / 9.0,
  100. beta=0.75,
  101. name='norm1')
  102. # conv2
  103. with tf.variable_scope('conv2') as scope:
  104. kernel = _variable_with_weight_decay('weights',
  105. shape=[5, 5, 64, 128],
  106. stddev=1e-4,
  107. wd=0.0)
  108. conv = tf.nn.conv2d(norm1, kernel, [1, 1, 1, 1], padding='SAME')
  109. biases = _variable_on_cpu('biases', [128], tf.constant_initializer(0.1))
  110. bias = tf.nn.bias_add(conv, biases)
  111. conv2 = tf.nn.relu(bias, name=scope.name)
  112. if dropout:
  113. conv2 = tf.nn.dropout(conv2, 0.3, seed=FLAGS.dropout_seed)
  114. # norm2
  115. norm2 = tf.nn.lrn(conv2,
  116. 4,
  117. bias=1.0,
  118. alpha=0.001 / 9.0,
  119. beta=0.75,
  120. name='norm2')
  121. # pool2
  122. pool2 = tf.nn.max_pool(norm2,
  123. ksize=[1, 3, 3, 1],
  124. strides=[1, 2, 2, 1],
  125. padding='SAME',
  126. name='pool2')
  127. # local3
  128. with tf.variable_scope('local3') as scope:
  129. # Move everything into depth so we can perform a single matrix multiply.
  130. reshape = tf.reshape(pool2, [FLAGS.batch_size, -1])
  131. dim = reshape.get_shape()[1].value
  132. weights = _variable_with_weight_decay('weights',
  133. shape=[dim, 384],
  134. stddev=0.04,
  135. wd=0.004)
  136. biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1))
  137. local3 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
  138. if dropout:
  139. local3 = tf.nn.dropout(local3, 0.5, seed=FLAGS.dropout_seed)
  140. # local4
  141. with tf.variable_scope('local4') as scope:
  142. weights = _variable_with_weight_decay('weights',
  143. shape=[384, 192],
  144. stddev=0.04,
  145. wd=0.004)
  146. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
  147. local4 = tf.nn.relu(tf.matmul(local3, weights) + biases, name=scope.name)
  148. if dropout:
  149. local4 = tf.nn.dropout(local4, 0.5, seed=FLAGS.dropout_seed)
  150. # compute logits
  151. with tf.variable_scope('softmax_linear') as scope:
  152. weights = _variable_with_weight_decay('weights',
  153. [192, FLAGS.nb_labels],
  154. stddev=1/192.0,
  155. wd=0.0)
  156. biases = _variable_on_cpu('biases',
  157. [FLAGS.nb_labels],
  158. tf.constant_initializer(0.0))
  159. logits = tf.add(tf.matmul(local4, weights), biases, name=scope.name)
  160. return logits
  161. def inference_deeper(images, dropout=False):
  162. """Build a deeper CNN model.
  163. Args:
  164. images: Images returned from distorted_inputs() or inputs().
  165. dropout: Boolean controlling whether to use dropout or not
  166. Returns:
  167. Logits
  168. """
  169. if FLAGS.dataset == 'mnist':
  170. first_conv_shape = [3, 3, 1, 96]
  171. else:
  172. first_conv_shape = [3, 3, 3, 96]
  173. # conv1
  174. with tf.variable_scope('conv1') as scope:
  175. kernel = _variable_with_weight_decay('weights',
  176. shape=first_conv_shape,
  177. stddev=0.05,
  178. wd=0.0)
  179. conv = tf.nn.conv2d(images, kernel, [1, 1, 1, 1], padding='SAME')
  180. biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
  181. bias = tf.nn.bias_add(conv, biases)
  182. conv1 = tf.nn.relu(bias, name=scope.name)
  183. # conv2
  184. with tf.variable_scope('conv2') as scope:
  185. kernel = _variable_with_weight_decay('weights',
  186. shape=[3, 3, 96, 96],
  187. stddev=0.05,
  188. wd=0.0)
  189. conv = tf.nn.conv2d(conv1, kernel, [1, 1, 1, 1], padding='SAME')
  190. biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
  191. bias = tf.nn.bias_add(conv, biases)
  192. conv2 = tf.nn.relu(bias, name=scope.name)
  193. # conv3
  194. with tf.variable_scope('conv3') as scope:
  195. kernel = _variable_with_weight_decay('weights',
  196. shape=[3, 3, 96, 96],
  197. stddev=0.05,
  198. wd=0.0)
  199. conv = tf.nn.conv2d(conv2, kernel, [1, 2, 2, 1], padding='SAME')
  200. biases = _variable_on_cpu('biases', [96], tf.constant_initializer(0.0))
  201. bias = tf.nn.bias_add(conv, biases)
  202. conv3 = tf.nn.relu(bias, name=scope.name)
  203. if dropout:
  204. conv3 = tf.nn.dropout(conv3, 0.5, seed=FLAGS.dropout_seed)
  205. # conv4
  206. with tf.variable_scope('conv4') as scope:
  207. kernel = _variable_with_weight_decay('weights',
  208. shape=[3, 3, 96, 192],
  209. stddev=0.05,
  210. wd=0.0)
  211. conv = tf.nn.conv2d(conv3, kernel, [1, 1, 1, 1], padding='SAME')
  212. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
  213. bias = tf.nn.bias_add(conv, biases)
  214. conv4 = tf.nn.relu(bias, name=scope.name)
  215. # conv5
  216. with tf.variable_scope('conv5') as scope:
  217. kernel = _variable_with_weight_decay('weights',
  218. shape=[3, 3, 192, 192],
  219. stddev=0.05,
  220. wd=0.0)
  221. conv = tf.nn.conv2d(conv4, kernel, [1, 1, 1, 1], padding='SAME')
  222. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
  223. bias = tf.nn.bias_add(conv, biases)
  224. conv5 = tf.nn.relu(bias, name=scope.name)
  225. # conv6
  226. with tf.variable_scope('conv6') as scope:
  227. kernel = _variable_with_weight_decay('weights',
  228. shape=[3, 3, 192, 192],
  229. stddev=0.05,
  230. wd=0.0)
  231. conv = tf.nn.conv2d(conv5, kernel, [1, 2, 2, 1], padding='SAME')
  232. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.0))
  233. bias = tf.nn.bias_add(conv, biases)
  234. conv6 = tf.nn.relu(bias, name=scope.name)
  235. if dropout:
  236. conv6 = tf.nn.dropout(conv6, 0.5, seed=FLAGS.dropout_seed)
  237. # conv7
  238. with tf.variable_scope('conv7') as scope:
  239. kernel = _variable_with_weight_decay('weights',
  240. shape=[5, 5, 192, 192],
  241. stddev=1e-4,
  242. wd=0.0)
  243. conv = tf.nn.conv2d(conv6, kernel, [1, 1, 1, 1], padding='SAME')
  244. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
  245. bias = tf.nn.bias_add(conv, biases)
  246. conv7 = tf.nn.relu(bias, name=scope.name)
  247. # local1
  248. with tf.variable_scope('local1') as scope:
  249. # Move everything into depth so we can perform a single matrix multiply.
  250. reshape = tf.reshape(conv7, [FLAGS.batch_size, -1])
  251. dim = reshape.get_shape()[1].value
  252. weights = _variable_with_weight_decay('weights',
  253. shape=[dim, 192],
  254. stddev=0.05,
  255. wd=0)
  256. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
  257. local1 = tf.nn.relu(tf.matmul(reshape, weights) + biases, name=scope.name)
  258. # local2
  259. with tf.variable_scope('local2') as scope:
  260. weights = _variable_with_weight_decay('weights',
  261. shape=[192, 192],
  262. stddev=0.05,
  263. wd=0)
  264. biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1))
  265. local2 = tf.nn.relu(tf.matmul(local1, weights) + biases, name=scope.name)
  266. if dropout:
  267. local2 = tf.nn.dropout(local2, 0.5, seed=FLAGS.dropout_seed)
  268. # compute logits
  269. with tf.variable_scope('softmax_linear') as scope:
  270. weights = _variable_with_weight_decay('weights',
  271. [192, FLAGS.nb_labels],
  272. stddev=0.05,
  273. wd=0.0)
  274. biases = _variable_on_cpu('biases',
  275. [FLAGS.nb_labels],
  276. tf.constant_initializer(0.0))
  277. logits = tf.add(tf.matmul(local2, weights), biases, name=scope.name)
  278. return logits
  279. def loss_fun(logits, labels):
  280. """Add L2Loss to all the trainable variables.
  281. Add summary for "Loss" and "Loss/avg".
  282. Args:
  283. logits: Logits from inference().
  284. labels: Labels from distorted_inputs or inputs(). 1-D tensor
  285. of shape [batch_size]
  286. distillation: if set to True, use probabilities and not class labels to
  287. compute softmax loss
  288. Returns:
  289. Loss tensor of type float.
  290. """
  291. # Calculate the cross entropy between labels and predictions
  292. labels = tf.cast(labels, tf.int64)
  293. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
  294. logits=logits, labels=labels, name='cross_entropy_per_example')
  295. # Calculate the average cross entropy loss across the batch.
  296. cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
  297. # Add to TF collection for losses
  298. tf.add_to_collection('losses', cross_entropy_mean)
  299. # The total loss is defined as the cross entropy loss plus all of the weight
  300. # decay terms (L2 loss).
  301. return tf.add_n(tf.get_collection('losses'), name='total_loss')
  302. def moving_av(total_loss):
  303. """
  304. Generates moving average for all losses
  305. Args:
  306. total_loss: Total loss from loss().
  307. Returns:
  308. loss_averages_op: op for generating moving averages of losses.
  309. """
  310. # Compute the moving average of all individual losses and the total loss.
  311. loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
  312. losses = tf.get_collection('losses')
  313. loss_averages_op = loss_averages.apply(losses + [total_loss])
  314. return loss_averages_op
  315. def train_op_fun(total_loss, global_step):
  316. """Train model.
  317. Create an optimizer and apply to all trainable variables. Add moving
  318. average for all trainable variables.
  319. Args:
  320. total_loss: Total loss from loss().
  321. global_step: Integer Variable counting the number of training steps
  322. processed.
  323. Returns:
  324. train_op: op for training.
  325. """
  326. # Variables that affect learning rate.
  327. nb_ex_per_train_epoch = int(60000 / FLAGS.nb_teachers)
  328. num_batches_per_epoch = nb_ex_per_train_epoch / FLAGS.batch_size
  329. decay_steps = int(num_batches_per_epoch * FLAGS.epochs_per_decay)
  330. initial_learning_rate = float(FLAGS.learning_rate) / 100.0
  331. # Decay the learning rate exponentially based on the number of steps.
  332. lr = tf.train.exponential_decay(initial_learning_rate,
  333. global_step,
  334. decay_steps,
  335. LEARNING_RATE_DECAY_FACTOR,
  336. staircase=True)
  337. tf.summary.scalar('learning_rate', lr)
  338. # Generate moving averages of all losses and associated summaries.
  339. loss_averages_op = moving_av(total_loss)
  340. # Compute gradients.
  341. with tf.control_dependencies([loss_averages_op]):
  342. opt = tf.train.GradientDescentOptimizer(lr)
  343. grads = opt.compute_gradients(total_loss)
  344. # Apply gradients.
  345. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
  346. # Add histograms for trainable variables.
  347. for var in tf.trainable_variables():
  348. tf.summary.histogram(var.op.name, var)
  349. # Track the moving averages of all trainable variables.
  350. variable_averages = tf.train.ExponentialMovingAverage(
  351. MOVING_AVERAGE_DECAY, global_step)
  352. variables_averages_op = variable_averages.apply(tf.trainable_variables())
  353. with tf.control_dependencies([apply_gradient_op, variables_averages_op]):
  354. train_op = tf.no_op(name='train')
  355. return train_op
  356. def _input_placeholder():
  357. """
  358. This helper function declares a TF placeholder for the graph input data
  359. :return: TF placeholder for the graph input data
  360. """
  361. if FLAGS.dataset == 'mnist':
  362. image_size = 28
  363. num_channels = 1
  364. else:
  365. image_size = 32
  366. num_channels = 3
  367. # Declare data placeholder
  368. train_node_shape = (FLAGS.batch_size, image_size, image_size, num_channels)
  369. return tf.placeholder(tf.float32, shape=train_node_shape)
  370. def train(images, labels, ckpt_path, dropout=False):
  371. """
  372. This function contains the loop that actually trains the model.
  373. :param images: a numpy array with the input data
  374. :param labels: a numpy array with the output labels
  375. :param ckpt_path: a path (including name) where model checkpoints are saved
  376. :param dropout: Boolean, whether to use dropout or not
  377. :return: True if everything went well
  378. """
  379. # Check training data
  380. assert len(images) == len(labels)
  381. assert images.dtype == np.float32
  382. assert labels.dtype == np.int32
  383. # Set default TF graph
  384. with tf.Graph().as_default():
  385. global_step = tf.Variable(0, trainable=False)
  386. # Declare data placeholder
  387. train_data_node = _input_placeholder()
  388. # Create a placeholder to hold labels
  389. train_labels_shape = (FLAGS.batch_size,)
  390. train_labels_node = tf.placeholder(tf.int32, shape=train_labels_shape)
  391. print("Done Initializing Training Placeholders")
  392. # Build a Graph that computes the logits predictions from the placeholder
  393. if FLAGS.deeper:
  394. logits = inference_deeper(train_data_node, dropout=dropout)
  395. else:
  396. logits = inference(train_data_node, dropout=dropout)
  397. # Calculate loss
  398. loss = loss_fun(logits, train_labels_node)
  399. # Build a Graph that trains the model with one batch of examples and
  400. # updates the model parameters.
  401. train_op = train_op_fun(loss, global_step)
  402. # Create a saver.
  403. saver = tf.train.Saver(tf.global_variables())
  404. print("Graph constructed and saver created")
  405. # Build an initialization operation to run below.
  406. init = tf.global_variables_initializer()
  407. # Create and init sessions
  408. sess = tf.Session(config=tf.ConfigProto(log_device_placement=FLAGS.log_device_placement)) #NOLINT(long-line)
  409. sess.run(init)
  410. print("Session ready, beginning training loop")
  411. # Initialize the number of batches
  412. data_length = len(images)
  413. nb_batches = math.ceil(data_length / FLAGS.batch_size)
  414. for step in xrange(FLAGS.max_steps):
  415. # for debug, save start time
  416. start_time = time.time()
  417. # Current batch number
  418. batch_nb = step % nb_batches
  419. # Current batch start and end indices
  420. start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
  421. # Prepare dictionnary to feed the session with
  422. feed_dict = {train_data_node: images[start:end],
  423. train_labels_node: labels[start:end]}
  424. # Run training step
  425. _, loss_value = sess.run([train_op, loss], feed_dict=feed_dict)
  426. # Compute duration of training step
  427. duration = time.time() - start_time
  428. # Sanity check
  429. assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
  430. # Echo loss once in a while
  431. if step % 100 == 0:
  432. num_examples_per_step = FLAGS.batch_size
  433. examples_per_sec = num_examples_per_step / duration
  434. sec_per_batch = float(duration)
  435. format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  436. 'sec/batch)')
  437. print (format_str % (datetime.now(), step, loss_value,
  438. examples_per_sec, sec_per_batch))
  439. # Save the model checkpoint periodically.
  440. if step % 1000 == 0 or (step + 1) == FLAGS.max_steps:
  441. saver.save(sess, ckpt_path, global_step=step)
  442. return True
  443. def softmax_preds(images, ckpt_path, return_logits=False):
  444. """
  445. Compute softmax activations (probabilities) with the model saved in the path
  446. specified as an argument
  447. :param images: a np array of images
  448. :param ckpt_path: a TF model checkpoint
  449. :param logits: if set to True, return logits instead of probabilities
  450. :return: probabilities (or logits if logits is set to True)
  451. """
  452. # Compute nb samples and deduce nb of batches
  453. data_length = len(images)
  454. nb_batches = math.ceil(len(images) / FLAGS.batch_size)
  455. # Declare data placeholder
  456. train_data_node = _input_placeholder()
  457. # Build a Graph that computes the logits predictions from the placeholder
  458. if FLAGS.deeper:
  459. logits = inference_deeper(train_data_node)
  460. else:
  461. logits = inference(train_data_node)
  462. if return_logits:
  463. # We are returning the logits directly (no need to apply softmax)
  464. output = logits
  465. else:
  466. # Add softmax predictions to graph: will return probabilities
  467. output = tf.nn.softmax(logits)
  468. # Restore the moving average version of the learned variables for eval.
  469. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY)
  470. variables_to_restore = variable_averages.variables_to_restore()
  471. saver = tf.train.Saver(variables_to_restore)
  472. # Will hold the result
  473. preds = np.zeros((data_length, FLAGS.nb_labels), dtype=np.float32)
  474. # Create TF session
  475. with tf.Session() as sess:
  476. # Restore TF session from checkpoint file
  477. saver.restore(sess, ckpt_path)
  478. # Parse data by batch
  479. for batch_nb in xrange(0, int(nb_batches+1)):
  480. # Compute batch start and end indices
  481. start, end = utils.batch_indices(batch_nb, data_length, FLAGS.batch_size)
  482. # Prepare feed dictionary
  483. feed_dict = {train_data_node: images[start:end]}
  484. # Run session ([0] because run returns a batch with len 1st dim == 1)
  485. preds[start:end, :] = sess.run([output], feed_dict=feed_dict)[0]
  486. # Reset graph to allow multiple calls
  487. tf.reset_default_graph()
  488. return preds