inception_train.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A library to train Inception using multiple GPU's with synchronous updates.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import copy
  21. from datetime import datetime
  22. import os.path
  23. import re
  24. import time
  25. import numpy as np
  26. import tensorflow as tf
  27. from inception import image_processing
  28. from inception import inception_model as inception
  29. from inception.slim import slim
  30. FLAGS = tf.app.flags.FLAGS
  31. tf.app.flags.DEFINE_string('train_dir', '/tmp/imagenet_train',
  32. """Directory where to write event logs """
  33. """and checkpoint.""")
  34. tf.app.flags.DEFINE_integer('max_steps', 10000000,
  35. """Number of batches to run.""")
  36. tf.app.flags.DEFINE_string('subset', 'train',
  37. """Either 'train' or 'validation'.""")
  38. # Flags governing the hardware employed for running TensorFlow.
  39. tf.app.flags.DEFINE_integer('num_gpus', 1,
  40. """How many GPUs to use.""")
  41. tf.app.flags.DEFINE_boolean('log_device_placement', False,
  42. """Whether to log device placement.""")
  43. # Flags governing the type of training.
  44. tf.app.flags.DEFINE_boolean('fine_tune', False,
  45. """If set, randomly initialize the final layer """
  46. """of weights in order to train the network on a """
  47. """new task.""")
  48. tf.app.flags.DEFINE_string('pretrained_model_checkpoint_path', '',
  49. """If specified, restore this pretrained model """
  50. """before beginning any training.""")
  51. # **IMPORTANT**
  52. # Please note that this learning rate schedule is heavily dependent on the
  53. # hardware architecture, batch size and any changes to the model architecture
  54. # specification. Selecting a finely tuned learning rate schedule is an
  55. # empirical process that requires some experimentation. Please see README.md
  56. # more guidance and discussion.
  57. #
  58. # With 8 Tesla K40's and a batch size = 256, the following setup achieves
  59. # precision@1 = 73.5% after 100 hours and 100K steps (20 epochs).
  60. # Learning rate decay factor selected from http://arxiv.org/abs/1404.5997.
  61. tf.app.flags.DEFINE_float('initial_learning_rate', 0.1,
  62. """Initial learning rate.""")
  63. tf.app.flags.DEFINE_float('num_epochs_per_decay', 30.0,
  64. """Epochs after which learning rate decays.""")
  65. tf.app.flags.DEFINE_float('learning_rate_decay_factor', 0.16,
  66. """Learning rate decay factor.""")
  67. # Constants dictating the learning rate schedule.
  68. RMSPROP_DECAY = 0.9 # Decay term for RMSProp.
  69. RMSPROP_MOMENTUM = 0.9 # Momentum in RMSProp.
  70. RMSPROP_EPSILON = 1.0 # Epsilon term for RMSProp.
  71. def _tower_loss(images, labels, num_classes, scope, reuse_variables=None):
  72. """Calculate the total loss on a single tower running the ImageNet model.
  73. We perform 'batch splitting'. This means that we cut up a batch across
  74. multiple GPU's. For instance, if the batch size = 32 and num_gpus = 2,
  75. then each tower will operate on an batch of 16 images.
  76. Args:
  77. images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
  78. FLAGS.image_size, 3].
  79. labels: 1-D integer Tensor of [batch_size].
  80. num_classes: number of classes
  81. scope: unique prefix string identifying the ImageNet tower, e.g.
  82. 'tower_0'.
  83. Returns:
  84. Tensor of shape [] containing the total loss for a batch of data
  85. """
  86. # When fine-tuning a model, we do not restore the logits but instead we
  87. # randomly initialize the logits. The number of classes in the output of the
  88. # logit is the number of classes in specified Dataset.
  89. restore_logits = not FLAGS.fine_tune
  90. # Build inference Graph.
  91. with tf.variable_scope(tf.get_variable_scope(), reuse=reuse_variables):
  92. logits = inception.inference(images, num_classes, for_training=True,
  93. restore_logits=restore_logits,
  94. scope=scope)
  95. # Build the portion of the Graph calculating the losses. Note that we will
  96. # assemble the total_loss using a custom function below.
  97. split_batch_size = images.get_shape().as_list()[0]
  98. inception.loss(logits, labels, batch_size=split_batch_size)
  99. # Assemble all of the losses for the current tower only.
  100. losses = tf.get_collection(slim.losses.LOSSES_COLLECTION, scope)
  101. # Calculate the total loss for the current tower.
  102. regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
  103. total_loss = tf.add_n(losses + regularization_losses, name='total_loss')
  104. # Compute the moving average of all individual losses and the total loss.
  105. loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg')
  106. loss_averages_op = loss_averages.apply(losses + [total_loss])
  107. # Attach a scalar summmary to all individual losses and the total loss; do the
  108. # same for the averaged version of the losses.
  109. for l in losses + [total_loss]:
  110. # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training
  111. # session. This helps the clarity of presentation on TensorBoard.
  112. loss_name = re.sub('%s_[0-9]*/' % inception.TOWER_NAME, '', l.op.name)
  113. # Name each loss as '(raw)' and name the moving average version of the loss
  114. # as the original loss name.
  115. tf.scalar_summary(loss_name +' (raw)', l)
  116. tf.scalar_summary(loss_name, loss_averages.average(l))
  117. with tf.control_dependencies([loss_averages_op]):
  118. total_loss = tf.identity(total_loss)
  119. return total_loss
  120. def _average_gradients(tower_grads):
  121. """Calculate the average gradient for each shared variable across all towers.
  122. Note that this function provides a synchronization point across all towers.
  123. Args:
  124. tower_grads: List of lists of (gradient, variable) tuples. The outer list
  125. is over individual gradients. The inner list is over the gradient
  126. calculation for each tower.
  127. Returns:
  128. List of pairs of (gradient, variable) where the gradient has been averaged
  129. across all towers.
  130. """
  131. average_grads = []
  132. for grad_and_vars in zip(*tower_grads):
  133. # Note that each grad_and_vars looks like the following:
  134. # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN))
  135. grads = []
  136. for g, _ in grad_and_vars:
  137. # Add 0 dimension to the gradients to represent the tower.
  138. expanded_g = tf.expand_dims(g, 0)
  139. # Append on a 'tower' dimension which we will average over below.
  140. grads.append(expanded_g)
  141. # Average over the 'tower' dimension.
  142. grad = tf.concat(0, grads)
  143. grad = tf.reduce_mean(grad, 0)
  144. # Keep in mind that the Variables are redundant because they are shared
  145. # across towers. So .. we will just return the first tower's pointer to
  146. # the Variable.
  147. v = grad_and_vars[0][1]
  148. grad_and_var = (grad, v)
  149. average_grads.append(grad_and_var)
  150. return average_grads
  151. def train(dataset):
  152. """Train on dataset for a number of steps."""
  153. with tf.Graph().as_default(), tf.device('/cpu:0'):
  154. # Create a variable to count the number of train() calls. This equals the
  155. # number of batches processed * FLAGS.num_gpus.
  156. global_step = tf.get_variable(
  157. 'global_step', [],
  158. initializer=tf.constant_initializer(0), trainable=False)
  159. # Calculate the learning rate schedule.
  160. num_batches_per_epoch = (dataset.num_examples_per_epoch() /
  161. FLAGS.batch_size)
  162. decay_steps = int(num_batches_per_epoch * FLAGS.num_epochs_per_decay)
  163. # Decay the learning rate exponentially based on the number of steps.
  164. lr = tf.train.exponential_decay(FLAGS.initial_learning_rate,
  165. global_step,
  166. decay_steps,
  167. FLAGS.learning_rate_decay_factor,
  168. staircase=True)
  169. # Create an optimizer that performs gradient descent.
  170. opt = tf.train.RMSPropOptimizer(lr, RMSPROP_DECAY,
  171. momentum=RMSPROP_MOMENTUM,
  172. epsilon=RMSPROP_EPSILON)
  173. # Get images and labels for ImageNet and split the batch across GPUs.
  174. assert FLAGS.batch_size % FLAGS.num_gpus == 0, (
  175. 'Batch size must be divisible by number of GPUs')
  176. split_batch_size = int(FLAGS.batch_size / FLAGS.num_gpus)
  177. # Override the number of preprocessing threads to account for the increased
  178. # number of GPU towers.
  179. num_preprocess_threads = FLAGS.num_preprocess_threads * FLAGS.num_gpus
  180. images, labels = image_processing.distorted_inputs(
  181. dataset,
  182. num_preprocess_threads=num_preprocess_threads)
  183. input_summaries = copy.copy(tf.get_collection(tf.GraphKeys.SUMMARIES))
  184. # Number of classes in the Dataset label set plus 1.
  185. # Label 0 is reserved for an (unused) background class.
  186. num_classes = dataset.num_classes() + 1
  187. # Split the batch of images and labels for towers.
  188. images_splits = tf.split(0, FLAGS.num_gpus, images)
  189. labels_splits = tf.split(0, FLAGS.num_gpus, labels)
  190. # Calculate the gradients for each model tower.
  191. tower_grads = []
  192. reuse_variables = None
  193. for i in range(FLAGS.num_gpus):
  194. with tf.device('/gpu:%d' % i):
  195. with tf.name_scope('%s_%d' % (inception.TOWER_NAME, i)) as scope:
  196. # Force all Variables to reside on the CPU.
  197. with slim.arg_scope([slim.variables.variable], device='/cpu:0'):
  198. # Calculate the loss for one tower of the ImageNet model. This
  199. # function constructs the entire ImageNet model but shares the
  200. # variables across all towers.
  201. loss = _tower_loss(images_splits[i], labels_splits[i], num_classes,
  202. scope, reuse_variables)
  203. # Reuse variables for the next tower.
  204. reuse_variables = True
  205. # Retain the summaries from the final tower.
  206. summaries = tf.get_collection(tf.GraphKeys.SUMMARIES, scope)
  207. # Retain the Batch Normalization updates operations only from the
  208. # final tower. Ideally, we should grab the updates from all towers
  209. # but these stats accumulate extremely fast so we can ignore the
  210. # other stats from the other towers without significant detriment.
  211. batchnorm_updates = tf.get_collection(slim.ops.UPDATE_OPS_COLLECTION,
  212. scope)
  213. # Calculate the gradients for the batch of data on this ImageNet
  214. # tower.
  215. grads = opt.compute_gradients(loss)
  216. # Keep track of the gradients across all towers.
  217. tower_grads.append(grads)
  218. # We must calculate the mean of each gradient. Note that this is the
  219. # synchronization point across all towers.
  220. grads = _average_gradients(tower_grads)
  221. # Add a summaries for the input processing and global_step.
  222. summaries.extend(input_summaries)
  223. # Add a summary to track the learning rate.
  224. summaries.append(tf.scalar_summary('learning_rate', lr))
  225. # Add histograms for gradients.
  226. for grad, var in grads:
  227. if grad is not None:
  228. summaries.append(
  229. tf.histogram_summary(var.op.name + '/gradients', grad))
  230. # Apply the gradients to adjust the shared variables.
  231. apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
  232. # Add histograms for trainable variables.
  233. for var in tf.trainable_variables():
  234. summaries.append(tf.histogram_summary(var.op.name, var))
  235. # Track the moving averages of all trainable variables.
  236. # Note that we maintain a "double-average" of the BatchNormalization
  237. # global statistics. This is more complicated then need be but we employ
  238. # this for backward-compatibility with our previous models.
  239. variable_averages = tf.train.ExponentialMovingAverage(
  240. inception.MOVING_AVERAGE_DECAY, global_step)
  241. # Another possiblility is to use tf.slim.get_variables().
  242. variables_to_average = (tf.trainable_variables() +
  243. tf.moving_average_variables())
  244. variables_averages_op = variable_averages.apply(variables_to_average)
  245. # Group all updates to into a single train op.
  246. batchnorm_updates_op = tf.group(*batchnorm_updates)
  247. train_op = tf.group(apply_gradient_op, variables_averages_op,
  248. batchnorm_updates_op)
  249. # Create a saver.
  250. saver = tf.train.Saver(tf.all_variables())
  251. # Build the summary operation from the last tower summaries.
  252. summary_op = tf.merge_summary(summaries)
  253. # Build an initialization operation to run below.
  254. init = tf.global_variables_initializer()
  255. # Start running operations on the Graph. allow_soft_placement must be set to
  256. # True to build towers on GPU, as some of the ops do not have GPU
  257. # implementations.
  258. sess = tf.Session(config=tf.ConfigProto(
  259. allow_soft_placement=True,
  260. log_device_placement=FLAGS.log_device_placement))
  261. sess.run(init)
  262. if FLAGS.pretrained_model_checkpoint_path:
  263. assert tf.gfile.Exists(FLAGS.pretrained_model_checkpoint_path)
  264. variables_to_restore = tf.get_collection(
  265. slim.variables.VARIABLES_TO_RESTORE)
  266. restorer = tf.train.Saver(variables_to_restore)
  267. restorer.restore(sess, FLAGS.pretrained_model_checkpoint_path)
  268. print('%s: Pre-trained model restored from %s' %
  269. (datetime.now(), FLAGS.pretrained_model_checkpoint_path))
  270. # Start the queue runners.
  271. tf.train.start_queue_runners(sess=sess)
  272. summary_writer = tf.train.SummaryWriter(
  273. FLAGS.train_dir,
  274. graph_def=sess.graph.as_graph_def(add_shapes=True))
  275. for step in range(FLAGS.max_steps):
  276. start_time = time.time()
  277. _, loss_value = sess.run([train_op, loss])
  278. duration = time.time() - start_time
  279. assert not np.isnan(loss_value), 'Model diverged with loss = NaN'
  280. if step % 10 == 0:
  281. examples_per_sec = FLAGS.batch_size / float(duration)
  282. format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  283. 'sec/batch)')
  284. print(format_str % (datetime.now(), step, loss_value,
  285. examples_per_sec, duration))
  286. if step % 100 == 0:
  287. summary_str = sess.run(summary_op)
  288. summary_writer.add_summary(summary_str, step)
  289. # Save the model checkpoint periodically.
  290. if step % 5000 == 0 or (step + 1) == FLAGS.max_steps:
  291. checkpoint_path = os.path.join(FLAGS.train_dir, 'model.ckpt')
  292. saver.save(sess, checkpoint_path, global_step=step)