|
|
@@ -53,6 +53,9 @@ tf.app.flags.DEFINE_integer('max_steps', 1000000,
|
|
|
tf.app.flags.DEFINE_boolean('log_device_placement', False,
|
|
|
"""Whether to log device placement.""")
|
|
|
|
|
|
+tf.app.flags.DEFINE_integer('log_steps_count', 10,
|
|
|
+ """Log process results per count.""")
|
|
|
+
|
|
|
|
|
|
def train():
|
|
|
"""Train CIFAR-10 for a number of steps."""
|
|
|
@@ -78,19 +81,23 @@ def train():
|
|
|
|
|
|
def begin(self):
|
|
|
self._step = -1
|
|
|
+ self._start_time = time.time()
|
|
|
|
|
|
def before_run(self, run_context):
|
|
|
self._step += 1
|
|
|
- self._start_time = time.time()
|
|
|
return tf.train.SessionRunArgs(loss) # Asks for loss value.
|
|
|
|
|
|
def after_run(self, run_context, run_values):
|
|
|
- duration = time.time() - self._start_time
|
|
|
- loss_value = run_values.results
|
|
|
- if self._step % 10 == 0:
|
|
|
+ log_steps = FLAGS.log_steps_count
|
|
|
+ if self._step % log_steps == 0:
|
|
|
+ duration = time.time() - self._start_time
|
|
|
+ self._start_time = time.time()
|
|
|
+
|
|
|
+ loss_value = run_values.results
|
|
|
+
|
|
|
num_examples_per_step = FLAGS.batch_size
|
|
|
- examples_per_sec = num_examples_per_step / duration
|
|
|
- sec_per_batch = float(duration)
|
|
|
+ examples_per_sec = num_examples_per_step * FLAGS.log_steps_count / duration
|
|
|
+ sec_per_batch = float(duration / log_steps)
|
|
|
|
|
|
format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f '
|
|
|
'sec/batch)')
|