1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Utilities for reading and writing sentences in dragnn."""
- import tensorflow as tf
- from syntaxnet.ops import gen_parser_ops
- class ConllSentenceReader(object):
- """A reader for conll files, with optional projectivizing."""
- def __init__(self, filepath, batch_size=32,
- projectivize=False, morph_to_pos=False):
- self._graph = tf.Graph()
- self._session = tf.Session(graph=self._graph)
- task_context_str = """
- input {
- name: 'documents'
- record_format: 'conll-sentence'
- Part {
- file_pattern: '%s'
- }
- }""" % filepath
- if morph_to_pos:
- task_context_str += """
- Parameter {
- name: "join_category_to_pos"
- value: "true"
- }
- Parameter {
- name: "add_pos_as_attribute"
- value: "true"
- }
- Parameter {
- name: "serialize_morph_to_pos"
- value: "true"
- }
- """
- with self._graph.as_default():
- self._source, self._is_last = gen_parser_ops.document_source(
- task_context_str=task_context_str, batch_size=batch_size)
- self._source = gen_parser_ops.well_formed_filter(self._source)
- if projectivize:
- self._source = gen_parser_ops.projectivize_filter(self._source)
- def read(self):
- """Reads a single batch of sentences."""
- if self._session:
- sentences, is_last = self._session.run([self._source, self._is_last])
- if is_last:
- self._session.close()
- self._session = None
- else:
- sentences, is_last = [], True
- return sentences, is_last
- def corpus(self):
- """Reads the entire corpus, and returns in a list."""
- tf.logging.info('Reading corpus...')
- corpus = []
- while True:
- sentences, is_last = self.read()
- corpus.extend(sentences)
- if is_last:
- break
- tf.logging.info('Read %d sentences.' % len(corpus))
- return corpus
|