Przeglądaj źródła

Fix a bug in the im2txt code where the Saver is created before the
optimizer.

Christopher Shallue 9 lat temu
rodzic
commit
cd5e9b7c2b

+ 3 - 4
im2txt/im2txt/configuration.py

@@ -77,10 +77,6 @@ class ModelConfig(object):
     # If < 1.0, the dropout keep probability applied to LSTM variables.
     self.lstm_dropout_keep_prob = 0.7
 
-    # How many model checkpoints to keep.
-    self.max_checkpoints_to_keep = 5
-    self.keep_checkpoint_every_n_hours = 10000
-
 
 class TrainingConfig(object):
   """Wrapper class for training hyperparameters."""
@@ -103,3 +99,6 @@ class TrainingConfig(object):
 
     # If not None, clip gradients to this value.
     self.clip_gradients = 5.0
+
+    # How many model checkpoints to keep.
+    self.max_checkpoints_to_keep = 5

+ 7 - 3
im2txt/im2txt/evaluate.py

@@ -104,11 +104,12 @@ def evaluate_model(sess, model, global_step, summary_writer, summary_op):
                   global_step)
 
 
-def run_once(model, summary_writer, summary_op):
+def run_once(model, saver, summary_writer, summary_op):
   """Evaluates the latest model checkpoint.
 
   Args:
     model: Instance of ShowAndTellModel; the model to evaluate.
+    saver: Instance of tf.train.Saver for restoring model Variables.
     summary_writer: Instance of SummaryWriter.
     summary_op: Op for generating model summaries.
   """
@@ -121,7 +122,7 @@ def run_once(model, summary_writer, summary_op):
   with tf.Session() as sess:
     # Load model from checkpoint.
     tf.logging.info("Loading model from checkpoint: %s", model_path)
-    model.saver.restore(sess, model_path)
+    saver.restore(sess, model_path)
     global_step = tf.train.global_step(sess, model.global_step.name)
     tf.logging.info("Successfully loaded %s at global step = %d.",
                     os.path.basename(model_path), global_step)
@@ -166,6 +167,9 @@ def run():
     model = show_and_tell_model.ShowAndTellModel(model_config, mode="eval")
     model.build()
 
+    # Create the Saver to restore model Variables.
+    saver = tf.train.Saver()
+
     # Create the summary operation and the summary writer.
     summary_op = tf.merge_all_summaries()
     summary_writer = tf.train.SummaryWriter(eval_dir)
@@ -177,7 +181,7 @@ def run():
       start = time.time()
       tf.logging.info("Starting evaluation at " + time.strftime(
           "%Y-%m-%d-%H:%M:%S", time.localtime()))
-      run_once(model, summary_writer, summary_op)
+      run_once(model, saver, summary_writer, summary_op)
       time_to_next_eval = start + FLAGS.eval_interval_secs - time.time()
       if time_to_next_eval > 0:
         time.sleep(time_to_next_eval)

+ 2 - 4
im2txt/im2txt/inference_utils/inference_wrapper_base.py

@@ -112,10 +112,8 @@ class InferenceWrapperBase(object):
         from the checkpoint file.
     """
     tf.logging.info("Building model.")
-    model = self.build_model(model_config)
-    saver = model.saver
-    if not saver:
-      saver = tf.Saver()
+    self.build_model(model_config)
+    saver = tf.train.Saver()
 
     return self._create_restore_fn(checkpoint_path, saver)
 

+ 0 - 7
im2txt/im2txt/show_and_tell_model.py

@@ -347,12 +347,6 @@ class ShowAndTellModel(object):
 
     self.global_step = global_step
 
-  def setup_saver(self):
-    """Sets up the Saver for loading and saving model checkpoints."""
-    self.saver = tf.train.Saver(
-        max_to_keep=self.config.max_checkpoints_to_keep,
-        keep_checkpoint_every_n_hours=self.config.keep_checkpoint_every_n_hours)
-
   def build(self):
     """Creates all ops for training and evaluation."""
     self.build_inputs()
@@ -361,4 +355,3 @@ class ShowAndTellModel(object):
     self.build_model()
     self.setup_inception_initializer()
     self.setup_global_step()
-    self.setup_saver()

+ 4 - 1
im2txt/im2txt/train.py

@@ -95,6 +95,9 @@ def main(unused_argv):
         clip_gradients=training_config.clip_gradients,
         learning_rate_decay_fn=learning_rate_decay_fn)
 
+    # Set up the Saver for saving and restoring model checkpoints.
+    saver = tf.train.Saver(max_to_keep=training_config.max_checkpoints_to_keep)
+
   # Run training.
   tf.contrib.slim.learning.train(
       train_op,
@@ -104,7 +107,7 @@ def main(unused_argv):
       global_step=model.global_step,
       number_of_steps=FLAGS.number_of_steps,
       init_fn=model.init_fn,
-      saver=model.saver)
+      saver=saver)
 
 
 if __name__ == "__main__":