.pipertmp-2H2v0i-dsn_eval.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. # pylint: disable=line-too-long
  16. r"""Evaluation for Domain Separation Networks (DSNs).
  17. To build locally for CPU:
  18. blaze build -c opt --copt=-mavx \
  19. third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
  20. To build locally for GPU:
  21. blaze build -c opt --copt=-mavx --config=cuda_clang \
  22. third_party/tensorflow_models/domain_adaptation/domain_separation:dsn_eval
  23. To run locally:
  24. $
  25. ./blaze-bin/third_party/tensorflow_models/domain_adaptation/domain_separation/dsn_eval
  26. \
  27. --alsologtostderr
  28. """
  29. # pylint: enable=line-too-long
  30. import math
  31. import google3
  32. import numpy as np
  33. import tensorflow as tf
  34. from google3.robotics.cad_learning.domain_adaptation.fnist import data_provider
  35. from google3.third_party.tensorflow_models.domain_adaptation.domain_separation import models
  36. slim = tf.contrib.slim
  37. FLAGS = tf.app.flags.FLAGS
  38. tf.app.flags.DEFINE_integer('batch_size', 50,
  39. 'The number of images in each batch.')
  40. tf.app.flags.DEFINE_string('master', 'local',
  41. 'BNS name of the TensorFlow master to use.')
  42. tf.app.flags.DEFINE_string('checkpoint_dir', '/tmp/da/',
  43. 'Directory where the model was written to.')
  44. tf.app.flags.DEFINE_string(
  45. 'eval_dir', '/tmp/da/',
  46. 'Directory where we should write the tf summaries to.')
  47. tf.app.flags.DEFINE_string(
  48. 'dataset', 'pose_real',
  49. 'Which dataset to test on: "pose_real", "pose_synthetic".')
  50. tf.app.flags.DEFINE_string('portion', 'valid',
  51. 'Which portion to test on: "valid", "test".')
  52. tf.app.flags.DEFINE_integer('num_examples', 1000, 'Number of test examples.')
  53. tf.app.flags.DEFINE_string('basic_tower', 'pose_mini',
  54. 'The basic tower building block.')
  55. tf.app.flags.DEFINE_bool('use_logging', False, 'Debugging messages.')
  56. def quaternion_metric(predictions, labels):
  57. product = tf.multiply(predictions, labels)
  58. internal_dot_products = tf.reduce_sum(product, [1])
  59. logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
  60. return tf.contrib.metrics.streaming_mean(logcost)
  61. def to_degrees(predictions, labels):
  62. """Converts a log quaternion distance to an angle.
  63. Args:
  64. log_quaternion_loss: The log quaternion distance between two
  65. unit quaternions (or a batch of pairs of quaternions).
  66. Returns:
  67. The angle in degrees of the implied angle-axis representation.
  68. """
  69. product = tf.multiply(predictions, labels)
  70. internal_dot_products = tf.reduce_sum(product, [1])
  71. log_quaternion_loss = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
  72. angle_loss = tf.acos(-(tf.exp(log_quaternion_loss) - 1)) * 2 * 180 / math.pi
  73. return tf.contrib.metrics.streaming_mean(angle_loss)
  74. def main(_):
  75. g = tf.Graph()
  76. with g.as_default():
  77. images, labels = data_provider.provide(FLAGS.dataset, FLAGS.portion,
  78. FLAGS.batch_size)
  79. num_classes = labels['classes'].shape[1]
  80. # Define the model:
  81. with tf.variable_scope('towers'):
  82. basic_tower = models.provide(FLAGS.basic_tower)
  83. predictions, endpoints = basic_tower(
  84. images, is_training=False, num_classes=num_classes)
  85. names_to_values = {}
  86. names_to_updates = {}
  87. # Define the metrics:
  88. if 'quaternions' in labels: # Also have to evaluate pose estimation!
  89. quaternion_loss = quaternion_metric(labels['quaternions'],
  90. endpoints['quaternion_pred'])
  91. metric_name = 'Angle Mean Error'
  92. names_to_values[metric_name], names_to_updates[metric_name] = to_degrees(
  93. labels['quaternions'], endpoints['quaternion_pred'])
  94. metric_name = 'Log Quaternion Error'
  95. names_to_values[metric_name], names_to_updates[
  96. metric_name] = quaternion_metric(labels['quaternions'],
  97. endpoints['quaternion_pred'])
  98. metric_name = 'Accuracy'
  99. names_to_values[metric_name], names_to_updates[
  100. metric_name] = tf.contrib.metrics.streaming_accuracy(
  101. tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
  102. metric_name = 'Accuracy'
  103. names_to_values[metric_name], names_to_updates[
  104. metric_name] = tf.contrib.metrics.streaming_accuracy(
  105. tf.argmax(predictions, 1), tf.argmax(labels['classes'], 1))
  106. # Create the summary ops such that they also print out to std output:
  107. summary_ops = []
  108. for metric_name, metric_value in names_to_values.iteritems():
  109. op = tf.contrib.deprecated.scalar_summary(metric_name, metric_value)
  110. op = tf.Print(op, [metric_value], metric_name)
  111. summary_ops.append(op)
  112. # This ensures that we make a single pass over all of the data.
  113. num_batches = math.ceil(FLAGS.num_examples / float(FLAGS.batch_size))
  114. # Setup the global step.
  115. slim.get_or_create_global_step()
  116. slim.evaluation.evaluation_loop(
  117. FLAGS.master,
  118. checkpoint_dir=FLAGS.checkpoint_dir,
  119. logdir=FLAGS.eval_dir,
  120. num_evals=num_batches,
  121. eval_op=names_to_updates.values(),
  122. summary_op=tf.contrib.deprecated.merge_summary(summary_ops))
  123. if __name__ == '__main__':
  124. tf.app.run()