sentence_io.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utilities for reading and writing sentences in dragnn."""
  16. import tensorflow as tf
  17. from syntaxnet.ops import gen_parser_ops
  18. class ConllSentenceReader(object):
  19. """A reader for conll files, with optional projectivizing."""
  20. def __init__(self, filepath, batch_size=32,
  21. projectivize=False, morph_to_pos=False):
  22. self._graph = tf.Graph()
  23. self._session = tf.Session(graph=self._graph)
  24. task_context_str = """
  25. input {
  26. name: 'documents'
  27. record_format: 'conll-sentence'
  28. Part {
  29. file_pattern: '%s'
  30. }
  31. }""" % filepath
  32. if morph_to_pos:
  33. task_context_str += """
  34. Parameter {
  35. name: "join_category_to_pos"
  36. value: "true"
  37. }
  38. Parameter {
  39. name: "add_pos_as_attribute"
  40. value: "true"
  41. }
  42. Parameter {
  43. name: "serialize_morph_to_pos"
  44. value: "true"
  45. }
  46. """
  47. with self._graph.as_default():
  48. self._source, self._is_last = gen_parser_ops.document_source(
  49. task_context_str=task_context_str, batch_size=batch_size)
  50. self._source = gen_parser_ops.well_formed_filter(self._source)
  51. if projectivize:
  52. self._source = gen_parser_ops.projectivize_filter(self._source)
  53. def read(self):
  54. """Reads a single batch of sentences."""
  55. if self._session:
  56. sentences, is_last = self._session.run([self._source, self._is_last])
  57. if is_last:
  58. self._session.close()
  59. self._session = None
  60. else:
  61. sentences, is_last = [], True
  62. return sentences, is_last
  63. def corpus(self):
  64. """Reads the entire corpus, and returns in a list."""
  65. tf.logging.info('Reading corpus...')
  66. corpus = []
  67. while True:
  68. sentences, is_last = self.read()
  69. corpus.extend(sentences)
  70. if is_last:
  71. break
  72. tf.logging.info('Read %d sentences.' % len(corpus))
  73. return corpus