resnet_main.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """ResNet Train/Eval module.
  16. """
  17. import sys
  18. import time
  19. import cifar_input
  20. import numpy as np
  21. import resnet_model
  22. import tensorflow as tf
  23. FLAGS = tf.app.flags.FLAGS
  24. tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
  25. tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
  26. tf.app.flags.DEFINE_string('train_data_path', '', 'Filename for training data.')
  27. tf.app.flags.DEFINE_string('eval_data_path', '', 'Filename for eval data')
  28. tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
  29. tf.app.flags.DEFINE_string('train_dir', '',
  30. 'Directory to keep training outputs.')
  31. tf.app.flags.DEFINE_string('eval_dir', '',
  32. 'Directory to keep eval outputs.')
  33. tf.app.flags.DEFINE_integer('eval_batch_count', 50,
  34. 'Number of batches to eval.')
  35. tf.app.flags.DEFINE_bool('eval_once', False,
  36. 'Whether evaluate the model only once.')
  37. tf.app.flags.DEFINE_string('log_root', '',
  38. 'Directory to keep the checkpoints. Should be a '
  39. 'parent directory of FLAGS.train_dir/eval_dir.')
  40. tf.app.flags.DEFINE_integer('num_gpus', 0,
  41. 'Number of gpus used for training. (0 or 1)')
  42. def train(hps):
  43. """Training loop."""
  44. images, labels = cifar_input.build_input(
  45. FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
  46. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  47. model.build_graph()
  48. summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
  49. sv = tf.train.Supervisor(logdir=FLAGS.log_root,
  50. is_chief=True,
  51. summary_op=None,
  52. save_summaries_secs=60,
  53. save_model_secs=300,
  54. global_step=model.global_step)
  55. sess = sv.prepare_or_wait_for_session()
  56. step = 0
  57. lrn_rate = 0.1
  58. while not sv.should_stop():
  59. (_, summaries, loss, predictions, truth, train_step) = sess.run(
  60. [model.train_op, model.summaries, model.cost, model.predictions,
  61. model.labels, model.global_step],
  62. feed_dict={model.lrn_rate: lrn_rate})
  63. if train_step < 40000:
  64. lrn_rate = 0.1
  65. elif train_step < 60000:
  66. lrn_rate = 0.01
  67. elif train_step < 80000:
  68. lrn_rate = 0.001
  69. else:
  70. lrn_rate = 0.0001
  71. truth = np.argmax(truth, axis=1)
  72. predictions = np.argmax(predictions, axis=1)
  73. precision = np.mean(truth == predictions)
  74. step += 1
  75. if step % 100 == 0:
  76. precision_summ = tf.Summary()
  77. precision_summ.value.add(
  78. tag='Precision', simple_value=precision)
  79. summary_writer.add_summary(precision_summ, train_step)
  80. summary_writer.add_summary(summaries, train_step)
  81. tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
  82. summary_writer.flush()
  83. sv.Stop()
  84. def evaluate(hps):
  85. """Eval loop."""
  86. images, labels = cifar_input.build_input(
  87. FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  88. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  89. model.build_graph()
  90. saver = tf.train.Saver()
  91. summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
  92. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  93. tf.train.start_queue_runners(sess)
  94. best_precision = 0.0
  95. while True:
  96. time.sleep(60)
  97. try:
  98. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  99. except tf.errors.OutOfRangeError as e:
  100. tf.logging.error('Cannot restore checkpoint: %s', e)
  101. continue
  102. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  103. tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
  104. continue
  105. tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
  106. saver.restore(sess, ckpt_state.model_checkpoint_path)
  107. total_prediction, correct_prediction = 0, 0
  108. for _ in xrange(FLAGS.eval_batch_count):
  109. (summaries, loss, predictions, truth, train_step) = sess.run(
  110. [model.summaries, model.cost, model.predictions,
  111. model.labels, model.global_step])
  112. truth = np.argmax(truth, axis=1)
  113. predictions = np.argmax(predictions, axis=1)
  114. correct_prediction += np.sum(truth == predictions)
  115. total_prediction += predictions.shape[0]
  116. precision = 1.0 * correct_prediction / total_prediction
  117. best_precision = max(precision, best_precision)
  118. precision_summ = tf.Summary()
  119. precision_summ.value.add(
  120. tag='Precision', simple_value=precision)
  121. summary_writer.add_summary(precision_summ, train_step)
  122. best_precision_summ = tf.Summary()
  123. best_precision_summ.value.add(
  124. tag='Best Precision', simple_value=best_precision)
  125. summary_writer.add_summary(best_precision_summ, train_step)
  126. summary_writer.add_summary(summaries, train_step)
  127. tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f\n' %
  128. (loss, precision, best_precision))
  129. summary_writer.flush()
  130. if FLAGS.eval_once:
  131. break
  132. def main(_):
  133. if FLAGS.num_gpus == 0:
  134. dev = '/cpu:0'
  135. elif FLAGS.num_gpus == 1:
  136. dev = '/gpu:0'
  137. else:
  138. raise ValueError('Only support 0 or 1 gpu.')
  139. if FLAGS.mode == 'train':
  140. batch_size = 128
  141. elif FLAGS.mode == 'eval':
  142. batch_size = 100
  143. if FLAGS.dataset == 'cifar10':
  144. num_classes = 10
  145. elif FLAGS.dataset == 'cifar100':
  146. num_classes = 100
  147. hps = resnet_model.HParams(batch_size=batch_size,
  148. num_classes=num_classes,
  149. min_lrn_rate=0.0001,
  150. lrn_rate=0.1,
  151. num_residual_units=5,
  152. use_bottleneck=False,
  153. weight_decay_rate=0.0002,
  154. relu_leakiness=0.1,
  155. optimizer='mom')
  156. with tf.device(dev):
  157. if FLAGS.mode == 'train':
  158. train(hps)
  159. elif FLAGS.mode == 'eval':
  160. evaluate(hps)
  161. if __name__ == '__main__':
  162. tf.app.run()