training_params.py 3.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Defines hyperparameters for training the def model.
  16. """
  17. from __future__ import absolute_import
  18. from __future__ import division
  19. from __future__ import print_function
  20. import tensorflow as tf
  21. from ops import tf_lib
  22. flags = tf.flags
  23. # Dataset options
  24. flags.DEFINE_enum('dataset', 'MNIST', ['MNIST', 'alternating'],
  25. 'Dataset to use. mnist or synthetic bernoulli data')
  26. flags.DEFINE_integer('n_timesteps', 28, 'Number of timesteps per example')
  27. flags.DEFINE_integer('timestep_dim', 28, 'Dimensionality of each timestep')
  28. flags.DEFINE_integer('n_examples', 50000, 'Number of examples to use from the '
  29. 'dataset.')
  30. # Model options
  31. flags.DEFINE_integer('z_dim', 2, 'Latent dimensionality')
  32. flags.DEFINE_float('p_z_sigma', 1., 'Prior variance for latent variables')
  33. flags.DEFINE_float('p_w_mu_sigma', 1., 'Prior variance for weights for mean')
  34. flags.DEFINE_float('p_w_sigma_sigma', 1., 'Prior variance for weights for '
  35. 'standard deviation')
  36. flags.DEFINE_boolean('fixed_p_z_sigma', True, 'Whether to have the variance '
  37. 'depend recurrently across timesteps')
  38. # Variational family options
  39. flags.DEFINE_float('init_q_sigma_scale', 0.1, 'Factor by which to scale prior'
  40. ' variances to use as initialization for variational stddev')
  41. flags.DEFINE_boolean('fixed_q_z_sigma', False, 'Whether to learn variational '
  42. 'variance parameters for latents')
  43. flags.DEFINE_boolean('fixed_q_w_mu_sigma', False, 'Whether to learn variational'
  44. 'variance parameters for weights for mean')
  45. flags.DEFINE_boolean('fixed_q_w_sigma_sigma', False, 'Whether to learn '
  46. 'variance parameters for weights for variance')
  47. flags.DEFINE_boolean('fixed_q_w_0_sigma', False, 'Whether to learn '
  48. 'variance parameters for weights for observations')
  49. # Training options
  50. flags.DEFINE_enum('optimizer', 'Adam', ['Adam', 'RMSProp', 'SGD', 'Adagrad'],
  51. 'Optimizer to use')
  52. flags.DEFINE_float('learning_rate', 1e-3, 'Learning rate')
  53. flags.DEFINE_float('momentum', 0., 'Momentum for optimizer')
  54. flags.DEFINE_integer('batch_size', 10, 'Batch size')
  55. FLAGS = tf.flags.FLAGS
  56. def h_params():
  57. """Returns hyperparameters defaulting to the corresponding flag values."""
  58. try:
  59. hparams = tf.HParams
  60. except AttributeError:
  61. hparams = tf_lib.HParams
  62. return hparams(
  63. dataset=FLAGS.dataset,
  64. z_dim=FLAGS.z_dim,
  65. timestep_dim=FLAGS.timestep_dim,
  66. n_timesteps=FLAGS.n_timesteps,
  67. batch_size=FLAGS.batch_size,
  68. learning_rate=FLAGS.learning_rate,
  69. n_examples=FLAGS.n_examples,
  70. momentum=FLAGS.momentum,
  71. p_z_sigma=FLAGS.p_z_sigma,
  72. p_w_mu_sigma=FLAGS.p_w_mu_sigma,
  73. p_w_sigma_sigma=FLAGS.p_w_sigma_sigma,
  74. init_q_sigma_scale=FLAGS.init_q_sigma_scale,
  75. fixed_p_z_sigma=FLAGS.fixed_p_z_sigma,
  76. fixed_q_z_sigma=FLAGS.fixed_q_z_sigma,
  77. fixed_q_w_mu_sigma=FLAGS.fixed_q_w_mu_sigma,
  78. fixed_q_w_sigma_sigma=FLAGS.fixed_q_w_sigma_sigma,
  79. fixed_q_w_0_sigma=FLAGS.fixed_q_w_0_sigma,
  80. dtype='float32')