utils.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Auxiliary functions for domain adaptation related losses.
  16. """
  17. import math
  18. import tensorflow as tf
  19. def create_summaries(end_points, prefix='', max_images=3, use_op_name=False):
  20. """Creates a tf summary per endpoint.
  21. If the endpoint is a 4 dimensional tensor it displays it as an image
  22. otherwise if it is a two dimensional one it creates a histogram summary.
  23. Args:
  24. end_points: a dictionary of name, tf tensor pairs.
  25. prefix: an optional string to prefix the summary with.
  26. max_images: the maximum number of images to display per summary.
  27. use_op_name: Use the op name as opposed to the shorter end_points key.
  28. """
  29. for layer_name in end_points:
  30. if use_op_name:
  31. name = end_points[layer_name].op.name
  32. else:
  33. name = layer_name
  34. if len(end_points[layer_name].get_shape().as_list()) == 4:
  35. # if it's an actual image do not attempt to reshape it
  36. if end_points[layer_name].get_shape().as_list()[-1] == 1 or end_points[
  37. layer_name].get_shape().as_list()[-1] == 3:
  38. visualization_image = end_points[layer_name]
  39. else:
  40. visualization_image = reshape_feature_maps(end_points[layer_name])
  41. tf.summary.image(
  42. '{}/{}'.format(prefix, name),
  43. visualization_image,
  44. max_outputs=max_images)
  45. elif len(end_points[layer_name].get_shape().as_list()) == 3:
  46. images = tf.expand_dims(end_points[layer_name], 3)
  47. tf.summary.image(
  48. '{}/{}'.format(prefix, name),
  49. images,
  50. max_outputs=max_images)
  51. elif len(end_points[layer_name].get_shape().as_list()) == 2:
  52. tf.summary.histogram('{}/{}'.format(prefix, name), end_points[layer_name])
  53. def reshape_feature_maps(features_tensor):
  54. """Reshape activations for tf.summary.image visualization.
  55. Arguments:
  56. features_tensor: a tensor of activations with a square number of feature
  57. maps, eg 4, 9, 16, etc.
  58. Returns:
  59. A composite image with all the feature maps that can be passed as an
  60. argument to tf.summary.image.
  61. """
  62. assert len(features_tensor.get_shape().as_list()) == 4
  63. num_filters = features_tensor.get_shape().as_list()[-1]
  64. assert num_filters > 0
  65. num_filters_sqrt = math.sqrt(num_filters)
  66. assert num_filters_sqrt.is_integer(
  67. ), 'Number of filters should be a square number but got {}'.format(
  68. num_filters)
  69. num_filters_sqrt = int(num_filters_sqrt)
  70. conv_summary = tf.unstack(features_tensor, axis=3)
  71. conv_one_row = tf.concat(axis=2, values=conv_summary[0:num_filters_sqrt])
  72. ind = 1
  73. conv_final = conv_one_row
  74. for ind in range(1, num_filters_sqrt):
  75. conv_one_row = tf.concat(axis=2,
  76. values=conv_summary[
  77. ind * num_filters_sqrt + 0:ind * num_filters_sqrt + num_filters_sqrt])
  78. conv_final = tf.concat(
  79. axis=1, values=[tf.squeeze(conv_final), tf.squeeze(conv_one_row)])
  80. conv_final = tf.expand_dims(conv_final, -1)
  81. return conv_final
  82. def accuracy(predictions, labels):
  83. """Calculates the classificaton accuracy.
  84. Args:
  85. predictions: the predicted values, a tensor whose size matches 'labels'.
  86. labels: the ground truth values, a tensor of any size.
  87. Returns:
  88. a tensor whose value on evaluation returns the total accuracy.
  89. """
  90. return tf.reduce_mean(tf.cast(tf.equal(predictions, labels), tf.float32))
  91. def compute_upsample_values(input_tensor, upsample_height, upsample_width):
  92. """Compute values for an upsampling op (ops.BatchCropAndResize).
  93. Args:
  94. input_tensor: image tensor with shape [batch, height, width, in_channels]
  95. upsample_height: integer
  96. upsample_width: integer
  97. Returns:
  98. grid_centers: tensor with shape [batch, 1]
  99. crop_sizes: tensor with shape [batch, 1]
  100. output_height: integer
  101. output_width: integer
  102. """
  103. batch, input_height, input_width, _ = input_tensor.shape
  104. height_half = input_height / 2.
  105. width_half = input_width / 2.
  106. grid_centers = tf.constant(batch * [[height_half, width_half]])
  107. crop_sizes = tf.constant(batch * [[input_height, input_width]])
  108. output_height = input_height * upsample_height
  109. output_width = input_width * upsample_width
  110. return grid_centers, tf.to_float(crop_sizes), output_height, output_width
  111. def compute_pairwise_distances(x, y):
  112. """Computes the squared pairwise Euclidean distances between x and y.
  113. Args:
  114. x: a tensor of shape [num_x_samples, num_features]
  115. y: a tensor of shape [num_y_samples, num_features]
  116. Returns:
  117. a distance matrix of dimensions [num_x_samples, num_y_samples].
  118. Raises:
  119. ValueError: if the inputs do no matched the specified dimensions.
  120. """
  121. if not len(x.get_shape()) == len(y.get_shape()) == 2:
  122. raise ValueError('Both inputs should be matrices.')
  123. if x.get_shape().as_list()[1] != y.get_shape().as_list()[1]:
  124. raise ValueError('The number of features should be the same.')
  125. norm = lambda x: tf.reduce_sum(tf.square(x), 1)
  126. # By making the `inner' dimensions of the two matrices equal to 1 using
  127. # broadcasting then we are essentially substracting every pair of rows
  128. # of x and y.
  129. # x will be num_samples x num_features x 1,
  130. # and y will be 1 x num_features x num_samples (after broadcasting).
  131. # After the substraction we will get a
  132. # num_x_samples x num_features x num_y_samples matrix.
  133. # The resulting dist will be of shape num_y_samples x num_x_samples.
  134. # and thus we need to transpose it again.
  135. return tf.transpose(norm(tf.expand_dims(x, 2) - tf.transpose(y)))
  136. def gaussian_kernel_matrix(x, y, sigmas):
  137. r"""Computes a Guassian Radial Basis Kernel between the samples of x and y.
  138. We create a sum of multiple gaussian kernels each having a width sigma_i.
  139. Args:
  140. x: a tensor of shape [num_samples, num_features]
  141. y: a tensor of shape [num_samples, num_features]
  142. sigmas: a tensor of floats which denote the widths of each of the
  143. gaussians in the kernel.
  144. Returns:
  145. A tensor of shape [num_samples{x}, num_samples{y}] with the RBF kernel.
  146. """
  147. beta = 1. / (2. * (tf.expand_dims(sigmas, 1)))
  148. dist = compute_pairwise_distances(x, y)
  149. s = tf.matmul(beta, tf.reshape(dist, (1, -1)))
  150. return tf.reshape(tf.reduce_sum(tf.exp(-s), 0), tf.shape(dist))