Browse Source

Convert resnet model to use monitored_session

Mustafa Ispir 8 years ago
parent
commit
931c70a1b8
3 changed files with 67 additions and 50 deletions
  1. 1 1
      resnet/cifar_input.py
  2. 65 48
      resnet/resnet_main.py
  3. 1 1
      resnet/resnet_model.py

+ 1 - 1
resnet/cifar_input.py

@@ -73,7 +73,7 @@ def build_input(dataset, data_path, batch_size, mode):
     # image = tf.image.random_brightness(image, max_delta=63. / 255.)
     # image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
     # image = tf.image.random_contrast(image, lower=0.2, upper=1.8)
-    image = tf.image.per_image_whitening(image)
+    image = tf.image.per_image_standardization(image)
 
     example_queue = tf.RandomShuffleQueue(
         capacity=16 * batch_size,

+ 65 - 48
resnet/resnet_main.py

@@ -15,8 +15,8 @@
 
 """ResNet Train/Eval module.
 """
-import sys
 import time
+import sys
 
 import cifar_input
 import numpy as np
@@ -26,8 +26,10 @@ import tensorflow as tf
 FLAGS = tf.app.flags.FLAGS
 tf.app.flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or cifar100.')
 tf.app.flags.DEFINE_string('mode', 'train', 'train or eval.')
-tf.app.flags.DEFINE_string('train_data_path', '', 'Filepattern for training data.')
-tf.app.flags.DEFINE_string('eval_data_path', '', 'Filepattern for eval data')
+tf.app.flags.DEFINE_string('train_data_path', '',
+                           'Filepattern for training data.')
+tf.app.flags.DEFINE_string('eval_data_path', '',
+                           'Filepattern for eval data')
 tf.app.flags.DEFINE_integer('image_size', 32, 'Image side length.')
 tf.app.flags.DEFINE_string('train_dir', '',
                            'Directory to keep training outputs.')
@@ -50,50 +52,65 @@ def train(hps):
       FLAGS.dataset, FLAGS.train_data_path, hps.batch_size, FLAGS.mode)
   model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
   model.build_graph()
-  summary_writer = tf.train.SummaryWriter(FLAGS.train_dir)
-
-  sv = tf.train.Supervisor(logdir=FLAGS.log_root,
-                           is_chief=True,
-                           summary_op=None,
-                           save_summaries_secs=60,
-                           save_model_secs=300,
-                           global_step=model.global_step)
-  sess = sv.prepare_or_wait_for_session(
-      config=tf.ConfigProto(allow_soft_placement=True))
-
-  step = 0
-  lrn_rate = 0.1
-
-  while not sv.should_stop():
-    (_, summaries, loss, predictions, truth, train_step) = sess.run(
-        [model.train_op, model.summaries, model.cost, model.predictions,
-         model.labels, model.global_step],
-        feed_dict={model.lrn_rate: lrn_rate})
-
-    if train_step < 40000:
-      lrn_rate = 0.1
-    elif train_step < 60000:
-      lrn_rate = 0.01
-    elif train_step < 80000:
-      lrn_rate = 0.001
-    else:
-      lrn_rate = 0.0001
-
-    truth = np.argmax(truth, axis=1)
-    predictions = np.argmax(predictions, axis=1)
-    precision = np.mean(truth == predictions)
-
-    step += 1
-    if step % 100 == 0:
-      precision_summ = tf.Summary()
-      precision_summ.value.add(
-          tag='Precision', simple_value=precision)
-      summary_writer.add_summary(precision_summ, train_step)
-      summary_writer.add_summary(summaries, train_step)
-      tf.logging.info('loss: %.3f, precision: %.3f\n' % (loss, precision))
-      summary_writer.flush()
-
-  sv.Stop()
+
+  param_stats = tf.contrib.tfprof.model_analyzer.print_model_analysis(
+      tf.get_default_graph(),
+      tfprof_options=tf.contrib.tfprof.model_analyzer.
+          TRAINABLE_VARS_PARAMS_STAT_OPTIONS)
+  sys.stdout.write('total_params: %d\n' % param_stats.total_parameters)
+
+  tf.contrib.tfprof.model_analyzer.print_model_analysis(
+      tf.get_default_graph(),
+      tfprof_options=tf.contrib.tfprof.model_analyzer.FLOAT_OPS_OPTIONS)
+
+  truth = tf.argmax(model.labels, axis=1)
+  predictions = tf.argmax(model.predictions, axis=1)
+  precision = tf.reduce_mean(tf.to_float(tf.equal(predictions, truth)))
+
+  summary_hook = tf.train.SummarySaverHook(
+      save_steps=100,
+      output_dir=FLAGS.train_dir,
+      summary_op=[model.summaries,
+                  tf.summary.scalar('Precision', precision)])
+
+  logging_hook = tf.train.LoggingTensorHook(
+      tensors={'step': model.global_step,
+               'loss': model.cost,
+               'precision': precision},
+      every_n_iter=100)
+
+  class _LearningRateSetterHook(tf.train.SessionRunHook):
+    """Sets learning_rate based on global step."""
+
+    def begin(self):
+      self._lrn_rate = 0.1
+
+    def before_run(self, run_context):
+      return tf.train.SessionRunArgs(
+          model.global_step,  # Asks for global step value.
+          feed_dict={model.lrn_rate: self._lrn_rate})  # Sets learning rate
+
+    def after_run(self, run_context, run_values):
+      train_step = run_values.results
+      if train_step < 40000:
+        self._lrn_rate = 0.1
+      elif train_step < 60000:
+        self._lrn_rate = 0.01
+      elif train_step < 80000:
+        self._lrn_rate = 0.001
+      else:
+        self._lrn_rate = 0.0001
+
+  with tf.train.MonitoredTrainingSession(
+      checkpoint_dir=FLAGS.log_root,
+      hooks=[logging_hook, _LearningRateSetterHook()],
+      chief_only_hooks=[summary_hook],
+      # Since we provide a SummarySaverHook, we need to disable default
+      # SummarySaverHook. To do that we set save_summaries_steps to 0.
+      save_summaries_steps=0,
+      config=tf.ConfigProto(allow_soft_placement=True)) as mon_sess:
+    while not mon_sess.should_stop():
+      mon_sess.run(model.train_op)
 
 
 def evaluate(hps):
@@ -103,7 +120,7 @@ def evaluate(hps):
   model = resnet_model.ResNet(hps, images, labels, FLAGS.mode)
   model.build_graph()
   saver = tf.train.Saver()
-  summary_writer = tf.train.SummaryWriter(FLAGS.eval_dir)
+  summary_writer = tf.summary.FileWriter(FLAGS.eval_dir)
 
   sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
   tf.train.start_queue_runners(sess)

+ 1 - 1
resnet/resnet_model.py

@@ -55,7 +55,7 @@ class ResNet(object):
 
   def build_graph(self):
     """Build a whole graph for the model."""
-    self.global_step = tf.Variable(0, name='global_step', trainable=False)
+    self.global_step = tf.contrib.framework.get_or_create_global_step()
     self._build_model()
     if self.mode == 'train':
       self._build_train_op()