Browse Source

1. Little bit optimized code
2. Fixed accuracy when calcuting logs more than 10 steps

Alexandr Baranezky 8 years ago
parent
commit
3bc7f6c288
1 changed files with 13 additions and 6 deletions
  1. 13 6
      tutorials/image/cifar10/cifar10_train.py

+ 13 - 6
tutorials/image/cifar10/cifar10_train.py

@@ -53,6 +53,9 @@ tf.app.flags.DEFINE_integer('max_steps', 1000000,
 tf.app.flags.DEFINE_boolean('log_device_placement', False,
                             """Whether to log device placement.""")
 
+tf.app.flags.DEFINE_integer('log_steps_count', 10,
+                            """Log process results per count.""")
+
 
 def train():
   """Train CIFAR-10 for a number of steps."""
@@ -78,19 +81,23 @@ def train():
 
       def begin(self):
         self._step = -1
+        self._start_time = time.time()
 
       def before_run(self, run_context):
         self._step += 1
-        self._start_time = time.time()
         return tf.train.SessionRunArgs(loss)  # Asks for loss value.
 
       def after_run(self, run_context, run_values):
-        duration = time.time() - self._start_time
-        loss_value = run_values.results
-        if self._step % 10 == 0:
+        log_steps = FLAGS.log_steps_count
+        if self._step % log_steps == 0:
+          duration = time.time() - self._start_time
+          self._start_time = time.time()
+
+          loss_value = run_values.results
+
           num_examples_per_step = FLAGS.batch_size
-          examples_per_sec = num_examples_per_step / duration
-          sec_per_batch = float(duration)
+          examples_per_sec = num_examples_per_step * FLAGS.log_steps_count / duration
+          sec_per_batch = float(duration / log_steps)
 
           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                         'sec/batch)')