grl_ops_test.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for grl_ops."""
  16. #from models.domain_adaptation.domain_separation import grl_op_grads # pylint: disable=unused-import
  17. #from models.domain_adaptation.domain_separation import grl_op_shapes # pylint: disable=unused-import
  18. import tensorflow as tf
  19. import grl_op_grads
  20. import grl_ops
  21. FLAGS = tf.app.flags.FLAGS
  22. class GRLOpsTest(tf.test.TestCase):
  23. def testGradientReversalOp(self):
  24. with tf.Graph().as_default():
  25. with self.test_session():
  26. # Test that in forward prop, gradient reversal op acts as the
  27. # identity operation.
  28. examples = tf.constant([5.0, 4.0, 3.0, 2.0, 1.0])
  29. output = grl_ops.gradient_reversal(examples)
  30. expected_output = examples
  31. self.assertAllEqual(output.eval(), expected_output.eval())
  32. # Test that shape inference works as expected.
  33. self.assertAllEqual(output.get_shape(), expected_output.get_shape())
  34. # Test that in backward prop, gradient reversal op multiplies
  35. # gradients by -1.
  36. examples = tf.constant([[1.0]])
  37. w = tf.get_variable(name='w', shape=[1, 1])
  38. b = tf.get_variable(name='b', shape=[1])
  39. init_op = tf.global_variables_initializer()
  40. init_op.run()
  41. features = tf.nn.xw_plus_b(examples, w, b)
  42. # Construct two outputs: features layer passes directly to output1, but
  43. # features layer passes through a gradient reversal layer before
  44. # reaching output2.
  45. output1 = features
  46. output2 = grl_ops.gradient_reversal(features)
  47. gold = tf.constant([1.0])
  48. loss1 = gold - output1
  49. loss2 = gold - output2
  50. opt = tf.train.GradientDescentOptimizer(learning_rate=0.01)
  51. grads_and_vars_1 = opt.compute_gradients(loss1,
  52. tf.trainable_variables())
  53. grads_and_vars_2 = opt.compute_gradients(loss2,
  54. tf.trainable_variables())
  55. self.assertAllEqual(len(grads_and_vars_1), len(grads_and_vars_2))
  56. for i in range(len(grads_and_vars_1)):
  57. g1 = grads_and_vars_1[i][0]
  58. g2 = grads_and_vars_2[i][0]
  59. # Verify that gradients of loss1 are the negative of gradients of
  60. # loss2.
  61. self.assertAllEqual(tf.negative(g1).eval(), g2.eval())
  62. if __name__ == '__main__':
  63. tf.test.main()