cifar10_train.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A binary to train CIFAR-10 using a single GPU.
  16. Accuracy:
  17. cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of
  18. data) as judged by cifar10_eval.py.
  19. Speed: With batch_size 128.
  20. System | Step Time (sec/batch) | Accuracy
  21. ------------------------------------------------------------------
  22. 1 Tesla K20m | 0.35-0.60 | ~86% at 60K steps (5 hours)
  23. 1 Tesla K40m | 0.25-0.35 | ~86% at 100K steps (4 hours)
  24. Usage:
  25. Please see the tutorial and website for how to download the CIFAR-10
  26. data set, compile the program and train the model.
  27. http://tensorflow.org/tutorials/deep_cnn/
  28. """
  29. from __future__ import absolute_import
  30. from __future__ import division
  31. from __future__ import print_function
  32. from datetime import datetime
  33. import time
  34. import tensorflow as tf
  35. import cifar10
  36. FLAGS = tf.app.flags.FLAGS
  37. tf.app.flags.DEFINE_string('train_dir', '/tmp/cifar10_train',
  38. """Directory where to write event logs """
  39. """and checkpoint.""")
  40. tf.app.flags.DEFINE_integer('max_steps', 1000000,
  41. """Number of batches to run.""")
  42. tf.app.flags.DEFINE_boolean('log_device_placement', False,
  43. """Whether to log device placement.""")
  44. tf.app.flags.DEFINE_integer('log_frequency', 10,
  45. """How often to log results to the console.""")
  46. def train():
  47. """Train CIFAR-10 for a number of steps."""
  48. with tf.Graph().as_default():
  49. global_step = tf.contrib.framework.get_or_create_global_step()
  50. # Get images and labels for CIFAR-10.
  51. images, labels = cifar10.distorted_inputs()
  52. # Build a Graph that computes the logits predictions from the
  53. # inference model.
  54. logits = cifar10.inference(images)
  55. # Calculate loss.
  56. loss = cifar10.loss(logits, labels)
  57. # Build a Graph that trains the model with one batch of examples and
  58. # updates the model parameters.
  59. train_op = cifar10.train(loss, global_step)
  60. class _LoggerHook(tf.train.SessionRunHook):
  61. """Logs loss and runtime."""
  62. def begin(self):
  63. self._step = -1
  64. self._start_time = time.time()
  65. def before_run(self, run_context):
  66. self._step += 1
  67. return tf.train.SessionRunArgs(loss) # Asks for loss value.
  68. def after_run(self, run_context, run_values):
  69. if self._step % FLAGS.log_frequency == 0:
  70. current_time = time.time()
  71. duration = current_time - self._start_time
  72. self._start_time = current_time
  73. loss_value = run_values.results
  74. examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
  75. sec_per_batch = float(duration / FLAGS.log_frequency)
  76. format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
  77. 'sec/batch)')
  78. print (format_str % (datetime.now(), self._step, loss_value,
  79. examples_per_sec, sec_per_batch))
  80. with tf.train.MonitoredTrainingSession(
  81. checkpoint_dir=FLAGS.train_dir,
  82. hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps),
  83. tf.train.NanTensorHook(loss),
  84. _LoggerHook()],
  85. config=tf.ConfigProto(
  86. log_device_placement=FLAGS.log_device_placement)) as mon_sess:
  87. while not mon_sess.should_stop():
  88. mon_sess.run(train_op)
  89. def main(argv=None): # pylint: disable=unused-argument
  90. cifar10.maybe_download_and_extract()
  91. if tf.gfile.Exists(FLAGS.train_dir):
  92. tf.gfile.DeleteRecursively(FLAGS.train_dir)
  93. tf.gfile.MakeDirs(FLAGS.train_dir)
  94. train()
  95. if __name__ == '__main__':
  96. tf.app.run()