image_processing.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Helper functions for image preprocessing."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. def distort_image(image, thread_id):
  21. """Perform random distortions on an image.
  22. Args:
  23. image: A float32 Tensor of shape [height, width, 3] with values in [0, 1).
  24. thread_id: Preprocessing thread id used to select the ordering of color
  25. distortions. There should be a multiple of 2 preprocessing threads.
  26. Returns:
  27. distorted_image: A float32 Tensor of shape [height, width, 3] with values in
  28. [0, 1].
  29. """
  30. # Randomly flip horizontally.
  31. with tf.name_scope("flip_horizontal", values=[image]):
  32. image = tf.image.random_flip_left_right(image)
  33. # Randomly distort the colors based on thread id.
  34. color_ordering = thread_id % 2
  35. with tf.name_scope("distort_color", values=[image]):
  36. if color_ordering == 0:
  37. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  38. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  39. image = tf.image.random_hue(image, max_delta=0.032)
  40. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  41. elif color_ordering == 1:
  42. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  43. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  44. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  45. image = tf.image.random_hue(image, max_delta=0.032)
  46. # The random_* ops do not necessarily clamp.
  47. image = tf.clip_by_value(image, 0.0, 1.0)
  48. return image
  49. def process_image(encoded_image,
  50. is_training,
  51. height,
  52. width,
  53. resize_height=346,
  54. resize_width=346,
  55. thread_id=0,
  56. image_format="jpeg"):
  57. """Decode an image, resize and apply random distortions.
  58. In training, images are distorted slightly differently depending on thread_id.
  59. Args:
  60. encoded_image: String Tensor containing the image.
  61. is_training: Boolean; whether preprocessing for training or eval.
  62. height: Height of the output image.
  63. width: Width of the output image.
  64. resize_height: If > 0, resize height before crop to final dimensions.
  65. resize_width: If > 0, resize width before crop to final dimensions.
  66. thread_id: Preprocessing thread id used to select the ordering of color
  67. distortions. There should be a multiple of 2 preprocessing threads.
  68. image_format: "jpeg" or "png".
  69. Returns:
  70. A float32 Tensor of shape [height, width, 3] with values in [-1, 1].
  71. Raises:
  72. ValueError: If image_format is invalid.
  73. """
  74. # Helper function to log an image summary to the visualizer. Summaries are
  75. # only logged in thread 0.
  76. def image_summary(name, image):
  77. if not thread_id:
  78. tf.image_summary(name, tf.expand_dims(image, 0))
  79. # Decode image into a float32 Tensor of shape [?, ?, 3] with values in [0, 1).
  80. with tf.name_scope("decode", values=[encoded_image]):
  81. if image_format == "jpeg":
  82. image = tf.image.decode_jpeg(encoded_image, channels=3)
  83. elif image_format == "png":
  84. image = tf.image.decode_png(encoded_image, channels=3)
  85. else:
  86. raise ValueError("Invalid image format: %s" % image_format)
  87. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  88. image_summary("original_image", image)
  89. # Resize image.
  90. assert (resize_height > 0) == (resize_width > 0)
  91. if resize_height:
  92. image = tf.image.resize_images(image,
  93. new_height=resize_height,
  94. new_width=resize_width,
  95. method=tf.image.ResizeMethod.BILINEAR)
  96. # Crop to final dimensions.
  97. if is_training:
  98. image = tf.random_crop(image, [height, width, 3])
  99. else:
  100. # Central crop, assuming resize_height > height, resize_width > width.
  101. image = tf.image.resize_image_with_crop_or_pad(image, height, width)
  102. image_summary("resized_image", image)
  103. # Randomly distort the image.
  104. if is_training:
  105. image = distort_image(image, thread_id)
  106. image_summary("final_image", image)
  107. # Rescale to [-1,1] instead of [0, 1]
  108. image = tf.sub(image, 0.5)
  109. image = tf.mul(image, 2.0)
  110. return image