resnet_main.py 7.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """ResNet Train/Eval module.
  16. """
  17. import time
  18. import six
  19. import sys
  20. import cifar_input
  21. import numpy as np
  22. import resnet_model
  23. import tensorflow as tf
  24. FLAGS = tf.app.flags.FLAGS
  25. tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
  26. tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
  27. tf.app.flags.DEFINE_string('train_data_path', '',
  28. 'Filepattern for training data.')
  29. tf.app.flags.DEFINE_string('eval_data_path', '',
  30. 'Filepattern for eval data')
  31. tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
  32. tf.app.flags.DEFINE_string('train_dir', '',
  33. 'Directory to keep training outputs.')
  34. tf.app.flags.DEFINE_string('eval_dir', '',
  35. 'Directory to keep eval outputs.')
  36. tf.app.flags.DEFINE_integer('eval_batch_count', 50,
  37. 'Number of batches to eval.')
  38. tf.app.flags.DEFINE_bool('eval_once', False,
  39. 'Whether evaluate the model only once.')
  40. tf.app.flags.DEFINE_string('log_root', '',
  41. 'Directory to keep the checkpoints. Should be a '
  42. 'parent directory of FLAGS.train_dir/eval_dir.')
  43. tf.app.flags.DEFINE_integer('num_gpus', 0,
  44. 'Number of gpus used for training. (0 or 1)')
  45. def train(hps):
  46. """Training loop."""
  47. images, labels = cifar_input.build_input(
  48. FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
  49. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  50. model.build_graph()
  51. param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
  52. tf.get_default_graph(),
  53. tfprof_options=tf.contrib.tfprof.model_analyzer.
  54. TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
  55. sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
  56. tf.contrib.tfprof.model_analyzer.print_model_analysis(
  57. tf.get_default_graph(),
  58. tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
  59. truth = tf.argmax(model.labels, axis=1)
  60. predictions = tf.argmax(model.predictions, axis=1)
  61. precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
  62. summary_hook = tf.train.SummarySaverHook(
  63. save_steps=100,
  64. output_dir=FLAGS.train_dir,
  65. summary_op=tf.summary.merge([model.summaries,
  66. tf.summary.scalar('Precision', precision)]))
  67. logging_hook = tf.train.LoggingTensorHook(
  68. tensors={'step': model.global_step,
  69. 'loss': model.cost,
  70. 'precision': precision},
  71. every_n_iter=100)
  72. class _LearningRateSetterHook(tf.train.SessionRunHook):
  73. """Sets learning_rate based on global step."""
  74. def begin(self):
  75. self._lrn_rate = 0.1
  76. def before_run(self, run_context):
  77. return tf.train.SessionRunArgs(
  78. model.global_step, # Asks for global step value.
  79. feed_dict={model.lrn_rate: self._lrn_rate}) # Sets learning rate
  80. def after_run(self, run_context, run_values):
  81. train_step = run_values.results
  82. if train_step < 40000:
  83. self._lrn_rate = 0.1
  84. elif train_step < 60000:
  85. self._lrn_rate = 0.01
  86. elif train_step < 80000:
  87. self._lrn_rate = 0.001
  88. else:
  89. self._lrn_rate = 0.0001
  90. with tf.train.MonitoredTrainingSession(
  91. checkpoint_dir=FLAGS.log_root,
  92. hooks=[logging_hook, _LearningRateSetterHook()],
  93. chief_only_hooks=[summary_hook],
  94. # Since we provide a SummarySaverHook, we need to disable default
  95. # SummarySaverHook. To do that we set save_summaries_steps to 0.
  96. save_summaries_steps=0,
  97. config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
  98. while not mon_sess.should_stop():
  99. mon_sess.run(model.train_op)
  100. def evaluate(hps):
  101. """Eval loop."""
  102. images, labels = cifar_input.build_input(
  103. FLAGS.dataset, FLAGS.eval_data_path, hps.batch_size, FLAGS.mode)
  104. model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
  105. model.build_graph()
  106. saver = tf.train.Saver()
  107. summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
  108. sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
  109. tf.train.start_queue_runners(sess)
  110. best_precision = 0.0
  111. while True:
  112. try:
  113. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  114. except tf.errors.OutOfRangeError as e:
  115. tf.logging.error('Cannot restore checkpoint: %s', e)
  116. continue
  117. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  118. tf.logging.info('No model to eval yet at %s', FLAGS.log_root)
  119. continue
  120. tf.logging.info('Loading checkpoint %s', ckpt_state.model_checkpoint_path)
  121. saver.restore(sess, ckpt_state.model_checkpoint_path)
  122. total_prediction, correct_prediction = 0, 0
  123. for _ in six.moves.range(FLAGS.eval_batch_count):
  124. (summaries, loss, predictions, truth, train_step) = sess.run(
  125. [model.summaries, model.cost, model.predictions,
  126. model.labels, model.global_step])
  127. truth = np.argmax(truth, axis=1)
  128. predictions = np.argmax(predictions, axis=1)
  129. correct_prediction += np.sum(truth == predictions)
  130. total_prediction += predictions.shape[0]
  131. precision = 1.0 * correct_prediction / total_prediction
  132. best_precision = max(precision, best_precision)
  133. precision_summ = tf.Summary()
  134. precision_summ.value.add(
  135. tag='Precision', simple_value=precision)
  136. summary_writer.add_summary(precision_summ, train_step)
  137. best_precision_summ = tf.Summary()
  138. best_precision_summ.value.add(
  139. tag='Best Precision', simple_value=best_precision)
  140. summary_writer.add_summary(best_precision_summ, train_step)
  141. summary_writer.add_summary(summaries, train_step)
  142. tf.logging.info('loss: %.3f, precision: %.3f, best precision: %.3f' %
  143. (loss, precision, best_precision))
  144. summary_writer.flush()
  145. if FLAGS.eval_once:
  146. break
  147. time.sleep(60)
  148. def main(_):
  149. if FLAGS.num_gpus == 0:
  150. dev = '/cpu:0'
  151. elif FLAGS.num_gpus == 1:
  152. dev = '/gpu:0'
  153. else:
  154. raise ValueError('Only support 0 or 1 gpu.')
  155. if FLAGS.mode == 'train':
  156. batch_size = 128
  157. elif FLAGS.mode == 'eval':
  158. batch_size = 100
  159. if FLAGS.dataset == 'cifar10':
  160. num_classes = 10
  161. elif FLAGS.dataset == 'cifar100':
  162. num_classes = 100
  163. hps = resnet_model.HParams(batch_size=batch_size,
  164. num_classes=num_classes,
  165. min_lrn_rate=0.0001,
  166. lrn_rate=0.1,
  167. num_residual_units=5,
  168. use_bottleneck=False,
  169. weight_decay_rate=0.0002,
  170. relu_leakiness=0.1,
  171. optimizer='mom')
  172. with tf.device(dev):
  173. if FLAGS.mode == 'train':
  174. train(hps)
  175. elif FLAGS.mode == 'eval':
  176. evaluate(hps)
  177. if __name__ == '__main__':
  178. tf.logging.set_verbosity(tf.logging.INFO)
  179. tf.app.run()