losses.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Domain Adaptation Loss Functions.
  16. The following domain adaptation loss functions are defined:
  17. - Maximum Mean Discrepancy (MMD).
  18. Relevant paper:
  19. Gretton, Arthur, et al.,
  20. "A kernel two-sample test."
  21. The Journal of Machine Learning Research, 2012
  22. - Correlation Loss on a batch.
  23. """
  24. from functools import partial
  25. import tensorflow as tf
  26. import grl_op_grads # pylint: disable=unused-import
  27. import grl_op_shapes # pylint: disable=unused-import
  28. import grl_ops
  29. import utils
  30. slim = tf.contrib.slim
  31. ################################################################################
  32. # SIMILARITY LOSS
  33. ################################################################################
  34. def maximum_mean_discrepancy(x, y, kernel=utils.gaussian_kernel_matrix):
  35. r"""Computes the Maximum Mean Discrepancy (MMD) of two samples: x and y.
  36. Maximum Mean Discrepancy (MMD) is a distance-measure between the samples of
  37. the distributions of x and y. Here we use the kernel two sample estimate
  38. using the empirical mean of the two distributions.
  39. MMD^2(P, Q) = || \E{\phi(x)} - \E{\phi(y)} ||^2
  40. = \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) },
  41. where K = <\phi(x), \phi(y)>,
  42. is the desired kernel function, in this case a radial basis kernel.
  43. Args:
  44. x: a tensor of shape [num_samples, num_features]
  45. y: a tensor of shape [num_samples, num_features]
  46. kernel: a function which computes the kernel in MMD. Defaults to the
  47. GaussianKernelMatrix.
  48. Returns:
  49. a scalar denoting the squared maximum mean discrepancy loss.
  50. """
  51. with tf.name_scope('MaximumMeanDiscrepancy'):
  52. # \E{ K(x, x) } + \E{ K(y, y) } - 2 \E{ K(x, y) }
  53. cost = tf.reduce_mean(kernel(x, x))
  54. cost += tf.reduce_mean(kernel(y, y))
  55. cost -= 2 * tf.reduce_mean(kernel(x, y))
  56. # We do not allow the loss to become negative.
  57. cost = tf.where(cost > 0, cost, 0, name='value')
  58. return cost
  59. def mmd_loss(source_samples, target_samples, weight, scope=None):
  60. """Adds a similarity loss term, the MMD between two representations.
  61. This Maximum Mean Discrepancy (MMD) loss is calculated with a number of
  62. different Gaussian kernels.
  63. Args:
  64. source_samples: a tensor of shape [num_samples, num_features].
  65. target_samples: a tensor of shape [num_samples, num_features].
  66. weight: the weight of the MMD loss.
  67. scope: optional name scope for summary tags.
  68. Returns:
  69. a scalar tensor representing the MMD loss value.
  70. """
  71. sigmas = [
  72. 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1, 5, 10, 15, 20, 25, 30, 35, 100,
  73. 1e3, 1e4, 1e5, 1e6
  74. ]
  75. gaussian_kernel = partial(
  76. utils.gaussian_kernel_matrix, sigmas=tf.constant(sigmas))
  77. loss_value = maximum_mean_discrepancy(
  78. source_samples, target_samples, kernel=gaussian_kernel)
  79. loss_value = tf.maximum(1e-4, loss_value) * weight
  80. assert_op = tf.Assert(tf.is_finite(loss_value), [loss_value])
  81. with tf.control_dependencies([assert_op]):
  82. tag = 'MMD Loss'
  83. if scope:
  84. tag = scope + tag
  85. tf.contrib.deprecated.scalar_summary(tag, loss_value)
  86. tf.losses.add_loss(loss_value)
  87. return loss_value
  88. def correlation_loss(source_samples, target_samples, weight, scope=None):
  89. """Adds a similarity loss term, the correlation between two representations.
  90. Args:
  91. source_samples: a tensor of shape [num_samples, num_features]
  92. target_samples: a tensor of shape [num_samples, num_features]
  93. weight: a scalar weight for the loss.
  94. scope: optional name scope for summary tags.
  95. Returns:
  96. a scalar tensor representing the correlation loss value.
  97. """
  98. with tf.name_scope('corr_loss'):
  99. source_samples -= tf.reduce_mean(source_samples, 0)
  100. target_samples -= tf.reduce_mean(target_samples, 0)
  101. source_samples = tf.nn.l2_normalize(source_samples, 1)
  102. target_samples = tf.nn.l2_normalize(target_samples, 1)
  103. source_cov = tf.matmul(tf.transpose(source_samples), source_samples)
  104. target_cov = tf.matmul(tf.transpose(target_samples), target_samples)
  105. corr_loss = tf.reduce_mean(tf.square(source_cov - target_cov)) * weight
  106. assert_op = tf.Assert(tf.is_finite(corr_loss), [corr_loss])
  107. with tf.control_dependencies([assert_op]):
  108. tag = 'Correlation Loss'
  109. if scope:
  110. tag = scope + tag
  111. tf.contrib.deprecated.scalar_summary(tag, corr_loss)
  112. tf.losses.add_loss(corr_loss)
  113. return corr_loss
  114. def dann_loss(source_samples, target_samples, weight, scope=None):
  115. """Adds the domain adversarial (DANN) loss.
  116. Args:
  117. source_samples: a tensor of shape [num_samples, num_features].
  118. target_samples: a tensor of shape [num_samples, num_features].
  119. weight: the weight of the loss.
  120. scope: optional name scope for summary tags.
  121. Returns:
  122. a scalar tensor representing the correlation loss value.
  123. """
  124. with tf.variable_scope('dann'):
  125. batch_size = tf.shape(source_samples)[0]
  126. samples = tf.concat([source_samples, target_samples], 0)
  127. samples = slim.flatten(samples)
  128. domain_selection_mask = tf.concat(
  129. [tf.zeros((batch_size, 1)), tf.ones((batch_size, 1))], 0)
  130. # Perform the gradient reversal and be careful with the shape.
  131. grl = grl_ops.gradient_reversal(samples)
  132. grl = tf.reshape(grl, (-1, samples.get_shape().as_list()[1]))
  133. grl = slim.fully_connected(grl, 100, scope='fc1')
  134. logits = slim.fully_connected(grl, 1, activation_fn=None, scope='fc2')
  135. domain_predictions = tf.sigmoid(logits)
  136. domain_loss = tf.losses.log_loss(
  137. domain_selection_mask, domain_predictions, weights=weight)
  138. domain_accuracy = utils.accuracy(
  139. tf.round(domain_predictions), domain_selection_mask)
  140. assert_op = tf.Assert(tf.is_finite(domain_loss), [domain_loss])
  141. with tf.control_dependencies([assert_op]):
  142. tag_loss = 'losses/Domain Loss'
  143. tag_accuracy = 'losses/Domain Accuracy'
  144. if scope:
  145. tag_loss = scope + tag_loss
  146. tag_accuracy = scope + tag_accuracy
  147. tf.contrib.deprecated.scalar_summary(
  148. tag_loss, domain_loss, name='domain_loss_summary')
  149. tf.contrib.deprecated.scalar_summary(
  150. tag_accuracy, domain_accuracy, name='domain_accuracy_summary')
  151. return domain_loss
  152. ################################################################################
  153. # DIFFERENCE LOSS
  154. ################################################################################
  155. def difference_loss(private_samples, shared_samples, weight=1.0, name=''):
  156. """Adds the difference loss between the private and shared representations.
  157. Args:
  158. private_samples: a tensor of shape [num_samples, num_features].
  159. shared_samples: a tensor of shape [num_samples, num_features].
  160. weight: the weight of the incoherence loss.
  161. name: the name of the tf summary.
  162. """
  163. private_samples -= tf.reduce_mean(private_samples, 0)
  164. shared_samples -= tf.reduce_mean(shared_samples, 0)
  165. private_samples = tf.nn.l2_normalize(private_samples, 1)
  166. shared_samples = tf.nn.l2_normalize(shared_samples, 1)
  167. correlation_matrix = tf.matmul(
  168. private_samples, shared_samples, transpose_a=True)
  169. cost = tf.reduce_mean(tf.square(correlation_matrix)) * weight
  170. cost = tf.where(cost > 0, cost, 0, name='value')
  171. tf.contrib.deprecated.scalar_summary('losses/Difference Loss {}'.format(name),
  172. cost)
  173. assert_op = tf.Assert(tf.is_finite(cost), [cost])
  174. with tf.control_dependencies([assert_op]):
  175. tf.losses.add_loss(cost)
  176. ################################################################################
  177. # TASK LOSS
  178. ################################################################################
  179. def log_quaternion_loss_batch(predictions, labels, params):
  180. """A helper function to compute the error between quaternions.
  181. Args:
  182. predictions: A Tensor of size [batch_size, 4].
  183. labels: A Tensor of size [batch_size, 4].
  184. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
  185. Returns:
  186. A Tensor of size [batch_size], denoting the error between the quaternions.
  187. """
  188. use_logging = params['use_logging']
  189. assertions = []
  190. if use_logging:
  191. assertions.append(
  192. tf.Assert(
  193. tf.reduce_all(
  194. tf.less(
  195. tf.abs(tf.reduce_sum(tf.square(predictions), [1]) - 1),
  196. 1e-4)),
  197. ['The l2 norm of each prediction quaternion vector should be 1.']))
  198. assertions.append(
  199. tf.Assert(
  200. tf.reduce_all(
  201. tf.less(
  202. tf.abs(tf.reduce_sum(tf.square(labels), [1]) - 1), 1e-4)),
  203. ['The l2 norm of each label quaternion vector should be 1.']))
  204. with tf.control_dependencies(assertions):
  205. product = tf.multiply(predictions, labels)
  206. internal_dot_products = tf.reduce_sum(product, [1])
  207. if use_logging:
  208. internal_dot_products = tf.Print(
  209. internal_dot_products,
  210. [internal_dot_products, tf.shape(internal_dot_products)],
  211. 'internal_dot_products:')
  212. logcost = tf.log(1e-4 + 1 - tf.abs(internal_dot_products))
  213. return logcost
  214. def log_quaternion_loss(predictions, labels, params):
  215. """A helper function to compute the mean error between batches of quaternions.
  216. The caller is expected to add the loss to the graph.
  217. Args:
  218. predictions: A Tensor of size [batch_size, 4].
  219. labels: A Tensor of size [batch_size, 4].
  220. params: A dictionary of parameters. Expecting 'use_logging', 'batch_size'.
  221. Returns:
  222. A Tensor of size 1, denoting the mean error between batches of quaternions.
  223. """
  224. use_logging = params['use_logging']
  225. logcost = log_quaternion_loss_batch(predictions, labels, params)
  226. logcost = tf.reduce_sum(logcost, [0])
  227. batch_size = params['batch_size']
  228. logcost = tf.multiply(logcost, 1.0 / batch_size, name='log_quaternion_loss')
  229. if use_logging:
  230. logcost = tf.Print(
  231. logcost, [logcost], '[logcost]', name='log_quaternion_loss_print')
  232. return logcost