train_normal_normal_def.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Wraps train_normal_normal_def."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import tensorflow as tf
  20. import train_normal_normal_def_lib
  21. from ops import tf_lib
  22. FLAGS = tf.flags.FLAGS
  23. def main(unused_argv):
  24. """Trains a gaussian def locally."""
  25. if tf.gfile.Exists(FLAGS.logdir) and FLAGS.delete_logdir:
  26. tf.gfile.DeleteRecursively(FLAGS.logdir)
  27. tf.gfile.MakeDirs(FLAGS.logdir)
  28. # The HParams are commented in training_params.py
  29. try:
  30. hparams = tf.HParams
  31. except AttributeError:
  32. hparams = tf_lib.HParams
  33. hparams = hparams(
  34. dataset='MNIST',
  35. z_dim=200,
  36. timestep_dim=28,
  37. n_timesteps=28,
  38. batch_size=50,
  39. samples_to_save=10,
  40. learning_rate=0.1,
  41. n_examples=50,
  42. momentum=0.0,
  43. p_z_sigma=1.,
  44. p_w_mu_sigma=1.,
  45. p_w_sigma_sigma=1.,
  46. init_q_sigma_scale=0.1,
  47. fixed_p_z_sigma=True,
  48. fixed_q_z_sigma=False,
  49. fixed_q_w_mu_sigma=False,
  50. fixed_q_w_sigma_sigma=False,
  51. fixed_q_w_0_sigma=False,
  52. dtype='float32')
  53. tf.logging.info('Starting experiment in %s with params %s', FLAGS.logdir,
  54. hparams)
  55. train_normal_normal_def_lib.run_training(
  56. hparams, FLAGS.logdir, FLAGS.max_steps, None,
  57. trainer=FLAGS.trainer)
  58. if __name__ == '__main__':
  59. tf.app.run()