prediction_train.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250
  1. # Copyright 2016 The TensorFlow Authors All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Code for training the prediction model."""
  16. import numpy as np
  17. import tensorflow as tf
  18. from tensorflow.python.platform import app
  19. from tensorflow.python.platform import flags
  20. from prediction_input import build_tfrecord_input
  21. from prediction_model import construct_model
  22. # How often to record tensorboard summaries.
  23. SUMMARY_INTERVAL = 40
  24. # How often to run a batch through the validation model.
  25. VAL_INTERVAL = 200
  26. # How often to save a model checkpoint
  27. SAVE_INTERVAL = 2000
  28. # tf record data location:
  29. DATA_DIR = 'push/push_train'
  30. # local output directory
  31. OUT_DIR = '/tmp/data'
  32. FLAGS = flags.FLAGS
  33. flags.DEFINE_string('data_dir', DATA_DIR, 'directory containing data.')
  34. flags.DEFINE_string('output_dir', OUT_DIR, 'directory for model checkpoints.')
  35. flags.DEFINE_string('event_log_dir', OUT_DIR, 'directory for writing summary.')
  36. flags.DEFINE_integer('num_iterations', 100000, 'number of training iterations.')
  37. flags.DEFINE_string('pretrained_model', '',
  38. 'filepath of a pretrained model to initialize from.')
  39. flags.DEFINE_integer('sequence_length', 10,
  40. 'sequence length, including context frames.')
  41. flags.DEFINE_integer('context_frames', 2, '# of frames before predictions.')
  42. flags.DEFINE_integer('use_state', 1,
  43. 'Whether or not to give the state+action to the model')
  44. flags.DEFINE_string('model', 'CDNA',
  45. 'model architecture to use - CDNA, DNA, or STP')
  46. flags.DEFINE_integer('num_masks', 10,
  47. 'number of masks, usually 1 for DNA, 10 for CDNA, STN.')
  48. flags.DEFINE_float('schedsamp_k', 900.0,
  49. 'The k hyperparameter for scheduled sampling,'
  50. '-1 for no scheduled sampling.')
  51. flags.DEFINE_float('train_val_split', 0.95,
  52. 'The percentage of files to use for the training set,'
  53. ' vs. the validation set.')
  54. flags.DEFINE_integer('batch_size', 32, 'batch size for training')
  55. flags.DEFINE_float('learning_rate', 0.001,
  56. 'the base learning rate of the generator')
  57. ## Helper functions
  58. def peak_signal_to_noise_ratio(true, pred):
  59. """Image quality metric based on maximal signal power vs. power of the noise.
  60. Args:
  61. true: the ground truth image.
  62. pred: the predicted image.
  63. Returns:
  64. peak signal to noise ratio (PSNR)
  65. """
  66. return 10.0 * tf.log(1.0 / mean_squared_error(true, pred)) / tf.log(10.0)
  67. def mean_squared_error(true, pred):
  68. """L2 distance between tensors true and pred.
  69. Args:
  70. true: the ground truth image.
  71. pred: the predicted image.
  72. Returns:
  73. mean squared error between ground truth and predicted image.
  74. """
  75. return tf.reduce_sum(tf.square(true - pred)) / tf.to_float(tf.size(pred))
  76. class Model(object):
  77. def __init__(self,
  78. images=None,
  79. actions=None,
  80. states=None,
  81. sequence_length=None,
  82. reuse_scope=None):
  83. if sequence_length is None:
  84. sequence_length = FLAGS.sequence_length
  85. self.prefix = prefix = tf.placeholder(tf.string, [])
  86. self.iter_num = tf.placeholder(tf.float32, [])
  87. summaries = []
  88. # Split into timesteps.
  89. actions = tf.split(axis=1, num_or_size_splits=actions.get_shape()[1], value=actions)
  90. actions = [tf.squeeze(act) for act in actions]
  91. states = tf.split(axis=1, num_or_size_splits=states.get_shape()[1], value=states)
  92. states = [tf.squeeze(st) for st in states]
  93. images = tf.split(axis=1, num_or_size_splits=images.get_shape()[1], value=images)
  94. images = [tf.squeeze(img) for img in images]
  95. if reuse_scope is None:
  96. gen_images, gen_states = construct_model(
  97. images,
  98. actions,
  99. states,
  100. iter_num=self.iter_num,
  101. k=FLAGS.schedsamp_k,
  102. use_state=FLAGS.use_state,
  103. num_masks=FLAGS.num_masks,
  104. cdna=FLAGS.model == 'CDNA',
  105. dna=FLAGS.model == 'DNA',
  106. stp=FLAGS.model == 'STP',
  107. context_frames=FLAGS.context_frames)
  108. else: # If it's a validation or test model.
  109. with tf.variable_scope(reuse_scope, reuse=True):
  110. gen_images, gen_states = construct_model(
  111. images,
  112. actions,
  113. states,
  114. iter_num=self.iter_num,
  115. k=FLAGS.schedsamp_k,
  116. use_state=FLAGS.use_state,
  117. num_masks=FLAGS.num_masks,
  118. cdna=FLAGS.model == 'CDNA',
  119. dna=FLAGS.model == 'DNA',
  120. stp=FLAGS.model == 'STP',
  121. context_frames=FLAGS.context_frames)
  122. # L2 loss, PSNR for eval.
  123. loss, psnr_all = 0.0, 0.0
  124. for i, x, gx in zip(
  125. range(len(gen_images)), images[FLAGS.context_frames:],
  126. gen_images[FLAGS.context_frames - 1:]):
  127. recon_cost = mean_squared_error(x, gx)
  128. psnr_i = peak_signal_to_noise_ratio(x, gx)
  129. psnr_all += psnr_i
  130. summaries.append(
  131. tf.summary.scalar(prefix + '_recon_cost' + str(i), recon_cost))
  132. summaries.append(tf.summary.scalar(prefix + '_psnr' + str(i), psnr_i))
  133. loss += recon_cost
  134. for i, state, gen_state in zip(
  135. range(len(gen_states)), states[FLAGS.context_frames:],
  136. gen_states[FLAGS.context_frames - 1:]):
  137. state_cost = mean_squared_error(state, gen_state) * 1e-4
  138. summaries.append(
  139. tf.summary.scalar(prefix + '_state_cost' + str(i), state_cost))
  140. loss += state_cost
  141. summaries.append(tf.summary.scalar(prefix + '_psnr_all', psnr_all))
  142. self.psnr_all = psnr_all
  143. self.loss = loss = loss / np.float32(len(images) - FLAGS.context_frames)
  144. summaries.append(tf.summary.scalar(prefix + '_loss', loss))
  145. self.lr = tf.placeholder_with_default(FLAGS.learning_rate, ())
  146. self.train_op = tf.train.AdamOptimizer(self.lr).minimize(loss)
  147. self.summ_op = tf.summary.merge(summaries)
  148. def main(unused_argv):
  149. print 'Constructing models and inputs.'
  150. with tf.variable_scope('model', reuse=None) as training_scope:
  151. images, actions, states = build_tfrecord_input(training=True)
  152. model = Model(images, actions, states, FLAGS.sequence_length)
  153. with tf.variable_scope('val_model', reuse=None):
  154. val_images, val_actions, val_states = build_tfrecord_input(training=False)
  155. val_model = Model(val_images, val_actions, val_states,
  156. FLAGS.sequence_length, training_scope)
  157. print 'Constructing saver.'
  158. # Make saver.
  159. saver = tf.train.Saver(
  160. tf.get_collection(tf.GraphKeys.VARIABLES), max_to_keep=0)
  161. # Make training session.
  162. sess = tf.InteractiveSession()
  163. summary_writer = tf.summary.FileWriter(
  164. FLAGS.event_log_dir, graph=sess.graph, flush_secs=10)
  165. if FLAGS.pretrained_model:
  166. saver.restore(sess, FLAGS.pretrained_model)
  167. tf.train.start_queue_runners(sess)
  168. sess.run(tf.global_variables_initializer())
  169. tf.logging.info('iteration number, cost')
  170. # Run training.
  171. for itr in range(FLAGS.num_iterations):
  172. # Generate new batch of data.
  173. feed_dict = {model.prefix: 'train',
  174. model.iter_num: np.float32(itr),
  175. model.lr: FLAGS.learning_rate}
  176. cost, _, summary_str = sess.run([model.loss, model.train_op, model.summ_op],
  177. feed_dict)
  178. # Print info: iteration #, cost.
  179. tf.logging.info(str(itr) + ' ' + str(cost))
  180. if (itr) % VAL_INTERVAL == 2:
  181. # Run through validation set.
  182. feed_dict = {val_model.lr: 0.0,
  183. val_model.prefix: 'val',
  184. val_model.iter_num: np.float32(itr)}
  185. _, val_summary_str = sess.run([val_model.train_op, val_model.summ_op],
  186. feed_dict)
  187. summary_writer.add_summary(val_summary_str, itr)
  188. if (itr) % SAVE_INTERVAL == 2:
  189. tf.logging.info('Saving model.')
  190. saver.save(sess, FLAGS.output_dir + '/model' + str(itr))
  191. if (itr) % SUMMARY_INTERVAL:
  192. summary_writer.add_summary(summary_str, itr)
  193. tf.logging.info('Saving model.')
  194. saver.save(sess, FLAGS.output_dir + '/model')
  195. tf.logging.info('Training complete')
  196. tf.logging.flush()
  197. if __name__ == '__main__':
  198. app.run()