# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Data batchers for data described in ..//data_prep/README.md.""" import glob import random import struct import sys from tensorflow.core.example import example_pb2 # Special tokens PARAGRAPH_START = '

' PARAGRAPH_END = '

' SENTENCE_START = '' SENTENCE_END = '' UNKNOWN_TOKEN = '' PAD_TOKEN = '' DOCUMENT_START = '' DOCUMENT_END = '' class Vocab(object): """Vocabulary class for mapping words and ids.""" def __init__(self, vocab_file, max_size): self._word_to_id = {} self._id_to_word = {} self._count = 0 with open(vocab_file, 'r') as vocab_f: for line in vocab_f: pieces = line.split() if len(pieces) != 2: sys.stderr.write('Bad line: %s\n' % line) continue if pieces[0] in self._word_to_id: raise ValueError('Duplicated word: %s.' % pieces[0]) self._word_to_id[pieces[0]] = self._count self._id_to_word[self._count] = pieces[0] self._count += 1 if self._count > max_size: raise ValueError('Too many words: >%d.' % max_size) def CheckVocab(self, word): if word not in self._word_to_id: return None return self._word_to_id[word] def WordToId(self, word): if word not in self._word_to_id: return self._word_to_id[UNKNOWN_TOKEN] return self._word_to_id[word] def IdToWord(self, word_id): if word_id not in self._id_to_word: raise ValueError('id not found in vocab: %d.' % word_id) return self._id_to_word[word_id] def NumIds(self): return self._count def ExampleGen(data_path, num_epochs=None): """Generates tf.Examples from path of data files. Binary data format: . represents the byte size of . is serialized tf.Example proto. The tf.Example contains the tokenized article text and summary. Args: data_path: path to tf.Example data files. num_epochs: Number of times to go through the data. None means infinite. Yields: Deserialized tf.Example. If there are multiple files specified, they accessed in a random order. """ epoch = 0 while True: if num_epochs is not None and epoch >= num_epochs: break filelist = glob.glob(data_path) assert filelist, 'Empty filelist.' random.shuffle(filelist) for f in filelist: reader = open(f, 'rb') while True: len_bytes = reader.read(8) if not len_bytes: break str_len = struct.unpack('q', len_bytes)[0] example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0] yield example_pb2.Example.FromString(example_str) epoch += 1 def Pad(ids, pad_id, length): """Pad or trim list to len length. Args: ids: list of ints to pad pad_id: what to pad with length: length to pad or trim to Returns: ids trimmed or padded with pad_id """ assert pad_id is not None assert length is not None if len(ids) < length: a = [pad_id] * (length - len(ids)) return ids + a else: return ids[:length] def GetWordIds(text, vocab, pad_len=None, pad_id=None): """Get ids corresponding to words in text. Assumes tokens separated by space. Args: text: a string vocab: TextVocabularyFile object pad_len: int, length to pad to pad_id: int, word id for pad symbol Returns: A list of ints representing word ids. """ ids = [] for w in text.split(): i = vocab.WordToId(w) if i >= 0: ids.append(i) else: ids.append(vocab.WordToId(UNKNOWN_TOKEN)) if pad_len is not None: return Pad(ids, pad_id, pad_len) return ids def Ids2Words(ids_list, vocab): """Get words from ids. Args: ids_list: list of int32 vocab: TextVocabulary object Returns: List of words corresponding to ids. """ assert isinstance(ids_list, list), '%s is not a list' % ids_list return [vocab.IdToWord(i) for i in ids_list] def SnippetGen(text, start_tok, end_tok, inclusive=True): """Generates consecutive snippets between start and end tokens. Args: text: a string start_tok: a string denoting the start of snippets end_tok: a string denoting the end of snippets inclusive: Whether include the tokens in the returned snippets. Yields: String snippets """ cur = 0 while True: try: start_p = text.index(start_tok, cur) end_p = text.index(end_tok, start_p + 1) cur = end_p + len(end_tok) if inclusive: yield text[start_p:cur] else: yield text[start_p+len(start_tok):end_p] except ValueError as e: raise StopIteration('no more snippets in text: %s' % e) def GetExFeatureText(ex, key): return ex.features.feature[key].bytes_list.value[0] def ToSentences(paragraph, include_token=True): """Takes tokens of a paragraph and returns list of sentences. Args: paragraph: string, text of paragraph include_token: Whether include the sentence separation tokens result. Returns: List of sentence strings. """ s_gen = SnippetGen(paragraph, SENTENCE_START, SENTENCE_END, include_token) return [s for s in s_gen]