network_units_test.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for network_units."""
  16. import tensorflow as tf
  17. from tensorflow.python.framework import test_util
  18. from tensorflow.python.platform import googletest
  19. from dragnn.protos import spec_pb2
  20. from dragnn.python import network_units
  21. import dragnn.python.load_dragnn_cc_impl
  22. import syntaxnet.load_parser_ops
  23. FLAGS = tf.app.flags.FLAGS
  24. class NetworkUnitsConverterTest(test_util.TensorFlowTestCase):
  25. def testConvertNetworkStateTensorarray(self):
  26. with self.test_session() as session:
  27. ta = tf.TensorArray(
  28. dtype=tf.float32,
  29. size=0,
  30. dynamic_size=True,
  31. clear_after_read=False,
  32. infer_shape=False)
  33. # Create a 3-step x 2-stride x 2-feature-dim source array.
  34. ta = ta.write(0, [[0., 0.]] * 2) # The zeroth step will be removed.
  35. ta = ta.write(1, [[1., 10.]] * 2)
  36. ta = ta.write(2, [[2., 20.]] * 2)
  37. ta = ta.write(3, [[3., 30.]] * 2)
  38. tensor = network_units.convert_network_state_tensorarray(ta)
  39. actual = session.run(tensor)
  40. self.assertEqual(actual.shape, (6, 2))
  41. # The arrangement of the values is expected to be stride * steps.
  42. expected = [[1., 10.], [2., 20.], [3., 30.], [1., 10.], [2., 20.],
  43. [3., 30.]]
  44. self.assertAllEqual(actual, expected)
  45. class MockComponent(object):
  46. def __init__(self, master, component_spec):
  47. self.master = master
  48. self.spec = component_spec
  49. self.name = component_spec.name
  50. self.beam_size = 1
  51. self._attrs = {}
  52. def attr(self, name):
  53. return self._attrs[name]
  54. class MockMaster(object):
  55. def __init__(self):
  56. self.spec = spec_pb2.MasterSpec()
  57. self.hyperparams = spec_pb2.GridPoint()
  58. self.lookup_component = {
  59. 'previous': MockComponent(self, spec_pb2.ComponentSpec())
  60. }
  61. class NetworkUnitsLookupTest(test_util.TensorFlowTestCase):
  62. def setUp(self):
  63. # Clear the graph and all existing variables. Otherwise, variables created
  64. # in different tests may collide with each other.
  65. tf.reset_default_graph()
  66. self._master = MockMaster()
  67. self._master.spec = spec_pb2.MasterSpec()
  68. # Add a component with a linked feature.
  69. component_spec = self._master.spec.component.add()
  70. component_spec.name = 'fake_linked'
  71. component_spec.backend.registered_name = 'FakeComponent'
  72. linked_feature = component_spec.linked_feature.add()
  73. linked_feature.source_component = 'fake_linked'
  74. linked_feature.source_translator = 'identity'
  75. linked_feature.embedding_dim = -1
  76. linked_feature.size = 2
  77. self._linked_component = MockComponent(self._master, component_spec)
  78. # Add a feature with a fixed feature.
  79. component_spec = self._master.spec.component.add()
  80. component_spec.name = 'fake_fixed'
  81. component_spec.backend.registered_name = 'FakeComponent'
  82. fixed_feature = component_spec.fixed_feature.add()
  83. fixed_feature.fml = 'input.word'
  84. fixed_feature.embedding_dim = 1
  85. fixed_feature.size = 1
  86. self._fixed_component = MockComponent(self._master, component_spec)
  87. def testExportFixedFeaturesNetworkWithEnabledEmbeddingMatrix(self):
  88. network = network_units.ExportFixedFeaturesNetwork(self._fixed_component)
  89. self.assertEqual(1, len(network.params))
  90. def testExportFixedFeaturesNetworkWithDisabledEmbeddingMatrix(self):
  91. self._fixed_component.spec.fixed_feature[0].embedding_dim = -1
  92. network = network_units.ExportFixedFeaturesNetwork(self._fixed_component)
  93. self.assertEqual(0, len(network.params))
  94. class GetAttrsWithDefaultsTest(test_util.TensorFlowTestCase):
  95. def MakeAttrs(self, defaults, key=None, value=None):
  96. """Returns attrs based on the |defaults| and one |key|,|value| override."""
  97. spec = spec_pb2.RegisteredModuleSpec()
  98. if key and value:
  99. spec.parameters[key] = value
  100. return network_units.get_attrs_with_defaults(spec.parameters, defaults)
  101. def testFalseValues(self):
  102. def _assert_attr_is_false(value=None):
  103. key = 'foo'
  104. attrs = self.MakeAttrs({key: False}, key, value)
  105. self.assertFalse(attrs[key])
  106. _assert_attr_is_false()
  107. _assert_attr_is_false('false')
  108. _assert_attr_is_false('False')
  109. _assert_attr_is_false('FALSE')
  110. _assert_attr_is_false('no')
  111. _assert_attr_is_false('whatever')
  112. _assert_attr_is_false(' ')
  113. _assert_attr_is_false('')
  114. def testTrueValues(self):
  115. def _assert_attr_is_true(value=None):
  116. key = 'foo'
  117. attrs = self.MakeAttrs({key: False}, key, value)
  118. self.assertTrue(attrs[key])
  119. _assert_attr_is_true('true')
  120. _assert_attr_is_true('True')
  121. _assert_attr_is_true('TRUE')
  122. if __name__ == '__main__':
  123. googletest.main()