digraph_ops_test.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. """Tests for digraph ops."""
  2. import tensorflow as tf
  3. from dragnn.python import digraph_ops
  4. class DigraphOpsTest(tf.test.TestCase):
  5. """Testing rig."""
  6. def testArcPotentialsFromTokens(self):
  7. with self.test_session():
  8. # Batch of two, where the second batch item is the reverse of the first.
  9. source_tokens = tf.constant([[[1, 2],
  10. [2, 3],
  11. [3, 4]],
  12. [[3, 4],
  13. [2, 3],
  14. [1, 2]]], tf.float32)
  15. target_tokens = tf.constant([[[4, 5, 6],
  16. [5, 6, 7],
  17. [6, 7, 8]],
  18. [[6, 7, 8],
  19. [5, 6, 7],
  20. [4, 5, 6]]], tf.float32)
  21. weights = tf.constant([[2, 3, 5],
  22. [7, 11, 13]],
  23. tf.float32)
  24. arcs = digraph_ops.ArcPotentialsFromTokens(source_tokens, target_tokens,
  25. weights)
  26. # For example,
  27. # ((1 * 2 * 4 + 1 * 3 * 5 + 1 * 5 * 6) +
  28. # (2 * 7 * 4 + 2 * 11 * 5 + 2 * 13 * 6)) = 375
  29. self.assertAllEqual(arcs.eval(),
  30. [[[375, 447, 519],
  31. [589, 702, 815],
  32. [803, 957, 1111]],
  33. [[1111, 957, 803], # reflected through the center
  34. [815, 702, 589],
  35. [519, 447, 375]]])
  36. def testArcSourcePotentialsFromTokens(self):
  37. with self.test_session():
  38. tokens = tf.constant([[[4, 5, 6],
  39. [5, 6, 7],
  40. [6, 7, 8]],
  41. [[6, 7, 8],
  42. [5, 6, 7],
  43. [4, 5, 6]]], tf.float32)
  44. weights = tf.constant([2, 3, 5], tf.float32)
  45. arcs = digraph_ops.ArcSourcePotentialsFromTokens(tokens, weights)
  46. self.assertAllEqual(arcs.eval(), [[[53, 53, 53],
  47. [63, 63, 63],
  48. [73, 73, 73]],
  49. [[73, 73, 73],
  50. [63, 63, 63],
  51. [53, 53, 53]]])
  52. def testRootPotentialsFromTokens(self):
  53. with self.test_session():
  54. root = tf.constant([1, 2], tf.float32)
  55. tokens = tf.constant([[[4, 5, 6],
  56. [5, 6, 7],
  57. [6, 7, 8]],
  58. [[6, 7, 8],
  59. [5, 6, 7],
  60. [4, 5, 6]]], tf.float32)
  61. weights = tf.constant([[2, 3, 5],
  62. [7, 11, 13]],
  63. tf.float32)
  64. roots = digraph_ops.RootPotentialsFromTokens(root, tokens, weights)
  65. self.assertAllEqual(roots.eval(), [[375, 447, 519],
  66. [519, 447, 375]])
  67. def testCombineArcAndRootPotentials(self):
  68. with self.test_session():
  69. arcs = tf.constant([[[1, 2, 3],
  70. [2, 3, 4],
  71. [3, 4, 5]],
  72. [[3, 4, 5],
  73. [2, 3, 4],
  74. [1, 2, 3]]], tf.float32)
  75. roots = tf.constant([[6, 7, 8],
  76. [8, 7, 6]], tf.float32)
  77. potentials = digraph_ops.CombineArcAndRootPotentials(arcs, roots)
  78. self.assertAllEqual(potentials.eval(), [[[6, 2, 3],
  79. [2, 7, 4],
  80. [3, 4, 8]],
  81. [[8, 4, 5],
  82. [2, 7, 4],
  83. [1, 2, 6]]])
  84. def testLabelPotentialsFromTokens(self):
  85. with self.test_session():
  86. tokens = tf.constant([[[1, 2],
  87. [3, 4],
  88. [5, 6]],
  89. [[6, 5],
  90. [4, 3],
  91. [2, 1]]], tf.float32)
  92. weights = tf.constant([[ 2, 3],
  93. [ 5, 7],
  94. [11, 13]], tf.float32)
  95. labels = digraph_ops.LabelPotentialsFromTokens(tokens, weights)
  96. self.assertAllEqual(labels.eval(),
  97. [[[ 8, 19, 37],
  98. [ 18, 43, 85],
  99. [ 28, 67, 133]],
  100. [[ 27, 65, 131],
  101. [ 17, 41, 83],
  102. [ 7, 17, 35]]])
  103. def testLabelPotentialsFromTokenPairs(self):
  104. with self.test_session():
  105. sources = tf.constant([[[1, 2],
  106. [3, 4],
  107. [5, 6]],
  108. [[6, 5],
  109. [4, 3],
  110. [2, 1]]], tf.float32)
  111. targets = tf.constant([[[3, 4],
  112. [5, 6],
  113. [7, 8]],
  114. [[8, 7],
  115. [6, 5],
  116. [4, 3]]], tf.float32)
  117. weights = tf.constant([[[ 2, 3],
  118. [ 5, 7]],
  119. [[11, 13],
  120. [17, 19]],
  121. [[23, 29],
  122. [31, 37]]], tf.float32)
  123. labels = digraph_ops.LabelPotentialsFromTokenPairs(sources, targets,
  124. weights)
  125. self.assertAllEqual(labels.eval(),
  126. [[[ 104, 339, 667],
  127. [ 352, 1195, 2375],
  128. [ 736, 2531, 5043]],
  129. [[ 667, 2419, 4857],
  130. [ 303, 1115, 2245],
  131. [ 75, 291, 593]]])
  132. if __name__ == "__main__":
  133. tf.test.main()