tutorial_2.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. """Second example: separate tagger and parser."""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import os.path
  6. import tensorflow as tf
  7. from google.protobuf import text_format
  8. from dragnn.protos import spec_pb2
  9. from dragnn.python import graph_builder
  10. from dragnn.python import lexicon
  11. from dragnn.python import spec_builder
  12. from dragnn.python import visualization
  13. from syntaxnet import sentence_pb2
  14. import dragnn.python.load_dragnn_cc_impl
  15. import syntaxnet.load_parser_ops
  16. data_dir = os.path.join(
  17. os.path.dirname(os.path.abspath(__file__)), 'tutorial_data')
  18. lexicon_dir = '/tmp/tutorial/lexicon'
  19. training_sentence = os.path.join(data_dir, 'sentence.prototext')
  20. if not os.path.isdir(lexicon_dir):
  21. os.makedirs(lexicon_dir)
  22. def main(argv):
  23. del argv # unused
  24. # Constructs lexical resources for SyntaxNet in the given resource path, from
  25. # the training data.
  26. lexicon.build_lexicon(
  27. lexicon_dir,
  28. training_sentence,
  29. training_corpus_format='sentence-prototext')
  30. # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  31. # sequence tagger.
  32. tagger = spec_builder.ComponentSpecBuilder('tagger')
  33. tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  34. tagger.set_transition_system(name='tagger')
  35. tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  36. tagger.add_rnn_link(embedding_dim=-1)
  37. tagger.fill_from_resources(lexicon_dir)
  38. # Construct the ComponentSpec for parsing.
  39. parser = spec_builder.ComponentSpecBuilder('parser')
  40. parser.set_network_unit(
  41. name='FeedForwardNetwork',
  42. hidden_layer_sizes='256',
  43. layer_norm_hidden='True')
  44. parser.set_transition_system(name='arc-standard')
  45. parser.add_token_link(
  46. source=tagger,
  47. fml='input.focus stack.focus stack(1).focus',
  48. embedding_dim=32,
  49. source_layer='logits')
  50. # Recurrent connection for the arc-standard parser. For both tokens on the
  51. # stack, we connect to the last time step to either SHIFT or REDUCE that
  52. # token. This allows the parser to build up compositional representations of
  53. # phrases.
  54. parser.add_link(
  55. source=parser, # recurrent connection
  56. name='rnn-stack', # unique identifier
  57. fml='stack.focus stack(1).focus', # look for both stack tokens
  58. source_translator='shift-reduce-step', # maps token indices -> step
  59. embedding_dim=32) # project down to 32 dims
  60. parser.fill_from_resources(lexicon_dir)
  61. master_spec = spec_pb2.MasterSpec()
  62. master_spec.component.extend([tagger.spec, parser.spec])
  63. hyperparam_config = spec_pb2.GridPoint()
  64. # Build the TensorFlow graph.
  65. graph = tf.Graph()
  66. with graph.as_default():
  67. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  68. target = spec_pb2.TrainTarget()
  69. target.name = 'all'
  70. target.unroll_using_oracle.extend([True, True])
  71. dry_run = builder.add_training_from_config(target, trace_only=True)
  72. # Read in serialized protos from training data.
  73. sentence = sentence_pb2.Sentence()
  74. text_format.Merge(open(training_sentence).read(), sentence)
  75. training_set = [sentence.SerializeToString()]
  76. with tf.Session(graph=graph) as sess:
  77. # Make sure to re-initialize all underlying state.
  78. sess.run(tf.initialize_all_variables())
  79. traces = sess.run(
  80. dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})
  81. with open('dragnn_tutorial_2.html', 'w') as f:
  82. f.write(
  83. visualization.trace_html(
  84. traces[0], height='400px', master_spec=master_spec).encode('utf-8'))
  85. if __name__ == '__main__':
  86. tf.app.run()