eval.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Eval Cross Convolutional Model."""
  16. import io
  17. import os
  18. import sys
  19. import time
  20. import numpy as np
  21. import tensorflow as tf
  22. import model as cross_conv_model
  23. import reader
  24. FLAGS = tf.flags.FLAGS
  25. tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
  26. tf.flags.DEFINE_string('data_filepattern',
  27. 'est',
  28. 'training data file pattern.')
  29. tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
  30. tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
  31. tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
  32. tf.flags.DEFINE_float('scale', 10.0,
  33. 'Scale the image after norm_scale and move the diff '
  34. 'to the positive realm.')
  35. tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
  36. tf.flags.DEFINE_integer('eval_batch_count', 100,
  37. 'Average the result this number of examples.')
  38. tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
  39. tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
  40. tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
  41. slim = tf.contrib.slim
  42. def _Eval():
  43. params = dict()
  44. params['batch_size'] = FLAGS.batch_size
  45. params['seq_len'] = FLAGS.sequence_length
  46. params['image_size'] = FLAGS.image_size
  47. params['is_training'] = False
  48. params['norm_scale'] = FLAGS.norm_scale
  49. params['scale'] = FLAGS.scale
  50. params['l2_loss'] = FLAGS.l2_loss
  51. params['reconstr_loss'] = FLAGS.reconstr_loss
  52. params['kl_loss'] = FLAGS.kl_loss
  53. eval_dir = os.path.join(FLAGS.log_root, 'eval')
  54. images = reader.ReadInput(
  55. FLAGS.data_filepattern, shuffle=False, params=params)
  56. images *= params['scale']
  57. # Increase the value makes training much faster.
  58. image_diff_list = reader.SequenceToImageAndDiff(images)
  59. model = cross_conv_model.CrossConvModel(image_diff_list, params)
  60. model.Build()
  61. summary_writer = tf.summary.FileWriter(eval_dir)
  62. saver = tf.train.Saver()
  63. sess = tf.Session('', config=tf.ConfigProto(allow_soft_placement=True))
  64. tf.train.start_queue_runners(sess)
  65. while True:
  66. time.sleep(60)
  67. try:
  68. ckpt_state = tf.train.get_checkpoint_state(FLAGS.log_root)
  69. except tf.errors.OutOfRangeError as e:
  70. sys.stderr.write('Cannot restore checkpoint: %s\n' % e)
  71. continue
  72. if not (ckpt_state and ckpt_state.model_checkpoint_path):
  73. sys.stderr.write('No model to eval yet at %s\n' % FLAGS.log_root)
  74. continue
  75. sys.stderr.write('Loading checkpoint %s\n' %
  76. ckpt_state.model_checkpoint_path)
  77. saver.restore(sess, ckpt_state.model_checkpoint_path)
  78. # Use the empirical distribution of z from training set.
  79. if not tf.gfile.Exists(os.path.join(FLAGS.log_root, 'z_mean.npy')):
  80. sys.stderr.write('No z at %s\n' % FLAGS.log_root)
  81. continue
  82. with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy')) as f:
  83. sample_z_mean = np.load(io.BytesIO(f.read()))
  84. with tf.gfile.Open(
  85. os.path.join(FLAGS.log_root, 'z_stddev_log.npy')) as f:
  86. sample_z_stddev_log = np.load(io.BytesIO(f.read()))
  87. total_loss = 0.0
  88. for _ in xrange(FLAGS.eval_batch_count):
  89. loss_val, total_steps, summaries = sess.run(
  90. [model.loss, model.global_step, model.summary_op],
  91. feed_dict={model.z_mean: sample_z_mean,
  92. model.z_stddev_log: sample_z_stddev_log})
  93. total_loss += loss_val
  94. summary_writer.add_summary(summaries, total_steps)
  95. sys.stderr.write('steps: %d, loss: %f\n' %
  96. (total_steps, total_loss / FLAGS.eval_batch_count))
  97. def main(_):
  98. _Eval()
  99. if __name__ == '__main__':
  100. tf.app.run()