structured_graph_builder.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Build structured parser models."""
  16. import tensorflow as tf
  17. from tensorflow.python.ops import control_flow_ops as cf
  18. from tensorflow.python.ops import state_ops
  19. from tensorflow.python.ops import tensor_array_ops
  20. from syntaxnet import graph_builder
  21. from syntaxnet.ops import gen_parser_ops
  22. tf.NoGradient('BeamParseReader')
  23. tf.NoGradient('BeamParser')
  24. tf.NoGradient('BeamParserOutput')
  25. def AddCrossEntropy(batch_size, n):
  26. """Adds a cross entropy cost function."""
  27. cross_entropies = []
  28. def _Pass():
  29. return tf.constant(0, dtype=tf.float32, shape=[1])
  30. for beam_id in range(batch_size):
  31. beam_gold_slot = tf.reshape(tf.slice(n['gold_slot'], [beam_id], [1]), [1])
  32. def _ComputeCrossEntropy():
  33. """Adds ops to compute cross entropy of the gold path in a beam."""
  34. # Requires a cast so that UnsortedSegmentSum, in the gradient,
  35. # is happy with the type of its input 'segment_ids', which
  36. # must be int32.
  37. idx = tf.cast(
  38. tf.reshape(
  39. tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
  40. beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
  41. num = tf.shape(idx)
  42. return tf.nn.softmax_cross_entropy_with_logits(
  43. beam_scores, tf.expand_dims(
  44. tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0))
  45. # The conditional here is needed to deal with the last few batches of the
  46. # corpus which can contain -1 in beam_gold_slot for empty batch slots.
  47. cross_entropies.append(cf.cond(
  48. beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
  49. return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
  50. class StructuredGraphBuilder(graph_builder.GreedyParser):
  51. """Extends the standard GreedyParser with a CRF objective using a beam.
  52. The constructor takes two additional keyword arguments.
  53. beam_size: the maximum size the beam can grow to.
  54. max_steps: the maximum number of steps in any particular beam.
  55. The model supports batch training with the batch_size argument to the
  56. AddTraining method.
  57. """
  58. def __init__(self, *args, **kwargs):
  59. self._beam_size = kwargs.pop('beam_size', 10)
  60. self._max_steps = kwargs.pop('max_steps', 25)
  61. super(StructuredGraphBuilder, self).__init__(*args, **kwargs)
  62. def _AddBeamReader(self,
  63. task_context,
  64. batch_size,
  65. corpus_name,
  66. until_all_final=False,
  67. always_start_new_sentences=False):
  68. """Adds an op capable of reading sentences and parsing them with a beam."""
  69. features, state, epochs = gen_parser_ops.beam_parse_reader(
  70. task_context=task_context,
  71. feature_size=self._feature_size,
  72. beam_size=self._beam_size,
  73. batch_size=batch_size,
  74. corpus_name=corpus_name,
  75. allow_feature_weights=self._allow_feature_weights,
  76. arg_prefix=self._arg_prefix,
  77. continue_until_all_final=until_all_final,
  78. always_start_new_sentences=always_start_new_sentences)
  79. return {'state': state, 'features': features, 'epochs': epochs}
  80. def _BuildSequence(self,
  81. batch_size,
  82. max_steps,
  83. features,
  84. state,
  85. use_average=False):
  86. """Adds a sequence of beam parsing steps."""
  87. def Advance(state, step, scores_array, alive, alive_steps, *features):
  88. scores = self._BuildNetwork(features,
  89. return_average=use_average)['logits']
  90. scores_array = scores_array.write(step, scores)
  91. features, state, alive = (
  92. gen_parser_ops.beam_parser(state, scores, self._feature_size))
  93. return [state, step + 1, scores_array, alive, alive_steps + tf.cast(
  94. alive, tf.int32)] + list(features)
  95. # args: (state, step, scores_array, alive, alive_steps, *features)
  96. def KeepGoing(*args):
  97. return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
  98. step = tf.constant(0, tf.int32, [])
  99. scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
  100. size=0,
  101. dynamic_size=True)
  102. alive = tf.constant(True, tf.bool, [batch_size])
  103. alive_steps = tf.constant(0, tf.int32, [batch_size])
  104. t = tf.while_loop(
  105. KeepGoing,
  106. Advance,
  107. [state, step, scores_array, alive, alive_steps] + list(features),
  108. parallel_iterations=100)
  109. # Link to the final nodes/values of ops that have passed through While:
  110. return {'state': t[0],
  111. 'concat_scores': t[2].concat(),
  112. 'alive': t[3],
  113. 'alive_steps': t[4]}
  114. def AddTraining(self,
  115. task_context,
  116. batch_size,
  117. learning_rate=0.1,
  118. decay_steps=4000,
  119. momentum=None,
  120. corpus_name='documents'):
  121. with tf.name_scope('training'):
  122. n = self.training
  123. n['accumulated_alive_steps'] = self._AddVariable(
  124. [batch_size], tf.int32, 'accumulated_alive_steps',
  125. tf.zeros_initializer)
  126. n.update(self._AddBeamReader(task_context, batch_size, corpus_name))
  127. # This adds a required 'step' node too:
  128. learning_rate = tf.constant(learning_rate, dtype=tf.float32)
  129. n['learning_rate'] = self._AddLearningRate(learning_rate, decay_steps)
  130. # Call BuildNetwork *only* to set up the params outside of the main loop.
  131. self._BuildNetwork(list(n['features']))
  132. n.update(self._BuildSequence(batch_size, self._max_steps, n['features'],
  133. n['state']))
  134. flat_concat_scores = tf.reshape(n['concat_scores'], [-1])
  135. (indices_and_paths, beams_and_slots, n['gold_slot'], n[
  136. 'beam_path_scores']) = gen_parser_ops.beam_parser_output(n[
  137. 'state'])
  138. n['indices'] = tf.reshape(tf.gather(indices_and_paths, [0]), [-1])
  139. n['path_ids'] = tf.reshape(tf.gather(indices_and_paths, [1]), [-1])
  140. n['all_path_scores'] = tf.sparse_segment_sum(
  141. flat_concat_scores, n['indices'], n['path_ids'])
  142. n['beam_ids'] = tf.reshape(tf.gather(beams_and_slots, [0]), [-1])
  143. n.update(AddCrossEntropy(batch_size, n))
  144. if self._only_train:
  145. trainable_params = {k: v for k, v in self.params.iteritems()
  146. if k in self._only_train}
  147. else:
  148. trainable_params = self.params
  149. for p in trainable_params:
  150. tf.logging.info('trainable_param: %s', p)
  151. regularized_params = [
  152. tf.nn.l2_loss(p) for k, p in trainable_params.iteritems()
  153. if k.startswith('weights') or k.startswith('bias')]
  154. l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
  155. n['cost'] = tf.add(n['cross_entropy'], l2_loss, name='cost')
  156. n['gradients'] = tf.gradients(n['cost'], trainable_params.values())
  157. with tf.control_dependencies([n['alive_steps']]):
  158. update_accumulators = tf.group(
  159. tf.assign_add(n['accumulated_alive_steps'], n['alive_steps']))
  160. def ResetAccumulators():
  161. return tf.assign(
  162. n['accumulated_alive_steps'], tf.zeros([batch_size], tf.int32))
  163. n['reset_accumulators_func'] = ResetAccumulators
  164. optimizer = tf.train.MomentumOptimizer(n['learning_rate'],
  165. momentum,
  166. use_locking=self._use_locking)
  167. train_op = optimizer.minimize(n['cost'],
  168. var_list=trainable_params.values())
  169. for param in trainable_params.values():
  170. slot = optimizer.get_slot(param, 'momentum')
  171. self.inits[slot.name] = state_ops.init_variable(slot,
  172. tf.zeros_initializer)
  173. self.variables[slot.name] = slot
  174. def NumericalChecks():
  175. return tf.group(*[
  176. tf.check_numerics(param, message='Parameter is not finite.')
  177. for param in trainable_params.values()
  178. if param.dtype.base_dtype in [tf.float32, tf.float64]])
  179. check_op = cf.cond(tf.equal(tf.mod(self.GetStep(), self._check_every), 0),
  180. NumericalChecks, tf.no_op)
  181. avg_update_op = tf.group(*self._averaging.values())
  182. train_ops = [train_op]
  183. if self._check_parameters:
  184. train_ops.append(check_op)
  185. if self._use_averaging:
  186. train_ops.append(avg_update_op)
  187. with tf.control_dependencies([update_accumulators]):
  188. n['train_op'] = tf.group(*train_ops, name='train_op')
  189. n['alive_steps'] = tf.identity(n['alive_steps'], name='alive_steps')
  190. return n
  191. def AddEvaluation(self,
  192. task_context,
  193. batch_size,
  194. evaluation_max_steps=300,
  195. corpus_name=None):
  196. with tf.name_scope('evaluation'):
  197. n = self.evaluation
  198. n.update(self._AddBeamReader(task_context,
  199. batch_size,
  200. corpus_name,
  201. until_all_final=True,
  202. always_start_new_sentences=True))
  203. self._BuildNetwork(
  204. list(n['features']),
  205. return_average=self._use_averaging)
  206. n.update(self._BuildSequence(batch_size, evaluation_max_steps, n[
  207. 'features'], n['state'], use_average=self._use_averaging))
  208. n['eval_metrics'], n['documents'] = (
  209. gen_parser_ops.beam_eval_output(n['state']))
  210. return n