batch_reader.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Batch reader to seq2seq attention model, with bucketing support."""
  16. from collections import namedtuple
  17. import Queue
  18. from random import shuffle
  19. from threading import Thread
  20. import time
  21. import numpy as np
  22. import tensorflow as tf
  23. import data
  24. ModelInput = namedtuple('ModelInput',
  25. 'enc_input dec_input target enc_len dec_len '
  26. 'origin_article origin_abstract')
  27. BUCKET_CACHE_BATCH = 100
  28. QUEUE_NUM_BATCH = 100
  29. class Batcher(object):
  30. """Batch reader with shuffling and bucketing support."""
  31. def __init__(self, data_path, vocab, hps,
  32. article_key, abstract_key, max_article_sentences,
  33. max_abstract_sentences, bucketing=True, truncate_input=False):
  34. """Batcher constructor.
  35. Args:
  36. data_path: tf.Example filepattern.
  37. vocab: Vocabulary.
  38. hps: Seq2SeqAttention model hyperparameters.
  39. article_key: article feature key in tf.Example.
  40. abstract_key: abstract feature key in tf.Example.
  41. max_article_sentences: Max number of sentences used from article.
  42. max_abstract_sentences: Max number of sentences used from abstract.
  43. bucketing: Whether bucket articles of similar length into the same batch.
  44. truncate_input: Whether to truncate input that is too long. Alternative is
  45. to discard such examples.
  46. """
  47. self._data_path = data_path
  48. self._vocab = vocab
  49. self._hps = hps
  50. self._article_key = article_key
  51. self._abstract_key = abstract_key
  52. self._max_article_sentences = max_article_sentences
  53. self._max_abstract_sentences = max_abstract_sentences
  54. self._bucketing = bucketing
  55. self._truncate_input = truncate_input
  56. self._input_queue = Queue.Queue(QUEUE_NUM_BATCH * self._hps.batch_size)
  57. self._bucket_input_queue = Queue.Queue(QUEUE_NUM_BATCH)
  58. self._input_threads = []
  59. for _ in xrange(16):
  60. self._input_threads.append(Thread(target=self._FillInputQueue))
  61. self._input_threads[-1].daemon = True
  62. self._input_threads[-1].start()
  63. self._bucketing_threads = []
  64. for _ in xrange(4):
  65. self._bucketing_threads.append(Thread(target=self._FillBucketInputQueue))
  66. self._bucketing_threads[-1].daemon = True
  67. self._bucketing_threads[-1].start()
  68. self._watch_thread = Thread(target=self._WatchThreads)
  69. self._watch_thread.daemon = True
  70. self._watch_thread.start()
  71. def NextBatch(self):
  72. """Returns a batch of inputs for seq2seq attention model.
  73. Returns:
  74. enc_batch: A batch of encoder inputs [batch_size, hps.enc_timestamps].
  75. dec_batch: A batch of decoder inputs [batch_size, hps.dec_timestamps].
  76. target_batch: A batch of targets [batch_size, hps.dec_timestamps].
  77. enc_input_len: encoder input lengths of the batch.
  78. dec_input_len: decoder input lengths of the batch.
  79. loss_weights: weights for loss function, 1 if not padded, 0 if padded.
  80. origin_articles: original article words.
  81. origin_abstracts: original abstract words.
  82. """
  83. enc_batch = np.zeros(
  84. (self._hps.batch_size, self._hps.enc_timesteps), dtype=np.int32)
  85. enc_input_lens = np.zeros(
  86. (self._hps.batch_size), dtype=np.int32)
  87. dec_batch = np.zeros(
  88. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
  89. dec_output_lens = np.zeros(
  90. (self._hps.batch_size), dtype=np.int32)
  91. target_batch = np.zeros(
  92. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.int32)
  93. loss_weights = np.zeros(
  94. (self._hps.batch_size, self._hps.dec_timesteps), dtype=np.float32)
  95. origin_articles = ['None'] * self._hps.batch_size
  96. origin_abstracts = ['None'] * self._hps.batch_size
  97. buckets = self._bucket_input_queue.get()
  98. for i in xrange(self._hps.batch_size):
  99. (enc_inputs, dec_inputs, targets, enc_input_len, dec_output_len,
  100. article, abstract) = buckets[i]
  101. origin_articles[i] = article
  102. origin_abstracts[i] = abstract
  103. enc_input_lens[i] = enc_input_len
  104. dec_output_lens[i] = dec_output_len
  105. enc_batch[i, :] = enc_inputs[:]
  106. dec_batch[i, :] = dec_inputs[:]
  107. target_batch[i, :] = targets[:]
  108. for j in xrange(dec_output_len):
  109. loss_weights[i][j] = 1
  110. return (enc_batch, dec_batch, target_batch, enc_input_lens, dec_output_lens,
  111. loss_weights, origin_articles, origin_abstracts)
  112. def _FillInputQueue(self):
  113. """Fill input queue with ModelInput."""
  114. start_id = self._vocab.WordToId(data.SENTENCE_START)
  115. end_id = self._vocab.WordToId(data.SENTENCE_END)
  116. pad_id = self._vocab.WordToId(data.PAD_TOKEN)
  117. input_gen = self._TextGenerator(data.ExampleGen(self._data_path))
  118. while True:
  119. (article, abstract) = input_gen.next()
  120. article_sentences = [sent.strip() for sent in
  121. data.ToSentences(article, include_token=False)]
  122. abstract_sentences = [sent.strip() for sent in
  123. data.ToSentences(abstract, include_token=False)]
  124. enc_inputs = []
  125. # Use the <s> as the <GO> symbol for decoder inputs.
  126. dec_inputs = [start_id]
  127. # Convert first N sentences to word IDs, stripping existing <s> and </s>.
  128. for i in xrange(min(self._max_article_sentences,
  129. len(article_sentences))):
  130. enc_inputs += data.GetWordIds(article_sentences[i], self._vocab)
  131. for i in xrange(min(self._max_abstract_sentences,
  132. len(abstract_sentences))):
  133. dec_inputs += data.GetWordIds(abstract_sentences[i], self._vocab)
  134. # Filter out too-short input
  135. if (len(enc_inputs) < self._hps.min_input_len or
  136. len(dec_inputs) < self._hps.min_input_len):
  137. tf.logging.warning('Drop an example - too short.\nenc:%d\ndec:%d',
  138. len(enc_inputs), len(dec_inputs))
  139. continue
  140. # If we're not truncating input, throw out too-long input
  141. if not self._truncate_input:
  142. if (len(enc_inputs) > self._hps.enc_timesteps or
  143. len(dec_inputs) > self._hps.dec_timesteps):
  144. tf.logging.warning('Drop an example - too long.\nenc:%d\ndec:%d',
  145. len(enc_inputs), len(dec_inputs))
  146. continue
  147. # If we are truncating input, do so if necessary
  148. else:
  149. if len(enc_inputs) > self._hps.enc_timesteps:
  150. enc_inputs = enc_inputs[:self._hps.enc_timesteps]
  151. if len(dec_inputs) > self._hps.dec_timesteps:
  152. dec_inputs = dec_inputs[:self._hps.dec_timesteps]
  153. # targets is dec_inputs without <s> at beginning, plus </s> at end
  154. targets = dec_inputs[1:]
  155. targets.append(end_id)
  156. # Now len(enc_inputs) should be <= enc_timesteps, and
  157. # len(targets) = len(dec_inputs) should be <= dec_timesteps
  158. enc_input_len = len(enc_inputs)
  159. dec_output_len = len(targets)
  160. # Pad if necessary
  161. while len(enc_inputs) < self._hps.enc_timesteps:
  162. enc_inputs.append(pad_id)
  163. while len(dec_inputs) < self._hps.dec_timesteps:
  164. dec_inputs.append(end_id)
  165. while len(targets) < self._hps.dec_timesteps:
  166. targets.append(end_id)
  167. element = ModelInput(enc_inputs, dec_inputs, targets, enc_input_len,
  168. dec_output_len, ' '.join(article_sentences),
  169. ' '.join(abstract_sentences))
  170. self._input_queue.put(element)
  171. def _FillBucketInputQueue(self):
  172. """Fill bucketed batches into the bucket_input_queue."""
  173. while True:
  174. inputs = []
  175. for _ in xrange(self._hps.batch_size * BUCKET_CACHE_BATCH):
  176. inputs.append(self._input_queue.get())
  177. if self._bucketing:
  178. inputs = sorted(inputs, key=lambda inp: inp.enc_len)
  179. batches = []
  180. for i in xrange(0, len(inputs), self._hps.batch_size):
  181. batches.append(inputs[i:i+self._hps.batch_size])
  182. shuffle(batches)
  183. for b in batches:
  184. self._bucket_input_queue.put(b)
  185. def _WatchThreads(self):
  186. """Watch the daemon input threads and restart if dead."""
  187. while True:
  188. time.sleep(60)
  189. input_threads = []
  190. for t in self._input_threads:
  191. if t.is_alive():
  192. input_threads.append(t)
  193. else:
  194. tf.logging.error('Found input thread dead.')
  195. new_t = Thread(target=self._FillInputQueue)
  196. input_threads.append(new_t)
  197. input_threads[-1].daemon = True
  198. input_threads[-1].start()
  199. self._input_threads = input_threads
  200. bucketing_threads = []
  201. for t in self._bucketing_threads:
  202. if t.is_alive():
  203. bucketing_threads.append(t)
  204. else:
  205. tf.logging.error('Found bucketing thread dead.')
  206. new_t = Thread(target=self._FillBucketInputQueue)
  207. bucketing_threads.append(new_t)
  208. bucketing_threads[-1].daemon = True
  209. bucketing_threads[-1].start()
  210. self._bucketing_threads = bucketing_threads
  211. def _TextGenerator(self, example_gen):
  212. """Generates article and abstract text from tf.Example."""
  213. while True:
  214. e = example_gen.next()
  215. try:
  216. article_text = self._GetExFeatureText(e, self._article_key)
  217. abstract_text = self._GetExFeatureText(e, self._abstract_key)
  218. except ValueError:
  219. tf.logging.error('Failed to get article or abstract from example')
  220. continue
  221. yield (article_text, abstract_text)
  222. def _GetExFeatureText(self, ex, key):
  223. """Extract text for a feature from td.Example.
  224. Args:
  225. ex: tf.Example.
  226. key: key of the feature to be extracted.
  227. Returns:
  228. feature: a feature text extracted.
  229. """
  230. return ex.features.feature[key].bytes_list.value[0]