configuration.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Image-to-text model and training configurations."""
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. class ModelConfig(object):
  20. """Wrapper class for model hyperparameters."""
  21. def __init__(self):
  22. """Sets the default model hyperparameters."""
  23. # File pattern of sharded TFRecord file containing SequenceExample protos.
  24. # Must be provided in training and evaluation modes.
  25. self.input_file_pattern = None
  26. # Image format ("jpeg" or "png").
  27. self.image_format = "jpeg"
  28. # Approximate number of values per input shard. Used to ensure sufficient
  29. # mixing between shards in training.
  30. self.values_per_input_shard = 2300
  31. # Minimum number of shards to keep in the input queue.
  32. self.input_queue_capacity_factor = 2
  33. # Number of threads for prefetching SequenceExample protos.
  34. self.num_input_reader_threads = 1
  35. # Name of the SequenceExample context feature containing image data.
  36. self.image_feature_name = "image/data"
  37. # Name of the SequenceExample feature list containing integer captions.
  38. self.caption_feature_name = "image/caption_ids"
  39. # Number of unique words in the vocab (plus 1, for <UNK>).
  40. # The default value is larger than the expected actual vocab size to allow
  41. # for differences between tokenizer versions used in preprocessing. There is
  42. # no harm in using a value greater than the actual vocab size, but using a
  43. # value less than the actual vocab size will result in an error.
  44. self.vocab_size = 12000
  45. # Number of threads for image preprocessing. Should be a multiple of 2.
  46. self.num_preprocess_threads = 4
  47. # Batch size.
  48. self.batch_size = 32
  49. # File containing an Inception v3 checkpoint to initialize the variables
  50. # of the Inception model. Must be provided when starting training for the
  51. # first time.
  52. self.inception_checkpoint_file = None
  53. # Dimensions of Inception v3 input images.
  54. self.image_height = 299
  55. self.image_width = 299
  56. # Scale used to initialize model variables.
  57. self.initializer_scale = 0.08
  58. # LSTM input and output dimensionality, respectively.
  59. self.embedding_size = 512
  60. self.num_lstm_units = 512
  61. # If < 1.0, the dropout keep probability applied to LSTM variables.
  62. self.lstm_dropout_keep_prob = 0.7
  63. class TrainingConfig(object):
  64. """Wrapper class for training hyperparameters."""
  65. def __init__(self):
  66. """Sets the default training hyperparameters."""
  67. # Number of examples per epoch of training data.
  68. self.num_examples_per_epoch = 586363
  69. # Optimizer for training the model.
  70. self.optimizer = "SGD"
  71. # Learning rate for the initial phase of training.
  72. self.initial_learning_rate = 2.0
  73. self.learning_rate_decay_factor = 0.5
  74. self.num_epochs_per_decay = 8.0
  75. # Learning rate when fine tuning the Inception v3 parameters.
  76. self.train_inception_learning_rate = 0.0005
  77. # If not None, clip gradients to this value.
  78. self.clip_gradients = 5.0
  79. # How many model checkpoints to keep.
  80. self.max_checkpoints_to_keep = 5