parse-to-conll.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
  2. """
  3. import re
  4. import time
  5. import tensorflow as tf
  6. from google.protobuf import text_format
  7. from tensorflow.python.client import timeline
  8. from tensorflow.python.platform import gfile
  9. from dragnn.protos import spec_pb2
  10. from dragnn.python import graph_builder
  11. from dragnn.python import sentence_io
  12. from dragnn.python import spec_builder
  13. from syntaxnet import sentence_pb2
  14. from syntaxnet.ops import gen_parser_ops
  15. from syntaxnet.util import check
  16. import dragnn.python.load_dragnn_cc_impl
  17. import syntaxnet.load_parser_ops
  18. flags = tf.app.flags
  19. FLAGS = flags.FLAGS
  20. flags.DEFINE_string('parser_master_spec', '',
  21. 'Path to text file containing a DRAGNN master spec to run.')
  22. flags.DEFINE_string('parser_checkpoint_file', '',
  23. 'Path to trained model checkpoint.')
  24. flags.DEFINE_string('parser_resource_dir', '',
  25. 'Optional base directory for resources in the master spec.')
  26. flags.DEFINE_string('segmenter_master_spec', '',
  27. 'Path to text file containing a DRAGNN master spec to run.')
  28. flags.DEFINE_string('segmenter_checkpoint_file', '',
  29. 'Path to trained model checkpoint.')
  30. flags.DEFINE_string('segmenter_resource_dir', '',
  31. 'Optional base directory for resources in the master spec.')
  32. flags.DEFINE_bool('complete_master_spec', True, 'Whether the master_specs '
  33. 'needs the lexicon and other resources added to them.')
  34. flags.DEFINE_string('input_file', '',
  35. 'File of CoNLL-formatted sentences to read from.')
  36. flags.DEFINE_string('output_file', '',
  37. 'File path to write annotated sentences to.')
  38. flags.DEFINE_integer('max_batch_size', 2048, 'Maximum batch size to support.')
  39. flags.DEFINE_string('inference_beam_size', '', 'Comma separated list of '
  40. 'component_name=beam_size pairs.')
  41. flags.DEFINE_string('locally_normalize', '', 'Comma separated list of '
  42. 'component names to do local normalization on.')
  43. flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
  44. 'inter-op parallelism.')
  45. flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
  46. 'If specified, the final iteration of the evaluation loop '
  47. 'will capture and save a TensorFlow timeline.')
  48. flags.DEFINE_bool('use_gold_segmentation', False,
  49. 'Whether or not to use gold segmentation.')
  50. def main(unused_argv):
  51. # Parse the flags containint lists, using regular expressions.
  52. # This matches and extracts key=value pairs.
  53. component_beam_sizes = re.findall(r'([^=,]+)=(\d+)',
  54. FLAGS.inference_beam_size)
  55. # This matches strings separated by a comma. Does not return any empty
  56. # strings.
  57. components_to_locally_normalize = re.findall(r'[^,]+',
  58. FLAGS.locally_normalize)
  59. ## SEGMENTATION ##
  60. if not FLAGS.use_gold_segmentation:
  61. # Reads master spec.
  62. master_spec = spec_pb2.MasterSpec()
  63. with gfile.FastGFile(FLAGS.segmenter_master_spec) as fin:
  64. text_format.Parse(fin.read(), master_spec)
  65. if FLAGS.complete_master_spec:
  66. spec_builder.complete_master_spec(
  67. master_spec, None, FLAGS.segmenter_resource_dir)
  68. # Graph building.
  69. tf.logging.info('Building the graph')
  70. g = tf.Graph()
  71. with g.as_default(), tf.device('/device:CPU:0'):
  72. hyperparam_config = spec_pb2.GridPoint()
  73. hyperparam_config.use_moving_average = True
  74. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  75. annotator = builder.add_annotation()
  76. builder.add_saver()
  77. tf.logging.info('Reading documents...')
  78. input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
  79. with tf.Session(graph=tf.Graph()) as tmp_session:
  80. char_input = gen_parser_ops.char_token_generator(input_corpus)
  81. char_corpus = tmp_session.run(char_input)
  82. check.Eq(len(input_corpus), len(char_corpus))
  83. session_config = tf.ConfigProto(
  84. log_device_placement=False,
  85. intra_op_parallelism_threads=FLAGS.threads,
  86. inter_op_parallelism_threads=FLAGS.threads)
  87. with tf.Session(graph=g, config=session_config) as sess:
  88. tf.logging.info('Initializing variables...')
  89. sess.run(tf.global_variables_initializer())
  90. tf.logging.info('Loading from checkpoint...')
  91. sess.run('save/restore_all',
  92. {'save/Const:0': FLAGS.segmenter_checkpoint_file})
  93. tf.logging.info('Processing sentences...')
  94. processed = []
  95. start_time = time.time()
  96. run_metadata = tf.RunMetadata()
  97. for start in range(0, len(char_corpus), FLAGS.max_batch_size):
  98. end = min(start + FLAGS.max_batch_size, len(char_corpus))
  99. feed_dict = {annotator['input_batch']: char_corpus[start:end]}
  100. if FLAGS.timeline_output_file and end == len(char_corpus):
  101. serialized_annotations = sess.run(
  102. annotator['annotations'], feed_dict=feed_dict,
  103. options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
  104. run_metadata=run_metadata)
  105. trace = timeline.Timeline(step_stats=run_metadata.step_stats)
  106. with open(FLAGS.timeline_output_file, 'w') as trace_file:
  107. trace_file.write(trace.generate_chrome_trace_format())
  108. else:
  109. serialized_annotations = sess.run(
  110. annotator['annotations'], feed_dict=feed_dict)
  111. processed.extend(serialized_annotations)
  112. tf.logging.info('Processed %d documents in %.2f seconds.',
  113. len(char_corpus), time.time() - start_time)
  114. input_corpus = processed
  115. else:
  116. input_corpus = sentence_io.ConllSentenceReader(FLAGS.input_file).corpus()
  117. ## PARSING
  118. # Reads master spec.
  119. master_spec = spec_pb2.MasterSpec()
  120. with gfile.FastGFile(FLAGS.parser_master_spec) as fin:
  121. text_format.Parse(fin.read(), master_spec)
  122. if FLAGS.complete_master_spec:
  123. spec_builder.complete_master_spec(
  124. master_spec, None, FLAGS.parser_resource_dir)
  125. # Graph building.
  126. tf.logging.info('Building the graph')
  127. g = tf.Graph()
  128. with g.as_default(), tf.device('/device:CPU:0'):
  129. hyperparam_config = spec_pb2.GridPoint()
  130. hyperparam_config.use_moving_average = True
  131. builder = graph_builder.MasterBuilder(master_spec, hyperparam_config)
  132. annotator = builder.add_annotation()
  133. builder.add_saver()
  134. tf.logging.info('Reading documents...')
  135. session_config = tf.ConfigProto(
  136. log_device_placement=False,
  137. intra_op_parallelism_threads=FLAGS.threads,
  138. inter_op_parallelism_threads=FLAGS.threads)
  139. with tf.Session(graph=g, config=session_config) as sess:
  140. tf.logging.info('Initializing variables...')
  141. sess.run(tf.global_variables_initializer())
  142. tf.logging.info('Loading from checkpoint...')
  143. sess.run('save/restore_all', {'save/Const:0': FLAGS.parser_checkpoint_file})
  144. tf.logging.info('Processing sentences...')
  145. processed = []
  146. start_time = time.time()
  147. run_metadata = tf.RunMetadata()
  148. for start in range(0, len(input_corpus), FLAGS.max_batch_size):
  149. end = min(start + FLAGS.max_batch_size, len(input_corpus))
  150. feed_dict = {annotator['input_batch']: input_corpus[start:end]}
  151. for comp, beam_size in component_beam_sizes:
  152. feed_dict['%s/InferenceBeamSize:0' % comp] = beam_size
  153. for comp in components_to_locally_normalize:
  154. feed_dict['%s/LocallyNormalize:0' % comp] = True
  155. if FLAGS.timeline_output_file and end == len(input_corpus):
  156. serialized_annotations = sess.run(
  157. annotator['annotations'], feed_dict=feed_dict,
  158. options=tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE),
  159. run_metadata=run_metadata)
  160. trace = timeline.Timeline(step_stats=run_metadata.step_stats)
  161. with open(FLAGS.timeline_output_file, 'w') as trace_file:
  162. trace_file.write(trace.generate_chrome_trace_format())
  163. else:
  164. serialized_annotations = sess.run(
  165. annotator['annotations'], feed_dict=feed_dict)
  166. processed.extend(serialized_annotations)
  167. tf.logging.info('Processed %d documents in %.2f seconds.',
  168. len(input_corpus), time.time() - start_time)
  169. if FLAGS.output_file:
  170. with gfile.GFile(FLAGS.output_file, 'w') as f:
  171. for serialized_sentence in processed:
  172. sentence = sentence_pb2.Sentence()
  173. sentence.ParseFromString(serialized_sentence)
  174. f.write('#' + sentence.text.encode('utf-8') + '\n')
  175. for i, token in enumerate(sentence.token):
  176. head = token.head + 1
  177. f.write('%s\t%s\t_\t_\t_\t_\t%d\t%s\t_\t_\n'%(
  178. i + 1,
  179. token.word.encode('utf-8'), head,
  180. token.label.encode('utf-8')))
  181. f.write('\n\n')
  182. if __name__ == '__main__':
  183. tf.app.run()