structured_graph_builder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Build structured parser models."""
  16. import tensorflow as tf
  17. from tensorflow.python.ops import control_flow_ops as cf
  18. from tensorflow.python.ops import state_ops
  19. from tensorflow.python.ops import tensor_array_ops
  20. from syntaxnet import graph_builder
  21. from syntaxnet.ops import gen_parser_ops
  22. tf.NotDifferentiable('BeamParseReader')
  23. tf.NotDifferentiable('BeamParser')
  24. tf.NotDifferentiable('BeamParserOutput')
  25. def AddCrossEntropy(batch_size, n):
  26. """Adds a cross entropy cost function."""
  27. cross_entropies = []
  28. def _Pass():
  29. return tf.constant(0, dtype=tf.float32, shape=[1])
  30. for beam_id in range(batch_size):
  31. beam_gold_slot = tf.reshape(
  32. tf.strided_slice(n['gold_slot'], [beam_id], [beam_id + 1]), [1])
  33. def _ComputeCrossEntropy():
  34. """Adds ops to compute cross entropy of the gold path in a beam."""
  35. # Requires a cast so that UnsortedSegmentSum, in the gradient,
  36. # is happy with the type of its input 'segment_ids', which
  37. # must be int32.
  38. idx = tf.cast(
  39. tf.reshape(
  40. tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
  41. beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
  42. num = tf.shape(idx)
  43. return tf.nn.softmax_cross_entropy_with_logits(
  44. labels=tf.expand_dims(
  45. tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0),
  46. logits=beam_scores)
  47. # The conditional here is needed to deal with the last few batches of the
  48. # corpus which can contain -1 in beam_gold_slot for empty batch slots.
  49. cross_entropies.append(cf.cond(
  50. beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
  51. return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
  52. class StructuredGraphBuilder(graph_builder.GreedyParser):
  53. """Extends the standard GreedyParser with a CRF objective using a beam.
  54. The constructor takes two additional keyword arguments.
  55. beam_size: the maximum size the beam can grow to.
  56. max_steps: the maximum number of steps in any particular beam.
  57. The model supports batch training with the batch_size argument to the
  58. AddTraining method.
  59. """
  60. def __init__(self, *args, **kwargs):
  61. self._beam_size = kwargs.pop('beam_size', 10)
  62. self._max_steps = kwargs.pop('max_steps', 25)
  63. super(StructuredGraphBuilder, self).__init__(*args, **kwargs)
  64. def _AddBeamReader(self,
  65. task_context,
  66. batch_size,
  67. corpus_name,
  68. until_all_final=False,
  69. always_start_new_sentences=False):
  70. """Adds an op capable of reading sentences and parsing them with a beam."""
  71. features, state, epochs = gen_parser_ops.beam_parse_reader(
  72. task_context=task_context,
  73. feature_size=self._feature_size,
  74. beam_size=self._beam_size,
  75. batch_size=batch_size,
  76. corpus_name=corpus_name,
  77. allow_feature_weights=self._allow_feature_weights,
  78. arg_prefix=self._arg_prefix,
  79. continue_until_all_final=until_all_final,
  80. always_start_new_sentences=always_start_new_sentences)
  81. return {'state': state, 'features': features, 'epochs': epochs}
  82. def _BuildSequence(self,
  83. batch_size,
  84. max_steps,
  85. features,
  86. state,
  87. use_average=False):
  88. """Adds a sequence of beam parsing steps."""
  89. def Advance(state, step, scores_array, alive, alive_steps, *features):
  90. scores = self._BuildNetwork(features,
  91. return_average=use_average)['logits']
  92. scores_array = scores_array.write(step, scores)
  93. features, state, alive = (
  94. gen_parser_ops.beam_parser(state, scores, self._feature_size))
  95. return [state, step + 1, scores_array, alive, alive_steps + tf.cast(
  96. alive, tf.int32)] + list(features)
  97. # args: (state, step, scores_array, alive, alive_steps, *features)
  98. def KeepGoing(*args):
  99. return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
  100. step = tf.constant(0, tf.int32, [])
  101. scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
  102. size=0,
  103. dynamic_size=True)
  104. alive = tf.constant(True, tf.bool, [batch_size])
  105. alive_steps = tf.constant(0, tf.int32, [batch_size])
  106. t = tf.while_loop(
  107. KeepGoing,
  108. Advance,
  109. [state, step, scores_array, alive, alive_steps] + list(features),
  110. shape_invariants=[tf.TensorShape(None)] * (len(features) + 5),
  111. parallel_iterations=100)
  112. # Link to the final nodes/values of ops that have passed through While:
  113. return {'state': t[0],
  114. 'concat_scores': t[2].concat(),
  115. 'alive': t[3],
  116. 'alive_steps': t[4]}
  117. def AddTraining(self,
  118. task_context,
  119. batch_size,
  120. learning_rate=0.1,
  121. decay_steps=4000,
  122. momentum=None,
  123. corpus_name='documents'):
  124. with tf.name_scope('training'):
  125. n = self.training
  126. n['accumulated_alive_steps'] = self._AddVariable(
  127. [batch_size], tf.int32, 'accumulated_alive_steps',
  128. tf.zeros_initializer())
  129. n.update(self._AddBeamReader(task_context, batch_size, corpus_name))
  130. # This adds a required 'step' node too:
  131. learning_rate = tf.constant(learning_rate, dtype=tf.float32)
  132. n['learning_rate'] = self._AddLearningRate(learning_rate, decay_steps)
  133. # Call BuildNetwork *only* to set up the params outside of the main loop.
  134. self._BuildNetwork(list(n['features']))
  135. n.update(self._BuildSequence(batch_size, self._max_steps, n['features'],
  136. n['state']))
  137. flat_concat_scores = tf.reshape(n['concat_scores'], [-1])
  138. (indices_and_paths, beams_and_slots, n['gold_slot'], n[
  139. 'beam_path_scores']) = gen_parser_ops.beam_parser_output(n[
  140. 'state'])
  141. n['indices'] = tf.reshape(tf.gather(indices_and_paths, [0]), [-1])
  142. n['path_ids'] = tf.reshape(tf.gather(indices_and_paths, [1]), [-1])
  143. n['all_path_scores'] = tf.sparse_segment_sum(
  144. flat_concat_scores, n['indices'], n['path_ids'])
  145. n['beam_ids'] = tf.reshape(tf.gather(beams_and_slots, [0]), [-1])
  146. n.update(AddCrossEntropy(batch_size, n))
  147. if self._only_train:
  148. trainable_params = {k: v for k, v in self.params.iteritems()
  149. if k in self._only_train}
  150. else:
  151. trainable_params = self.params
  152. for p in trainable_params:
  153. tf.logging.info('trainable_param: %s', p)
  154. regularized_params = [
  155. tf.nn.l2_loss(p) for k, p in trainable_params.iteritems()
  156. if k.startswith('weights') or k.startswith('bias')]
  157. l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
  158. n['cost'] = tf.add(n['cross_entropy'], l2_loss, name='cost')
  159. n['gradients'] = tf.gradients(n['cost'], trainable_params.values())
  160. with tf.control_dependencies([n['alive_steps']]):
  161. update_accumulators = tf.group(
  162. tf.assign_add(n['accumulated_alive_steps'], n['alive_steps']))
  163. def ResetAccumulators():
  164. return tf.assign(
  165. n['accumulated_alive_steps'], tf.zeros([batch_size], tf.int32))
  166. n['reset_accumulators_func'] = ResetAccumulators
  167. optimizer = tf.train.MomentumOptimizer(n['learning_rate'],
  168. momentum,
  169. use_locking=self._use_locking)
  170. train_op = optimizer.minimize(n['cost'],
  171. var_list=trainable_params.values())
  172. for param in trainable_params.values():
  173. slot = optimizer.get_slot(param, 'momentum')
  174. self.inits[slot.name] = state_ops.init_variable(slot,
  175. tf.zeros_initializer())
  176. self.variables[slot.name] = slot
  177. def NumericalChecks():
  178. return tf.group(*[
  179. tf.check_numerics(param, message='Parameter is not finite.')
  180. for param in trainable_params.values()
  181. if param.dtype.base_dtype in [tf.float32, tf.float64]])
  182. check_op = cf.cond(tf.equal(tf.mod(self.GetStep(), self._check_every), 0),
  183. NumericalChecks, tf.no_op)
  184. avg_update_op = tf.group(*self._averaging.values())
  185. train_ops = [train_op]
  186. if self._check_parameters:
  187. train_ops.append(check_op)
  188. if self._use_averaging:
  189. train_ops.append(avg_update_op)
  190. with tf.control_dependencies([update_accumulators]):
  191. n['train_op'] = tf.group(*train_ops, name='train_op')
  192. n['alive_steps'] = tf.identity(n['alive_steps'], name='alive_steps')
  193. return n
  194. def AddEvaluation(self,
  195. task_context,
  196. batch_size,
  197. evaluation_max_steps=300,
  198. corpus_name=None):
  199. with tf.name_scope('evaluation'):
  200. n = self.evaluation
  201. n.update(self._AddBeamReader(task_context,
  202. batch_size,
  203. corpus_name,
  204. until_all_final=True,
  205. always_start_new_sentences=True))
  206. self._BuildNetwork(
  207. list(n['features']),
  208. return_average=self._use_averaging)
  209. n.update(self._BuildSequence(batch_size, evaluation_max_steps, n[
  210. 'features'], n['state'], use_average=self._use_averaging))
  211. n['eval_metrics'], n['documents'] = (
  212. gen_parser_ops.beam_eval_output(n['state']))
  213. return n