image_processing.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Read and preprocess image data.
  16. Image processing occurs on a single image at a time. Image are read and
  17. preprocessed in pararllel across mulitple threads. The resulting images
  18. are concatenated together to form a single batch for training or evaluation.
  19. -- Provide processed image data for a network:
  20. inputs: Construct batches of evaluation examples of images.
  21. distorted_inputs: Construct batches of training examples of images.
  22. batch_inputs: Construct batches of training or evaluation examples of images.
  23. -- Data processing:
  24. parse_example_proto: Parses an Example proto containing a training example
  25. of an image.
  26. -- Image decoding:
  27. decode_jpeg: Decode a JPEG encoded string into a 3-D float32 Tensor.
  28. -- Image preprocessing:
  29. image_preprocessing: Decode and preprocess one image for evaluation or training
  30. distort_image: Distort one image for training a network.
  31. eval_image: Prepare one image for evaluation.
  32. distort_color: Distort the color in one image for training.
  33. """
  34. from __future__ import absolute_import
  35. from __future__ import division
  36. from __future__ import print_function
  37. import tensorflow as tf
  38. FLAGS = tf.app.flags.FLAGS
  39. tf.app.flags.DEFINE_integer('batch_size', 32,
  40. """Number of images to process in a batch.""")
  41. tf.app.flags.DEFINE_integer('image_size', 299,
  42. """Provide square images of this size.""")
  43. tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
  44. """Number of preprocessing threads per tower. """
  45. """Please make this a multiple of 4.""")
  46. # Images are preprocessed asynchronously using multiple threads specifed by
  47. # --num_preprocss_threads and the resulting processed images are stored in a
  48. # random shuffling queue. The shuffling queue dequeues --batch_size images
  49. # for processing on a given Inception tower. A larger shuffling queue guarantees
  50. # better mixing across examples within a batch and results in slightly higher
  51. # predictive performance in a trained model. Empirically,
  52. # --input_queue_memory_factor=16 works well. A value of 16 implies a queue size
  53. # of 1024*16 images. Assuming RGB 299x299 images, this implies a queue size of
  54. # 16GB. If the machine is memory limited, then decrease this factor to
  55. # decrease the CPU memory footprint, accordingly.
  56. tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16,
  57. """Size of the queue of preprocessed images. """
  58. """Default is ideal but try smaller values, e.g. """
  59. """4, 2 or 1, if host memory is constrained. See """
  60. """comments in code for more details.""")
  61. def inputs(dataset, batch_size=None, num_preprocess_threads=None):
  62. """Generate batches of ImageNet images for evaluation.
  63. Use this function as the inputs for evaluating a network.
  64. Note that some (minimal) image preprocessing occurs during evaluation
  65. including central cropping and resizing of the image to fit the network.
  66. Args:
  67. dataset: instance of Dataset class specifying the dataset.
  68. batch_size: integer, number of examples in batch
  69. num_preprocess_threads: integer, total number of preprocessing threads but
  70. None defaults to FLAGS.num_preprocess_threads.
  71. Returns:
  72. images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
  73. image_size, 3].
  74. labels: 1-D integer Tensor of [FLAGS.batch_size].
  75. """
  76. if not batch_size:
  77. batch_size = FLAGS.batch_size
  78. # Force all input processing onto CPU in order to reserve the GPU for
  79. # the forward inference and back-propagation.
  80. with tf.device('/cpu:0'):
  81. images, labels = batch_inputs(
  82. dataset, batch_size, train=False,
  83. num_preprocess_threads=num_preprocess_threads)
  84. return images, labels
  85. def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
  86. """Generate batches of distorted versions of ImageNet images.
  87. Use this function as the inputs for training a network.
  88. Distorting images provides a useful technique for augmenting the data
  89. set during training in order to make the network invariant to aspects
  90. of the image that do not effect the label.
  91. Args:
  92. dataset: instance of Dataset class specifying the dataset.
  93. batch_size: integer, number of examples in batch
  94. num_preprocess_threads: integer, total number of preprocessing threads but
  95. None defaults to FLAGS.num_preprocess_threads.
  96. Returns:
  97. images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
  98. FLAGS.image_size, 3].
  99. labels: 1-D integer Tensor of [batch_size].
  100. """
  101. if not batch_size:
  102. batch_size = FLAGS.batch_size
  103. # Force all input processing onto CPU in order to reserve the GPU for
  104. # the forward inference and back-propagation.
  105. with tf.device('/cpu:0'):
  106. images, labels = batch_inputs(
  107. dataset, batch_size, train=True,
  108. num_preprocess_threads=num_preprocess_threads)
  109. return images, labels
  110. def decode_jpeg(image_buffer, scope=None):
  111. """Decode a JPEG string into one 3-D float image Tensor.
  112. Args:
  113. image_buffer: scalar string Tensor.
  114. scope: Optional scope for op_scope.
  115. Returns:
  116. 3-D float Tensor with values ranging from [0, 1).
  117. """
  118. with tf.op_scope([image_buffer], scope, 'decode_jpeg'):
  119. # Decode the string as an RGB JPEG.
  120. # Note that the resulting image contains an unknown height and width
  121. # that is set dynamically by decode_jpeg. In other words, the height
  122. # and width of image is unknown at compile-time.
  123. image = tf.image.decode_jpeg(image_buffer, channels=3)
  124. # After this point, all image pixels reside in [0,1)
  125. # until the very end, when they're rescaled to (-1, 1). The various
  126. # adjust_* ops all require this range for dtype float.
  127. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  128. return image
  129. def distort_color(image, thread_id=0, scope=None):
  130. """Distort the color of the image.
  131. Each color distortion is non-commutative and thus ordering of the color ops
  132. matters. Ideally we would randomly permute the ordering of the color ops.
  133. Rather then adding that level of complication, we select a distinct ordering
  134. of color ops for each preprocessing thread.
  135. Args:
  136. image: Tensor containing single image.
  137. thread_id: preprocessing thread ID.
  138. scope: Optional scope for op_scope.
  139. Returns:
  140. color-distorted image
  141. """
  142. with tf.op_scope([image], scope, 'distort_color'):
  143. color_ordering = thread_id % 2
  144. if color_ordering == 0:
  145. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  146. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  147. image = tf.image.random_hue(image, max_delta=0.2)
  148. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  149. elif color_ordering == 1:
  150. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  151. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  152. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  153. image = tf.image.random_hue(image, max_delta=0.2)
  154. # The random_* ops do not necessarily clamp.
  155. image = tf.clip_by_value(image, 0.0, 1.0)
  156. return image
  157. def distort_image(image, height, width, bbox, thread_id=0, scope=None):
  158. """Distort one image for training a network.
  159. Distorting images provides a useful technique for augmenting the data
  160. set during training in order to make the network invariant to aspects
  161. of the image that do not effect the label.
  162. Args:
  163. image: 3-D float Tensor of image
  164. height: integer
  165. width: integer
  166. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  167. where each coordinate is [0, 1) and the coordinates are arranged
  168. as [ymin, xmin, ymax, xmax].
  169. thread_id: integer indicating the preprocessing thread.
  170. scope: Optional scope for op_scope.
  171. Returns:
  172. 3-D float Tensor of distorted image used for training.
  173. """
  174. with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
  175. # Each bounding box has shape [1, num_boxes, box coords] and
  176. # the coordinates are ordered [ymin, xmin, ymax, xmax].
  177. # Display the bounding box in the first thread only.
  178. if not thread_id:
  179. image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
  180. bbox)
  181. tf.image_summary('image_with_bounding_boxes', image_with_box)
  182. # A large fraction of image datasets contain a human-annotated bounding
  183. # box delineating the region of the image containing the object of interest.
  184. # We choose to create a new bounding box for the object which is a randomly
  185. # distorted version of the human-annotated bounding box that obeys an allowed
  186. # range of aspect ratios, sizes and overlap with the human-annotated
  187. # bounding box. If no box is supplied, then we assume the bounding box is
  188. # the entire image.
  189. sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
  190. tf.shape(image),
  191. bounding_boxes=bbox,
  192. min_object_covered=0.1,
  193. aspect_ratio_range=[0.75, 1.33],
  194. area_range=[0.05, 1.0],
  195. max_attempts=100,
  196. use_image_if_no_bounding_boxes=True)
  197. bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
  198. if not thread_id:
  199. image_with_distorted_box = tf.image.draw_bounding_boxes(
  200. tf.expand_dims(image, 0), distort_bbox)
  201. tf.image_summary('images_with_distorted_bounding_box',
  202. image_with_distorted_box)
  203. # Crop the image to the specified bounding box.
  204. distorted_image = tf.slice(image, bbox_begin, bbox_size)
  205. # This resizing operation may distort the images because the aspect
  206. # ratio is not respected. We select a resize method in a round robin
  207. # fashion based on the thread number.
  208. # Note that ResizeMethod contains 4 enumerated resizing methods.
  209. resize_method = thread_id % 4
  210. distorted_image = tf.image.resize_images(distorted_image, height, width,
  211. resize_method)
  212. # Restore the shape since the dynamic slice based upon the bbox_size loses
  213. # the third dimension.
  214. distorted_image.set_shape([height, width, 3])
  215. if not thread_id:
  216. tf.image_summary('cropped_resized_image',
  217. tf.expand_dims(distorted_image, 0))
  218. # Randomly flip the image horizontally.
  219. distorted_image = tf.image.random_flip_left_right(distorted_image)
  220. # Randomly distort the colors.
  221. distorted_image = distort_color(distorted_image, thread_id)
  222. if not thread_id:
  223. tf.image_summary('final_distorted_image',
  224. tf.expand_dims(distorted_image, 0))
  225. return distorted_image
  226. def eval_image(image, height, width, scope=None):
  227. """Prepare one image for evaluation.
  228. Args:
  229. image: 3-D float Tensor
  230. height: integer
  231. width: integer
  232. scope: Optional scope for op_scope.
  233. Returns:
  234. 3-D float Tensor of prepared image.
  235. """
  236. with tf.op_scope([image, height, width], scope, 'eval_image'):
  237. # Crop the central region of the image with an area containing 87.5% of
  238. # the original image.
  239. image = tf.image.central_crop(image, central_fraction=0.875)
  240. # Resize the image to the original height and width.
  241. image = tf.expand_dims(image, 0)
  242. image = tf.image.resize_bilinear(image, [height, width],
  243. align_corners=False)
  244. image = tf.squeeze(image, [0])
  245. return image
  246. def image_preprocessing(image_buffer, bbox, train, thread_id=0):
  247. """Decode and preprocess one image for evaluation or training.
  248. Args:
  249. image_buffer: JPEG encoded string Tensor
  250. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  251. where each coordinate is [0, 1) and the coordinates are arranged as
  252. [ymin, xmin, ymax, xmax].
  253. train: boolean
  254. thread_id: integer indicating preprocessing thread
  255. Returns:
  256. 3-D float Tensor containing an appropriately scaled image
  257. Raises:
  258. ValueError: if user does not provide bounding box
  259. """
  260. if bbox is None:
  261. raise ValueError('Please supply a bounding box.')
  262. image = decode_jpeg(image_buffer)
  263. height = FLAGS.image_size
  264. width = FLAGS.image_size
  265. if train:
  266. image = distort_image(image, height, width, bbox, thread_id)
  267. else:
  268. image = eval_image(image, height, width)
  269. # Finally, rescale to [-1,1] instead of [0, 1)
  270. image = tf.sub(image, 0.5)
  271. image = tf.mul(image, 2.0)
  272. return image
  273. def parse_example_proto(example_serialized):
  274. """Parses an Example proto containing a training example of an image.
  275. The output of the build_image_data.py image preprocessing script is a dataset
  276. containing serialized Example protocol buffers. Each Example proto contains
  277. the following fields:
  278. image/height: 462
  279. image/width: 581
  280. image/colorspace: 'RGB'
  281. image/channels: 3
  282. image/class/label: 615
  283. image/class/synset: 'n03623198'
  284. image/class/text: 'knee pad'
  285. image/object/bbox/xmin: 0.1
  286. image/object/bbox/xmax: 0.9
  287. image/object/bbox/ymin: 0.2
  288. image/object/bbox/ymax: 0.6
  289. image/object/bbox/label: 615
  290. image/format: 'JPEG'
  291. image/filename: 'ILSVRC2012_val_00041207.JPEG'
  292. image/encoded: <JPEG encoded string>
  293. Args:
  294. example_serialized: scalar Tensor tf.string containing a serialized
  295. Example protocol buffer.
  296. Returns:
  297. image_buffer: Tensor tf.string containing the contents of a JPEG file.
  298. label: Tensor tf.int32 containing the label.
  299. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  300. where each coordinate is [0, 1) and the coordinates are arranged as
  301. [ymin, xmin, ymax, xmax].
  302. text: Tensor tf.string containing the human-readable label.
  303. """
  304. # Dense features in Example proto.
  305. feature_map = {
  306. 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
  307. default_value=''),
  308. 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
  309. default_value=-1),
  310. 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
  311. default_value=''),
  312. }
  313. sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
  314. # Sparse features in Example proto.
  315. feature_map.update(
  316. {k: sparse_float32 for k in ['image/object/bbox/xmin',
  317. 'image/object/bbox/ymin',
  318. 'image/object/bbox/xmax',
  319. 'image/object/bbox/ymax']})
  320. features = tf.parse_single_example(example_serialized, feature_map)
  321. label = tf.cast(features['image/class/label'], dtype=tf.int32)
  322. xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
  323. ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
  324. xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
  325. ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
  326. # Note that we impose an ordering of (y, x) just to make life difficult.
  327. bbox = tf.concat(0, [ymin, xmin, ymax, xmax])
  328. # Force the variable number of bounding boxes into the shape
  329. # [1, num_boxes, coords].
  330. bbox = tf.expand_dims(bbox, 0)
  331. bbox = tf.transpose(bbox, [0, 2, 1])
  332. return features['image/encoded'], label, bbox, features['image/class/text']
  333. def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
  334. """Contruct batches of training or evaluation examples from the image dataset.
  335. Args:
  336. dataset: instance of Dataset class specifying the dataset.
  337. See dataset.py for details.
  338. batch_size: integer
  339. train: boolean
  340. num_preprocess_threads: integer, total number of preprocessing threads
  341. Returns:
  342. images: 4-D float Tensor of a batch of images
  343. labels: 1-D integer Tensor of [batch_size].
  344. Raises:
  345. ValueError: if data is not found
  346. """
  347. with tf.name_scope('batch_processing'):
  348. data_files = dataset.data_files()
  349. if data_files is None:
  350. raise ValueError('No data files found for this dataset')
  351. filename_queue = tf.train.string_input_producer(data_files, capacity=16)
  352. if num_preprocess_threads is None:
  353. num_preprocess_threads = FLAGS.num_preprocess_threads
  354. if num_preprocess_threads % 4:
  355. raise ValueError('Please make num_preprocess_threads a multiple '
  356. 'of 4 (%d % 4 != 0).', num_preprocess_threads)
  357. # Create a subgraph with its own reader (but sharing the
  358. # filename_queue) for each preprocessing thread.
  359. images_and_labels = []
  360. for thread_id in range(num_preprocess_threads):
  361. reader = dataset.reader()
  362. _, example_serialized = reader.read(filename_queue)
  363. # Parse a serialized Example proto to extract the image and metadata.
  364. image_buffer, label_index, bbox, _ = parse_example_proto(
  365. example_serialized)
  366. image = image_preprocessing(image_buffer, bbox, train, thread_id)
  367. images_and_labels.append([image, label_index])
  368. # Approximate number of examples per shard.
  369. examples_per_shard = 1024
  370. # Size the random shuffle queue to balance between good global
  371. # mixing (more examples) and memory use (fewer examples).
  372. # 1 image uses 299*299*3*4 bytes = 1MB
  373. # The default input_queue_memory_factor is 16 implying a shuffling queue
  374. # size: examples_per_shard * 16 * 1MB = 17.6GB
  375. min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
  376. # Create a queue that produces the examples in batches after shuffling.
  377. if train:
  378. images, label_index_batch = tf.train.shuffle_batch_join(
  379. images_and_labels,
  380. batch_size=batch_size,
  381. capacity=min_queue_examples + 3 * batch_size,
  382. min_after_dequeue=min_queue_examples)
  383. else:
  384. images, label_index_batch = tf.train.batch_join(
  385. images_and_labels,
  386. batch_size=batch_size,
  387. capacity=min_queue_examples + 3 * batch_size)
  388. # Reshape images into these desired dimensions.
  389. height = FLAGS.image_size
  390. width = FLAGS.image_size
  391. depth = 3
  392. images = tf.cast(images, tf.float32)
  393. images = tf.reshape(images, shape=[batch_size, height, width, depth])
  394. # Display the training images in the visualizer.
  395. tf.image_summary('images', images)
  396. return images, tf.reshape(label_index_batch, [batch_size])