|
@@ -15,8 +15,8 @@
|
|
|
|
|
|
"""ResNet Train/Eval module.
|
|
|
"""
|
|
|
-import sys
|
|
|
import time
|
|
|
+import sys
|
|
|
|
|
|
import cifar_input
|
|
|
import numpy as np
|
|
@@ -26,8 +26,10 @@ import tensorflow as tf
|
|
|
FLAGS = tf.app.flags.FLAGS
|
|
|
tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
|
|
|
tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
|
|
|
-tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
|
|
|
-tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
|
|
|
+tf.app.flags.DEFINE_string('train_data_path', '',
|
|
|
+ 'Filepattern for training data.')
|
|
|
+tf.app.flags.DEFINE_string('eval_data_path', '',
|
|
|
+ 'Filepattern for eval data')
|
|
|
tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
|
|
|
tf.app.flags.DEFINE_string('train_dir', '',
|
|
|
'Directory to keep training outputs.')
|
|
@@ -50,50 +52,65 @@ def train(hps):
|
|
|
FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
|
|
|
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
|
|
|
model.build_graph()
|
|
|
- summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
|
|
|
-
|
|
|
- sv = tf.train.Supervisor(logdir=FLAGS.log_root,
|
|
|
- is_chief=True,
|
|
|
- summary_op=None,
|
|
|
- save_summaries_secs=60,
|
|
|
- save_model_secs=300,
|
|
|
- global_step=model.global_step)
|
|
|
- sess = sv.prepare_or_wait_for_session(
|
|
|
- config=tf.ConfigProto(allow_soft_placement=True))
|
|
|
-
|
|
|
- step = 0
|
|
|
- lrn_rate = 0.1
|
|
|
-
|
|
|
- while not sv.should_stop():
|
|
|
- (_, summaries, loss, predictions, truth, train_step) = sess.run(
|
|
|
- [model.train_op, model.summaries, model.cost, model.predictions,
|
|
|
- model.labels, model.global_step],
|
|
|
- feed_dict={model.lrn_rate: lrn_rate})
|
|
|
-
|
|
|
- if train_step < 40000:
|
|
|
- lrn_rate = 0.1
|
|
|
- elif train_step < 60000:
|
|
|
- lrn_rate = 0.01
|
|
|
- elif train_step < 80000:
|
|
|
- lrn_rate = 0.001
|
|
|
- else:
|
|
|
- lrn_rate = 0.0001
|
|
|
-
|
|
|
- truth = np.argmax(truth, axis=1)
|
|
|
- predictions = np.argmax(predictions, axis=1)
|
|
|
- precision = np.mean(truth == predictions)
|
|
|
-
|
|
|
- step += 1
|
|
|
- if step % 100 == 0:
|
|
|
- precision_summ = tf.Summary()
|
|
|
- precision_summ.value.add(
|
|
|
- tag='Precision', simple_value=precision)
|
|
|
- summary_writer.add_summary(precision_summ, train_step)
|
|
|
- summary_writer.add_summary(summaries, train_step)
|
|
|
- tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
|
|
|
- summary_writer.flush()
|
|
|
-
|
|
|
- sv.Stop()
|
|
|
+
|
|
|
+ param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
|
|
+ tf.get_default_graph(),
|
|
|
+ tfprof_options=tf.contrib.tfprof.model_analyzer.
|
|
|
+ TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
|
|
|
+ sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
|
|
|
+
|
|
|
+ tf.contrib.tfprof.model_analyzer.print_model_analysis(
|
|
|
+ tf.get_default_graph(),
|
|
|
+ tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
|
|
|
+
|
|
|
+ truth = tf.argmax(model.labels, axis=1)
|
|
|
+ predictions = tf.argmax(model.predictions, axis=1)
|
|
|
+ precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
|
|
|
+
|
|
|
+ summary_hook = tf.train.SummarySaverHook(
|
|
|
+ save_steps=100,
|
|
|
+ output_dir=FLAGS.train_dir,
|
|
|
+ summary_op=[model.summaries,
|
|
|
+ tf.summary.scalar('Precision', precision)])
|
|
|
+
|
|
|
+ logging_hook = tf.train.LoggingTensorHook(
|
|
|
+ tensors={'step': model.global_step,
|
|
|
+ 'loss': model.cost,
|
|
|
+ 'precision': precision},
|
|
|
+ every_n_iter=100)
|
|
|
+
|
|
|
+ class _LearningRateSetterHook(tf.train.SessionRunHook):
|
|
|
+ """Sets learning_rate based on global step."""
|
|
|
+
|
|
|
+ def begin(self):
|
|
|
+ self._lrn_rate = 0.1
|
|
|
+
|
|
|
+ def before_run(self, run_context):
|
|
|
+ return tf.train.SessionRunArgs(
|
|
|
+ model.global_step, # Asks for global step value.
|
|
|
+ feed_dict={model.lrn_rate: self._lrn_rate}) # Sets learning rate
|
|
|
+
|
|
|
+ def after_run(self, run_context, run_values):
|
|
|
+ train_step = run_values.results
|
|
|
+ if train_step < 40000:
|
|
|
+ self._lrn_rate = 0.1
|
|
|
+ elif train_step < 60000:
|
|
|
+ self._lrn_rate = 0.01
|
|
|
+ elif train_step < 80000:
|
|
|
+ self._lrn_rate = 0.001
|
|
|
+ else:
|
|
|
+ self._lrn_rate = 0.0001
|
|
|
+
|
|
|
+ with tf.train.MonitoredTrainingSession(
|
|
|
+ checkpoint_dir=FLAGS.log_root,
|
|
|
+ hooks=[logging_hook, _LearningRateSetterHook()],
|
|
|
+ chief_only_hooks=[summary_hook],
|
|
|
+ # Since we provide a SummarySaverHook, we need to disable default
|
|
|
+ # SummarySaverHook. To do that we set save_summaries_steps to 0.
|
|
|
+ save_summaries_steps=0,
|
|
|
+ config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
|
|
|
+ while not mon_sess.should_stop():
|
|
|
+ mon_sess.run(model.train_op)
|
|
|
|
|
|
|
|
|
def evaluate(hps):
|
|
@@ -103,7 +120,7 @@ def evaluate(hps):
|
|
|
model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
|
|
|
model.build_graph()
|
|
|
saver = tf.train.Saver()
|
|
|
- summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
|
|
|
+ summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
|
|
|
|
|
|
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
|
|
tf.train.start_queue_runners(sess)
|