data_utils.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A library for loading 1B word benchmark dataset."""
  16. import random
  17. import numpy as np
  18. import tensorflow as tf
  19. class Vocabulary(object):
  20. """Class that holds a vocabulary for the dataset."""
  21. def __init__(self, filename):
  22. """Initialize vocabulary.
  23. Args:
  24. filename: Vocabulary file name.
  25. """
  26. self._id_to_word = []
  27. self._word_to_id = {}
  28. self._unk = -1
  29. self._bos = -1
  30. self._eos = -1
  31. with tf.gfile.Open(filename) as f:
  32. idx = 0
  33. for line in f:
  34. word_name = line.strip()
  35. if word_name == '<S>':
  36. self._bos = idx
  37. elif word_name == '</S>':
  38. self._eos = idx
  39. elif word_name == '<UNK>':
  40. self._unk = idx
  41. if word_name == '!!!MAXTERMID':
  42. continue
  43. self._id_to_word.append(word_name)
  44. self._word_to_id[word_name] = idx
  45. idx += 1
  46. @property
  47. def bos(self):
  48. return self._bos
  49. @property
  50. def eos(self):
  51. return self._eos
  52. @property
  53. def unk(self):
  54. return self._unk
  55. @property
  56. def size(self):
  57. return len(self._id_to_word)
  58. def word_to_id(self, word):
  59. if word in self._word_to_id:
  60. return self._word_to_id[word]
  61. return self.unk
  62. def id_to_word(self, cur_id):
  63. if cur_id < self.size:
  64. return self._id_to_word[cur_id]
  65. return 'ERROR'
  66. def decode(self, cur_ids):
  67. """Convert a list of ids to a sentence, with space inserted."""
  68. return ' '.join([self.id_to_word(cur_id) for cur_id in cur_ids])
  69. def encode(self, sentence):
  70. """Convert a sentence to a list of ids, with special tokens added."""
  71. word_ids = [self.word_to_id(cur_word) for cur_word in sentence.split()]
  72. return np.array([self.bos] + word_ids + [self.eos], dtype=np.int32)
  73. class CharsVocabulary(Vocabulary):
  74. """Vocabulary containing character-level information."""
  75. def __init__(self, filename, max_word_length):
  76. super(CharsVocabulary, self).__init__(filename)
  77. self._max_word_length = max_word_length
  78. chars_set = set()
  79. for word in self._id_to_word:
  80. chars_set |= set(word)
  81. free_ids = []
  82. for i in range(256):
  83. if chr(i) in chars_set:
  84. continue
  85. free_ids.append(chr(i))
  86. if len(free_ids) < 5:
  87. raise ValueError('Not enough free char ids: %d' % len(free_ids))
  88. self.bos_char = free_ids[0] # <begin sentence>
  89. self.eos_char = free_ids[1] # <end sentence>
  90. self.bow_char = free_ids[2] # <begin word>
  91. self.eow_char = free_ids[3] # <end word>
  92. self.pad_char = free_ids[4] # <padding>
  93. chars_set |= {self.bos_char, self.eos_char, self.bow_char, self.eow_char,
  94. self.pad_char}
  95. self._char_set = chars_set
  96. num_words = len(self._id_to_word)
  97. self._word_char_ids = np.zeros([num_words, max_word_length], dtype=np.int32)
  98. self.bos_chars = self._convert_word_to_char_ids(self.bos_char)
  99. self.eos_chars = self._convert_word_to_char_ids(self.eos_char)
  100. for i, word in enumerate(self._id_to_word):
  101. self._word_char_ids[i] = self._convert_word_to_char_ids(word)
  102. @property
  103. def word_char_ids(self):
  104. return self._word_char_ids
  105. @property
  106. def max_word_length(self):
  107. return self._max_word_length
  108. def _convert_word_to_char_ids(self, word):
  109. code = np.zeros([self.max_word_length], dtype=np.int32)
  110. code[:] = ord(self.pad_char)
  111. if len(word) > self.max_word_length - 2:
  112. word = word[:self.max_word_length-2]
  113. cur_word = self.bow_char + word + self.eow_char
  114. for j in range(len(cur_word)):
  115. code[j] = ord(cur_word[j])
  116. return code
  117. def word_to_char_ids(self, word):
  118. if word in self._word_to_id:
  119. return self._word_char_ids[self._word_to_id[word]]
  120. else:
  121. return self._convert_word_to_char_ids(word)
  122. def encode_chars(self, sentence):
  123. chars_ids = [self.word_to_char_ids(cur_word)
  124. for cur_word in sentence.split()]
  125. return np.vstack([self.bos_chars] + chars_ids + [self.eos_chars])
  126. def get_batch(generator, batch_size, num_steps, max_word_length, pad=False):
  127. """Read batches of input."""
  128. cur_stream = [None] * batch_size
  129. inputs = np.zeros([batch_size, num_steps], np.int32)
  130. char_inputs = np.zeros([batch_size, num_steps, max_word_length], np.int32)
  131. global_word_ids = np.zeros([batch_size, num_steps], np.int32)
  132. targets = np.zeros([batch_size, num_steps], np.int32)
  133. weights = np.ones([batch_size, num_steps], np.float32)
  134. no_more_data = False
  135. while True:
  136. inputs[:] = 0
  137. char_inputs[:] = 0
  138. global_word_ids[:] = 0
  139. targets[:] = 0
  140. weights[:] = 0.0
  141. for i in range(batch_size):
  142. cur_pos = 0
  143. while cur_pos < num_steps:
  144. if cur_stream[i] is None or len(cur_stream[i][0]) <= 1:
  145. try:
  146. cur_stream[i] = list(generator.next())
  147. except StopIteration:
  148. # No more data, exhaust current streams and quit
  149. no_more_data = True
  150. break
  151. how_many = min(len(cur_stream[i][0]) - 1, num_steps - cur_pos)
  152. next_pos = cur_pos + how_many
  153. inputs[i, cur_pos:next_pos] = cur_stream[i][0][:how_many]
  154. char_inputs[i, cur_pos:next_pos] = cur_stream[i][1][:how_many]
  155. global_word_ids[i, cur_pos:next_pos] = cur_stream[i][2][:how_many]
  156. targets[i, cur_pos:next_pos] = cur_stream[i][0][1:how_many+1]
  157. weights[i, cur_pos:next_pos] = 1.0
  158. cur_pos = next_pos
  159. cur_stream[i][0] = cur_stream[i][0][how_many:]
  160. cur_stream[i][1] = cur_stream[i][1][how_many:]
  161. cur_stream[i][2] = cur_stream[i][2][how_many:]
  162. if pad:
  163. break
  164. if no_more_data and np.sum(weights) == 0:
  165. # There is no more data and this is an empty batch. Done!
  166. break
  167. yield inputs, char_inputs, global_word_ids, targets, weights
  168. class LM1BDataset(object):
  169. """Utility class for 1B word benchmark dataset.
  170. The current implementation reads the data from the tokenized text files.
  171. """
  172. def __init__(self, filepattern, vocab):
  173. """Initialize LM1BDataset reader.
  174. Args:
  175. filepattern: Dataset file pattern.
  176. vocab: Vocabulary.
  177. """
  178. self._vocab = vocab
  179. self._all_shards = tf.gfile.Glob(filepattern)
  180. tf.logging.info('Found %d shards at %s', len(self._all_shards), filepattern)
  181. def _load_random_shard(self):
  182. """Randomly select a file and read it."""
  183. return self._load_shard(random.choice(self._all_shards))
  184. def _load_shard(self, shard_name):
  185. """Read one file and convert to ids.
  186. Args:
  187. shard_name: file path.
  188. Returns:
  189. list of (id, char_id, global_word_id) tuples.
  190. """
  191. tf.logging.info('Loading data from: %s', shard_name)
  192. with tf.gfile.Open(shard_name) as f:
  193. sentences = f.readlines()
  194. chars_ids = [self.vocab.encode_chars(sentence) for sentence in sentences]
  195. ids = [self.vocab.encode(sentence) for sentence in sentences]
  196. global_word_ids = []
  197. current_idx = 0
  198. for word_ids in ids:
  199. current_size = len(word_ids) - 1 # without <BOS> symbol
  200. cur_ids = np.arange(current_idx, current_idx + current_size)
  201. global_word_ids.append(cur_ids)
  202. current_idx += current_size
  203. tf.logging.info('Loaded %d words.', current_idx)
  204. tf.logging.info('Finished loading')
  205. return zip(ids, chars_ids, global_word_ids)
  206. def _get_sentence(self, forever=True):
  207. while True:
  208. ids = self._load_random_shard()
  209. for current_ids in ids:
  210. yield current_ids
  211. if not forever:
  212. break
  213. def get_batch(self, batch_size, num_steps, pad=False, forever=True):
  214. return get_batch(self._get_sentence(forever), batch_size, num_steps,
  215. self.vocab.max_word_length, pad=pad)
  216. @property
  217. def vocab(self):
  218. return self._vocab