parser_trainer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """A program to train a tensorflow neural net parser from a a conll file."""
  16. import os
  17. import os.path
  18. import time
  19. import tensorflow as tf
  20. from tensorflow.python.platform import gfile
  21. from tensorflow.python.platform import tf_logging as logging
  22. from google.protobuf import text_format
  23. from syntaxnet import graph_builder
  24. from syntaxnet import structured_graph_builder
  25. from syntaxnet.ops import gen_parser_ops
  26. from syntaxnet import task_spec_pb2
  27. flags = tf.app.flags
  28. FLAGS = flags.FLAGS
  29. flags.DEFINE_string('tf_master', '',
  30. 'TensorFlow execution engine to connect to.')
  31. flags.DEFINE_string('output_path', '', 'Top level for output.')
  32. flags.DEFINE_string('task_context', '',
  33. 'Path to a task context with resource locations and '
  34. 'parameters.')
  35. flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
  36. flags.DEFINE_string('params', '0', 'Unique identifier of parameter grid point.')
  37. flags.DEFINE_string('training_corpus', 'training-corpus',
  38. 'Name of the context input to read training data from.')
  39. flags.DEFINE_string('tuning_corpus', 'tuning-corpus',
  40. 'Name of the context input to read tuning data from.')
  41. flags.DEFINE_string('word_embeddings', None,
  42. 'Recordio containing pretrained word embeddings, will be '
  43. 'loaded as the first embedding matrix.')
  44. flags.DEFINE_bool('compute_lexicon', False, '')
  45. flags.DEFINE_bool('projectivize_training_set', False, '')
  46. flags.DEFINE_string('hidden_layer_sizes', '200,200',
  47. 'Comma separated list of hidden layer sizes.')
  48. flags.DEFINE_string('graph_builder', 'greedy',
  49. 'Graph builder to use, either "greedy" or "structured".')
  50. flags.DEFINE_integer('batch_size', 32,
  51. 'Number of sentences to process in parallel.')
  52. flags.DEFINE_integer('beam_size', 10, 'Number of slots for beam parsing.')
  53. flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to train for.')
  54. flags.DEFINE_integer('max_steps', 50,
  55. 'Max number of parser steps during a training step.')
  56. flags.DEFINE_integer('report_every', 100,
  57. 'Report cost and training accuracy every this many steps.')
  58. flags.DEFINE_integer('checkpoint_every', 5000,
  59. 'Measure tuning UAS and checkpoint every this many steps.')
  60. flags.DEFINE_bool('slim_model', False,
  61. 'Whether to remove non-averaged variables, for compactness.')
  62. flags.DEFINE_float('learning_rate', 0.1, 'Initial learning rate parameter.')
  63. flags.DEFINE_integer('decay_steps', 4000,
  64. 'Decay learning rate by 0.96 every this many steps.')
  65. flags.DEFINE_float('momentum', 0.9,
  66. 'Momentum parameter for momentum optimizer.')
  67. flags.DEFINE_string('seed', '0', 'Initialization seed for TF variables.')
  68. flags.DEFINE_string('pretrained_params', None,
  69. 'Path to model from which to load params.')
  70. flags.DEFINE_string('pretrained_params_names', None,
  71. 'List of names of tensors to load from pretrained model.')
  72. flags.DEFINE_float('averaging_decay', 0.9999,
  73. 'Decay for exponential moving average when computing'
  74. 'averaged parameters, set to 1 to do vanilla averaging.')
  75. def StageName():
  76. return os.path.join(FLAGS.arg_prefix, FLAGS.graph_builder)
  77. def OutputPath(path):
  78. return os.path.join(FLAGS.output_path, StageName(), FLAGS.params, path)
  79. def RewriteContext():
  80. context = task_spec_pb2.TaskSpec()
  81. with gfile.FastGFile(FLAGS.task_context, 'rb') as fin:
  82. text_format.Merge(fin.read(), context)
  83. for resource in context.input:
  84. if resource.creator == StageName():
  85. del resource.part[:]
  86. part = resource.part.add()
  87. part.file_pattern = os.path.join(OutputPath(resource.name))
  88. with gfile.FastGFile(OutputPath('context'), 'w') as fout:
  89. fout.write(str(context))
  90. def WriteStatus(num_steps, eval_metric, best_eval_metric):
  91. status = os.path.join(os.getenv('GOOGLE_STATUS_DIR') or '/tmp', 'STATUS')
  92. message = ('Parameters: %s | Steps: %d | Tuning score: %.2f%% | '
  93. 'Best tuning score: %.2f%%' % (FLAGS.params, num_steps,
  94. eval_metric, best_eval_metric))
  95. with gfile.FastGFile(status, 'w') as fout:
  96. fout.write(message)
  97. with gfile.FastGFile(OutputPath('status'), 'a') as fout:
  98. fout.write(message + '\n')
  99. def Eval(sess, parser, num_steps, best_eval_metric):
  100. """Evaluates a network and checkpoints it to disk.
  101. Args:
  102. sess: tensorflow session to use
  103. parser: graph builder containing all ops references
  104. num_steps: number of training steps taken, for logging
  105. best_eval_metric: current best eval metric, to decide whether this model is
  106. the best so far
  107. Returns:
  108. new best eval metric
  109. """
  110. logging.info('Evaluating training network.')
  111. t = time.time()
  112. num_epochs = None
  113. num_tokens = 0
  114. num_correct = 0
  115. while True:
  116. tf_eval_epochs, tf_eval_metrics = sess.run([
  117. parser.evaluation['epochs'], parser.evaluation['eval_metrics']
  118. ])
  119. num_tokens += tf_eval_metrics[0]
  120. num_correct += tf_eval_metrics[1]
  121. if num_epochs is None:
  122. num_epochs = tf_eval_epochs
  123. elif num_epochs < tf_eval_epochs:
  124. break
  125. eval_metric = 0 if num_tokens == 0 else (100.0 * num_correct / num_tokens)
  126. logging.info('Seconds elapsed in evaluation: %.2f, '
  127. 'eval metric: %.2f%%', time.time() - t, eval_metric)
  128. WriteStatus(num_steps, eval_metric, max(eval_metric, best_eval_metric))
  129. # Save parameters.
  130. if FLAGS.output_path:
  131. logging.info('Writing out trained parameters.')
  132. parser.saver.save(sess, OutputPath('latest-model'))
  133. if eval_metric > best_eval_metric:
  134. parser.saver.save(sess, OutputPath('model'))
  135. return max(eval_metric, best_eval_metric)
  136. def Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
  137. """Builds and trains the network.
  138. Args:
  139. sess: tensorflow session to use.
  140. num_actions: number of possible golden actions.
  141. feature_sizes: size of each feature vector.
  142. domain_sizes: number of possible feature ids in each feature vector.
  143. embedding_dims: embedding dimension to use for each feature group.
  144. """
  145. t = time.time()
  146. hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
  147. logging.info('Building training network with parameters: feature_sizes: %s '
  148. 'domain_sizes: %s', feature_sizes, domain_sizes)
  149. if FLAGS.graph_builder == 'greedy':
  150. parser = graph_builder.GreedyParser(num_actions,
  151. feature_sizes,
  152. domain_sizes,
  153. embedding_dims,
  154. hidden_layer_sizes,
  155. seed=int(FLAGS.seed),
  156. gate_gradients=True,
  157. averaging_decay=FLAGS.averaging_decay,
  158. arg_prefix=FLAGS.arg_prefix)
  159. else:
  160. parser = structured_graph_builder.StructuredGraphBuilder(
  161. num_actions,
  162. feature_sizes,
  163. domain_sizes,
  164. embedding_dims,
  165. hidden_layer_sizes,
  166. seed=int(FLAGS.seed),
  167. gate_gradients=True,
  168. averaging_decay=FLAGS.averaging_decay,
  169. arg_prefix=FLAGS.arg_prefix,
  170. beam_size=FLAGS.beam_size,
  171. max_steps=FLAGS.max_steps)
  172. task_context = OutputPath('context')
  173. if FLAGS.word_embeddings is not None:
  174. parser.AddPretrainedEmbeddings(0, FLAGS.word_embeddings, task_context)
  175. corpus_name = ('projectivized-training-corpus' if
  176. FLAGS.projectivize_training_set else FLAGS.training_corpus)
  177. parser.AddTraining(task_context,
  178. FLAGS.batch_size,
  179. learning_rate=FLAGS.learning_rate,
  180. momentum=FLAGS.momentum,
  181. decay_steps=FLAGS.decay_steps,
  182. corpus_name=corpus_name)
  183. parser.AddEvaluation(task_context,
  184. FLAGS.batch_size,
  185. corpus_name=FLAGS.tuning_corpus)
  186. parser.AddSaver(FLAGS.slim_model)
  187. # Save graph.
  188. if FLAGS.output_path:
  189. with gfile.FastGFile(OutputPath('graph'), 'w') as f:
  190. f.write(sess.graph_def.SerializeToString())
  191. logging.info('Initializing...')
  192. num_epochs = 0
  193. cost_sum = 0.0
  194. num_steps = 0
  195. best_eval_metric = 0.0
  196. sess.run(parser.inits.values())
  197. if FLAGS.pretrained_params is not None:
  198. logging.info('Loading pretrained params from %s', FLAGS.pretrained_params)
  199. feed_dict = {'save/Const:0': FLAGS.pretrained_params}
  200. targets = []
  201. for node in sess.graph_def.node:
  202. if (node.name.startswith('save/Assign') and
  203. node.input[0] in FLAGS.pretrained_params_names.split(',')):
  204. logging.info('Loading %s with op %s', node.input[0], node.name)
  205. targets.append(node.name)
  206. sess.run(targets, feed_dict=feed_dict)
  207. logging.info('Training...')
  208. while num_epochs < FLAGS.num_epochs:
  209. tf_epochs, tf_cost, _ = sess.run([parser.training[
  210. 'epochs'], parser.training['cost'], parser.training['train_op']])
  211. num_epochs = tf_epochs
  212. num_steps += 1
  213. cost_sum += tf_cost
  214. if num_steps % FLAGS.report_every == 0:
  215. logging.info('Epochs: %d, num steps: %d, '
  216. 'seconds elapsed: %.2f, avg cost: %.2f, ', num_epochs,
  217. num_steps, time.time() - t, cost_sum / FLAGS.report_every)
  218. cost_sum = 0.0
  219. if num_steps % FLAGS.checkpoint_every == 0:
  220. best_eval_metric = Eval(sess, parser, num_steps, best_eval_metric)
  221. def main(unused_argv):
  222. logging.set_verbosity(logging.INFO)
  223. if not gfile.IsDirectory(OutputPath('')):
  224. gfile.MakeDirs(OutputPath(''))
  225. # Rewrite context.
  226. RewriteContext()
  227. # Creates necessary term maps.
  228. if FLAGS.compute_lexicon:
  229. logging.info('Computing lexicon...')
  230. with tf.Session(FLAGS.tf_master) as sess:
  231. gen_parser_ops.lexicon_builder(task_context=OutputPath('context'),
  232. corpus_name=FLAGS.training_corpus).run()
  233. with tf.Session(FLAGS.tf_master) as sess:
  234. feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
  235. gen_parser_ops.feature_size(task_context=OutputPath('context'),
  236. arg_prefix=FLAGS.arg_prefix))
  237. # Well formed and projectivize.
  238. if FLAGS.projectivize_training_set:
  239. logging.info('Preprocessing...')
  240. with tf.Session(FLAGS.tf_master) as sess:
  241. source, last = gen_parser_ops.document_source(
  242. task_context=OutputPath('context'),
  243. batch_size=FLAGS.batch_size,
  244. corpus_name=FLAGS.training_corpus)
  245. sink = gen_parser_ops.document_sink(
  246. task_context=OutputPath('context'),
  247. corpus_name='projectivized-training-corpus',
  248. documents=gen_parser_ops.projectivize_filter(
  249. gen_parser_ops.well_formed_filter(source,
  250. task_context=OutputPath(
  251. 'context')),
  252. task_context=OutputPath('context')))
  253. while True:
  254. tf_last, _ = sess.run([last, sink])
  255. if tf_last:
  256. break
  257. logging.info('Training...')
  258. with tf.Session(FLAGS.tf_master) as sess:
  259. Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
  260. if __name__ == '__main__':
  261. tf.app.run()