configuration.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Default configuration for model architecture and training."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. class _HParams(object):
  20. """Wrapper for configuration parameters."""
  21. pass
  22. def model_config(input_file_pattern=None,
  23. input_queue_capacity=640000,
  24. num_input_reader_threads=1,
  25. shuffle_input_data=True,
  26. uniform_init_scale=0.1,
  27. vocab_size=20000,
  28. batch_size=128,
  29. word_embedding_dim=620,
  30. bidirectional_encoder=False,
  31. encoder_dim=2400):
  32. """Creates a model configuration object.
  33. Args:
  34. input_file_pattern: File pattern of sharded TFRecord files containing
  35. tf.Example protobufs.
  36. input_queue_capacity: Number of examples to keep in the input queue.
  37. num_input_reader_threads: Number of threads for prefetching input
  38. tf.Examples.
  39. shuffle_input_data: Whether to shuffle the input data.
  40. uniform_init_scale: Scale of random uniform initializer.
  41. vocab_size: Number of unique words in the vocab.
  42. batch_size: Batch size (training and evaluation only).
  43. word_embedding_dim: Word embedding dimension.
  44. bidirectional_encoder: Whether to use a bidirectional or unidirectional
  45. encoder RNN.
  46. encoder_dim: Number of output dimensions of the sentence encoder.
  47. Returns:
  48. An object containing model configuration parameters.
  49. """
  50. config = _HParams()
  51. config.input_file_pattern = input_file_pattern
  52. config.input_queue_capacity = input_queue_capacity
  53. config.num_input_reader_threads = num_input_reader_threads
  54. config.shuffle_input_data = shuffle_input_data
  55. config.uniform_init_scale = uniform_init_scale
  56. config.vocab_size = vocab_size
  57. config.batch_size = batch_size
  58. config.word_embedding_dim = word_embedding_dim
  59. config.bidirectional_encoder = bidirectional_encoder
  60. config.encoder_dim = encoder_dim
  61. return config
  62. def training_config(learning_rate=0.0008,
  63. learning_rate_decay_factor=0.5,
  64. learning_rate_decay_steps=400000,
  65. number_of_steps=500000,
  66. clip_gradient_norm=5.0,
  67. save_model_secs=600,
  68. save_summaries_secs=600):
  69. """Creates a training configuration object.
  70. Args:
  71. learning_rate: Initial learning rate.
  72. learning_rate_decay_factor: If > 0, the learning rate decay factor.
  73. learning_rate_decay_steps: The number of steps before the learning rate
  74. decays by learning_rate_decay_factor.
  75. number_of_steps: The total number of training steps to run. Passing None
  76. will cause the training script to run indefinitely.
  77. clip_gradient_norm: If not None, then clip gradients to this value.
  78. save_model_secs: How often (in seconds) to save model checkpoints.
  79. save_summaries_secs: How often (in seconds) to save model summaries.
  80. Returns:
  81. An object containing training configuration parameters.
  82. Raises:
  83. ValueError: If learning_rate_decay_factor is set and
  84. learning_rate_decay_steps is unset.
  85. """
  86. if learning_rate_decay_factor and not learning_rate_decay_steps:
  87. raise ValueError(
  88. "learning_rate_decay_factor requires learning_rate_decay_steps.")
  89. config = _HParams()
  90. config.learning_rate = learning_rate
  91. config.learning_rate_decay_factor = learning_rate_decay_factor
  92. config.learning_rate_decay_steps = learning_rate_decay_steps
  93. config.number_of_steps = number_of_steps
  94. config.clip_gradient_norm = clip_gradient_norm
  95. config.save_model_secs = save_model_secs
  96. config.save_summaries_secs = save_summaries_secs
  97. return config