소스 검색

Merge pull request #1127 from AsterAI/fix_cifar_logs

1. Little bit optimized code
Neal Wu 8 년 전
부모
커밋
0505b99c40
1개의 변경된 파일11개의 추가작업 그리고 7개의 파일을 삭제
  1. 11 7
      tutorials/image/cifar10/cifar10_train.py

+ 11 - 7
tutorials/image/cifar10/cifar10_train.py

@@ -52,6 +52,8 @@ tf.app.flags.DEFINE_integer('max_steps', 1000000,
                             """Number of batches to run.""")
 tf.app.flags.DEFINE_boolean('log_device_placement', False,
                             """Whether to log device placement.""")
+tf.app.flags.DEFINE_integer('log_frequency', 10,
+                            """How often to log results to the console.""")
 
 
 def train():
@@ -78,19 +80,21 @@ def train():
 
       def begin(self):
         self._step = -1
+        self._start_time = time.time()
 
       def before_run(self, run_context):
         self._step += 1
-        self._start_time = time.time()
         return tf.train.SessionRunArgs(loss)  # Asks for loss value.
 
       def after_run(self, run_context, run_values):
-        duration = time.time() - self._start_time
-        loss_value = run_values.results
-        if self._step % 10 == 0:
-          num_examples_per_step = FLAGS.batch_size
-          examples_per_sec = num_examples_per_step / duration
-          sec_per_batch = float(duration)
+        if self._step % FLAGS.log_frequency == 0:
+          current_time = time.time()
+          duration = current_time - self._start_time
+          self._start_time = current_time
+
+          loss_value = run_values.results
+          examples_per_sec = FLAGS.log_frequency * FLAGS.batch_size / duration
+          sec_per_batch = float(duration / FLAGS.log_frequency)
 
           format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
                         'sec/batch)')