eval.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Generic evaluation script that trains a given model a specified dataset."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. import tensorflow as tf
  21. from slim.datasets import dataset_factory
  22. from slim.models import model_factory
  23. from slim.models import preprocessing_factory
  24. slim = tf.contrib.slim
  25. tf.app.flags.DEFINE_integer(
  26. 'batch_size', 100, 'The number of samples in each batch.')
  27. tf.app.flags.DEFINE_integer(
  28. 'max_num_batches', None,
  29. 'Max number of batches to evaluate by default use all.')
  30. tf.app.flags.DEFINE_string(
  31. 'master', '', 'The address of the TensorFlow master to use.')
  32. tf.app.flags.DEFINE_string(
  33. 'checkpoint_path', '/tmp/tfmodel/',
  34. 'The directory where the model was written to or an absolute path to a '
  35. 'checkpoint file.')
  36. tf.app.flags.DEFINE_bool(
  37. 'restore_global_step', True,
  38. 'Whether or not to restore the global step. When evaluating a model '
  39. 'checkpoint containing ONLY weights, set this flag to `False`.')
  40. tf.app.flags.DEFINE_string(
  41. 'eval_dir', '/tmp/tfmodel/', 'Directory where the results are saved to.')
  42. tf.app.flags.DEFINE_integer(
  43. 'num_preprocessing_threads', 4,
  44. 'The number of threads used to create the batches.')
  45. tf.app.flags.DEFINE_string(
  46. 'dataset_name', 'imagenet', 'The name of the dataset to load.')
  47. tf.app.flags.DEFINE_string(
  48. 'dataset_split_name', 'train', 'The name of the train/test split.')
  49. tf.app.flags.DEFINE_string(
  50. 'dataset_dir', None, 'The directory where the dataset files are stored.')
  51. tf.app.flags.MarkFlagAsRequired('dataset_dir')
  52. tf.app.flags.DEFINE_integer(
  53. 'labels_offset', 0,
  54. 'An offset for the labels in the dataset. This flag is primarily used to '
  55. 'evaluate the VGG and ResNet architectures which do not use a background '
  56. 'class for the ImageNet dataset.')
  57. tf.app.flags.DEFINE_string(
  58. 'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
  59. tf.app.flags.DEFINE_string(
  60. 'preprocessing_name', None, 'The name of the preprocessing to use. If left '
  61. 'as `None`, then the model_name flag is used.')
  62. tf.app.flags.DEFINE_float(
  63. 'moving_average_decay', None,
  64. 'The decay to use for the moving average.'
  65. 'If left as None, then moving averages are not used.')
  66. FLAGS = tf.app.flags.FLAGS
  67. def main(_):
  68. with tf.Graph().as_default():
  69. tf_global_step = slim.get_or_create_global_step()
  70. ######################
  71. # Select the dataset #
  72. ######################
  73. dataset = dataset_factory.get_dataset(
  74. FLAGS.dataset_name, FLAGS.dataset_split_name, FLAGS.dataset_dir)
  75. ####################
  76. # Select the model #
  77. ####################
  78. model_fn = model_factory.get_model(
  79. FLAGS.model_name,
  80. num_classes=(dataset.num_classes - FLAGS.labels_offset),
  81. is_training=False)
  82. ##############################################################
  83. # Create a dataset provider that loads data from the dataset #
  84. ##############################################################
  85. provider = slim.dataset_data_provider.DatasetDataProvider(
  86. dataset,
  87. shuffle=False,
  88. common_queue_capacity=2 * FLAGS.batch_size,
  89. common_queue_min=FLAGS.batch_size)
  90. [image, label] = provider.get(['image', 'label'])
  91. label -= FLAGS.labels_offset
  92. #####################################
  93. # Select the preprocessing function #
  94. #####################################
  95. preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
  96. image_preprocessing_fn = preprocessing_factory.get_preprocessing(
  97. preprocessing_name,
  98. is_training=False)
  99. image = image_preprocessing_fn(image,
  100. height=model_fn.default_image_size,
  101. width=model_fn.default_image_size)
  102. images, labels = tf.train.batch(
  103. [image, label],
  104. batch_size=FLAGS.batch_size,
  105. num_threads=FLAGS.num_preprocessing_threads,
  106. capacity=5 * FLAGS.batch_size)
  107. ####################
  108. # Define the model #
  109. ####################
  110. logits, _ = model_fn(images)
  111. if FLAGS.moving_average_decay:
  112. variable_averages = tf.train.ExponentialMovingAverage(
  113. FLAGS.moving_average_decay, tf_global_step)
  114. variables_to_restore = variable_averages.variables_to_restore(
  115. slim.get_model_variables())
  116. if FLAGS.restore_global_step:
  117. variables_to_restore[tf_global_step.op.name] = tf_global_step
  118. else:
  119. exclude = None if FLAGS.restore_global_step else ['global_step']
  120. variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
  121. predictions = tf.argmax(logits, 1)
  122. labels = tf.squeeze(labels)
  123. # Define the metrics:
  124. names_to_values, names_to_updates = slim.metrics.aggregate_metric_map({
  125. 'Accuracy': slim.metrics.streaming_accuracy(predictions, labels),
  126. 'Recall@5': slim.metrics.streaming_recall_at_k(
  127. logits, labels, 5),
  128. })
  129. # Print the summaries to screen.
  130. for name, value in names_to_values.iteritems():
  131. summary_name = 'eval/%s' % name
  132. op = tf.scalar_summary(summary_name, value, collections=[])
  133. op = tf.Print(op, [value], summary_name)
  134. tf.add_to_collection(tf.GraphKeys.SUMMARIES, op)
  135. # TODO(sguada) use num_epochs=1
  136. if FLAGS.max_num_batches:
  137. num_batches = FLAGS.max_num_batches
  138. else:
  139. # This ensures that we make a single pass over all of the data.
  140. num_batches = math.ceil(dataset.num_samples / float(FLAGS.batch_size))
  141. if tf.gfile.IsDirectory(FLAGS.checkpoint_path):
  142. checkpoint_path = tf.train.latest_checkpoint(FLAGS.checkpoint_path)
  143. else:
  144. checkpoint_path = FLAGS.checkpoint_path
  145. tf.logging.info('Evaluating %s' % checkpoint_path)
  146. slim.evaluation.evaluate_once(
  147. FLAGS.master,
  148. checkpoint_path,
  149. logdir=FLAGS.eval_dir,
  150. num_evals=num_batches,
  151. eval_op=names_to_updates.values(),
  152. variables_to_restore=variables_to_restore)
  153. if __name__ == '__main__':
  154. tf.app.run()