train.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Train the cross convolutional model."""
  16. import os
  17. import sys
  18. import numpy as np
  19. import tensorflow as tf
  20. import model as cross_conv_model
  21. import reader
  22. FLAGS = tf.flags.FLAGS
  23. tf.flags.DEFINE_string('master', '', 'Session address.')
  24. tf.flags.DEFINE_string('log_root', '/tmp/moving_obj', 'The root dir of output.')
  25. tf.flags.DEFINE_string('data_filepattern', '',
  26. 'training data file pattern.')
  27. tf.flags.DEFINE_integer('image_size', 64, 'Image height and width.')
  28. tf.flags.DEFINE_integer('batch_size', 1, 'Batch size.')
  29. tf.flags.DEFINE_float('norm_scale', 1.0, 'Normalize the original image')
  30. tf.flags.DEFINE_float('scale', 10.0,
  31. 'Scale the image after norm_scale and move the diff '
  32. 'to the positive realm.')
  33. tf.flags.DEFINE_integer('sequence_length', 2, 'tf.SequenceExample length.')
  34. tf.flags.DEFINE_float('learning_rate', 0.8, 'Learning rate.')
  35. tf.flags.DEFINE_bool('l2_loss', True, 'If true, include l2_loss.')
  36. tf.flags.DEFINE_bool('reconstr_loss', False, 'If true, include reconstr_loss.')
  37. tf.flags.DEFINE_bool('kl_loss', True, 'If true, include KL loss.')
  38. slim = tf.contrib.slim
  39. def _Train():
  40. params = dict()
  41. params['batch_size'] = FLAGS.batch_size
  42. params['seq_len'] = FLAGS.sequence_length
  43. params['image_size'] = FLAGS.image_size
  44. params['is_training'] = True
  45. params['norm_scale'] = FLAGS.norm_scale
  46. params['scale'] = FLAGS.scale
  47. params['learning_rate'] = FLAGS.learning_rate
  48. params['l2_loss'] = FLAGS.l2_loss
  49. params['reconstr_loss'] = FLAGS.reconstr_loss
  50. params['kl_loss'] = FLAGS.kl_loss
  51. train_dir = os.path.join(FLAGS.log_root, 'train')
  52. images = reader.ReadInput(FLAGS.data_filepattern, shuffle=True, params=params)
  53. images *= params['scale']
  54. # Increase the value makes training much faster.
  55. image_diff_list = reader.SequenceToImageAndDiff(images)
  56. model = cross_conv_model.CrossConvModel(image_diff_list, params)
  57. model.Build()
  58. tf.contrib.tfprof.model_analyzer.print_model_analysis(tf.get_default_graph())
  59. summary_writer = tf.summary.FileWriter(train_dir)
  60. sv = tf.train.Supervisor(logdir=FLAGS.log_root,
  61. summary_op=None,
  62. is_chief=True,
  63. save_model_secs=60,
  64. global_step=model.global_step)
  65. sess = sv.prepare_or_wait_for_session(
  66. FLAGS.master, config=tf.ConfigProto(allow_soft_placement=True))
  67. total_loss = 0.0
  68. step = 0
  69. sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
  70. sample_z_stddev_log = np.zeros(model.z_stddev_log.get_shape().as_list())
  71. sample_step = 0
  72. while True:
  73. _, loss_val, total_steps, summaries, z_mean, z_stddev_log = sess.run(
  74. [model.train_op, model.loss, model.global_step,
  75. model.summary_op,
  76. model.z_mean, model.z_stddev_log])
  77. sample_z_mean += z_mean
  78. sample_z_stddev_log += z_stddev_log
  79. total_loss += loss_val
  80. step += 1
  81. sample_step += 1
  82. if step % 100 == 0:
  83. summary_writer.add_summary(summaries, total_steps)
  84. sys.stderr.write('step: %d, loss: %f\n' %
  85. (total_steps, total_loss / step))
  86. total_loss = 0.0
  87. step = 0
  88. # Sampled z is used for eval.
  89. # It seems 10k is better than 1k. Maybe try 100k next?
  90. if sample_step % 10000 == 0:
  91. with tf.gfile.Open(os.path.join(FLAGS.log_root, 'z_mean.npy'), 'w') as f:
  92. np.save(f, sample_z_mean / sample_step)
  93. with tf.gfile.Open(
  94. os.path.join(FLAGS.log_root, 'z_stddev_log.npy'), 'w') as f:
  95. np.save(f, sample_z_stddev_log / sample_step)
  96. sample_z_mean = np.zeros(model.z_mean.get_shape().as_list())
  97. sample_z_stddev_log = np.zeros(
  98. model.z_stddev_log.get_shape().as_list())
  99. sample_step = 0
  100. def main(_):
  101. _Train()
  102. if __name__ == '__main__':
  103. tf.app.run()