vgsl_input.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """String network description language to define network layouts."""
  16. import collections
  17. import tensorflow as tf
  18. from tensorflow.python.ops import parsing_ops
  19. # Named tuple for the standard tf image tensor Shape.
  20. # batch_size: Number of images to batch-up for training.
  21. # height: Fixed height of image or None for variable.
  22. # width: Fixed width of image or None for variable.
  23. # depth: Desired depth in bytes per pixel of input images.
  24. ImageShape = collections.namedtuple('ImageTensorDims',
  25. ['batch_size', 'height', 'width', 'depth'])
  26. def ImageInput(input_pattern, num_threads, shape, using_ctc, reader=None):
  27. """Creates an input image tensor from the input_pattern filenames.
  28. TODO(rays) Expand for 2-d labels, 0-d labels, and logistic targets.
  29. Args:
  30. input_pattern: Filenames of the dataset(s) to read.
  31. num_threads: Number of preprocessing threads.
  32. shape: ImageShape with the desired shape of the input.
  33. using_ctc: Take the unpadded_class labels instead of padded.
  34. reader: Function that returns an actual reader to read Examples from
  35. input files. If None, uses tf.TFRecordReader().
  36. Returns:
  37. images: Float Tensor containing the input image scaled to [-1.28, 1.27].
  38. heights: Tensor int64 containing the heights of the images.
  39. widths: Tensor int64 containing the widths of the images.
  40. labels: Serialized SparseTensor containing the int64 labels.
  41. sparse_labels: Serialized SparseTensor containing the int64 labels.
  42. truths: Tensor string of the utf8 truth texts.
  43. Raises:
  44. ValueError: if the optimizer type is unrecognized.
  45. """
  46. data_files = tf.gfile.Glob(input_pattern)
  47. assert data_files, 'no files found for dataset ' + input_pattern
  48. queue_capacity = shape.batch_size * num_threads * 2
  49. filename_queue = tf.train.string_input_producer(
  50. data_files, capacity=queue_capacity)
  51. # Create a subgraph with its own reader (but sharing the
  52. # filename_queue) for each preprocessing thread.
  53. images_and_label_lists = []
  54. for _ in range(num_threads):
  55. image, height, width, labels, text = _ReadExamples(filename_queue, shape,
  56. using_ctc, reader)
  57. images_and_label_lists.append([image, height, width, labels, text])
  58. # Create a queue that produces the examples in batches.
  59. images, heights, widths, labels, truths = tf.train.batch_join(
  60. images_and_label_lists,
  61. batch_size=shape.batch_size,
  62. capacity=16 * shape.batch_size,
  63. dynamic_pad=True)
  64. # Deserialize back to sparse, because the batcher doesn't do sparse.
  65. labels = tf.deserialize_many_sparse(labels, tf.int64)
  66. sparse_labels = tf.cast(labels, tf.int32)
  67. labels = tf.sparse_tensor_to_dense(labels)
  68. labels = tf.reshape(labels, [shape.batch_size, -1], name='Labels')
  69. # Crush the other shapes to just the batch dimension.
  70. heights = tf.reshape(heights, [-1], name='Heights')
  71. widths = tf.reshape(widths, [-1], name='Widths')
  72. truths = tf.reshape(truths, [-1], name='Truths')
  73. # Give the images a nice name as well.
  74. images = tf.identity(images, name='Images')
  75. tf.summary.image('Images', images)
  76. return images, heights, widths, labels, sparse_labels, truths
  77. def _ReadExamples(filename_queue, shape, using_ctc, reader=None):
  78. """Builds network input tensor ops for TF Example.
  79. Args:
  80. filename_queue: Queue of filenames, from tf.train.string_input_producer
  81. shape: ImageShape with the desired shape of the input.
  82. using_ctc: Take the unpadded_class labels instead of padded.
  83. reader: Function that returns an actual reader to read Examples from
  84. input files. If None, uses tf.TFRecordReader().
  85. Returns:
  86. image: Float Tensor containing the input image scaled to [-1.28, 1.27].
  87. height: Tensor int64 containing the height of the image.
  88. width: Tensor int64 containing the width of the image.
  89. labels: Serialized SparseTensor containing the int64 labels.
  90. text: Tensor string of the utf8 truth text.
  91. """
  92. if reader:
  93. reader = reader()
  94. else:
  95. reader = tf.TFRecordReader()
  96. _, example_serialized = reader.read(filename_queue)
  97. example_serialized = tf.reshape(example_serialized, shape=[])
  98. features = tf.parse_single_example(
  99. example_serialized,
  100. {'image/encoded': parsing_ops.FixedLenFeature(
  101. [1], dtype=tf.string, default_value=''),
  102. 'image/text': parsing_ops.FixedLenFeature(
  103. [1], dtype=tf.string, default_value=''),
  104. 'image/class': parsing_ops.VarLenFeature(dtype=tf.int64),
  105. 'image/unpadded_class': parsing_ops.VarLenFeature(dtype=tf.int64),
  106. 'image/height': parsing_ops.FixedLenFeature(
  107. [1], dtype=tf.int64, default_value=1),
  108. 'image/width': parsing_ops.FixedLenFeature(
  109. [1], dtype=tf.int64, default_value=1)})
  110. if using_ctc:
  111. labels = features['image/unpadded_class']
  112. else:
  113. labels = features['image/class']
  114. labels = tf.serialize_sparse(labels)
  115. image = tf.reshape(features['image/encoded'], shape=[], name='encoded')
  116. image = _ImageProcessing(image, shape)
  117. height = tf.reshape(features['image/height'], [-1])
  118. width = tf.reshape(features['image/width'], [-1])
  119. text = tf.reshape(features['image/text'], shape=[])
  120. return image, height, width, labels, text
  121. def _ImageProcessing(image_buffer, shape):
  122. """Convert a PNG string into an input tensor.
  123. We allow for fixed and variable sizes.
  124. Does fixed conversion to floats in the range [-1.28, 1.27].
  125. Args:
  126. image_buffer: Tensor containing a PNG encoded image.
  127. shape: ImageShape with the desired shape of the input.
  128. Returns:
  129. image: Decoded, normalized image in the range [-1.28, 1.27].
  130. """
  131. image = tf.image.decode_png(image_buffer, channels=shape.depth)
  132. image.set_shape([shape.height, shape.width, shape.depth])
  133. image = tf.cast(image, tf.float32)
  134. image = tf.subtract(image, 128.0)
  135. image = tf.multiply(image, 1 / 100.0)
  136. return image