inception_preprocessing.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Provides utilities to preprocess images for the Inception networks."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. from tensorflow.python.ops import control_flow_ops
  21. def apply_with_random_selector(x, func, num_cases):
  22. """Computes func(x, sel), with sel sampled from [0...num_cases-1].
  23. Args:
  24. x: input Tensor.
  25. func: Python function to apply.
  26. num_cases: Python int32, number of cases to sample sel from.
  27. Returns:
  28. The result of func(x, sel), where func receives the value of the
  29. selector as a python integer, but sel is sampled dynamically.
  30. """
  31. sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
  32. # Pass the real x only to one of the func calls.
  33. return control_flow_ops.merge([
  34. func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
  35. for case in range(num_cases)])[0]
  36. def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
  37. """Distort the color of a Tensor image.
  38. Each color distortion is non-commutative and thus ordering of the color ops
  39. matters. Ideally we would randomly permute the ordering of the color ops.
  40. Rather then adding that level of complication, we select a distinct ordering
  41. of color ops for each preprocessing thread.
  42. Args:
  43. image: 3-D Tensor containing single image in [0, 1].
  44. color_ordering: Python int, a type of distortion (valid values: 0-3).
  45. fast_mode: Avoids slower ops (random_hue and random_contrast)
  46. scope: Optional scope for name_scope.
  47. Returns:
  48. 3-D Tensor color-distorted image on range [0, 1]
  49. Raises:
  50. ValueError: if color_ordering not in [0, 3]
  51. """
  52. with tf.name_scope(scope, 'distort_color', [image]):
  53. if fast_mode:
  54. if color_ordering == 0:
  55. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  56. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  57. else:
  58. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  59. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  60. else:
  61. if color_ordering == 0:
  62. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  63. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  64. image = tf.image.random_hue(image, max_delta=0.2)
  65. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  66. elif color_ordering == 1:
  67. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  68. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  69. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  70. image = tf.image.random_hue(image, max_delta=0.2)
  71. elif color_ordering == 2:
  72. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  73. image = tf.image.random_hue(image, max_delta=0.2)
  74. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  75. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  76. elif color_ordering == 3:
  77. image = tf.image.random_hue(image, max_delta=0.2)
  78. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  79. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  80. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  81. else:
  82. raise ValueError('color_ordering must be in [0, 3]')
  83. # The random_* ops do not necessarily clamp.
  84. return tf.clip_by_value(image, 0.0, 1.0)
  85. def distorted_bounding_box_crop(image,
  86. bbox,
  87. min_object_covered=0.1,
  88. aspect_ratio_range=(0.75, 1.33),
  89. area_range=(0.05, 1.0),
  90. max_attempts=100,
  91. scope=None):
  92. """Generates cropped_image using a one of the bboxes randomly distorted.
  93. See `tf.image.sample_distorted_bounding_box` for more documentation.
  94. Args:
  95. image: 3-D Tensor of image (it will be converted to floats in [0, 1]).
  96. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  97. where each coordinate is [0, 1) and the coordinates are arranged
  98. as [ymin, xmin, ymax, xmax]. If num_boxes is 0 then it would use the whole
  99. image.
  100. min_object_covered: An optional `float`. Defaults to `0.1`. The cropped
  101. area of the image must contain at least this fraction of any bounding box
  102. supplied.
  103. aspect_ratio_range: An optional list of `floats`. The cropped area of the
  104. image must have an aspect ratio = width / height within this range.
  105. area_range: An optional list of `floats`. The cropped area of the image
  106. must contain a fraction of the supplied image within in this range.
  107. max_attempts: An optional `int`. Number of attempts at generating a cropped
  108. region of the image of the specified constraints. After `max_attempts`
  109. failures, return the entire image.
  110. scope: Optional scope for name_scope.
  111. Returns:
  112. A tuple, a 3-D Tensor cropped_image and the distorted bbox
  113. """
  114. with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
  115. # Each bounding box has shape [1, num_boxes, box coords] and
  116. # the coordinates are ordered [ymin, xmin, ymax, xmax].
  117. # A large fraction of image datasets contain a human-annotated bounding
  118. # box delineating the region of the image containing the object of interest.
  119. # We choose to create a new bounding box for the object which is a randomly
  120. # distorted version of the human-annotated bounding box that obeys an
  121. # allowed range of aspect ratios, sizes and overlap with the human-annotated
  122. # bounding box. If no box is supplied, then we assume the bounding box is
  123. # the entire image.
  124. sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
  125. tf.shape(image),
  126. bounding_boxes=bbox,
  127. min_object_covered=min_object_covered,
  128. aspect_ratio_range=aspect_ratio_range,
  129. area_range=area_range,
  130. max_attempts=max_attempts,
  131. use_image_if_no_bounding_boxes=True)
  132. bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
  133. # Crop the image to the specified bounding box.
  134. cropped_image = tf.slice(image, bbox_begin, bbox_size)
  135. return cropped_image, distort_bbox
  136. def preprocess_for_train(image, height, width, bbox,
  137. fast_mode=True,
  138. scope=None):
  139. """Distort one image for training a network.
  140. Distorting images provides a useful technique for augmenting the data
  141. set during training in order to make the network invariant to aspects
  142. of the image that do not effect the label.
  143. Additionally it would create image_summaries to display the different
  144. transformations applied to the image.
  145. Args:
  146. image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
  147. [0, 1], otherwise it would converted to tf.float32 assuming that the range
  148. is [0, MAX], where MAX is largest positive representable number for
  149. int(8/16/32) data type (see `tf.image.convert_image_dtype` for details).
  150. height: integer
  151. width: integer
  152. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  153. where each coordinate is [0, 1) and the coordinates are arranged
  154. as [ymin, xmin, ymax, xmax].
  155. fast_mode: Optional boolean, if True avoids slower transformations (i.e.
  156. bi-cubic resizing, random_hue or random_contrast).
  157. scope: Optional scope for name_scope.
  158. Returns:
  159. 3-D float Tensor of distorted image used for training with range [-1, 1].
  160. """
  161. with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
  162. if bbox is None:
  163. bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
  164. dtype=tf.float32,
  165. shape=[1, 1, 4])
  166. if image.dtype != tf.float32:
  167. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  168. # Each bounding box has shape [1, num_boxes, box coords] and
  169. # the coordinates are ordered [ymin, xmin, ymax, xmax].
  170. image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
  171. bbox)
  172. tf.summary.image('image_with_bounding_boxes', image_with_box)
  173. distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
  174. # Restore the shape since the dynamic slice based upon the bbox_size loses
  175. # the third dimension.
  176. distorted_image.set_shape([None, None, 3])
  177. image_with_distorted_box = tf.image.draw_bounding_boxes(
  178. tf.expand_dims(image, 0), distorted_bbox)
  179. tf.summary.image('images_with_distorted_bounding_box',
  180. image_with_distorted_box)
  181. # This resizing operation may distort the images because the aspect
  182. # ratio is not respected. We select a resize method in a round robin
  183. # fashion based on the thread number.
  184. # Note that ResizeMethod contains 4 enumerated resizing methods.
  185. # We select only 1 case for fast_mode bilinear.
  186. num_resize_cases = 1 if fast_mode else 4
  187. distorted_image = apply_with_random_selector(
  188. distorted_image,
  189. lambda x, method: tf.image.resize_images(x, [height, width], method=method),
  190. num_cases=num_resize_cases)
  191. tf.summary.image('cropped_resized_image',
  192. tf.expand_dims(distorted_image, 0))
  193. # Randomly flip the image horizontally.
  194. distorted_image = tf.image.random_flip_left_right(distorted_image)
  195. # Randomly distort the colors. There are 4 ways to do it.
  196. distorted_image = apply_with_random_selector(
  197. distorted_image,
  198. lambda x, ordering: distort_color(x, ordering, fast_mode),
  199. num_cases=4)
  200. tf.summary.image('final_distorted_image',
  201. tf.expand_dims(distorted_image, 0))
  202. distorted_image = tf.subtract(distorted_image, 0.5)
  203. distorted_image = tf.multiply(distorted_image, 2.0)
  204. return distorted_image
  205. def preprocess_for_eval(image, height, width,
  206. central_fraction=0.875, scope=None):
  207. """Prepare one image for evaluation.
  208. If height and width are specified it would output an image with that size by
  209. applying resize_bilinear.
  210. If central_fraction is specified it would cropt the central fraction of the
  211. input image.
  212. Args:
  213. image: 3-D Tensor of image. If dtype is tf.float32 then the range should be
  214. [0, 1], otherwise it would converted to tf.float32 assuming that the range
  215. is [0, MAX], where MAX is largest positive representable number for
  216. int(8/16/32) data type (see `tf.image.convert_image_dtype` for details)
  217. height: integer
  218. width: integer
  219. central_fraction: Optional Float, fraction of the image to crop.
  220. scope: Optional scope for name_scope.
  221. Returns:
  222. 3-D float Tensor of prepared image.
  223. """
  224. with tf.name_scope(scope, 'eval_image', [image, height, width]):
  225. if image.dtype != tf.float32:
  226. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  227. # Crop the central region of the image with an area containing 87.5% of
  228. # the original image.
  229. if central_fraction:
  230. image = tf.image.central_crop(image, central_fraction=central_fraction)
  231. if height and width:
  232. # Resize the image to the specified height and width.
  233. image = tf.expand_dims(image, 0)
  234. image = tf.image.resize_bilinear(image, [height, width],
  235. align_corners=False)
  236. image = tf.squeeze(image, [0])
  237. image = tf.subtract(image, 0.5)
  238. image = tf.multiply(image, 2.0)
  239. return image
  240. def preprocess_image(image, height, width,
  241. is_training=False,
  242. bbox=None,
  243. fast_mode=True):
  244. """Pre-process one image for training or evaluation.
  245. Args:
  246. image: 3-D Tensor [height, width, channels] with the image.
  247. height: integer, image expected height.
  248. width: integer, image expected width.
  249. is_training: Boolean. If true it would transform an image for train,
  250. otherwise it would transform it for evaluation.
  251. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  252. where each coordinate is [0, 1) and the coordinates are arranged as
  253. [ymin, xmin, ymax, xmax].
  254. fast_mode: Optional boolean, if True avoids slower transformations.
  255. Returns:
  256. 3-D float Tensor containing an appropriately scaled image
  257. Raises:
  258. ValueError: if user does not provide bounding box
  259. """
  260. if is_training:
  261. return preprocess_for_train(image, height, width, bbox, fast_mode)
  262. else:
  263. return preprocess_for_eval(image, height, width)