model_trainer.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Trainer for generic DRAGNN models.
  16. This trainer uses a "model directory" for both input and output. When invoked,
  17. the model directory should contain the following inputs:
  18. <model_dir>/config.txt: A stringified dict that defines high-level
  19. configuration parameters. Unset parameters default to False.
  20. <model_dir>/master.pbtxt: A text-format MasterSpec proto that defines
  21. the DRAGNN network to train.
  22. <model_dir>/hyperparameters.pbtxt: A text-format GridPoint proto that
  23. defines training hyper-parameters.
  24. <model_dir>/targets.pbtxt: (Optional) A text-format TrainingGridSpec whose
  25. "target" field defines the training targets. If missing, then default
  26. training targets are used instead.
  27. On success, the model directory will contain the following outputs:
  28. <model_dir>/checkpoints/best: The best checkpoint seen during training, as
  29. measured by accuracy on the eval corpus.
  30. <model_dir>/tensorboard: TensorBoard log directory.
  31. Outside of the files and subdirectories named above, the model directory should
  32. contain any other necessary files (e.g., pretrained embeddings). See the model
  33. builders in dragnn/examples.
  34. """
  35. import ast
  36. import collections
  37. import os
  38. import os.path
  39. import tensorflow as tf
  40. from google.protobuf import text_format
  41. from dragnn.protos import spec_pb2
  42. from dragnn.python import evaluation
  43. from dragnn.python import graph_builder
  44. from dragnn.python import sentence_io
  45. from dragnn.python import spec_builder
  46. from dragnn.python import trainer_lib
  47. from syntaxnet.ops import gen_parser_ops
  48. from syntaxnet.util import check
  49. import dragnn.python.load_dragnn_cc_impl
  50. import syntaxnet.load_parser_ops
  51. flags = tf.app.flags
  52. FLAGS = flags.FLAGS
  53. flags.DEFINE_string('tf_master', '',
  54. 'TensorFlow execution engine to connect to.')
  55. flags.DEFINE_string('model_dir', None, 'Path to a prepared model directory.')
  56. flags.DEFINE_string(
  57. 'pretrain_steps', None,
  58. 'Comma-delimited list of pre-training steps per training target.')
  59. flags.DEFINE_string(
  60. 'pretrain_epochs', None,
  61. 'Comma-delimited list of pre-training epochs per training target.')
  62. flags.DEFINE_string(
  63. 'train_steps', None,
  64. 'Comma-delimited list of training steps per training target.')
  65. flags.DEFINE_string(
  66. 'train_epochs', None,
  67. 'Comma-delimited list of training epochs per training target.')
  68. flags.DEFINE_integer('batch_size', 4, 'Batch size.')
  69. flags.DEFINE_integer('report_every', 200,
  70. 'Report cost and training accuracy every this many steps.')
  71. def _read_text_proto(path, proto_type):
  72. """Reads a text-format instance of |proto_type| from the |path|."""
  73. proto = proto_type()
  74. with tf.gfile.FastGFile(path) as proto_file:
  75. text_format.Parse(proto_file.read(), proto)
  76. return proto
  77. def _convert_to_char_corpus(corpus):
  78. """Converts the word-based |corpus| into a char-based corpus."""
  79. with tf.Session(graph=tf.Graph()) as tmp_session:
  80. conversion_op = gen_parser_ops.segmenter_training_data_constructor(corpus)
  81. return tmp_session.run(conversion_op)
  82. def _get_steps(steps_flag, epochs_flag, corpus_length):
  83. """Converts the |steps_flag| or |epochs_flag| into a list of step counts."""
  84. if steps_flag:
  85. return map(int, steps_flag.split(','))
  86. return [corpus_length * int(epochs) for epochs in epochs_flag.split(',')]
  87. def main(unused_argv):
  88. tf.logging.set_verbosity(tf.logging.INFO)
  89. check.NotNone(FLAGS.model_dir, '--model_dir is required')
  90. check.Ne(FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None,
  91. 'Exactly one of --pretrain_steps or --pretrain_epochs is required')
  92. check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None,
  93. 'Exactly one of --train_steps or --train_epochs is required')
  94. config_path = os.path.join(FLAGS.model_dir, 'config.txt')
  95. master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt')
  96. hyperparameters_path = os.path.join(FLAGS.model_dir, 'hyperparameters.pbtxt')
  97. targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt')
  98. checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best')
  99. tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard')
  100. with tf.gfile.FastGFile(config_path) as config_file:
  101. config = collections.defaultdict(bool, ast.literal_eval(config_file.read()))
  102. train_corpus_path = config['train_corpus_path']
  103. tune_corpus_path = config['tune_corpus_path']
  104. projectivize_train_corpus = config['projectivize_train_corpus']
  105. master = _read_text_proto(master_path, spec_pb2.MasterSpec)
  106. hyperparameters = _read_text_proto(hyperparameters_path, spec_pb2.GridPoint)
  107. targets = spec_builder.default_targets_from_spec(master)
  108. if tf.gfile.Exists(targets_path):
  109. targets = _read_text_proto(targets_path, spec_pb2.TrainingGridSpec).target
  110. # Build the TensorFlow graph.
  111. graph = tf.Graph()
  112. with graph.as_default():
  113. tf.set_random_seed(hyperparameters.seed)
  114. builder = graph_builder.MasterBuilder(master, hyperparameters)
  115. trainers = [
  116. builder.add_training_from_config(target) for target in targets
  117. ]
  118. annotator = builder.add_annotation()
  119. builder.add_saver()
  120. # Read in serialized protos from training data.
  121. train_corpus = sentence_io.ConllSentenceReader(
  122. train_corpus_path, projectivize=projectivize_train_corpus).corpus()
  123. tune_corpus = sentence_io.ConllSentenceReader(
  124. tune_corpus_path, projectivize=False).corpus()
  125. gold_tune_corpus = tune_corpus
  126. # Convert to char-based corpora, if requested.
  127. if config['convert_to_char_corpora']:
  128. # NB: Do not convert the |gold_tune_corpus|, which should remain word-based
  129. # for segmentation evaluation purposes.
  130. train_corpus = _convert_to_char_corpus(train_corpus)
  131. tune_corpus = _convert_to_char_corpus(tune_corpus)
  132. pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs,
  133. len(train_corpus))
  134. train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs,
  135. len(train_corpus))
  136. check.Eq(len(targets), len(pretrain_steps),
  137. 'Length mismatch between training targets and --pretrain_steps')
  138. check.Eq(len(targets), len(train_steps),
  139. 'Length mismatch between training targets and --train_steps')
  140. # Ready to train!
  141. tf.logging.info('Training on %d sentences.', len(train_corpus))
  142. tf.logging.info('Tuning on %d sentences.', len(tune_corpus))
  143. tf.logging.info('Creating TensorFlow checkpoint dir...')
  144. summary_writer = trainer_lib.get_summary_writer(tensorboard_dir)
  145. checkpoint_dir = os.path.dirname(checkpoint_path)
  146. if tf.gfile.IsDirectory(checkpoint_dir):
  147. tf.gfile.DeleteRecursively(checkpoint_dir)
  148. elif tf.gfile.Exists(checkpoint_dir):
  149. tf.gfile.Remove(checkpoint_dir)
  150. tf.gfile.MakeDirs(checkpoint_dir)
  151. with tf.Session(FLAGS.tf_master, graph=graph) as sess:
  152. # Make sure to re-initialize all underlying state.
  153. sess.run(tf.global_variables_initializer())
  154. trainer_lib.run_training(sess, trainers, annotator,
  155. evaluation.parser_summaries, pretrain_steps,
  156. train_steps, train_corpus, tune_corpus,
  157. gold_tune_corpus, FLAGS.batch_size, summary_writer,
  158. FLAGS.report_every, builder.saver, checkpoint_path)
  159. tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
  160. if __name__ == '__main__':
  161. tf.app.run()