data.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Data batchers for data described in ..//data_prep/README.md."""
  16. import glob
  17. import random
  18. import struct
  19. import sys
  20. from tensorflow.core.example import example_pb2
  21. # Special tokens
  22. PARAGRAPH_START = '<p>'
  23. PARAGRAPH_END = '</p>'
  24. SENTENCE_START = '<s>'
  25. SENTENCE_END = '</s>'
  26. UNKNOWN_TOKEN = '<UNK>'
  27. PAD_TOKEN = '<PAD>'
  28. DOCUMENT_START = '<d>'
  29. DOCUMENT_END = '</d>'
  30. class Vocab(object):
  31. """Vocabulary class for mapping words and ids."""
  32. def __init__(self, vocab_file, max_size):
  33. self._word_to_id = {}
  34. self._id_to_word = {}
  35. self._count = 0
  36. with open(vocab_file, 'r') as vocab_f:
  37. for line in vocab_f:
  38. pieces = line.split()
  39. if len(pieces) != 2:
  40. sys.stderr.write('Bad line: %s\n' % line)
  41. continue
  42. if pieces[0] in self._word_to_id:
  43. raise ValueError('Duplicated word: %s.' % pieces[0])
  44. self._word_to_id[pieces[0]] = self._count
  45. self._id_to_word[self._count] = pieces[0]
  46. self._count += 1
  47. if self._count > max_size:
  48. raise ValueError('Too many words: >%d.' % max_size)
  49. def WordToId(self, word):
  50. if word not in self._word_to_id:
  51. return self._word_to_id[UNKNOWN_TOKEN]
  52. return self._word_to_id[word]
  53. def IdToWord(self, word_id):
  54. if word_id not in self._id_to_word:
  55. raise ValueError('id not found in vocab: %d.' % word_id)
  56. return self._id_to_word[word_id]
  57. def NumIds(self):
  58. return self._count
  59. def ExampleGen(recordio_path, num_epochs=None):
  60. """Generates tf.Examples from path of recordio files.
  61. Args:
  62. recordio_path: CNS path to tf.Example recordio
  63. num_epochs: Number of times to go through the data. None means infinite.
  64. Yields:
  65. Deserialized tf.Example.
  66. If there are multiple files specified, they accessed in a random order.
  67. """
  68. epoch = 0
  69. while True:
  70. if num_epochs is not None and epoch >= num_epochs:
  71. break
  72. filelist = glob.glob(recordio_path)
  73. assert filelist, 'Empty filelist.'
  74. random.shuffle(filelist)
  75. for f in filelist:
  76. reader = open(f, 'rb')
  77. while True:
  78. len_bytes = reader.read(8)
  79. if not len_bytes: break
  80. str_len = struct.unpack('q', len_bytes)[0]
  81. example_str = struct.unpack('%ds' % str_len, reader.read(str_len))[0]
  82. yield example_pb2.Example.FromString(example_str)
  83. epoch += 1
  84. def Pad(ids, pad_id, length):
  85. """Pad or trim list to len length.
  86. Args:
  87. ids: list of ints to pad
  88. pad_id: what to pad with
  89. length: length to pad or trim to
  90. Returns:
  91. ids trimmed or padded with pad_id
  92. """
  93. assert pad_id is not None
  94. assert length is not None
  95. if len(ids) < length:
  96. a = [pad_id] * (length - len(ids))
  97. return ids + a
  98. else:
  99. return ids[:length]
  100. def GetWordIds(text, vocab, pad_len=None, pad_id=None):
  101. """Get ids corresponding to words in text.
  102. Assumes tokens separated by space.
  103. Args:
  104. text: a string
  105. vocab: TextVocabularyFile object
  106. pad_len: int, length to pad to
  107. pad_id: int, word id for pad symbol
  108. Returns:
  109. A list of ints representing word ids.
  110. """
  111. ids = []
  112. for w in text.split():
  113. i = vocab.WordToId(w)
  114. if i >= 0:
  115. ids.append(i)
  116. else:
  117. ids.append(vocab.WordToId(UNKNOWN_TOKEN))
  118. if pad_len is not None:
  119. return Pad(ids, pad_id, pad_len)
  120. return ids
  121. def Ids2Words(ids_list, vocab):
  122. """Get words from ids.
  123. Args:
  124. ids_list: list of int32
  125. vocab: TextVocabulary object
  126. Returns:
  127. List of words corresponding to ids.
  128. """
  129. assert isinstance(ids_list, list), '%s is not a list' % ids_list
  130. return [vocab.IdToWord(i) for i in ids_list]
  131. def SnippetGen(text, start_tok, end_tok, inclusive=True):
  132. """Generates consecutive snippets between start and end tokens.
  133. Args:
  134. text: a string
  135. start_tok: a string denoting the start of snippets
  136. end_tok: a string denoting the end of snippets
  137. inclusive: Whether include the tokens in the returned snippets.
  138. Yields:
  139. String snippets
  140. """
  141. cur = 0
  142. while True:
  143. try:
  144. start_p = text.index(start_tok, cur)
  145. end_p = text.index(end_tok, start_p + 1)
  146. cur = end_p + len(end_tok)
  147. if inclusive:
  148. yield text[start_p:cur]
  149. else:
  150. yield text[start_p+len(start_tok):end_p]
  151. except ValueError as e:
  152. raise StopIteration('no more snippets in text: %s' % e)
  153. def GetExFeatureText(ex, key):
  154. return ex.features.feature[key].bytes_list.value[0]
  155. def ToSentences(paragraph, include_token=True):
  156. """Takes tokens of a paragraph and returns list of sentences.
  157. Args:
  158. paragraph: string, text of paragraph
  159. include_token: Whether include the sentence separation tokens result.
  160. Returns:
  161. List of sentence strings.
  162. """
  163. s_gen = SnippetGen(paragraph, SENTENCE_START, SENTENCE_END, include_token)
  164. return [s for s in s_gen]