tutorial_1.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """First example--RNN POS tagger."""
  2. from __future__ import absolute_import
  3. from __future__ import division
  4. from __future__ import print_function
  5. import os.path
  6. import tensorflow as tf
  7. from google.protobuf import text_format
  8. from dragnn.protos import spec_pb2
  9. from dragnn.python import graph_builder
  10. from dragnn.python import lexicon
  11. from dragnn.python import spec_builder
  12. from dragnn.python import visualization
  13. from syntaxnet import sentence_pb2
  14. import dragnn.python.load_dragnn_cc_impl
  15. import syntaxnet.load_parser_ops
  16. data_dir = os.path.join(
  17. os.path.dirname(os.path.abspath(__file__)), 'tutorial_data')
  18. lexicon_dir = '/tmp/tutorial/lexicon'
  19. training_sentence = os.path.join(data_dir, 'sentence.prototext')
  20. if not os.path.isdir(lexicon_dir):
  21. os.makedirs(lexicon_dir)
  22. def main(argv):
  23. del argv # unused
  24. # Constructs lexical resources for SyntaxNet in the given resource path, from
  25. # the training data.
  26. lexicon.build_lexicon(
  27. lexicon_dir,
  28. training_sentence,
  29. training_corpus_format='sentence-prototext')
  30. # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  31. # sequence tagger.
  32. tagger = spec_builder.ComponentSpecBuilder('tagger')
  33. tagger.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256')
  34. tagger.set_transition_system(name='tagger')
  35. tagger.add_fixed_feature(name='words', fml='input.word', embedding_dim=64)
  36. tagger.add_rnn_link(embedding_dim=-1)
  37. tagger.fill_from_resources(lexicon_dir)
  38. master_spec = spec_pb2.MasterSpec()
  39. master_spec.component.extend([tagger.spec])
  40. hyperparam_config = spec_pb2.GridPoint()
  41. # Build the TensorFlow graph.
  42. graph = tf.Graph()
  43. with graph.as_default():
  44. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  45. target = spec_pb2.TrainTarget()
  46. target.name = 'all'
  47. target.unroll_using_oracle.extend([True])
  48. dry_run = builder.add_training_from_config(target, trace_only=True)
  49. # Read in serialized protos from training data.
  50. sentence = sentence_pb2.Sentence()
  51. text_format.Merge(open(training_sentence).read(), sentence)
  52. training_set = [sentence.SerializeToString()]
  53. with tf.Session(graph=graph) as sess:
  54. # Make sure to re-initialize all underlying state.
  55. sess.run(tf.initialize_all_variables())
  56. traces = sess.run(
  57. dry_run['traces'], feed_dict={dry_run['input_batch']: training_set})
  58. with open('dragnn_tutorial_1.html', 'w') as f:
  59. f.write(visualization.trace_html(traces[0], height='300px').encode('utf-8'))
  60. if __name__ == '__main__':
  61. tf.app.run()