conll2tree.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A program to generate ASCII trees from conll files."""
  16. import collections
  17. import re
  18. import asciitree
  19. import tensorflow as tf
  20. import syntaxnet.load_parser_ops
  21. from tensorflow.python.platform import tf_logging as logging
  22. from syntaxnet import sentence_pb2
  23. from syntaxnet.ops import gen_parser_ops
  24. flags = tf.app.flags
  25. FLAGS = flags.FLAGS
  26. flags.DEFINE_string('task_context',
  27. 'syntaxnet/models/parsey_mcparseface/context.pbtxt',
  28. 'Path to a task context with inputs and parameters for '
  29. 'feature extractors.')
  30. flags.DEFINE_string('corpus_name', 'stdin-conll',
  31. 'Path to a task context with inputs and parameters for '
  32. 'feature extractors.')
  33. def to_dict(sentence):
  34. """Builds a dictionary representing the parse tree of a sentence.
  35. Note that the suffix "@id" (where 'id' is a number) is appended to each
  36. element to handle the sentence that has multiple elements with identical
  37. representation. Those suffix needs to be removed after the asciitree is
  38. rendered.
  39. Args:
  40. sentence: Sentence protocol buffer to represent.
  41. Returns:
  42. Dictionary mapping tokens to children.
  43. """
  44. token_str = list()
  45. children = [[] for token in sentence.token]
  46. root = -1
  47. for i in range(0, len(sentence.token)):
  48. token = sentence.token[i]
  49. token_str.append('%s %s %s @%d' %
  50. (token.word, token.tag, token.label, (i+1)))
  51. if token.head == -1:
  52. root = i
  53. else:
  54. children[token.head].append(i)
  55. def _get_dict(i):
  56. d = collections.OrderedDict()
  57. for c in children[i]:
  58. d[token_str[c]] = _get_dict(c)
  59. return d
  60. tree = collections.OrderedDict()
  61. tree[token_str[root]] = _get_dict(root)
  62. return tree
  63. def main(unused_argv):
  64. logging.set_verbosity(logging.INFO)
  65. with tf.Session() as sess:
  66. src = gen_parser_ops.document_source(batch_size=32,
  67. corpus_name=FLAGS.corpus_name,
  68. task_context=FLAGS.task_context)
  69. sentence = sentence_pb2.Sentence()
  70. while True:
  71. documents, finished = sess.run(src)
  72. logging.info('Read %d documents', len(documents))
  73. for d in documents:
  74. sentence.ParseFromString(d)
  75. tr = asciitree.LeftAligned()
  76. d = to_dict(sentence)
  77. print 'Input: %s' % sentence.text
  78. print 'Parse:'
  79. tr_str = tr(d)
  80. pat = re.compile(r'\s*@\d+$')
  81. for tr_ln in tr_str.splitlines():
  82. print pat.sub('', tr_ln)
  83. if finished:
  84. break
  85. if __name__ == '__main__':
  86. tf.app.run()