losses_test.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for DSN losses."""
  16. from functools import partial
  17. import numpy as np
  18. import tensorflow as tf
  19. import losses
  20. import utils
  21. def MaximumMeanDiscrepancySlow(x, y, sigmas):
  22. num_samples = x.get_shape().as_list()[0]
  23. def AverageGaussianKernel(x, y, sigmas):
  24. result = 0
  25. for sigma in sigmas:
  26. dist = tf.reduce_sum(tf.square(x - y))
  27. result += tf.exp((-1.0 / (2.0 * sigma)) * dist)
  28. return result / num_samples**2
  29. total = 0
  30. for i in range(num_samples):
  31. for j in range(num_samples):
  32. total += AverageGaussianKernel(x[i, :], x[j, :], sigmas)
  33. total += AverageGaussianKernel(y[i, :], y[j, :], sigmas)
  34. total += -2 * AverageGaussianKernel(x[i, :], y[j, :], sigmas)
  35. return total
  36. class LogQuaternionLossTest(tf.test.TestCase):
  37. def test_log_quaternion_loss_batch(self):
  38. with self.test_session():
  39. predictions = tf.random_uniform((10, 4), seed=1)
  40. predictions = tf.nn.l2_normalize(predictions, 1)
  41. labels = tf.random_uniform((10, 4), seed=1)
  42. labels = tf.nn.l2_normalize(labels, 1)
  43. params = {'batch_size': 10, 'use_logging': False}
  44. x = losses.log_quaternion_loss_batch(predictions, labels, params)
  45. self.assertTrue(((10,) == tf.shape(x).eval()).all())
  46. class MaximumMeanDiscrepancyTest(tf.test.TestCase):
  47. def test_mmd_name(self):
  48. with self.test_session():
  49. x = tf.random_uniform((2, 3), seed=1)
  50. kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
  51. loss = losses.maximum_mean_discrepancy(x, x, kernel)
  52. self.assertEquals(loss.op.name, 'MaximumMeanDiscrepancy/value')
  53. def test_mmd_is_zero_when_inputs_are_same(self):
  54. with self.test_session():
  55. x = tf.random_uniform((2, 3), seed=1)
  56. kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
  57. self.assertEquals(0, losses.maximum_mean_discrepancy(x, x, kernel).eval())
  58. def test_fast_mmd_is_similar_to_slow_mmd(self):
  59. with self.test_session():
  60. x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
  61. y = tf.constant(np.random.rand(2, 3), tf.float32)
  62. cost_old = MaximumMeanDiscrepancySlow(x, y, [1.]).eval()
  63. kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([1.]))
  64. cost_new = losses.maximum_mean_discrepancy(x, y, kernel).eval()
  65. self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
  66. def test_multiple_sigmas(self):
  67. with self.test_session():
  68. x = tf.constant(np.random.normal(size=(2, 3)), tf.float32)
  69. y = tf.constant(np.random.rand(2, 3), tf.float32)
  70. sigmas = tf.constant([2., 5., 10, 20, 30])
  71. kernel = partial(utils.gaussian_kernel_matrix, sigmas=sigmas)
  72. cost_old = MaximumMeanDiscrepancySlow(x, y, [2., 5., 10, 20, 30]).eval()
  73. cost_new = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
  74. self.assertAlmostEqual(cost_old, cost_new, delta=1e-5)
  75. def test_mmd_is_zero_when_distributions_are_same(self):
  76. with self.test_session():
  77. x = tf.random_uniform((1000, 10), seed=1)
  78. y = tf.random_uniform((1000, 10), seed=3)
  79. kernel = partial(utils.gaussian_kernel_matrix, sigmas=tf.constant([100.]))
  80. loss = losses.maximum_mean_discrepancy(x, y, kernel=kernel).eval()
  81. self.assertAlmostEqual(0, loss, delta=1e-4)
  82. if __name__ == '__main__':
  83. tf.test.main()