dsn_test.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for DSN model assembly functions."""
  16. import numpy as np
  17. import tensorflow as tf
  18. import dsn
  19. class HelperFunctionsTest(tf.test.TestCase):
  20. def testBasicDomainSeparationStartPoint(self):
  21. with self.test_session() as sess:
  22. # Test for when global_step < domain_separation_startpoint
  23. step = tf.contrib.slim.get_or_create_global_step()
  24. sess.run(tf.global_variables_initializer()) # global_step = 0
  25. params = {'domain_separation_startpoint': 2}
  26. weight = dsn.dsn_loss_coefficient(params)
  27. weight_np = sess.run(weight)
  28. self.assertAlmostEqual(weight_np, 1e-10)
  29. step_op = tf.assign_add(step, 1)
  30. step_np = sess.run(step_op) # global_step = 1
  31. weight = dsn.dsn_loss_coefficient(params)
  32. weight_np = sess.run(weight)
  33. self.assertAlmostEqual(weight_np, 1e-10)
  34. # Test for when global_step >= domain_separation_startpoint
  35. step_np = sess.run(step_op) # global_step = 2
  36. tf.logging.info(step_np)
  37. weight = dsn.dsn_loss_coefficient(params)
  38. weight_np = sess.run(weight)
  39. self.assertAlmostEqual(weight_np, 1.0)
  40. class DsnModelAssemblyTest(tf.test.TestCase):
  41. def _testBuildDefaultModel(self):
  42. images = tf.to_float(np.random.rand(32, 28, 28, 1))
  43. labels = {}
  44. labels['classes'] = tf.one_hot(
  45. tf.to_int32(np.random.randint(0, 9, (32))), 10)
  46. params = {
  47. 'use_separation': True,
  48. 'layers_to_regularize': 'fc3',
  49. 'weight_decay': 0.0,
  50. 'ps_tasks': 1,
  51. 'domain_separation_startpoint': 1,
  52. 'alpha_weight': 1,
  53. 'beta_weight': 1,
  54. 'gamma_weight': 1,
  55. 'recon_loss_name': 'sum_of_squares',
  56. 'decoder_name': 'small_decoder',
  57. 'encoder_name': 'default_encoder',
  58. }
  59. return images, labels, params
  60. def testBuildModelDann(self):
  61. images, labels, params = self._testBuildDefaultModel()
  62. with self.test_session():
  63. dsn.create_model(images, labels,
  64. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  65. 'dann_loss', params, 'dann_mnist')
  66. loss_tensors = tf.contrib.losses.get_losses()
  67. self.assertEqual(len(loss_tensors), 6)
  68. def testBuildModelDannSumOfPairwiseSquares(self):
  69. images, labels, params = self._testBuildDefaultModel()
  70. with self.test_session():
  71. dsn.create_model(images, labels,
  72. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  73. 'dann_loss', params, 'dann_mnist')
  74. loss_tensors = tf.contrib.losses.get_losses()
  75. self.assertEqual(len(loss_tensors), 6)
  76. def testBuildModelDannMultiPSTasks(self):
  77. images, labels, params = self._testBuildDefaultModel()
  78. params['ps_tasks'] = 10
  79. with self.test_session():
  80. dsn.create_model(images, labels,
  81. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  82. 'dann_loss', params, 'dann_mnist')
  83. loss_tensors = tf.contrib.losses.get_losses()
  84. self.assertEqual(len(loss_tensors), 6)
  85. def testBuildModelMmd(self):
  86. images, labels, params = self._testBuildDefaultModel()
  87. with self.test_session():
  88. dsn.create_model(images, labels,
  89. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  90. 'mmd_loss', params, 'dann_mnist')
  91. loss_tensors = tf.contrib.losses.get_losses()
  92. self.assertEqual(len(loss_tensors), 6)
  93. def testBuildModelCorr(self):
  94. images, labels, params = self._testBuildDefaultModel()
  95. with self.test_session():
  96. dsn.create_model(images, labels,
  97. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  98. 'correlation_loss', params, 'dann_mnist')
  99. loss_tensors = tf.contrib.losses.get_losses()
  100. self.assertEqual(len(loss_tensors), 6)
  101. def testBuildModelNoDomainAdaptation(self):
  102. images, labels, params = self._testBuildDefaultModel()
  103. params['use_separation'] = False
  104. with self.test_session():
  105. dsn.create_model(images, labels,
  106. tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
  107. params, 'dann_mnist')
  108. loss_tensors = tf.contrib.losses.get_losses()
  109. self.assertEqual(len(loss_tensors), 1)
  110. self.assertEqual(len(tf.contrib.losses.get_regularization_losses()), 0)
  111. def testBuildModelNoAdaptationWeightDecay(self):
  112. images, labels, params = self._testBuildDefaultModel()
  113. params['use_separation'] = False
  114. params['weight_decay'] = 1e-5
  115. with self.test_session():
  116. dsn.create_model(images, labels,
  117. tf.cast(tf.ones([32,]), tf.bool), images, labels, 'none',
  118. params, 'dann_mnist')
  119. loss_tensors = tf.contrib.losses.get_losses()
  120. self.assertEqual(len(loss_tensors), 1)
  121. self.assertTrue(len(tf.contrib.losses.get_regularization_losses()) >= 1)
  122. def testBuildModelNoSeparation(self):
  123. images, labels, params = self._testBuildDefaultModel()
  124. params['use_separation'] = False
  125. with self.test_session():
  126. dsn.create_model(images, labels,
  127. tf.cast(tf.ones([32,]), tf.bool), images, labels,
  128. 'dann_loss', params, 'dann_mnist')
  129. loss_tensors = tf.contrib.losses.get_losses()
  130. self.assertEqual(len(loss_tensors), 2)
  131. if __name__ == '__main__':
  132. tf.test.main()