resnet_preprocessing.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides utilities to preprocess images for the ResNet networks."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.contrib.slim import nets
  21. from tensorflow.python.ops import control_flow_ops
  22. slim = tf.contrib.slim
  23. _R_MEAN = 123.68
  24. _G_MEAN = 116.78
  25. _B_MEAN = 103.94
  26. _CROP_HEIGHT = nets.resnet_v1.resnet_v1.default_image_size
  27. _CROP_WIDTH = nets.resnet_v1.resnet_v1.default_image_size
  28. _RESIZE_SIDE = 256
  29. def _mean_image_subtraction(image, means):
  30. """Subtracts the given means from each image channel.
  31. For example:
  32. means = [123.68, 116.779, 103.939]
  33. image = _mean_image_subtraction(image, means)
  34. Note that the rank of `image` must be known.
  35. Args:
  36. image: a tensor of size [height, width, C].
  37. means: a C-vector of values to subtract from each channel.
  38. Returns:
  39. the centered image.
  40. Raises:
  41. ValueError: If the rank of `image` is unknown, if `image` has a rank other
  42. than three or if the number of channels in `image` doesn't match the
  43. number of values in `means`.
  44. """
  45. if image.get_shape().ndims != 3:
  46. raise ValueError('Input must be of size [height, width, C>0]')
  47. num_channels = image.get_shape().as_list()[-1]
  48. if len(means) != num_channels:
  49. raise ValueError('len(means) must match the number of channels')
  50. channels = tf.split(2, num_channels, image)
  51. for i in range(num_channels):
  52. channels[i] -= means[i]
  53. return tf.concat(2, channels)
  54. def _smallest_size_at_least(height, width, smallest_side):
  55. """Computes new shape with the smallest side equal to `smallest_side`.
  56. Computes new shape with the smallest side equal to `smallest_side` while
  57. preserving the original aspect ratio.
  58. Args:
  59. height: an int32 scalar tensor indicating the current height.
  60. width: an int32 scalar tensor indicating the current width.
  61. smallest_side: an python integer indicating the smallest side of the new
  62. shape.
  63. Returns:
  64. new_height: an int32 scalar tensor indicating the new height.
  65. new_width: and int32 scalar tensor indicating the new width.
  66. """
  67. height = tf.to_float(height)
  68. width = tf.to_float(width)
  69. smallest_side = float(smallest_side)
  70. scale = tf.cond(tf.greater(height, width),
  71. lambda: smallest_side / width,
  72. lambda: smallest_side / height)
  73. new_height = tf.to_int32(height * scale)
  74. new_width = tf.to_int32(width * scale)
  75. return new_height, new_width
  76. def _aspect_preserving_resize(image, smallest_side):
  77. """Resize images preserving the original aspect ratio.
  78. Args:
  79. image: a 3-D image tensor.
  80. smallest_side: a python integer indicating the size of the smallest side
  81. after resize.
  82. Returns:
  83. resized_image: a 3-D tensor containing the resized image.
  84. """
  85. shape = tf.shape(image)
  86. height = shape[0]
  87. width = shape[1]
  88. new_height, new_width = _smallest_size_at_least(height, width, smallest_side)
  89. image = tf.expand_dims(image, 0)
  90. resized_image = tf.image.resize_bilinear(image, [new_height, new_width],
  91. align_corners=False)
  92. resized_image = tf.squeeze(resized_image)
  93. resized_image.set_shape([None, None, 3])
  94. return resized_image
  95. def _crop(image, offset_height, offset_width, crop_height, crop_width):
  96. """Crops the given image using the provided offsets and sizes.
  97. Note that the method doesn't assume we know the input image size but it does
  98. assume we know the input image rank.
  99. Args:
  100. image: an image of shape [height, width, channels].
  101. offset_height: a scalar tensor indicating the height offset.
  102. offset_width: a scalar tensor indicating the width offset.
  103. crop_height: the height of the cropped image.
  104. crop_width: the width of the cropped image.
  105. Returns:
  106. the cropped (and resized) image.
  107. Raises:
  108. InvalidArgumentError: if the rank is not 3 or if the image dimensions are
  109. less than the crop size.
  110. """
  111. original_shape = tf.shape(image)
  112. rank_assertion = tf.Assert(
  113. tf.equal(tf.rank(image), 3),
  114. ['Rank of image must be equal to 3.'])
  115. cropped_shape = control_flow_ops.with_dependencies(
  116. [rank_assertion],
  117. tf.pack([crop_height, crop_width, original_shape[2]]))
  118. size_assertion = tf.Assert(
  119. tf.logical_and(
  120. tf.greater_equal(original_shape[0], crop_height),
  121. tf.greater_equal(original_shape[1], crop_width)),
  122. ['Crop size greater than the image size.'])
  123. offsets = tf.to_int32(tf.pack([offset_height, offset_width, 0]))
  124. # Use tf.slice instead of crop_to_bounding box as it accepts tensors to
  125. # define the crop size.
  126. image = control_flow_ops.with_dependencies(
  127. [size_assertion],
  128. tf.slice(image, offsets, cropped_shape))
  129. return tf.reshape(image, cropped_shape)
  130. def _central_crop(image_list, crop_height, crop_width):
  131. """Performs central crops of the given image list.
  132. Args:
  133. image_list: a list of image tensors of the same dimension but possibly
  134. varying channel.
  135. crop_height: the height of the image following the crop.
  136. crop_width: the width of the image following the crop.
  137. Returns:
  138. the list of cropped images.
  139. """
  140. outputs = []
  141. for image in image_list:
  142. image_height = tf.shape(image)[0]
  143. image_width = tf.shape(image)[1]
  144. offset_height = (image_height - crop_height) / 2
  145. offset_width = (image_width - crop_width) / 2
  146. outputs.append(_crop(image, offset_height, offset_width,
  147. crop_height, crop_width))
  148. return outputs
  149. def preprocess_image(image,
  150. height=_CROP_HEIGHT,
  151. width=_CROP_WIDTH,
  152. is_training=False, # pylint: disable=unused-argument
  153. resize_side=_RESIZE_SIDE):
  154. image = _aspect_preserving_resize(image, resize_side)
  155. image = _central_crop([image], height, width)[0]
  156. image.set_shape([height, width, 3])
  157. image = tf.to_float(image)
  158. image = _mean_image_subtraction(image, [_R_MEAN, _G_MEAN, _B_MEAN])
  159. return image