trainer_lib.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. """Utility functions to build DRAGNN MasterSpecs and schedule model training.
  2. Provides functions to finish a MasterSpec, building required lexicons for it and
  3. adding them as resources, as well as setting features sizes.
  4. """
  5. import random
  6. import tensorflow as tf
  7. from tensorflow.core.framework.summary_pb2 import Summary
  8. from tensorflow.python.framework import errors
  9. from tensorflow.python.platform import gfile
  10. flags = tf.app.flags
  11. FLAGS = flags.FLAGS
  12. def calculate_component_accuracies(eval_res_values):
  13. """Transforms the DRAGNN eval_res output to float accuracies of components."""
  14. # The structure of eval_res_values is
  15. # [comp1_total, comp1_correct, comp2_total, comp2_correct, ...]
  16. return [
  17. 1.0 * eval_res_values[2 * i + 1] / eval_res_values[2 * i]
  18. if eval_res_values[2 * i] > 0 else float('nan')
  19. for i in xrange(len(eval_res_values) // 2)
  20. ]
  21. def _write_summary(summary_writer, label, value, step):
  22. """Write a summary for a certain evaluation."""
  23. summary = Summary(value=[Summary.Value(tag=label, simple_value=float(value))])
  24. summary_writer.add_summary(summary, step)
  25. summary_writer.flush()
  26. def annotate_dataset(sess, annotator, eval_corpus):
  27. """Annotate eval_corpus given a model."""
  28. batch_size = min(len(eval_corpus), 1024)
  29. processed = []
  30. tf.logging.info('Annotating datset: %d examples', len(eval_corpus))
  31. for start in range(0, len(eval_corpus), batch_size):
  32. end = min(start + batch_size, len(eval_corpus))
  33. serialized_annotations = sess.run(
  34. annotator['annotations'],
  35. feed_dict={annotator['input_batch']: eval_corpus[start:end]})
  36. assert len(serialized_annotations) == end - start
  37. processed.extend(serialized_annotations)
  38. tf.logging.info('Done. Produced %d annotations', len(processed))
  39. return processed
  40. def get_summary_writer(tensorboard_dir):
  41. """Creates a directory for writing summaries and returns a writer."""
  42. tf.logging.info('TensorBoard directory: %s', tensorboard_dir)
  43. tf.logging.info('Deleting prior data if exists...')
  44. try:
  45. gfile.DeleteRecursively(tensorboard_dir)
  46. except errors.OpError as err:
  47. tf.logging.error('Directory did not exist? Error: %s', err)
  48. tf.logging.info('Deleted! Creating the directory again...')
  49. gfile.MakeDirs(tensorboard_dir)
  50. tf.logging.info('Created! Instatiating SummaryWriter...')
  51. summary_writer = tf.summary.FileWriter(tensorboard_dir)
  52. return summary_writer
  53. def run_training_step(sess, trainer, train_corpus, batch_size):
  54. """Runs a single iteration of train_op on a randomly sampled batch."""
  55. batch = random.sample(train_corpus, batch_size)
  56. sess.run(trainer['run'], feed_dict={trainer['input_batch']: batch})
  57. def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
  58. train_steps, train_corpus, eval_corpus, eval_gold,
  59. batch_size, summary_writer, report_every, saver,
  60. checkpoint_filename, checkpoint_stats=None):
  61. """Runs multi-task DRAGNN training on a single corpus.
  62. Arguments:
  63. sess: TF session to use.
  64. trainers: List of training ops to use.
  65. annotator: Annotation op.
  66. evaluator: Function taking two serialized corpora and returning a dict of
  67. scalar summaries representing evaluation metrics. The 'eval_metric'
  68. summary will be used for early stopping.
  69. pretrain_steps: List of the no. of pre-training steps for each train op.
  70. train_steps: List of the total no. of steps for each train op.
  71. train_corpus: Training corpus to use.
  72. eval_corpus: Holdout Corpus for early stoping.
  73. eval_gold: Reference of eval_corpus for computing accuracy.
  74. eval_corpus and eval_gold are allowed to be the same if eval_corpus
  75. already contains gold annotation.
  76. Note for segmentation eval_corpus and eval_gold are always different since
  77. eval_corpus contains sentences whose tokens are utf8-characters while
  78. eval_gold's tokens are gold words.
  79. batch_size: How many examples to send to the train op at a time.
  80. summary_writer: TF SummaryWriter to use to write summaries.
  81. report_every: How often to compute summaries (in steps).
  82. saver: TF saver op to save variables.
  83. checkpoint_filename: File to save checkpoints to.
  84. checkpoint_stats: Stats of checkpoint.
  85. """
  86. random.seed(0x31337)
  87. if not checkpoint_stats:
  88. checkpoint_stats = [0] * (len(train_steps) + 1)
  89. tf.logging.info('Determining the training schedule...')
  90. target_for_step = []
  91. for target_idx in xrange(len(pretrain_steps)):
  92. target_for_step += [target_idx] * pretrain_steps[target_idx]
  93. while sum(train_steps) > 0:
  94. step = random.randint(0, sum(train_steps) - 1)
  95. cumulative_steps = 0
  96. for target_idx in xrange(len(train_steps)):
  97. cumulative_steps += train_steps[target_idx]
  98. if step < cumulative_steps:
  99. break
  100. assert train_steps[target_idx] > 0
  101. train_steps[target_idx] -= 1
  102. target_for_step.append(target_idx)
  103. tf.logging.info('Training schedule defined!')
  104. best_eval_metric = -1.0
  105. tf.logging.info('Starting training...')
  106. actual_step = sum(checkpoint_stats[1:])
  107. for step, target_idx in enumerate(target_for_step):
  108. run_training_step(sess, trainers[target_idx], train_corpus, batch_size)
  109. checkpoint_stats[target_idx + 1] += 1
  110. if step % 100 == 0:
  111. tf.logging.info('training step: %d, actual: %d', step, actual_step + step)
  112. if step % report_every == 0:
  113. tf.logging.info('finished step: %d, actual: %d', step, actual_step + step)
  114. annotated = annotate_dataset(sess, annotator, eval_corpus)
  115. summaries = evaluator(eval_gold, annotated)
  116. for label, metric in summaries.iteritems():
  117. _write_summary(summary_writer, label, metric, actual_step + step)
  118. eval_metric = summaries['eval_metric']
  119. if best_eval_metric < eval_metric:
  120. tf.logging.info('Updating best eval to %.2f%%, saving checkpoint.',
  121. eval_metric)
  122. best_eval_metric = eval_metric
  123. saver.save(sess, checkpoint_filename)
  124. with gfile.GFile('%s.stats' % checkpoint_filename, 'w') as f:
  125. stats_str = ','.join([str(x) for x in checkpoint_stats])
  126. f.write(stats_str)
  127. tf.logging.info('Writing stats: %s', stats_str)
  128. tf.logging.info('Finished training!')