data.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Data batchers for data described in ..//data_prep/README.md."""
  16. import glob
  17. import random
  18. import struct
  19. import sys
  20. from tensorflow.core.example import example_pb2
  21. # Special tokens
  22. PARAGRAPH_START = '<p>'
  23. PARAGRAPH_END = '</p>'
  24. SENTENCE_START = '<s>'
  25. SENTENCE_END = '</s>'
  26. UNKNOWN_TOKEN = '<UNK>'
  27. PAD_TOKEN = '<PAD>'
  28. DOCUMENT_START = '<d>'
  29. DOCUMENT_END = '</d>'
  30. class Vocab(object):
  31. """Vocabulary class for mapping words and ids."""
  32. def __init__(self, vocab_file, max_size):
  33. self._word_to_id = {}
  34. self._id_to_word = {}
  35. self._count = 0
  36. with open(vocab_file, 'r') as vocab_f:
  37. for line in vocab_f:
  38. pieces = line.split()
  39. if len(pieces) != 2:
  40. sys.stderr.write('Bad line: %s\n' % line)
  41. continue
  42. if pieces[0] in self._word_to_id:
  43. raise ValueError('Duplicated word: %s.' % pieces[0])
  44. self._word_to_id[pieces[0]] = self._count
  45. self._id_to_word[self._count] = pieces[0]
  46. self._count += 1
  47. if self._count > max_size:
  48. raise ValueError('Too many words: >%d.' % max_size)
  49. def CheckVocab(self, word):
  50. if word not in self._word_to_id:
  51. return None
  52. return self._word_to_id[word]
  53. def WordToId(self, word):
  54. if word not in self._word_to_id:
  55. return self._word_to_id[UNKNOWN_TOKEN]
  56. return self._word_to_id[word]
  57. def IdToWord(self, word_id):
  58. if word_id not in self._id_to_word:
  59. raise ValueError('id not found in vocab: %d.' % word_id)
  60. return self._id_to_word[word_id]
  61. def NumIds(self):
  62. return self._count
  63. def ExampleGen(data_path, num_epochs=None):
  64. """Generates tf.Examples from path of data files.
  65. Binary data format: <length><blob>. <length> represents the byte size
  66. of <blob>. <blob> is serialized tf.Example proto. The tf.Example contains
  67. the tokenized article text and summary.
  68. Args:
  69. data_path: path to tf.Example data files.
  70. num_epochs: Number of times to go through the data. None means infinite.
  71. Yields:
  72. Deserialized tf.Example.
  73. If there are multiple files specified, they accessed in a random order.
  74. """
  75. epoch = 0
  76. while True:
  77. if num_epochs is not None and epoch >= num_epochs:
  78. break
  79. filelist = glob.glob(data_path)
  80. assert filelist, 'Empty filelist.'
  81. random.shuffle(filelist)
  82. for f in filelist:
  83. reader = open(f, 'rb')
  84. while True:
  85. len_bytes = reader.read(8)
  86. if not len_bytes: break
  87. str_len = struct.unpack('q', len_bytes)[0]
  88. example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
  89. yield example_pb2.Example.FromString(example_str)
  90. epoch += 1
  91. def Pad(ids, pad_id, length):
  92. """Pad or trim list to len length.
  93. Args:
  94. ids: list of ints to pad
  95. pad_id: what to pad with
  96. length: length to pad or trim to
  97. Returns:
  98. ids trimmed or padded with pad_id
  99. """
  100. assert pad_id is not None
  101. assert length is not None
  102. if len(ids) < length:
  103. a = [pad_id] * (length - len(ids))
  104. return ids + a
  105. else:
  106. return ids[:length]
  107. def GetWordIds(text, vocab, pad_len=None, pad_id=None):
  108. """Get ids corresponding to words in text.
  109. Assumes tokens separated by space.
  110. Args:
  111. text: a string
  112. vocab: TextVocabularyFile object
  113. pad_len: int, length to pad to
  114. pad_id: int, word id for pad symbol
  115. Returns:
  116. A list of ints representing word ids.
  117. """
  118. ids = []
  119. for w in text.split():
  120. i = vocab.WordToId(w)
  121. if i >= 0:
  122. ids.append(i)
  123. else:
  124. ids.append(vocab.WordToId(UNKNOWN_TOKEN))
  125. if pad_len is not None:
  126. return Pad(ids, pad_id, pad_len)
  127. return ids
  128. def Ids2Words(ids_list, vocab):
  129. """Get words from ids.
  130. Args:
  131. ids_list: list of int32
  132. vocab: TextVocabulary object
  133. Returns:
  134. List of words corresponding to ids.
  135. """
  136. assert isinstance(ids_list, list), '%s is not a list' % ids_list
  137. return [vocab.IdToWord(i) for i in ids_list]
  138. def SnippetGen(text, start_tok, end_tok, inclusive=True):
  139. """Generates consecutive snippets between start and end tokens.
  140. Args:
  141. text: a string
  142. start_tok: a string denoting the start of snippets
  143. end_tok: a string denoting the end of snippets
  144. inclusive: Whether include the tokens in the returned snippets.
  145. Yields:
  146. String snippets
  147. """
  148. cur = 0
  149. while True:
  150. try:
  151. start_p = text.index(start_tok, cur)
  152. end_p = text.index(end_tok, start_p + 1)
  153. cur = end_p + len(end_tok)
  154. if inclusive:
  155. yield text[start_p:cur]
  156. else:
  157. yield text[start_p+len(start_tok):end_p]
  158. except ValueError as e:
  159. raise StopIteration('no more snippets in text: %s' % e)
  160. def GetExFeatureText(ex, key):
  161. return ex.features.feature[key].bytes_list.value[0]
  162. def ToSentences(paragraph, include_token=True):
  163. """Takes tokens of a paragraph and returns list of sentences.
  164. Args:
  165. paragraph: string, text of paragraph
  166. include_token: Whether include the sentence separation tokens result.
  167. Returns:
  168. List of sentence strings.
  169. """
  170. s_gen = SnippetGen(paragraph, SENTENCE_START, SENTENCE_END, include_token)
  171. return [s for s in s_gen]