render_spec_with_graphviz_test.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Tests for render_spec_with_graphviz."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from tensorflow.python.platform import googletest
  20. from dragnn.protos import spec_pb2
  21. from dragnn.python import render_spec_with_graphviz
  22. from dragnn.python import spec_builder
  23. def _make_basic_master_spec():
  24. """Constructs a simple spec.
  25. Modified version of nlp/saft/opensource/dragnn/tools/parser_trainer.py
  26. Returns:
  27. spec_pb2.MasterSpec instance.
  28. """
  29. # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
  30. # sequence model, which encodes the context to the right of each token. It has
  31. # no loss except for the downstream components.
  32. lookahead = spec_builder.ComponentSpecBuilder('lookahead')
  33. lookahead.set_network_unit(
  34. name='FeedForwardNetwork', hidden_layer_sizes='256')
  35. lookahead.set_transition_system(name='shift-only', left_to_right='true')
  36. lookahead.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  37. lookahead.add_rnn_link(embedding_dim=-1)
  38. # Construct the ComponentSpec for parsing.
  39. parser = spec_builder.ComponentSpecBuilder('parser')
  40. parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  41. parser.set_transition_system(name='arc-standard')
  42. parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)
  43. master_spec = spec_pb2.MasterSpec()
  44. master_spec.component.extend([lookahead.spec, parser.spec])
  45. return master_spec
  46. class RenderSpecWithGraphvizTest(googletest.TestCase):
  47. def test_constructs_simple_graph(self):
  48. master_spec = _make_basic_master_spec()
  49. contents = render_spec_with_graphviz.master_spec_graph(master_spec)
  50. self.assertIn('lookahead', contents)
  51. self.assertIn('<polygon', contents)
  52. self.assertIn('roboto, helvetica, arial', contents)
  53. self.assertIn('FeedForwardNetwork', contents)
  54. # Graphviz currently over-escapes hyphens.
  55. self.assertTrue(('arc-standard' in contents) or
  56. ('arc&#45;standard' in contents))
  57. self.assertIn('input.focus', contents)
  58. self.assertTrue('input.word' not in contents,
  59. "We don't yet show fixed features")
  60. if __name__ == '__main__':
  61. googletest.main()