inputs.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Input ops."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. def parse_sequence_example(serialized, image_feature, caption_feature):
  21. """Parses a tensorflow.SequenceExample into an image and caption.
  22. Args:
  23. serialized: A scalar string Tensor; a single serialized SequenceExample.
  24. image_feature: Name of SequenceExample context feature containing image
  25. data.
  26. caption_feature: Name of SequenceExample feature list containing integer
  27. captions.
  28. Returns:
  29. encoded_image: A scalar string Tensor containing a JPEG encoded image.
  30. caption: A 1-D uint64 Tensor with dynamically specified length.
  31. """
  32. context, sequence = tf.parse_single_sequence_example(
  33. serialized,
  34. context_features={
  35. image_feature: tf.FixedLenFeature([], dtype=tf.string)
  36. },
  37. sequence_features={
  38. caption_feature: tf.FixedLenSequenceFeature([], dtype=tf.int64),
  39. })
  40. encoded_image = context[image_feature]
  41. caption = sequence[caption_feature]
  42. return encoded_image, caption
  43. def prefetch_input_data(reader,
  44. file_pattern,
  45. is_training,
  46. batch_size,
  47. values_per_shard,
  48. input_queue_capacity_factor=16,
  49. num_reader_threads=1,
  50. shard_queue_name="filename_queue",
  51. value_queue_name="input_queue"):
  52. """Prefetches string values from disk into an input queue.
  53. In training the capacity of the queue is important because a larger queue
  54. means better mixing of training examples between shards. The minimum number of
  55. values kept in the queue is values_per_shard * input_queue_capacity_factor,
  56. where input_queue_memory factor should be chosen to trade-off better mixing
  57. with memory usage.
  58. Args:
  59. reader: Instance of tf.ReaderBase.
  60. file_pattern: Comma-separated list of file patterns (e.g.
  61. /tmp/train_data-?????-of-00100).
  62. is_training: Boolean; whether prefetching for training or eval.
  63. batch_size: Model batch size used to determine queue capacity.
  64. values_per_shard: Approximate number of values per shard.
  65. input_queue_capacity_factor: Minimum number of values to keep in the queue
  66. in multiples of values_per_shard. See comments above.
  67. num_reader_threads: Number of reader threads to fill the queue.
  68. shard_queue_name: Name for the shards filename queue.
  69. value_queue_name: Name for the values input queue.
  70. Returns:
  71. A Queue containing prefetched string values.
  72. """
  73. data_files = []
  74. for pattern in file_pattern.split(","):
  75. data_files.extend(tf.gfile.Glob(pattern))
  76. if not data_files:
  77. tf.logging.fatal("Found no input files matching %s", file_pattern)
  78. else:
  79. tf.logging.info("Prefetching values from %d files matching %s",
  80. len(data_files), file_pattern)
  81. if is_training:
  82. filename_queue = tf.train.string_input_producer(
  83. data_files, shuffle=True, capacity=16, name=shard_queue_name)
  84. min_queue_examples = values_per_shard * input_queue_capacity_factor
  85. capacity = min_queue_examples + 100 * batch_size
  86. values_queue = tf.RandomShuffleQueue(
  87. capacity=capacity,
  88. min_after_dequeue=min_queue_examples,
  89. dtypes=[tf.string],
  90. name="random_" + value_queue_name)
  91. else:
  92. filename_queue = tf.train.string_input_producer(
  93. data_files, shuffle=False, capacity=1, name=shard_queue_name)
  94. capacity = values_per_shard + 3 * batch_size
  95. values_queue = tf.FIFOQueue(
  96. capacity=capacity, dtypes=[tf.string], name="fifo_" + value_queue_name)
  97. enqueue_ops = []
  98. for _ in range(num_reader_threads):
  99. _, value = reader.read(filename_queue)
  100. enqueue_ops.append(values_queue.enqueue([value]))
  101. tf.train.queue_runner.add_queue_runner(tf.train.queue_runner.QueueRunner(
  102. values_queue, enqueue_ops))
  103. tf.scalar_summary(
  104. "queue/%s/fraction_of_%d_full" % (values_queue.name, capacity),
  105. tf.cast(values_queue.size(), tf.float32) * (1. / capacity))
  106. return values_queue
  107. def batch_with_dynamic_pad(images_and_captions,
  108. batch_size,
  109. queue_capacity,
  110. add_summaries=True):
  111. """Batches input images and captions.
  112. This function splits the caption into an input sequence and a target sequence,
  113. where the target sequence is the input sequence right-shifted by 1. Input and
  114. target sequences are batched and padded up to the maximum length of sequences
  115. in the batch. A mask is created to distinguish real words from padding words.
  116. Example:
  117. Actual captions in the batch ('-' denotes padded character):
  118. [
  119. [ 1 2 5 4 5 ],
  120. [ 1 2 3 4 - ],
  121. [ 1 2 3 - - ],
  122. ]
  123. input_seqs:
  124. [
  125. [ 1 2 3 4 ],
  126. [ 1 2 3 - ],
  127. [ 1 2 - - ],
  128. ]
  129. target_seqs:
  130. [
  131. [ 2 3 4 5 ],
  132. [ 2 3 4 - ],
  133. [ 2 3 - - ],
  134. ]
  135. mask:
  136. [
  137. [ 1 1 1 1 ],
  138. [ 1 1 1 0 ],
  139. [ 1 1 0 0 ],
  140. ]
  141. Args:
  142. images_and_captions: A list of pairs [image, caption], where image is a
  143. Tensor of shape [height, width, channels] and caption is a 1-D Tensor of
  144. any length. Each pair will be processed and added to the queue in a
  145. separate thread.
  146. batch_size: Batch size.
  147. queue_capacity: Queue capacity.
  148. add_summaries: If true, add caption length summaries.
  149. Returns:
  150. images: A Tensor of shape [batch_size, height, width, channels].
  151. input_seqs: An int32 Tensor of shape [batch_size, padded_length].
  152. target_seqs: An int32 Tensor of shape [batch_size, padded_length].
  153. mask: An int32 0/1 Tensor of shape [batch_size, padded_length].
  154. """
  155. enqueue_list = []
  156. for image, caption in images_and_captions:
  157. caption_length = tf.shape(caption)[0]
  158. input_length = tf.expand_dims(tf.sub(caption_length, 1), 0)
  159. input_seq = tf.slice(caption, [0], input_length)
  160. target_seq = tf.slice(caption, [1], input_length)
  161. indicator = tf.ones(input_length, dtype=tf.int32)
  162. enqueue_list.append([image, input_seq, target_seq, indicator])
  163. images, input_seqs, target_seqs, mask = tf.train.batch_join(
  164. enqueue_list,
  165. batch_size=batch_size,
  166. capacity=queue_capacity,
  167. dynamic_pad=True,
  168. name="batch_and_pad")
  169. if add_summaries:
  170. lengths = tf.add(tf.reduce_sum(mask, 1), 1)
  171. tf.scalar_summary("caption_length/batch_min", tf.reduce_min(lengths))
  172. tf.scalar_summary("caption_length/batch_max", tf.reduce_max(lengths))
  173. tf.scalar_summary("caption_length/batch_mean", tf.reduce_mean(lengths))
  174. return images, input_seqs, target_seqs, mask