소스 검색

Improve image processing (#45)

* improve image processing performance for Inception.
Jianmin Chen 9 년 전
부모
커밋
5d7612c63c
1개의 변경된 파일57개의 추가작업 그리고 26개의 파일을 삭제
  1. 57 26
      inception/inception/image_processing.py

+ 57 - 26
inception/inception/image_processing.py

@@ -40,7 +40,6 @@ from __future__ import absolute_import
 from __future__ import division
 from __future__ import print_function
 
-
 import tensorflow as tf
 
 FLAGS = tf.app.flags.FLAGS
@@ -52,6 +51,8 @@ tf.app.flags.DEFINE_integer('image_size', 299,
 tf.app.flags.DEFINE_integer('num_preprocess_threads', 4,
                             """Number of preprocessing threads per tower. """
                             """Please make this a multiple of 4.""")
+tf.app.flags.DEFINE_integer('num_readers', 4,
+                            """Number of parallel readers during train.""")
 
 # Images are preprocessed asynchronously using multiple threads specifed by
 # --num_preprocss_threads and the resulting processed images are stored in a
@@ -97,7 +98,8 @@ def inputs(dataset, batch_size=None, num_preprocess_threads=None):
   with tf.device('/cpu:0'):
     images, labels = batch_inputs(
         dataset, batch_size, train=False,
-        num_preprocess_threads=num_preprocess_threads)
+        num_preprocess_threads=num_preprocess_threads,
+        num_readers=1)
 
   return images, labels
 
@@ -130,7 +132,8 @@ def distorted_inputs(dataset, batch_size=None, num_preprocess_threads=None):
   with tf.device('/cpu:0'):
     images, labels = batch_inputs(
         dataset, batch_size, train=True,
-        num_preprocess_threads=num_preprocess_threads)
+        num_preprocess_threads=num_preprocess_threads,
+        num_readers=FLAGS.num_readers)
   return images, labels
 
 
@@ -401,7 +404,8 @@ def parse_example_proto(example_serialized):
   return features['image/encoded'], label, bbox, features['image/class/text']
 
 
-def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
+def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None,
+                 num_readers=1):
   """Contruct batches of training or evaluation examples from the image dataset.
 
   Args:
@@ -410,6 +414,7 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
     batch_size: integer
     train: boolean
     num_preprocess_threads: integer, total number of preprocessing threads
+    num_readers: integer, number of parallel readers
 
   Returns:
     images: 4-D float Tensor of a batch of images
@@ -422,26 +427,28 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
     data_files = dataset.data_files()
     if data_files is None:
       raise ValueError('No data files found for this dataset')
-    filename_queue = tf.train.string_input_producer(data_files, capacity=16)
 
+    # Create filename_queue
+    if train:
+      filename_queue = tf.train.string_input_producer(data_files,
+                                                      shuffle=True,
+                                                      capacity=16)
+    else:
+      filename_queue = tf.train.string_input_producer(data_files,
+                                                      shuffle=False,
+                                                      capacity=1)
     if num_preprocess_threads is None:
       num_preprocess_threads = FLAGS.num_preprocess_threads
 
     if num_preprocess_threads % 4:
       raise ValueError('Please make num_preprocess_threads a multiple '
                        'of 4 (%d % 4 != 0).', num_preprocess_threads)
-    # Create a subgraph with its own reader (but sharing the
-    # filename_queue) for each preprocessing thread.
-    images_and_labels = []
-    for thread_id in range(num_preprocess_threads):
-      reader = dataset.reader()
-      _, example_serialized = reader.read(filename_queue)
 
-      # Parse a serialized Example proto to extract the image and metadata.
-      image_buffer, label_index, bbox, _ = parse_example_proto(
-          example_serialized)
-      image = image_preprocessing(image_buffer, bbox, train, thread_id)
-      images_and_labels.append([image, label_index])
+    if num_readers is None:
+      num_readers = FLAGS.num_readers
+
+    if num_readers < 1:
+      raise ValueError('Please make num_readers at least 1')
 
     # Approximate number of examples per shard.
     examples_per_shard = 1024
@@ -451,19 +458,43 @@ def batch_inputs(dataset, batch_size, train, num_preprocess_threads=None):
     # The default input_queue_memory_factor is 16 implying a shuffling queue
     # size: examples_per_shard * 16 * 1MB = 17.6GB
     min_queue_examples = examples_per_shard * FLAGS.input_queue_memory_factor
-
-    # Create a queue that produces the examples in batches after shuffling.
     if train:
-      images, label_index_batch = tf.train.shuffle_batch_join(
-          images_and_labels,
-          batch_size=batch_size,
+      examples_queue = tf.RandomShuffleQueue(
           capacity=min_queue_examples + 3 * batch_size,
-          min_after_dequeue=min_queue_examples)
+          min_after_dequeue=min_queue_examples,
+          dtypes=[tf.string])
+    else:
+      examples_queue = tf.FIFOQueue(
+          capacity=examples_per_shard + 3 * batch_size,
+          dtypes=[tf.string])
+
+    # Create multiple readers to populate the queue of examples.
+    if num_readers > 1:
+      enqueue_ops = []
+      for _ in range(num_readers):
+        reader = dataset.reader()
+        _, value = reader.read(filename_queue)
+        enqueue_ops.append(examples_queue.enqueue([value]))
+
+      tf.train.queue_runner.add_queue_runner(
+          tf.train.queue_runner.QueueRunner(examples_queue, enqueue_ops))
+      example_serialized = examples_queue.dequeue()
     else:
-      images, label_index_batch = tf.train.batch_join(
-          images_and_labels,
-          batch_size=batch_size,
-          capacity=min_queue_examples + 3 * batch_size)
+      reader = dataset.reader()
+      _, example_serialized = reader.read(filename_queue)
+
+    images_and_labels = []
+    for thread_id in range(num_preprocess_threads):
+      # Parse a serialized Example proto to extract the image and metadata.
+      image_buffer, label_index, bbox, _ = parse_example_proto(
+          example_serialized)
+      image = image_preprocessing(image_buffer, bbox, train, thread_id)
+      images_and_labels.append([image, label_index])
+
+    images, label_index_batch = tf.train.batch_join(
+        images_and_labels,
+        batch_size=batch_size,
+        capacity=2 * num_preprocess_threads * batch_size)
 
     # Reshape images into these desired dimensions.
     height = FLAGS.image_size