digraph_ops_test.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for digraph ops."""
  16. import tensorflow as tf
  17. from dragnn.python import digraph_ops
  18. class DigraphOpsTest(tf.test.TestCase):
  19. """Testing rig."""
  20. def testArcPotentialsFromTokens(self):
  21. with self.test_session():
  22. # Batch of two, where the second batch item is the reverse of the first.
  23. source_tokens = tf.constant([[[1, 2],
  24. [2, 3],
  25. [3, 4]],
  26. [[3, 4],
  27. [2, 3],
  28. [1, 2]]], tf.float32)
  29. target_tokens = tf.constant([[[4, 5, 6],
  30. [5, 6, 7],
  31. [6, 7, 8]],
  32. [[6, 7, 8],
  33. [5, 6, 7],
  34. [4, 5, 6]]], tf.float32)
  35. weights = tf.constant([[2, 3, 5],
  36. [7, 11, 13]],
  37. tf.float32)
  38. arcs = digraph_ops.ArcPotentialsFromTokens(source_tokens, target_tokens,
  39. weights)
  40. # For example,
  41. # ((1 * 2 * 4 + 1 * 3 * 5 + 1 * 5 * 6) +
  42. # (2 * 7 * 4 + 2 * 11 * 5 + 2 * 13 * 6)) = 375
  43. self.assertAllEqual(arcs.eval(),
  44. [[[375, 447, 519],
  45. [589, 702, 815],
  46. [803, 957, 1111]],
  47. [[1111, 957, 803], # reflected through the center
  48. [815, 702, 589],
  49. [519, 447, 375]]])
  50. def testArcSourcePotentialsFromTokens(self):
  51. with self.test_session():
  52. tokens = tf.constant([[[4, 5, 6],
  53. [5, 6, 7],
  54. [6, 7, 8]],
  55. [[6, 7, 8],
  56. [5, 6, 7],
  57. [4, 5, 6]]], tf.float32)
  58. weights = tf.constant([2, 3, 5], tf.float32)
  59. arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights)
  60. self.assertAllEqual(arcs.eval(), [[[53, 53, 53],
  61. [63, 63, 63],
  62. [73, 73, 73]],
  63. [[73, 73, 73],
  64. [63, 63, 63],
  65. [53, 53, 53]]])
  66. def testRootPotentialsFromTokens(self):
  67. with self.test_session():
  68. root = tf.constant([1, 2], tf.float32)
  69. tokens = tf.constant([[[4, 5, 6],
  70. [5, 6, 7],
  71. [6, 7, 8]],
  72. [[6, 7, 8],
  73. [5, 6, 7],
  74. [4, 5, 6]]], tf.float32)
  75. weights = tf.constant([[2, 3, 5],
  76. [7, 11, 13]],
  77. tf.float32)
  78. roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights)
  79. self.assertAllEqual(roots.eval(), [[375, 447, 519],
  80. [519, 447, 375]])
  81. def testCombineArcAndRootPotentials(self):
  82. with self.test_session():
  83. arcs = tf.constant([[[1, 2, 3],
  84. [2, 3, 4],
  85. [3, 4, 5]],
  86. [[3, 4, 5],
  87. [2, 3, 4],
  88. [1, 2, 3]]], tf.float32)
  89. roots = tf.constant([[6, 7, 8],
  90. [8, 7, 6]], tf.float32)
  91. potentials = digraph_ops.CombineArcAndRootPotentials(arcs, roots)
  92. self.assertAllEqual(potentials.eval(), [[[6, 2, 3],
  93. [2, 7, 4],
  94. [3, 4, 8]],
  95. [[8, 4, 5],
  96. [2, 7, 4],
  97. [1, 2, 6]]])
  98. def testLabelPotentialsFromTokens(self):
  99. with self.test_session():
  100. tokens = tf.constant([[[1, 2],
  101. [3, 4],
  102. [5, 6]],
  103. [[6, 5],
  104. [4, 3],
  105. [2, 1]]], tf.float32)
  106. weights = tf.constant([[ 2, 3],
  107. [ 5, 7],
  108. [11, 13]], tf.float32)
  109. labels = digraph_ops.LabelPotentialsFromTokens(tokens, weights)
  110. self.assertAllEqual(labels.eval(),
  111. [[[ 8, 19, 37],
  112. [ 18, 43, 85],
  113. [ 28, 67, 133]],
  114. [[ 27, 65, 131],
  115. [ 17, 41, 83],
  116. [ 7, 17, 35]]])
  117. def testLabelPotentialsFromTokenPairs(self):
  118. with self.test_session():
  119. sources = tf.constant([[[1, 2],
  120. [3, 4],
  121. [5, 6]],
  122. [[6, 5],
  123. [4, 3],
  124. [2, 1]]], tf.float32)
  125. targets = tf.constant([[[3, 4],
  126. [5, 6],
  127. [7, 8]],
  128. [[8, 7],
  129. [6, 5],
  130. [4, 3]]], tf.float32)
  131. weights = tf.constant([[[ 2, 3],
  132. [ 5, 7]],
  133. [[11, 13],
  134. [17, 19]],
  135. [[23, 29],
  136. [31, 37]]], tf.float32)
  137. labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets,
  138. weights)
  139. self.assertAllEqual(labels.eval(),
  140. [[[ 104, 339, 667],
  141. [ 352, 1195, 2375],
  142. [ 736, 2531, 5043]],
  143. [[ 667, 2419, 4857],
  144. [ 303, 1115, 2245],
  145. [ 75, 291, 593]]])
  146. if __name__ == "__main__":
  147. tf.test.main()