structured_graph_builder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Build structured parser models."""
  16. import tensorflow as tf
  17. from tensorflow.python.ops import control_flow_ops as cf
  18. from tensorflow.python.ops import state_ops
  19. from tensorflow.python.ops import tensor_array_ops
  20. from syntaxnet import graph_builder
  21. from syntaxnet.ops import gen_parser_ops
  22. tf.NotDifferentiable('BeamParseReader')
  23. tf.NotDifferentiable('BeamParser')
  24. tf.NotDifferentiable('BeamParserOutput')
  25. def AddCrossEntropy(batch_size, n):
  26. """Adds a cross entropy cost function."""
  27. cross_entropies = []
  28. def _Pass():
  29. return tf.constant(0, dtype=tf.float32, shape=[1])
  30. for beam_id in range(batch_size):
  31. beam_gold_slot = tf.reshape(tf.slice(n['gold_slot'], [beam_id], [1]), [1])
  32. def _ComputeCrossEntropy():
  33. """Adds ops to compute cross entropy of the gold path in a beam."""
  34. # Requires a cast so that UnsortedSegmentSum, in the gradient,
  35. # is happy with the type of its input 'segment_ids', which
  36. # must be int32.
  37. idx = tf.cast(
  38. tf.reshape(
  39. tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
  40. beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
  41. num = tf.shape(idx)
  42. return tf.nn.softmax_cross_entropy_with_logits(
  43. beam_scores, tf.expand_dims(
  44. tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0))
  45. # The conditional here is needed to deal with the last few batches of the
  46. # corpus which can contain -1 in beam_gold_slot for empty batch slots.
  47. cross_entropies.append(cf.cond(
  48. beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
  49. return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
  50. class StructuredGraphBuilder(graph_builder.GreedyParser):
  51. """Extends the standard GreedyParser with a CRF objective using a beam.
  52. The constructor takes two additional keyword arguments.
  53. beam_size: the maximum size the beam can grow to.
  54. max_steps: the maximum number of steps in any particular beam.
  55. The model supports batch training with the batch_size argument to the
  56. AddTraining method.
  57. """
  58. def __init__(self, *args, **kwargs):
  59. self._beam_size = kwargs.pop('beam_size', 10)
  60. self._max_steps = kwargs.pop('max_steps', 25)
  61. super(StructuredGraphBuilder, self).__init__(*args, **kwargs)
  62. def _AddBeamReader(self,
  63. task_context,
  64. batch_size,
  65. corpus_name,
  66. until_all_final=False,
  67. always_start_new_sentences=False):
  68. """Adds an op capable of reading sentences and parsing them with a beam."""
  69. features, state, epochs = gen_parser_ops.beam_parse_reader(
  70. task_context=task_context,
  71. feature_size=self._feature_size,
  72. beam_size=self._beam_size,
  73. batch_size=batch_size,
  74. corpus_name=corpus_name,
  75. allow_feature_weights=self._allow_feature_weights,
  76. arg_prefix=self._arg_prefix,
  77. continue_until_all_final=until_all_final,
  78. always_start_new_sentences=always_start_new_sentences)
  79. return {'state': state, 'features': features, 'epochs': epochs}
  80. def _BuildSequence(self,
  81. batch_size,
  82. max_steps,
  83. features,
  84. state,
  85. use_average=False):
  86. """Adds a sequence of beam parsing steps."""
  87. def Advance(state, step, scores_array, alive, alive_steps, *features):
  88. scores = self._BuildNetwork(features,
  89. return_average=use_average)['logits']
  90. scores_array = scores_array.write(step, scores)
  91. features, state, alive = (
  92. gen_parser_ops.beam_parser(state, scores, self._feature_size))
  93. return [state, step + 1, scores_array, alive, alive_steps + tf.cast(
  94. alive, tf.int32)] + list(features)
  95. # args: (state, step, scores_array, alive, alive_steps, *features)
  96. def KeepGoing(*args):
  97. return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
  98. step = tf.constant(0, tf.int32, [])
  99. scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
  100. size=0,
  101. dynamic_size=True)
  102. alive = tf.constant(True, tf.bool, [batch_size])
  103. alive_steps = tf.constant(0, tf.int32, [batch_size])
  104. t = tf.while_loop(
  105. KeepGoing,
  106. Advance,
  107. [state, step, scores_array, alive, alive_steps] + list(features),
  108. shape_invariants=[tf.TensorShape(None)] * (len(features) + 5),
  109. parallel_iterations=100)
  110. # Link to the final nodes/values of ops that have passed through While:
  111. return {'state': t[0],
  112. 'concat_scores': t[2].concat(),
  113. 'alive': t[3],
  114. 'alive_steps': t[4]}
  115. def AddTraining(self,
  116. task_context,
  117. batch_size,
  118. learning_rate=0.1,
  119. decay_steps=4000,
  120. momentum=None,
  121. corpus_name='documents'):
  122. with tf.name_scope('training'):
  123. n = self.training
  124. n['accumulated_alive_steps'] = self._AddVariable(
  125. [batch_size], tf.int32, 'accumulated_alive_steps',
  126. tf.zeros_initializer)
  127. n.update(self._AddBeamReader(task_context, batch_size, corpus_name))
  128. # This adds a required 'step' node too:
  129. learning_rate = tf.constant(learning_rate, dtype=tf.float32)
  130. n['learning_rate'] = self._AddLearningRate(learning_rate, decay_steps)
  131. # Call BuildNetwork *only* to set up the params outside of the main loop.
  132. self._BuildNetwork(list(n['features']))
  133. n.update(self._BuildSequence(batch_size, self._max_steps, n['features'],
  134. n['state']))
  135. flat_concat_scores = tf.reshape(n['concat_scores'], [-1])
  136. (indices_and_paths, beams_and_slots, n['gold_slot'], n[
  137. 'beam_path_scores']) = gen_parser_ops.beam_parser_output(n[
  138. 'state'])
  139. n['indices'] = tf.reshape(tf.gather(indices_and_paths, [0]), [-1])
  140. n['path_ids'] = tf.reshape(tf.gather(indices_and_paths, [1]), [-1])
  141. n['all_path_scores'] = tf.sparse_segment_sum(
  142. flat_concat_scores, n['indices'], n['path_ids'])
  143. n['beam_ids'] = tf.reshape(tf.gather(beams_and_slots, [0]), [-1])
  144. n.update(AddCrossEntropy(batch_size, n))
  145. if self._only_train:
  146. trainable_params = {k: v for k, v in self.params.iteritems()
  147. if k in self._only_train}
  148. else:
  149. trainable_params = self.params
  150. for p in trainable_params:
  151. tf.logging.info('trainable_param: %s', p)
  152. regularized_params = [
  153. tf.nn.l2_loss(p) for k, p in trainable_params.iteritems()
  154. if k.startswith('weights') or k.startswith('bias')]
  155. l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
  156. n['cost'] = tf.add(n['cross_entropy'], l2_loss, name='cost')
  157. n['gradients'] = tf.gradients(n['cost'], trainable_params.values())
  158. with tf.control_dependencies([n['alive_steps']]):
  159. update_accumulators = tf.group(
  160. tf.assign_add(n['accumulated_alive_steps'], n['alive_steps']))
  161. def ResetAccumulators():
  162. return tf.assign(
  163. n['accumulated_alive_steps'], tf.zeros([batch_size], tf.int32))
  164. n['reset_accumulators_func'] = ResetAccumulators
  165. optimizer = tf.train.MomentumOptimizer(n['learning_rate'],
  166. momentum,
  167. use_locking=self._use_locking)
  168. train_op = optimizer.minimize(n['cost'],
  169. var_list=trainable_params.values())
  170. for param in trainable_params.values():
  171. slot = optimizer.get_slot(param, 'momentum')
  172. self.inits[slot.name] = state_ops.init_variable(slot,
  173. tf.zeros_initializer)
  174. self.variables[slot.name] = slot
  175. def NumericalChecks():
  176. return tf.group(*[
  177. tf.check_numerics(param, message='Parameter is not finite.')
  178. for param in trainable_params.values()
  179. if param.dtype.base_dtype in [tf.float32, tf.float64]])
  180. check_op = cf.cond(tf.equal(tf.mod(self.GetStep(), self._check_every), 0),
  181. NumericalChecks, tf.no_op)
  182. avg_update_op = tf.group(*self._averaging.values())
  183. train_ops = [train_op]
  184. if self._check_parameters:
  185. train_ops.append(check_op)
  186. if self._use_averaging:
  187. train_ops.append(avg_update_op)
  188. with tf.control_dependencies([update_accumulators]):
  189. n['train_op'] = tf.group(*train_ops, name='train_op')
  190. n['alive_steps'] = tf.identity(n['alive_steps'], name='alive_steps')
  191. return n
  192. def AddEvaluation(self,
  193. task_context,
  194. batch_size,
  195. evaluation_max_steps=300,
  196. corpus_name=None):
  197. with tf.name_scope('evaluation'):
  198. n = self.evaluation
  199. n.update(self._AddBeamReader(task_context,
  200. batch_size,
  201. corpus_name,
  202. until_all_final=True,
  203. always_start_new_sentences=True))
  204. self._BuildNetwork(
  205. list(n['features']),
  206. return_average=self._use_averaging)
  207. n.update(self._BuildSequence(batch_size, evaluation_max_steps, n[
  208. 'features'], n['state'], use_average=self._use_averaging))
  209. n['eval_metrics'], n['documents'] = (
  210. gen_parser_ops.beam_eval_output(n['state']))
  211. return n