resnet_main.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """ResNet Train/Eval module.
  16. """
  17. import sys
  18. import time
  19. import cifar_input
  20. import numpy as np
  21. import resnet_model
  22. import tensorflow as tf
  23. FLAGS = tf.app.flags.FLAGS
  24. tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
  25. tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
  26. tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
  27. tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
  28. tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
  29. tf.app.flags.DEFINE_string('train_dir', '',
  30. 'Directory to keep training outputs.')
  31. tf.app.flags.DEFINE_string('eval_dir', '',
  32. 'Directory to keep eval outputs.')
  33. tf.app.flags.DEFINE_integer('eval_batch_count', 50,
  34. 'Number of batches to eval.')
  35. tf.app.flags.DEFINE_bool('eval_once', False,
  36. 'Whether evaluate the model only once.')
  37. tf.app.flags.DEFINE_string('log_root', '',
  38. 'Directory to keep the checkpoints. Should be a '
  39. 'parent directory of FLAGS.train_dir/eval_dir.')
  40. tf.app.flags.DEFINE_integer('num_gpus', 0,
  41. 'Number of gpus used for training. (0 or 1)')
  42. def train(hps):
  43. """Training loop."""
  44. images, labels = cifar_input.build_input(
  45. FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
  46. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  47. model.build_graph()
  48. summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
  49. sv = tf.train.Supervisor(logdir=FLAGS.log_root,
  50. is_chief=True,
  51. summary_op=None,
  52. save_summaries_secs=60,
  53. save_model_secs=300,
  54. global_step=model.global_step)
  55. sess = sv.prepare_or_wait_for_session(
  56. config=tf.ConfigProto(allow_soft_placement=True))
  57. step = 0
  58. lrn_rate = 0.1
  59. while not sv.should_stop():
  60. (_, summaries, loss, predictions, truth, train_step) = sess.run(
  61. [model.train_op, model.summaries, model.cost, model.predictions,
  62. model.labels, model.global_step],
  63. feed_dict={model.lrn_rate: lrn_rate})
  64. if train_step < 40000:
  65. lrn_rate = 0.1
  66. elif train_step < 60000:
  67. lrn_rate = 0.01
  68. elif train_step < 80000:
  69. lrn_rate = 0.001
  70. else:
  71. lrn_rate = 0.0001
  72. truth = np.argmax(truth, axis=1)
  73. predictions = np.argmax(predictions, axis=1)
  74. precision = np.mean(truth == predictions)
  75. step += 1
  76. if step % 100 == 0:
  77. precision_summ = tf.Summary()
  78. precision_summ.value.add(
  79. tag='Precision', simple_value=precision)
  80. summary_writer.add_summary(precision_summ, train_step)
  81. summary_writer.add_summary(summaries, train_step)
  82. tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
  83. summary_writer.flush()
  84. sv.Stop()
  85. def evaluate(hps):
  86. """Eval loop."""
  87. images, labels = cifar_input.build_input(
  88. FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  89. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  90. model.build_graph()
  91. saver = tf.train.Saver()
  92. summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
  93. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  94. tf.train.start_queue_runners(sess)
  95. best_precision = 0.0
  96. while True:
  97. time.sleep(60)
  98. try:
  99. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  100. except tf.errors.OutOfRangeError as e:
  101. tf.logging.error('Cannot restore checkpoint: %s', e)
  102. continue
  103. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  104. tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
  105. continue
  106. tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
  107. saver.restore(sess, ckpt_state.model_checkpoint_path)
  108. total_prediction, correct_prediction = 0, 0
  109. for _ in xrange(FLAGS.eval_batch_count):
  110. (summaries, loss, predictions, truth, train_step) = sess.run(
  111. [model.summaries, model.cost, model.predictions,
  112. model.labels, model.global_step])
  113. truth = np.argmax(truth, axis=1)
  114. predictions = np.argmax(predictions, axis=1)
  115. correct_prediction += np.sum(truth == predictions)
  116. total_prediction += predictions.shape[0]
  117. precision = 1.0 * correct_prediction / total_prediction
  118. best_precision = max(precision, best_precision)
  119. precision_summ = tf.Summary()
  120. precision_summ.value.add(
  121. tag='Precision', simple_value=precision)
  122. summary_writer.add_summary(precision_summ, train_step)
  123. best_precision_summ = tf.Summary()
  124. best_precision_summ.value.add(
  125. tag='Best Precision', simple_value=best_precision)
  126. summary_writer.add_summary(best_precision_summ, train_step)
  127. summary_writer.add_summary(summaries, train_step)
  128. tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f\n' %
  129. (loss, precision, best_precision))
  130. summary_writer.flush()
  131. if FLAGS.eval_once:
  132. break
  133. def main(_):
  134. if FLAGS.num_gpus == 0:
  135. dev = '/cpu:0'
  136. elif FLAGS.num_gpus == 1:
  137. dev = '/gpu:0'
  138. else:
  139. raise ValueError('Only support 0 or 1 gpu.')
  140. if FLAGS.mode == 'train':
  141. batch_size = 128
  142. elif FLAGS.mode == 'eval':
  143. batch_size = 100
  144. if FLAGS.dataset == 'cifar10':
  145. num_classes = 10
  146. elif FLAGS.dataset == 'cifar100':
  147. num_classes = 100
  148. hps = resnet_model.HParams(batch_size=batch_size,
  149. num_classes=num_classes,
  150. min_lrn_rate=0.0001,
  151. lrn_rate=0.1,
  152. num_residual_units=5,
  153. use_bottleneck=False,
  154. weight_decay_rate=0.0002,
  155. relu_leakiness=0.1,
  156. optimizer='mom')
  157. with tf.device(dev):
  158. if FLAGS.mode == 'train':
  159. train(hps)
  160. elif FLAGS.mode == 'eval':
  161. evaluate(hps)
  162. if __name__ == '__main__':
  163. tf.app.run()