render_spec_with_graphviz_test.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. """Tests for render_spec_with_graphviz."""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. from tensorflow.python.platform import googletest
  6. from dragnn.protos import spec_pb2
  7. from dragnn.python import render_spec_with_graphviz
  8. from dragnn.python import spec_builder
  9. def _make_basic_master_spec():
  10. """Constructs a simple spec.
  11. Modified version of nlp/saft/opensource/dragnn/tools/parser_trainer.py
  12. Returns:
  13. spec_pb2.MasterSpec instance.
  14. """
  15. # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
  16. # sequence model, which encodes the context to the right of each token. It has
  17. # no loss except for the downstream components.
  18. lookahead = spec_builder.ComponentSpecBuilder('lookahead')
  19. lookahead.set_network_unit(
  20. name='FeedForwardNetwork', hidden_layer_sizes='256')
  21. lookahead.set_transition_system(name='shift-only', left_to_right='true')
  22. lookahead.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  23. lookahead.add_rnn_link(embedding_dim=-1)
  24. # Construct the ComponentSpec for parsing.
  25. parser = spec_builder.ComponentSpecBuilder('parser')
  26. parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  27. parser.set_transition_system(name='arc-standard')
  28. parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)
  29. master_spec = spec_pb2.MasterSpec()
  30. master_spec.component.extend([lookahead.spec, parser.spec])
  31. return master_spec
  32. class RenderSpecWithGraphvizTest(googletest.TestCase):
  33. def test_constructs_simple_graph(self):
  34. master_spec = _make_basic_master_spec()
  35. contents = render_spec_with_graphviz.master_spec_graph(master_spec)
  36. self.assertIn('lookahead', contents)
  37. self.assertIn('<polygon', contents)
  38. self.assertIn('roboto, helvetica, arial', contents)
  39. self.assertIn('FeedForwardNetwork', contents)
  40. # Graphviz currently over-escapes hyphens.
  41. self.assertTrue(('arc-standard' in contents) or
  42. ('arc&#45;standard' in contents))
  43. self.assertIn('input.focus', contents)
  44. self.assertTrue('input.word' not in contents,
  45. "We don't yet show fixed features")
  46. if __name__ == '__main__':
  47. googletest.main()