parser_eval.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A program to annotate a conll file with a tensorflow neural net parser."""
  16. import os
  17. import os.path
  18. import time
  19. import tensorflow as tf
  20. from tensorflow.python.platform import gfile
  21. from tensorflow.python.platform import tf_logging as logging
  22. from syntaxnet import sentence_pb2
  23. from syntaxnet import graph_builder
  24. from syntaxnet import structured_graph_builder
  25. from syntaxnet.ops import gen_parser_ops
  26. flags = tf.app.flags
  27. FLAGS = flags.FLAGS
  28. flags.DEFINE_string('task_context', '',
  29. 'Path to a task context with inputs and parameters for '
  30. 'feature extractors.')
  31. flags.DEFINE_string('model_path', '', 'Path to model parameters.')
  32. flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
  33. flags.DEFINE_string('graph_builder', 'greedy',
  34. 'Which graph builder to use, either greedy or structured.')
  35. flags.DEFINE_string('input', 'stdin',
  36. 'Name of the context input to read data from.')
  37. flags.DEFINE_string('output', 'stdout',
  38. 'Name of the context input to write data to.')
  39. flags.DEFINE_string('hidden_layer_sizes', '200,200',
  40. 'Comma separated list of hidden layer sizes.')
  41. flags.DEFINE_integer('batch_size', 32,
  42. 'Number of sentences to process in parallel.')
  43. flags.DEFINE_integer('beam_size', 8, 'Number of slots for beam parsing.')
  44. flags.DEFINE_integer('max_steps', 1000, 'Max number of steps to take.')
  45. flags.DEFINE_bool('slim_model', False,
  46. 'Whether to expect only averaged variables.')
  47. def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
  48. """Builds and evaluates a network.
  49. Args:
  50. sess: tensorflow session to use
  51. num_actions: number of possible golden actions
  52. feature_sizes: size of each feature vector
  53. domain_sizes: number of possible feature ids in each feature vector
  54. embedding_dims: embedding dimension for each feature group
  55. """
  56. t = time.time()
  57. hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
  58. logging.info('Building training network with parameters: feature_sizes: %s '
  59. 'domain_sizes: %s', feature_sizes, domain_sizes)
  60. if FLAGS.graph_builder == 'greedy':
  61. parser = graph_builder.GreedyParser(num_actions,
  62. feature_sizes,
  63. domain_sizes,
  64. embedding_dims,
  65. hidden_layer_sizes,
  66. gate_gradients=True,
  67. arg_prefix=FLAGS.arg_prefix)
  68. else:
  69. parser = structured_graph_builder.StructuredGraphBuilder(
  70. num_actions,
  71. feature_sizes,
  72. domain_sizes,
  73. embedding_dims,
  74. hidden_layer_sizes,
  75. gate_gradients=True,
  76. arg_prefix=FLAGS.arg_prefix,
  77. beam_size=FLAGS.beam_size,
  78. max_steps=FLAGS.max_steps)
  79. task_context = FLAGS.task_context
  80. parser.AddEvaluation(task_context,
  81. FLAGS.batch_size,
  82. corpus_name=FLAGS.input,
  83. evaluation_max_steps=FLAGS.max_steps)
  84. parser.AddSaver(FLAGS.slim_model)
  85. sess.run(parser.inits.values())
  86. parser.saver.restore(sess, FLAGS.model_path)
  87. sink_documents = tf.placeholder(tf.string)
  88. sink = gen_parser_ops.document_sink(sink_documents,
  89. task_context=FLAGS.task_context,
  90. corpus_name=FLAGS.output)
  91. t = time.time()
  92. num_epochs = None
  93. num_tokens = 0
  94. num_correct = 0
  95. num_documents = 0
  96. while True:
  97. tf_eval_epochs, tf_eval_metrics, tf_documents = sess.run([
  98. parser.evaluation['epochs'],
  99. parser.evaluation['eval_metrics'],
  100. parser.evaluation['documents'],
  101. ])
  102. if len(tf_documents):
  103. logging.info('Processed %d documents', len(tf_documents))
  104. num_documents += len(tf_documents)
  105. sess.run(sink, feed_dict={sink_documents: tf_documents})
  106. num_tokens += tf_eval_metrics[0]
  107. num_correct += tf_eval_metrics[1]
  108. if num_epochs is None:
  109. num_epochs = tf_eval_epochs
  110. elif num_epochs < tf_eval_epochs:
  111. break
  112. logging.info('Total processed documents: %d', num_documents)
  113. if num_tokens > 0:
  114. eval_metric = 100.0 * num_correct / num_tokens
  115. logging.info('num correct tokens: %d', num_correct)
  116. logging.info('total tokens: %d', num_tokens)
  117. logging.info('Seconds elapsed in evaluation: %.2f, '
  118. 'eval metric: %.2f%%', time.time() - t, eval_metric)
  119. def main(unused_argv):
  120. logging.set_verbosity(logging.INFO)
  121. with tf.Session() as sess:
  122. feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
  123. gen_parser_ops.feature_size(task_context=FLAGS.task_context,
  124. arg_prefix=FLAGS.arg_prefix))
  125. with tf.Session() as sess:
  126. Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
  127. if __name__ == '__main__':
  128. tf.app.run()