batch_reader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Batch reader to seq2seq attention model, with bucketing support."""
  16. from collections import namedtuple
  17. from random import shuffle
  18. from threading import Thread
  19. import time
  20. import numpy as np
  21. from six.moves import queue as Queue
  22. from six.moves import xrange
  23. import tensorflow as tf
  24. import data
  25. ModelInput = namedtuple('ModelInput',
  26. 'enc_input dec_input target enc_len dec_len '
  27. 'origin_article origin_abstract')
  28. BUCKET_CACHE_BATCH = 100
  29. QUEUE_NUM_BATCH = 100
  30. class Batcher(object):
  31. """Batch reader with shuffling and bucketing support."""
  32. def __init__(self, data_path, vocab, hps,
  33. article_key, abstract_key, max_article_sentences,
  34. max_abstract_sentences, bucketing=True, truncate_input=False):
  35. """Batcher constructor.
  36. Args:
  37. data_path: tf.Example filepattern.
  38. vocab: Vocabulary.
  39. hps: Seq2SeqAttention model hyperparameters.
  40. article_key: article feature key in tf.Example.
  41. abstract_key: abstract feature key in tf.Example.
  42. max_article_sentences: Max number of sentences used from article.
  43. max_abstract_sentences: Max number of sentences used from abstract.
  44. bucketing: Whether bucket articles of similar length into the same batch.
  45. truncate_input: Whether to truncate input that is too long. Alternative is
  46. to discard such examples.
  47. """
  48. self._data_path = data_path
  49. self._vocab = vocab
  50. self._hps = hps
  51. self._article_key = article_key
  52. self._abstract_key = abstract_key
  53. self._max_article_sentences = max_article_sentences
  54. self._max_abstract_sentences = max_abstract_sentences
  55. self._bucketing = bucketing
  56. self._truncate_input = truncate_input
  57. self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
  58. self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
  59. self._input_threads = []
  60. for _ in xrange(16):
  61. self._input_threads.append(Thread(target=self._FillInputQueue))
  62. self._input_threads[-1].daemon = True
  63. self._input_threads[-1].start()
  64. self._bucketing_threads = []
  65. for _ in xrange(4):
  66. self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
  67. self._bucketing_threads[-1].daemon = True
  68. self._bucketing_threads[-1].start()
  69. self._watch_thread = Thread(target=self._WatchThreads)
  70. self._watch_thread.daemon = True
  71. self._watch_thread.start()
  72. def NextBatch(self):
  73. """Returns a batch of inputs for seq2seq attention model.
  74. Returns:
  75. enc_batch: A batch of encoder inputs [batch_size, hps.enc_timestamps].
  76. dec_batch: A batch of decoder inputs [batch_size, hps.dec_timestamps].
  77. target_batch: A batch of targets [batch_size, hps.dec_timestamps].
  78. enc_input_len: encoder input lengths of the batch.
  79. dec_input_len: decoder input lengths of the batch.
  80. loss_weights: weights for loss function, 1 if not padded, 0 if padded.
  81. origin_articles: original article words.
  82. origin_abstracts: original abstract words.
  83. """
  84. enc_batch = np.zeros(
  85. (self._hps.batch_size, self._hps.enc_timesteps), dtype=np.int32)
  86. enc_input_lens = np.zeros(
  87. (self._hps.batch_size), dtype=np.int32)
  88. dec_batch = np.zeros(
  89. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
  90. dec_output_lens = np.zeros(
  91. (self._hps.batch_size), dtype=np.int32)
  92. target_batch = np.zeros(
  93. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
  94. loss_weights = np.zeros(
  95. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.float32)
  96. origin_articles = ['None'] * self._hps.batch_size
  97. origin_abstracts = ['None'] * self._hps.batch_size
  98. buckets = self._bucket_input_queue.get()
  99. for i in xrange(self._hps.batch_size):
  100. (enc_inputs, dec_inputs, targets, enc_input_len, dec_output_len,
  101. article, abstract) = buckets[i]
  102. origin_articles[i] = article
  103. origin_abstracts[i] = abstract
  104. enc_input_lens[i] = enc_input_len
  105. dec_output_lens[i] = dec_output_len
  106. enc_batch[i, :] = enc_inputs[:]
  107. dec_batch[i, :] = dec_inputs[:]
  108. target_batch[i, :] = targets[:]
  109. for j in xrange(dec_output_len):
  110. loss_weights[i][j] = 1
  111. return (enc_batch, dec_batch, target_batch, enc_input_lens, dec_output_lens,
  112. loss_weights, origin_articles, origin_abstracts)
  113. def _FillInputQueue(self):
  114. """Fill input queue with ModelInput."""
  115. start_id = self._vocab.WordToId(data.SENTENCE_START)
  116. end_id = self._vocab.WordToId(data.SENTENCE_END)
  117. pad_id = self._vocab.WordToId(data.PAD_TOKEN)
  118. input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
  119. while True:
  120. (article, abstract) = input_gen.next()
  121. article_sentences = [sent.strip() for sent in
  122. data.ToSentences(article, include_token=False)]
  123. abstract_sentences = [sent.strip() for sent in
  124. data.ToSentences(abstract, include_token=False)]
  125. enc_inputs = []
  126. # Use the <s> as the <GO> symbol for decoder inputs.
  127. dec_inputs = [start_id]
  128. # Convert first N sentences to word IDs, stripping existing <s> and </s>.
  129. for i in xrange(min(self._max_article_sentences,
  130. len(article_sentences))):
  131. enc_inputs += data.GetWordIds(article_sentences[i], self._vocab)
  132. for i in xrange(min(self._max_abstract_sentences,
  133. len(abstract_sentences))):
  134. dec_inputs += data.GetWordIds(abstract_sentences[i], self._vocab)
  135. # Filter out too-short input
  136. if (len(enc_inputs) < self._hps.min_input_len or
  137. len(dec_inputs) < self._hps.min_input_len):
  138. tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
  139. len(enc_inputs), len(dec_inputs))
  140. continue
  141. # If we're not truncating input, throw out too-long input
  142. if not self._truncate_input:
  143. if (len(enc_inputs) > self._hps.enc_timesteps or
  144. len(dec_inputs) > self._hps.dec_timesteps):
  145. tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
  146. len(enc_inputs), len(dec_inputs))
  147. continue
  148. # If we are truncating input, do so if necessary
  149. else:
  150. if len(enc_inputs) > self._hps.enc_timesteps:
  151. enc_inputs = enc_inputs[:self._hps.enc_timesteps]
  152. if len(dec_inputs) > self._hps.dec_timesteps:
  153. dec_inputs = dec_inputs[:self._hps.dec_timesteps]
  154. # targets is dec_inputs without <s> at beginning, plus </s> at end
  155. targets = dec_inputs[1:]
  156. targets.append(end_id)
  157. # Now len(enc_inputs) should be <= enc_timesteps, and
  158. # len(targets) = len(dec_inputs) should be <= dec_timesteps
  159. enc_input_len = len(enc_inputs)
  160. dec_output_len = len(targets)
  161. # Pad if necessary
  162. while len(enc_inputs) < self._hps.enc_timesteps:
  163. enc_inputs.append(pad_id)
  164. while len(dec_inputs) < self._hps.dec_timesteps:
  165. dec_inputs.append(end_id)
  166. while len(targets) < self._hps.dec_timesteps:
  167. targets.append(end_id)
  168. element = ModelInput(enc_inputs, dec_inputs, targets, enc_input_len,
  169. dec_output_len, ' '.join(article_sentences),
  170. ' '.join(abstract_sentences))
  171. self._input_queue.put(element)
  172. def _FillBucketInputQueue(self):
  173. """Fill bucketed batches into the bucket_input_queue."""
  174. while True:
  175. inputs = []
  176. for _ in xrange(self._hps.batch_size * BUCKET_CACHE_BATCH):
  177. inputs.append(self._input_queue.get())
  178. if self._bucketing:
  179. inputs = sorted(inputs, key=lambda inp: inp.enc_len)
  180. batches = []
  181. for i in xrange(0, len(inputs), self._hps.batch_size):
  182. batches.append(inputs[i:i+self._hps.batch_size])
  183. shuffle(batches)
  184. for b in batches:
  185. self._bucket_input_queue.put(b)
  186. def _WatchThreads(self):
  187. """Watch the daemon input threads and restart if dead."""
  188. while True:
  189. time.sleep(60)
  190. input_threads = []
  191. for t in self._input_threads:
  192. if t.is_alive():
  193. input_threads.append(t)
  194. else:
  195. tf.logging.error('Found input thread dead.')
  196. new_t = Thread(target=self._FillInputQueue)
  197. input_threads.append(new_t)
  198. input_threads[-1].daemon = True
  199. input_threads[-1].start()
  200. self._input_threads = input_threads
  201. bucketing_threads = []
  202. for t in self._bucketing_threads:
  203. if t.is_alive():
  204. bucketing_threads.append(t)
  205. else:
  206. tf.logging.error('Found bucketing thread dead.')
  207. new_t = Thread(target=self._FillBucketInputQueue)
  208. bucketing_threads.append(new_t)
  209. bucketing_threads[-1].daemon = True
  210. bucketing_threads[-1].start()
  211. self._bucketing_threads = bucketing_threads
  212. def _TextGenerator(self, example_gen):
  213. """Generates article and abstract text from tf.Example."""
  214. while True:
  215. e = example_gen.next()
  216. try:
  217. article_text = self._GetExFeatureText(e, self._article_key)
  218. abstract_text = self._GetExFeatureText(e, self._abstract_key)
  219. except ValueError:
  220. tf.logging.error('Failed to get article or abstract from example')
  221. continue
  222. yield (article_text, abstract_text)
  223. def _GetExFeatureText(self, ex, key):
  224. """Extract text for a feature from td.Example.
  225. Args:
  226. ex: tf.Example.
  227. key: key of the feature to be extracted.
  228. Returns:
  229. feature: a feature text extracted.
  230. """
  231. return ex.features.feature[key].bytes_list.value[0]