parser_trainer.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A program to train a tensorflow neural net parser from a conll file."""
  16. import os
  17. import os.path
  18. import random
  19. import time
  20. import tensorflow as tf
  21. from tensorflow.python.platform import gfile
  22. from tensorflow.python.platform import tf_logging as logging
  23. from google.protobuf import text_format
  24. from syntaxnet.ops import gen_parser_ops
  25. from syntaxnet import task_spec_pb2
  26. from syntaxnet import sentence_pb2
  27. from dragnn.protos import spec_pb2
  28. from dragnn.python import evaluation
  29. from dragnn.python import graph_builder
  30. from dragnn.python import lexicon
  31. from dragnn.python import sentence_io
  32. from dragnn.python import spec_builder
  33. from dragnn.python import trainer_lib
  34. import dragnn.python.load_dragnn_cc_impl
  35. import syntaxnet.load_parser_ops
  36. flags = tf.app.flags
  37. FLAGS = flags.FLAGS
  38. flags.DEFINE_string('tf_master', '',
  39. 'TensorFlow execution engine to connect to.')
  40. flags.DEFINE_string('resource_path', '', 'Path to constructed resources.')
  41. flags.DEFINE_string('tensorboard_dir', '',
  42. 'Directory for TensorBoard logs output.')
  43. flags.DEFINE_string('checkpoint_filename', '',
  44. 'Filename to save the best checkpoint to.')
  45. flags.DEFINE_string('training_corpus_path', '', 'Path to training data.')
  46. flags.DEFINE_string('dev_corpus_path', '', 'Path to development set data.')
  47. flags.DEFINE_bool('compute_lexicon', False, '')
  48. flags.DEFINE_bool('projectivize_training_set', True, '')
  49. flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to train for.')
  50. flags.DEFINE_integer('batch_size', 4, 'Batch size.')
  51. flags.DEFINE_integer('report_every', 200,
  52. 'Report cost and training accuracy every this many steps.')
  53. def main(unused_argv):
  54. logging.set_verbosity(logging.INFO)
  55. if not gfile.IsDirectory(FLAGS.resource_path):
  56. gfile.MakeDirs(FLAGS.resource_path)
  57. # Constructs lexical resources for SyntaxNet in the given resource path, from
  58. # the training data.
  59. if FLAGS.compute_lexicon:
  60. logging.info('Computing lexicon...')
  61. lexicon.build_lexicon(FLAGS.resource_path, FLAGS.training_corpus_path)
  62. # Construct the "lookahead" ComponentSpec. This is a simple right-to-left RNN
  63. # sequence model, which encodes the context to the right of each token. It has
  64. # no loss except for the downstream components.
  65. char2word = spec_builder.ComponentSpecBuilder('char_lstm')
  66. char2word.set_network_unit(
  67. name='wrapped_units.LayerNormBasicLSTMNetwork',
  68. hidden_layer_sizes='256')
  69. char2word.set_transition_system(name='char-shift-only', left_to_right='true')
  70. char2word.add_fixed_feature(name='chars', fml='char-input.text-char',
  71. embedding_dim=16)
  72. char2word.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)
  73. lookahead = spec_builder.ComponentSpecBuilder('lookahead')
  74. lookahead.set_network_unit(
  75. name='wrapped_units.LayerNormBasicLSTMNetwork',
  76. hidden_layer_sizes='256')
  77. lookahead.set_transition_system(name='shift-only', left_to_right='false')
  78. lookahead.add_link(source=char2word, fml='input.last-char-focus',
  79. embedding_dim=32)
  80. lookahead.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)
  81. # Construct the ComponentSpec for tagging. This is a simple left-to-right RNN
  82. # sequence tagger.
  83. tagger = spec_builder.ComponentSpecBuilder('tagger')
  84. tagger.set_network_unit(
  85. name='wrapped_units.LayerNormBasicLSTMNetwork',
  86. hidden_layer_sizes='256')
  87. tagger.set_transition_system(name='tagger')
  88. tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)
  89. tagger.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)
  90. # Construct the ComponentSpec for parsing.
  91. parser = spec_builder.ComponentSpecBuilder('parser')
  92. parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256',
  93. layer_norm_hidden='True')
  94. parser.set_transition_system(name='arc-standard')
  95. parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=32)
  96. parser.add_token_link(
  97. source=tagger,
  98. fml='input.focus stack.focus stack(1).focus',
  99. embedding_dim=32)
  100. # Recurrent connection for the arc-standard parser. For both tokens on the
  101. # stack, we connect to the last time step to either SHIFT or REDUCE that
  102. # token. This allows the parser to build up compositional representations of
  103. # phrases.
  104. parser.add_link(
  105. source=parser, # recurrent connection
  106. name='rnn-stack', # unique identifier
  107. fml='stack.focus stack(1).focus', # look for both stack tokens
  108. source_translator='shift-reduce-step', # maps token indices -> step
  109. embedding_dim=32) # project down to 32 dims
  110. parser.fill_from_resources(FLAGS.resource_path, FLAGS.tf_master)
  111. master_spec = spec_pb2.MasterSpec()
  112. master_spec.component.extend([char2word.spec, lookahead.spec,
  113. tagger.spec, parser.spec])
  114. logging.info('Constructed master spec: %s', str(master_spec))
  115. hyperparam_config = spec_pb2.GridPoint()
  116. hyperparam_config.decay_steps = 128000
  117. hyperparam_config.learning_rate = 0.001
  118. hyperparam_config.learning_method = 'adam'
  119. hyperparam_config.adam_beta1 = 0.9
  120. hyperparam_config.adam_beta2 = 0.9
  121. hyperparam_config.adam_eps = 0.0001
  122. hyperparam_config.gradient_clip_norm = 1
  123. hyperparam_config.self_norm_alpha = 1.0
  124. hyperparam_config.use_moving_average = True
  125. hyperparam_config.dropout_rate = 0.7
  126. hyperparam_config.seed = 1
  127. # Build the TensorFlow graph.
  128. graph = tf.Graph()
  129. with graph.as_default():
  130. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  131. component_targets = spec_builder.default_targets_from_spec(master_spec)
  132. trainers = [
  133. builder.add_training_from_config(target) for target in component_targets
  134. ]
  135. assert len(trainers) == 2
  136. annotator = builder.add_annotation()
  137. builder.add_saver()
  138. # Read in serialized protos from training data.
  139. training_set = sentence_io.ConllSentenceReader(
  140. FLAGS.training_corpus_path,
  141. projectivize=FLAGS.projectivize_training_set).corpus()
  142. dev_set = sentence_io.ConllSentenceReader(
  143. FLAGS.dev_corpus_path, projectivize=False).corpus()
  144. # Ready to train!
  145. logging.info('Training on %d sentences.', len(training_set))
  146. logging.info('Tuning on %d sentences.', len(dev_set))
  147. pretrain_steps = [100, 0]
  148. tagger_steps = 1000
  149. train_steps = [tagger_steps, 8 * tagger_steps]
  150. tf.logging.info('Creating TensorFlow checkpoint dir...')
  151. gfile.MakeDirs(os.path.dirname(FLAGS.checkpoint_filename))
  152. summary_writer = trainer_lib.get_summary_writer(FLAGS.tensorboard_dir)
  153. with tf.Session(FLAGS.tf_master, graph=graph) as sess:
  154. # Make sure to re-initialize all underlying state.
  155. sess.run(tf.global_variables_initializer())
  156. trainer_lib.run_training(
  157. sess, trainers, annotator, evaluation.parser_summaries, pretrain_steps,
  158. train_steps, training_set, dev_set, dev_set, FLAGS.batch_size,
  159. summary_writer, FLAGS.report_every, builder.saver,
  160. FLAGS.checkpoint_filename)
  161. if __name__ == '__main__':
  162. tf.app.run()