trainer_lib.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Utility functions to build DRAGNN MasterSpecs and schedule model training.
  16. Provides functions to finish a MasterSpec, building required lexicons for it and
  17. adding them as resources, as well as setting features sizes.
  18. """
  19. import random
  20. import tensorflow as tf
  21. from tensorflow.core.framework.summary_pb2 import Summary
  22. from tensorflow.python.framework import errors
  23. from tensorflow.python.platform import gfile
  24. flags = tf.app.flags
  25. FLAGS = flags.FLAGS
  26. def calculate_component_accuracies(eval_res_values):
  27. """Transforms the DRAGNN eval_res output to float accuracies of components."""
  28. # The structure of eval_res_values is
  29. # [comp1_total, comp1_correct, comp2_total, comp2_correct, ...]
  30. return [
  31. 1.0 * eval_res_values[2 * i + 1] / eval_res_values[2 * i]
  32. if eval_res_values[2 * i] > 0 else float('nan')
  33. for i in xrange(len(eval_res_values) // 2)
  34. ]
  35. def write_summary(summary_writer, label, value, step):
  36. """Write a summary for a certain evaluation."""
  37. summary = Summary(value=[Summary.Value(tag=label, simple_value=float(value))])
  38. summary_writer.add_summary(summary, step)
  39. summary_writer.flush()
  40. def annotate_dataset(sess, annotator, eval_corpus):
  41. """Annotate eval_corpus given a model."""
  42. batch_size = min(len(eval_corpus), 1024)
  43. processed = []
  44. tf.logging.info('Annotating datset: %d examples', len(eval_corpus))
  45. for start in range(0, len(eval_corpus), batch_size):
  46. end = min(start + batch_size, len(eval_corpus))
  47. serialized_annotations = sess.run(
  48. annotator['annotations'],
  49. feed_dict={annotator['input_batch']: eval_corpus[start:end]})
  50. assert len(serialized_annotations) == end - start
  51. processed.extend(serialized_annotations)
  52. tf.logging.info('Done. Produced %d annotations', len(processed))
  53. return processed
  54. def get_summary_writer(tensorboard_dir):
  55. """Creates a directory for writing summaries and returns a writer."""
  56. tf.logging.info('TensorBoard directory: %s', tensorboard_dir)
  57. tf.logging.info('Deleting prior data if exists...')
  58. try:
  59. gfile.DeleteRecursively(tensorboard_dir)
  60. except errors.OpError as err:
  61. tf.logging.error('Directory did not exist? Error: %s', err)
  62. tf.logging.info('Deleted! Creating the directory again...')
  63. gfile.MakeDirs(tensorboard_dir)
  64. tf.logging.info('Created! Instatiating SummaryWriter...')
  65. summary_writer = tf.summary.FileWriter(tensorboard_dir)
  66. return summary_writer
  67. def run_training_step(sess, trainer, train_corpus, batch_size):
  68. """Runs a single iteration of train_op on a randomly sampled batch."""
  69. batch = random.sample(train_corpus, batch_size)
  70. sess.run(trainer['run'], feed_dict={trainer['input_batch']: batch})
  71. def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
  72. train_steps, train_corpus, eval_corpus, eval_gold,
  73. batch_size, summary_writer, report_every, saver,
  74. checkpoint_filename, checkpoint_stats=None):
  75. """Runs multi-task DRAGNN training on a single corpus.
  76. Arguments:
  77. sess: TF session to use.
  78. trainers: List of training ops to use.
  79. annotator: Annotation op.
  80. evaluator: Function taking two serialized corpora and returning a dict of
  81. scalar summaries representing evaluation metrics. The 'eval_metric'
  82. summary will be used for early stopping.
  83. pretrain_steps: List of the no. of pre-training steps for each train op.
  84. train_steps: List of the total no. of steps for each train op.
  85. train_corpus: Training corpus to use.
  86. eval_corpus: Holdout Corpus for early stoping.
  87. eval_gold: Reference of eval_corpus for computing accuracy.
  88. eval_corpus and eval_gold are allowed to be the same if eval_corpus
  89. already contains gold annotation.
  90. Note for segmentation eval_corpus and eval_gold are always different since
  91. eval_corpus contains sentences whose tokens are utf8-characters while
  92. eval_gold's tokens are gold words.
  93. batch_size: How many examples to send to the train op at a time.
  94. summary_writer: TF SummaryWriter to use to write summaries.
  95. report_every: How often to compute summaries (in steps).
  96. saver: TF saver op to save variables.
  97. checkpoint_filename: File to save checkpoints to.
  98. checkpoint_stats: Stats of checkpoint.
  99. """
  100. random.seed(0x31337)
  101. if not checkpoint_stats:
  102. checkpoint_stats = [0] * (len(train_steps) + 1)
  103. tf.logging.info('Determining the training schedule...')
  104. target_for_step = []
  105. for target_idx in xrange(len(pretrain_steps)):
  106. target_for_step += [target_idx] * pretrain_steps[target_idx]
  107. while sum(train_steps) > 0:
  108. step = random.randint(0, sum(train_steps) - 1)
  109. cumulative_steps = 0
  110. for target_idx in xrange(len(train_steps)):
  111. cumulative_steps += train_steps[target_idx]
  112. if step < cumulative_steps:
  113. break
  114. assert train_steps[target_idx] > 0
  115. train_steps[target_idx] -= 1
  116. target_for_step.append(target_idx)
  117. tf.logging.info('Training schedule defined!')
  118. best_eval_metric = -1.0
  119. tf.logging.info('Starting training...')
  120. actual_step = sum(checkpoint_stats[1:])
  121. for step, target_idx in enumerate(target_for_step):
  122. run_training_step(sess, trainers[target_idx], train_corpus, batch_size)
  123. checkpoint_stats[target_idx + 1] += 1
  124. if step % 100 == 0:
  125. tf.logging.info('training step: %d, actual: %d', step, actual_step + step)
  126. if step % report_every == 0:
  127. tf.logging.info('finished step: %d, actual: %d', step, actual_step + step)
  128. annotated = annotate_dataset(sess, annotator, eval_corpus)
  129. summaries = evaluator(eval_gold, annotated)
  130. for label, metric in summaries.iteritems():
  131. write_summary(summary_writer, label, metric, actual_step + step)
  132. eval_metric = summaries['eval_metric']
  133. if best_eval_metric < eval_metric:
  134. tf.logging.info('Updating best eval to %.2f%%, saving checkpoint.',
  135. eval_metric)
  136. best_eval_metric = eval_metric
  137. saver.save(sess, checkpoint_filename)
  138. with gfile.GFile('%s.stats' % checkpoint_filename, 'w') as f:
  139. stats_str = ','.join([str(x) for x in checkpoint_stats])
  140. f.write(stats_str)
  141. tf.logging.info('Writing stats: %s', stats_str)
  142. tf.logging.info('Finished training!')