.pipertmp-son4h0-dsn_eval.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # pylint: disable=line-too-long
  16. r"""Evaluation for Domain Separation Networks (DSNs).
  17. To build locally for CPU:
  18. blaze build -c opt --copt=-mavx \
  19. third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
  20. To build locally for GPU:
  21. blaze build -c opt --copt=-mavx --config=cuda_clang \
  22. third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
  23. To run locally:
  24. $
  25. ./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
  26. \
  27. --alsologtostderr
  28. """
  29. # pylint: enable=line-too-long
  30. import math
  31. import google3
  32. import numpy as np
  33. import tensorflow as tf
  34. from google3.third_party.tensorflow_models.domain_adaptation.datasets import dataset_factory
  35. from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
  36. from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import losses
  37. from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
  38. slim = tf.contrib.slim
  39. FLAGS = tf.app.flags.FLAGS
  40. tf.app.flags.DEFINE_integer('batch_size', 32,
  41. 'The number of images in each batch.')
  42. tf.app.flags.DEFINE_string('master', '',
  43. 'BNS name of the TensorFlow master to use.')
  44. tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
  45. 'Directory where the model was written to.')
  46. tf.app.flags.DEFINE_string(
  47. 'eval_dir', '/tmp/da/',
  48. 'Directory where we should write the tf summaries to.')
  49. tf.app.flags.DEFINE_string('dataset_dir', None,
  50. 'The directory where the dataset files are stored.')
  51. tf.app.flags.DEFINE_string('dataset', 'mnist_m',
  52. 'Which dataset to test on: "mnist", "mnist_m".')
  53. tf.app.flags.DEFINE_string('split', 'valid',
  54. 'Which portion to test on: "valid", "test".')
  55. tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
  56. >>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
  57. tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
  58. ==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
  59. tf.app.flags.DEFINE_string('basic_tower', 'dsn_cropped_linemod',
  60. ==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
  61. tf.app.flags.DEFINE_string('basic_tower', 'dann_mnist',
  62. <<<<
  63. 'The basic tower building block.')
  64. >>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
  65. ==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
  66. tf.app.flags.DEFINE_bool('enable_precision_recall', False,
  67. 'If True, precision and recall for each class will '
  68. 'be added to the metrics.')
  69. ==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
  70. tf.app.flags.DEFINE_bool('enable_precision_recall', False,
  71. 'If True, precision and recall for each class will '
  72. 'be added to the metrics.')
  73. <<<<
  74. tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
  75. def quaternion_metric(predictions, labels):
  76. params = {'batch_size': FLAGS.batch_size, 'use_logging': False}
  77. logcost = losses.log_quaternion_loss_batch(predictions, labels, params)
  78. return slim.metrics.streaming_mean(logcost)
  79. def angle_diff(true_q, pred_q):
  80. angles = 2 * (
  81. 180.0 /
  82. np.pi) * np.arccos(np.abs(np.sum(np.multiply(pred_q, true_q), axis=1)))
  83. return angles
  84. >>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
  85. Returns:
  86. The angle in degrees of the implied angle-axis representation.
  87. """
  88. product = tf.multiply(predictions, labels)
  89. internal_dot_products = tf.reduce_sum(product, [1])
  90. log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
  91. angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
  92. return tf.contrib.metrics.streaming_mean(angle_loss)
  93. ==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
  94. ==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
  95. def provide_batch_fn():
  96. """ The provide_batch function to use. """
  97. return dataset_factory.provide_batch
  98. <<<<
  99. def main(_):
  100. g = tf.Graph()
  101. with g.as_default():
  102. # Load the data.
  103. images, labels = provide_batch_fn()(
  104. FLAGS.dataset, FLAGS.split, FLAGS.dataset_dir, 4, FLAGS.batch_size, 4)
  105. num_classes = labels['classes'].get_shape().as_list()[1]
  106. tf.summary.image('eval_images', images, max_outputs=3)
  107. # Define the model:
  108. with tf.variable_scope('towers'):
  109. basic_tower = getattr(models, FLAGS.basic_tower)
  110. predictions, endpoints = basic_tower(
  111. images,
  112. num_classes=num_classes,
  113. is_training=False,
  114. batch_norm_params=None)
  115. metric_names_to_values = {}
  116. # Define the metrics:
  117. if 'quaternions' in labels: # Also have to evaluate pose estimation!
  118. quaternion_loss = quaternion_metric(labels['quaternions'],
  119. endpoints['quaternion_pred'])
  120. angle_errors, = tf.py_func(
  121. angle_diff, [labels['quaternions'], endpoints['quaternion_pred']],
  122. [tf.float32])
  123. >>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
  124. metric_name = 'Log Quaternion Error'
  125. names_to_values[metric_name], names_to_updates[
  126. metric_name] = quaternion_metric(labels['quaternions'],
  127. endpoints['quaternion_pred'])
  128. metric_name = 'Accuracy'
  129. names_to_values[metric_name], names_to_updates[
  130. metric_name] = tf.contrib.metrics.streaming_accuracy(
  131. tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
  132. ==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
  133. metric_names_to_values[
  134. 'Angular mean error'] = slim.metrics.streaming_mean(angle_errors)
  135. metric_names_to_values['Quaternion Loss'] = quaternion_loss
  136. ==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
  137. metric_names_to_values['Angular mean error'] = slim.metrics.mean(
  138. angle_errors)
  139. metric_names_to_values['Quaternion Loss'] = quaternion_loss
  140. <<<<
  141. accuracy = tf.contrib.metrics.streaming_accuracy(
  142. tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
  143. >>>> ORIGINAL //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#5
  144. ==== THEIRS //depot/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py#6
  145. predictions = tf.argmax(predictions, 1)
  146. labels = tf.argmax(labels['classes'], 1)
  147. metric_names_to_values['Accuracy'] = accuracy
  148. names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
  149. metric_names_to_values)
  150. ==== YOURS //konstantinos:opensource:883:citc/google3/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval.py
  151. predictions = tf.argmax(predictions, 1)
  152. labels = tf.argmax(labels['classes'], 1)
  153. metric_names_to_values['Accuracy'] = accuracy
  154. for i in xrange(num_classes):
  155. index_map = tf.one_hot(i, depth=num_classes)
  156. name = 'PR/Precision_{}'.format(i)
  157. metric_names_to_values[name] = slim.metrics.streaming_precision(
  158. tf.gather(index_map, predictions), tf.gather(index_map, labels))
  159. name = 'PR/Recall_{}'.format(i)
  160. metric_names_to_values[name] = slim.metrics.streaming_recall(
  161. tf.gather(index_map, predictions), tf.gather(index_map, labels))
  162. names_to_values, names_to_updates = slim.metrics.aggregate_metric_map(
  163. metric_names_to_values)
  164. <<<<
  165. # Create the summary ops such that they also print out to std output:
  166. summary_ops = []
  167. for metric_name, metric_value in names_to_values.iteritems():
  168. op = tf.summary.scalar(metric_name, metric_value)
  169. op = tf.Print(op, [metric_value], metric_name)
  170. summary_ops.append(op)
  171. # This ensures that we make a single pass over all of the data.
  172. num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
  173. # Setup the global step.
  174. slim.get_or_create_global_step()
  175. slim.evaluation.evaluation_loop(
  176. FLAGS.master,
  177. checkpoint_dir=FLAGS.checkpoint_dir,
  178. logdir=FLAGS.eval_dir,
  179. num_evals=num_batches,
  180. eval_op=names_to_updates.values(),
  181. summary_op=tf.summary.merge(summary_ops))
  182. if __name__ == '__main__':
  183. tf.app.run()