segmenter-evaluator.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. r"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
  2. Sample invocation:
  3. bazel run -c opt <...>:dragnn_eval -- \
  4. --master_spec="/path/to/master-spec" \
  5. --resource_dir="/path/to/resources/"
  6. --checkpoint_file="/path/to/model/name.checkpoint" \
  7. --input_file="/path/to/input/documents/test.connlu"
  8. """
  9. import os
  10. import re
  11. import time
  12. import tensorflow as tf
  13. from google.protobuf import text_format
  14. from tensorflow.python.client import timeline
  15. from tensorflow.python.platform import gfile
  16. from dragnn.protos import spec_pb2
  17. from dragnn.python import evaluation
  18. from dragnn.python import graph_builder
  19. from dragnn.python import sentence_io
  20. from dragnn.python import spec_builder
  21. from syntaxnet import sentence_pb2
  22. from syntaxnet.ops import gen_parser_ops
  23. from syntaxnet.util import check
  24. import dragnn.python.load_dragnn_cc_impl
  25. import syntaxnet.load_parser_ops
  26. flags = tf.app.flags
  27. FLAGS = flags.FLAGS
  28. flags.DEFINE_string('master_spec', '',
  29. 'Path to text file containing a DRAGNN master spec to run.')
  30. flags.DEFINE_string('resource_dir', '',
  31. 'Optional base directory for resources in the master spec.')
  32. flags.DEFINE_bool('complete_master_spec', False, 'Whether the master_spec '
  33. 'needs the lexicon and other resources added to it.')
  34. flags.DEFINE_string('checkpoint_file', '', 'Path to trained model checkpoint.')
  35. flags.DEFINE_string('input_file', '',
  36. 'File of CoNLL-formatted sentences to read from.')
  37. flags.DEFINE_string('output_file', '',
  38. 'File path to write annotated sentences to.')
  39. flags.DEFINE_integer('max_batch_size', 2048, 'Maximum batch size to support.')
  40. flags.DEFINE_string('inference_beam_size', '', 'Comma separated list of '
  41. 'component_name=beam_size pairs.')
  42. flags.DEFINE_string('locally_normalize', '', 'Comma separated list of '
  43. 'component names to do local normalization on.')
  44. flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
  45. 'inter-op parallelism.')
  46. flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
  47. 'If specified, the final iteration of the evaluation loop '
  48. 'will capture and save a TensorFlow timeline.')
  49. def main(unused_argv):
  50. # Parse the flags containint lists, using regular expressions.
  51. # This matches and extracts key=value pairs.
  52. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
  53. FLAGS.inference_beam_size)
  54. # This matches strings separated by a comma. Does not return any empty
  55. # strings.
  56. components_to_locally_normalize = re.findall(r'[^,]+',
  57. FLAGS.locally_normalize)
  58. # Reads master spec.
  59. master_spec = spec_pb2.MasterSpec()
  60. with gfile.FastGFile(FLAGS.master_spec) as fin:
  61. text_format.Parse(fin.read(), master_spec)
  62. # Rewrite resource locations.
  63. if FLAGS.resource_dir:
  64. for component in master_spec.component:
  65. for resource in component.resource:
  66. for part in resource.part:
  67. part.file_pattern = os.path.join(FLAGS.resource_dir,
  68. part.file_pattern)
  69. if FLAGS.complete_master_spec:
  70. spec_builder.complete_master_spec(master_spec, None, FLAGS.resource_dir)
  71. # Graph building.
  72. tf.logging.info('Building the graph')
  73. g = tf.Graph()
  74. with g.as_default(), tf.device('/device:CPU:0'):
  75. hyperparam_config = spec_pb2.GridPoint()
  76. hyperparam_config.use_moving_average = True
  77. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  78. annotator = builder.add_annotation()
  79. builder.add_saver()
  80. tf.logging.info('Reading documents...')
  81. input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
  82. with tf.Session(graph=tf.Graph()) as tmp_session:
  83. char_input = gen_parser_ops.char_token_generator(input_corpus)
  84. char_corpus = tmp_session.run(char_input)
  85. check.Eq(len(input_corpus), len(char_corpus))
  86. session_config = tf.ConfigProto(
  87. log_device_placement=False,
  88. intra_op_parallelism_threads=FLAGS.threads,
  89. inter_op_parallelism_threads=FLAGS.threads)
  90. with tf.Session(graph=g, config=session_config) as sess:
  91. tf.logging.info('Initializing variables...')
  92. sess.run(tf.global_variables_initializer())
  93. tf.logging.info('Loading from checkpoint...')
  94. sess.run('save/restore_all', {'save/Const:0': FLAGS.checkpoint_file})
  95. tf.logging.info('Processing sentences...')
  96. processed = []
  97. start_time = time.time()
  98. run_metadata = tf.RunMetadata()
  99. for start in range(0, len(char_corpus), FLAGS.max_batch_size):
  100. end = min(start + FLAGS.max_batch_size, len(char_corpus))
  101. feed_dict = {annotator['input_batch']: char_corpus[start:end]}
  102. for comp, beam_size in component_beam_sizes:
  103. feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
  104. for comp in components_to_locally_normalize:
  105. feed_dict['%s/LocallyNormalize:0' % comp] = True
  106. if FLAGS.timeline_output_file and end == len(char_corpus):
  107. serialized_annotations = sess.run(
  108. annotator['annotations'], feed_dict=feed_dict,
  109. options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
  110. run_metadata=run_metadata)
  111. trace = timeline.Timeline(step_stats=run_metadata.step_stats)
  112. with open(FLAGS.timeline_output_file, 'w') as trace_file:
  113. trace_file.write(trace.generate_chrome_trace_format())
  114. else:
  115. serialized_annotations = sess.run(
  116. annotator['annotations'], feed_dict=feed_dict)
  117. processed.extend(serialized_annotations)
  118. tf.logging.info('Processed %d documents in %.2f seconds.',
  119. len(char_corpus), time.time() - start_time)
  120. evaluation.calculate_segmentation_metrics(input_corpus, processed)
  121. if FLAGS.output_file:
  122. with gfile.GFile(FLAGS.output_file, 'w') as f:
  123. for serialized_sentence in processed:
  124. sentence = sentence_pb2.Sentence()
  125. sentence.ParseFromString(serialized_sentence)
  126. f.write(text_format.MessageToString(sentence) + '\n\n')
  127. if __name__ == '__main__':
  128. tf.app.run()