spec_builder_test.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for the DRAGNN spec builder."""
  16. import os.path
  17. import tempfile
  18. import tensorflow as tf
  19. from google.protobuf import text_format
  20. from dragnn.protos import spec_pb2
  21. from dragnn.python import spec_builder
  22. # Imported for FLAGS.tf_master, which is used in the lexicon module.
  23. from syntaxnet import parser_trainer
  24. import syntaxnet.load_parser_ops
  25. FLAGS = tf.app.flags.FLAGS
  26. if not hasattr(FLAGS, 'test_srcdir'):
  27. FLAGS.test_srcdir = ''
  28. if not hasattr(FLAGS, 'test_tmpdir'):
  29. FLAGS.test_tmpdir = tf.test.get_temp_dir()
  30. class SpecBuilderTest(tf.test.TestCase):
  31. def assertSpecEqual(self, expected_spec_text, spec):
  32. expected_spec = spec_pb2.ComponentSpec()
  33. text_format.Parse(expected_spec_text, expected_spec)
  34. self.assertProtoEquals(expected_spec, spec)
  35. def testComponentSpecBuilderEmpty(self):
  36. builder = spec_builder.ComponentSpecBuilder('test')
  37. self.assertSpecEqual(r"""
  38. name: "test"
  39. backend { registered_name: "SyntaxNetComponent" }
  40. component_builder { registered_name: "DynamicComponentBuilder" }
  41. """, builder.spec)
  42. def testComponentSpecBuilderFixedFeature(self):
  43. builder = spec_builder.ComponentSpecBuilder('test')
  44. builder.set_network_unit('FeedForwardNetwork', hidden_layer_sizes='64,64')
  45. builder.set_transition_system('shift-only')
  46. builder.add_fixed_feature(name='words', fml='input.word', embedding_dim=16)
  47. self.assertSpecEqual(r"""
  48. name: "test"
  49. fixed_feature { name: "words" fml: "input.word" embedding_dim: 16 }
  50. backend { registered_name: "SyntaxNetComponent" }
  51. component_builder { registered_name: "DynamicComponentBuilder" }
  52. network_unit { registered_name: "FeedForwardNetwork"
  53. parameters { key: "hidden_layer_sizes" value: "64,64" } }
  54. transition_system { registered_name: "shift-only" }
  55. """, builder.spec)
  56. def testComponentSpecBuilderLinkedFeature(self):
  57. builder1 = spec_builder.ComponentSpecBuilder('test1')
  58. builder1.set_transition_system('shift-only')
  59. builder1.add_fixed_feature(name='words', fml='input.word', embedding_dim=16)
  60. builder2 = spec_builder.ComponentSpecBuilder('test2')
  61. builder2.set_network_unit('IdentityNetwork')
  62. builder2.set_transition_system('tagger')
  63. builder2.add_token_link(
  64. source=builder1,
  65. source_layer='words',
  66. fml='input.focus',
  67. embedding_dim=-1)
  68. self.assertSpecEqual(r"""
  69. name: "test2"
  70. linked_feature { name: "test1" source_component: "test1" source_layer: "words"
  71. source_translator: "identity" fml: "input.focus"
  72. embedding_dim: -1 }
  73. backend { registered_name: "SyntaxNetComponent" }
  74. component_builder { registered_name: "DynamicComponentBuilder" }
  75. network_unit { registered_name: "IdentityNetwork" }
  76. transition_system { registered_name: "tagger" }
  77. """, builder2.spec)
  78. def testFillsTaggerTransitions(self):
  79. lexicon_dir = tempfile.mkdtemp()
  80. def write_lines(filename, lines):
  81. with open(os.path.join(lexicon_dir, filename), 'w') as f:
  82. f.write(''.join('{}\n'.format(line) for line in lines))
  83. # Label map is required, even though it isn't used
  84. write_lines('label-map', ['0'])
  85. write_lines('word-map', ['2', 'miranda 1', 'rights 1'])
  86. write_lines('tag-map', ['2', 'NN 1', 'NNP 1'])
  87. write_lines('tag-to-category', ['NN\tNOUN', 'NNP\tNOUN'])
  88. tagger = spec_builder.ComponentSpecBuilder('tagger')
  89. tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  90. tagger.set_transition_system(name='tagger')
  91. tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  92. tagger.add_rnn_link(embedding_dim=-1)
  93. tagger.fill_from_resources(lexicon_dir)
  94. fixed_feature, = tagger.spec.fixed_feature
  95. linked_feature, = tagger.spec.linked_feature
  96. self.assertEqual(fixed_feature.vocabulary_size, 5)
  97. self.assertEqual(fixed_feature.size, 1)
  98. self.assertEqual(fixed_feature.size, 1)
  99. self.assertEqual(linked_feature.size, 1)
  100. self.assertEqual(tagger.spec.num_actions, 2)
  101. if __name__ == '__main__':
  102. tf.test.main()