image_processing.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Read and preprocess image data.
  16. Image processing occurs on a single image at a time. Image are read and
  17. preprocessed in parallel across multiple threads. The resulting images
  18. are concatenated together to form a single batch for training or evaluation.
  19. -- Provide processed image data for a network:
  20. inputs: Construct batches of evaluation examples of images.
  21. distorted_inputs: Construct batches of training examples of images.
  22. batch_inputs: Construct batches of training or evaluation examples of images.
  23. -- Data processing:
  24. parse_example_proto: Parses an Example proto containing a training example
  25. of an image.
  26. -- Image decoding:
  27. decode_jpeg: Decode a JPEG encoded string into a 3-D float32 Tensor.
  28. -- Image preprocessing:
  29. image_preprocessing: Decode and preprocess one image for evaluation or training
  30. distort_image: Distort one image for training a network.
  31. eval_image: Prepare one image for evaluation.
  32. distort_color: Distort the color in one image for training.
  33. """
  34. from __future__ import absolute_import
  35. from __future__ import division
  36. from __future__ import print_function
  37. import tensorflow as tf
  38. FLAGS = tf.app.flags.FLAGS
  39. tf.app.flags.DEFINE_integer('batch_size', 32,
  40. """Number of images to process in a batch.""")
  41. tf.app.flags.DEFINE_integer('image_size', 299,
  42. """Provide square images of this size.""")
  43. tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
  44. """Number of preprocessing threads per tower. """
  45. """Please make this a multiple of 4.""")
  46. tf.app.flags.DEFINE_integer('num_readers', 4,
  47. """Number of parallel readers during train.""")
  48. # Images are preprocessed asynchronously using multiple threads specified by
  49. # --num_preprocss_threads and the resulting processed images are stored in a
  50. # random shuffling queue. The shuffling queue dequeues --batch_size images
  51. # for processing on a given Inception tower. A larger shuffling queue guarantees
  52. # better mixing across examples within a batch and results in slightly higher
  53. # predictive performance in a trained model. Empirically,
  54. # --input_queue_memory_factor=16 works well. A value of 16 implies a queue size
  55. # of 1024*16 images. Assuming RGB 299x299 images, this implies a queue size of
  56. # 16GB. If the machine is memory limited, then decrease this factor to
  57. # decrease the CPU memory footprint, accordingly.
  58. tf.app.flags.DEFINE_integer('input_queue_memory_factor', 16,
  59. """Size of the queue of preprocessed images. """
  60. """Default is ideal but try smaller values, e.g. """
  61. """4, 2 or 1, if host memory is constrained. See """
  62. """comments in code for more details.""")
  63. def inputs(dataset, batch_size=None, num_preprocess_threads=None):
  64. """Generate batches of ImageNet images for evaluation.
  65. Use this function as the inputs for evaluating a network.
  66. Note that some (minimal) image preprocessing occurs during evaluation
  67. including central cropping and resizing of the image to fit the network.
  68. Args:
  69. dataset: instance of Dataset class specifying the dataset.
  70. batch_size: integer, number of examples in batch
  71. num_preprocess_threads: integer, total number of preprocessing threads but
  72. None defaults to FLAGS.num_preprocess_threads.
  73. Returns:
  74. images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
  75. image_size, 3].
  76. labels: 1-D integer Tensor of [FLAGS.batch_size].
  77. """
  78. if not batch_size:
  79. batch_size = FLAGS.batch_size
  80. # Force all input processing onto CPU in order to reserve the GPU for
  81. # the forward inference and back-propagation.
  82. with tf.device('/cpu:0'):
  83. images, labels = batch_inputs(
  84. dataset, batch_size, train=False,
  85. num_preprocess_threads=num_preprocess_threads,
  86. num_readers=1)
  87. return images, labels
  88. def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
  89. """Generate batches of distorted versions of ImageNet images.
  90. Use this function as the inputs for training a network.
  91. Distorting images provides a useful technique for augmenting the data
  92. set during training in order to make the network invariant to aspects
  93. of the image that do not effect the label.
  94. Args:
  95. dataset: instance of Dataset class specifying the dataset.
  96. batch_size: integer, number of examples in batch
  97. num_preprocess_threads: integer, total number of preprocessing threads but
  98. None defaults to FLAGS.num_preprocess_threads.
  99. Returns:
  100. images: Images. 4D tensor of size [batch_size, FLAGS.image_size,
  101. FLAGS.image_size, 3].
  102. labels: 1-D integer Tensor of [batch_size].
  103. """
  104. if not batch_size:
  105. batch_size = FLAGS.batch_size
  106. # Force all input processing onto CPU in order to reserve the GPU for
  107. # the forward inference and back-propagation.
  108. with tf.device('/cpu:0'):
  109. images, labels = batch_inputs(
  110. dataset, batch_size, train=True,
  111. num_preprocess_threads=num_preprocess_threads,
  112. num_readers=FLAGS.num_readers)
  113. return images, labels
  114. def decode_jpeg(image_buffer, scope=None):
  115. """Decode a JPEG string into one 3-D float image Tensor.
  116. Args:
  117. image_buffer: scalar string Tensor.
  118. scope: Optional scope for op_scope.
  119. Returns:
  120. 3-D float Tensor with values ranging from [0, 1).
  121. """
  122. with tf.op_scope([image_buffer], scope, 'decode_jpeg'):
  123. # Decode the string as an RGB JPEG.
  124. # Note that the resulting image contains an unknown height and width
  125. # that is set dynamically by decode_jpeg. In other words, the height
  126. # and width of image is unknown at compile-time.
  127. image = tf.image.decode_jpeg(image_buffer, channels=3)
  128. # After this point, all image pixels reside in [0,1)
  129. # until the very end, when they're rescaled to (-1, 1). The various
  130. # adjust_* ops all require this range for dtype float.
  131. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  132. return image
  133. def distort_color(image, thread_id=0, scope=None):
  134. """Distort the color of the image.
  135. Each color distortion is non-commutative and thus ordering of the color ops
  136. matters. Ideally we would randomly permute the ordering of the color ops.
  137. Rather then adding that level of complication, we select a distinct ordering
  138. of color ops for each preprocessing thread.
  139. Args:
  140. image: Tensor containing single image.
  141. thread_id: preprocessing thread ID.
  142. scope: Optional scope for op_scope.
  143. Returns:
  144. color-distorted image
  145. """
  146. with tf.op_scope([image], scope, 'distort_color'):
  147. color_ordering = thread_id % 2
  148. if color_ordering == 0:
  149. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  150. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  151. image = tf.image.random_hue(image, max_delta=0.2)
  152. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  153. elif color_ordering == 1:
  154. image = tf.image.random_brightness(image, max_delta=32. / 255.)
  155. image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
  156. image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
  157. image = tf.image.random_hue(image, max_delta=0.2)
  158. # The random_* ops do not necessarily clamp.
  159. image = tf.clip_by_value(image, 0.0, 1.0)
  160. return image
  161. def distort_image(image, height, width, bbox, thread_id=0, scope=None):
  162. """Distort one image for training a network.
  163. Distorting images provides a useful technique for augmenting the data
  164. set during training in order to make the network invariant to aspects
  165. of the image that do not effect the label.
  166. Args:
  167. image: 3-D float Tensor of image
  168. height: integer
  169. width: integer
  170. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  171. where each coordinate is [0, 1) and the coordinates are arranged
  172. as [ymin, xmin, ymax, xmax].
  173. thread_id: integer indicating the preprocessing thread.
  174. scope: Optional scope for op_scope.
  175. Returns:
  176. 3-D float Tensor of distorted image used for training.
  177. """
  178. with tf.op_scope([image, height, width, bbox], scope, 'distort_image'):
  179. # Each bounding box has shape [1, num_boxes, box coords] and
  180. # the coordinates are ordered [ymin, xmin, ymax, xmax].
  181. # Display the bounding box in the first thread only.
  182. if not thread_id:
  183. image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
  184. bbox)
  185. tf.image_summary('image_with_bounding_boxes', image_with_box)
  186. # A large fraction of image datasets contain a human-annotated bounding
  187. # box delineating the region of the image containing the object of interest.
  188. # We choose to create a new bounding box for the object which is a randomly
  189. # distorted version of the human-annotated bounding box that obeys an allowed
  190. # range of aspect ratios, sizes and overlap with the human-annotated
  191. # bounding box. If no box is supplied, then we assume the bounding box is
  192. # the entire image.
  193. sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
  194. tf.shape(image),
  195. bounding_boxes=bbox,
  196. min_object_covered=0.1,
  197. aspect_ratio_range=[0.75, 1.33],
  198. area_range=[0.05, 1.0],
  199. max_attempts=100,
  200. use_image_if_no_bounding_boxes=True)
  201. bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
  202. if not thread_id:
  203. image_with_distorted_box = tf.image.draw_bounding_boxes(
  204. tf.expand_dims(image, 0), distort_bbox)
  205. tf.image_summary('images_with_distorted_bounding_box',
  206. image_with_distorted_box)
  207. # Crop the image to the specified bounding box.
  208. distorted_image = tf.slice(image, bbox_begin, bbox_size)
  209. # This resizing operation may distort the images because the aspect
  210. # ratio is not respected. We select a resize method in a round robin
  211. # fashion based on the thread number.
  212. # Note that ResizeMethod contains 4 enumerated resizing methods.
  213. resize_method = thread_id % 4
  214. distorted_image = tf.image.resize_images(distorted_image, [height, width],
  215. method=resize_method)
  216. # Restore the shape since the dynamic slice based upon the bbox_size loses
  217. # the third dimension.
  218. distorted_image.set_shape([height, width, 3])
  219. if not thread_id:
  220. tf.image_summary('cropped_resized_image',
  221. tf.expand_dims(distorted_image, 0))
  222. # Randomly flip the image horizontally.
  223. distorted_image = tf.image.random_flip_left_right(distorted_image)
  224. # Randomly distort the colors.
  225. distorted_image = distort_color(distorted_image, thread_id)
  226. if not thread_id:
  227. tf.image_summary('final_distorted_image',
  228. tf.expand_dims(distorted_image, 0))
  229. return distorted_image
  230. def eval_image(image, height, width, scope=None):
  231. """Prepare one image for evaluation.
  232. Args:
  233. image: 3-D float Tensor
  234. height: integer
  235. width: integer
  236. scope: Optional scope for op_scope.
  237. Returns:
  238. 3-D float Tensor of prepared image.
  239. """
  240. with tf.op_scope([image, height, width], scope, 'eval_image'):
  241. # Crop the central region of the image with an area containing 87.5% of
  242. # the original image.
  243. image = tf.image.central_crop(image, central_fraction=0.875)
  244. # Resize the image to the original height and width.
  245. image = tf.expand_dims(image, 0)
  246. image = tf.image.resize_bilinear(image, [height, width],
  247. align_corners=False)
  248. image = tf.squeeze(image, [0])
  249. return image
  250. def image_preprocessing(image_buffer, bbox, train, thread_id=0):
  251. """Decode and preprocess one image for evaluation or training.
  252. Args:
  253. image_buffer: JPEG encoded string Tensor
  254. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  255. where each coordinate is [0, 1) and the coordinates are arranged as
  256. [ymin, xmin, ymax, xmax].
  257. train: boolean
  258. thread_id: integer indicating preprocessing thread
  259. Returns:
  260. 3-D float Tensor containing an appropriately scaled image
  261. Raises:
  262. ValueError: if user does not provide bounding box
  263. """
  264. if bbox is None:
  265. raise ValueError('Please supply a bounding box.')
  266. image = decode_jpeg(image_buffer)
  267. height = FLAGS.image_size
  268. width = FLAGS.image_size
  269. if train:
  270. image = distort_image(image, height, width, bbox, thread_id)
  271. else:
  272. image = eval_image(image, height, width)
  273. # Finally, rescale to [-1,1] instead of [0, 1)
  274. image = tf.sub(image, 0.5)
  275. image = tf.mul(image, 2.0)
  276. return image
  277. def parse_example_proto(example_serialized):
  278. """Parses an Example proto containing a training example of an image.
  279. The output of the build_image_data.py image preprocessing script is a dataset
  280. containing serialized Example protocol buffers. Each Example proto contains
  281. the following fields:
  282. image/height: 462
  283. image/width: 581
  284. image/colorspace: 'RGB'
  285. image/channels: 3
  286. image/class/label: 615
  287. image/class/synset: 'n03623198'
  288. image/class/text: 'knee pad'
  289. image/object/bbox/xmin: 0.1
  290. image/object/bbox/xmax: 0.9
  291. image/object/bbox/ymin: 0.2
  292. image/object/bbox/ymax: 0.6
  293. image/object/bbox/label: 615
  294. image/format: 'JPEG'
  295. image/filename: 'ILSVRC2012_val_00041207.JPEG'
  296. image/encoded: <JPEG encoded string>
  297. Args:
  298. example_serialized: scalar Tensor tf.string containing a serialized
  299. Example protocol buffer.
  300. Returns:
  301. image_buffer: Tensor tf.string containing the contents of a JPEG file.
  302. label: Tensor tf.int32 containing the label.
  303. bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
  304. where each coordinate is [0, 1) and the coordinates are arranged as
  305. [ymin, xmin, ymax, xmax].
  306. text: Tensor tf.string containing the human-readable label.
  307. """
  308. # Dense features in Example proto.
  309. feature_map = {
  310. 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
  311. default_value=''),
  312. 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
  313. default_value=-1),
  314. 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
  315. default_value=''),
  316. }
  317. sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
  318. # Sparse features in Example proto.
  319. feature_map.update(
  320. {k: sparse_float32 for k in ['image/object/bbox/xmin',
  321. 'image/object/bbox/ymin',
  322. 'image/object/bbox/xmax',
  323. 'image/object/bbox/ymax']})
  324. features = tf.parse_single_example(example_serialized, feature_map)
  325. label = tf.cast(features['image/class/label'], dtype=tf.int32)
  326. xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
  327. ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
  328. xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
  329. ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
  330. # Note that we impose an ordering of (y, x) just to make life difficult.
  331. bbox = tf.concat(0, [ymin, xmin, ymax, xmax])
  332. # Force the variable number of bounding boxes into the shape
  333. # [1, num_boxes, coords].
  334. bbox = tf.expand_dims(bbox, 0)
  335. bbox = tf.transpose(bbox, [0, 2, 1])
  336. return features['image/encoded'], label, bbox, features['image/class/text']
  337. def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
  338. num_readers=1):
  339. """Contruct batches of training or evaluation examples from the image dataset.
  340. Args:
  341. dataset: instance of Dataset class specifying the dataset.
  342. See dataset.py for details.
  343. batch_size: integer
  344. train: boolean
  345. num_preprocess_threads: integer, total number of preprocessing threads
  346. num_readers: integer, number of parallel readers
  347. Returns:
  348. images: 4-D float Tensor of a batch of images
  349. labels: 1-D integer Tensor of [batch_size].
  350. Raises:
  351. ValueError: if data is not found
  352. """
  353. with tf.name_scope('batch_processing'):
  354. data_files = dataset.data_files()
  355. if data_files is None:
  356. raise ValueError('No data files found for this dataset')
  357. # Create filename_queue
  358. if train:
  359. filename_queue = tf.train.string_input_producer(data_files,
  360. shuffle=True,
  361. capacity=16)
  362. else:
  363. filename_queue = tf.train.string_input_producer(data_files,
  364. shuffle=False,
  365. capacity=1)
  366. if num_preprocess_threads is None:
  367. num_preprocess_threads = FLAGS.num_preprocess_threads
  368. if num_preprocess_threads % 4:
  369. raise ValueError('Please make num_preprocess_threads a multiple '
  370. 'of 4 (%d % 4 != 0).', num_preprocess_threads)
  371. if num_readers is None:
  372. num_readers = FLAGS.num_readers
  373. if num_readers < 1:
  374. raise ValueError('Please make num_readers at least 1')
  375. # Approximate number of examples per shard.
  376. examples_per_shard = 1024
  377. # Size the random shuffle queue to balance between good global
  378. # mixing (more examples) and memory use (fewer examples).
  379. # 1 image uses 299*299*3*4 bytes = 1MB
  380. # The default input_queue_memory_factor is 16 implying a shuffling queue
  381. # size: examples_per_shard * 16 * 1MB = 17.6GB
  382. min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
  383. if train:
  384. examples_queue = tf.RandomShuffleQueue(
  385. capacity=min_queue_examples + 3 * batch_size,
  386. min_after_dequeue=min_queue_examples,
  387. dtypes=[tf.string])
  388. else:
  389. examples_queue = tf.FIFOQueue(
  390. capacity=examples_per_shard + 3 * batch_size,
  391. dtypes=[tf.string])
  392. # Create multiple readers to populate the queue of examples.
  393. if num_readers > 1:
  394. enqueue_ops = []
  395. for _ in range(num_readers):
  396. reader = dataset.reader()
  397. _, value = reader.read(filename_queue)
  398. enqueue_ops.append(examples_queue.enqueue([value]))
  399. tf.train.queue_runner.add_queue_runner(
  400. tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
  401. example_serialized = examples_queue.dequeue()
  402. else:
  403. reader = dataset.reader()
  404. _, example_serialized = reader.read(filename_queue)
  405. images_and_labels = []
  406. for thread_id in range(num_preprocess_threads):
  407. # Parse a serialized Example proto to extract the image and metadata.
  408. image_buffer, label_index, bbox, _ = parse_example_proto(
  409. example_serialized)
  410. image = image_preprocessing(image_buffer, bbox, train, thread_id)
  411. images_and_labels.append([image, label_index])
  412. images, label_index_batch = tf.train.batch_join(
  413. images_and_labels,
  414. batch_size=batch_size,
  415. capacity=2 * num_preprocess_threads * batch_size)
  416. # Reshape images into these desired dimensions.
  417. height = FLAGS.image_size
  418. width = FLAGS.image_size
  419. depth = 3
  420. images = tf.cast(images, tf.float32)
  421. images = tf.reshape(images, shape=[batch_size, height, width, depth])
  422. # Display the training images in the visualizer.
  423. tf.image_summary('images', images)
  424. return images, tf.reshape(label_index_batch, [batch_size])