component.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builds a DRAGNN graph for local training."""
  16. from abc import ABCMeta
  17. from abc import abstractmethod
  18. import tensorflow as tf
  19. from tensorflow.python.platform import tf_logging as logging
  20. from dragnn.python import dragnn_ops
  21. from dragnn.python import network_units
  22. from syntaxnet.util import check
  23. from syntaxnet.util import registry
  24. class NetworkState(object):
  25. """Simple utility to manage the state of a DRAGNN network.
  26. This class encapsulates the variables that are a specific to any
  27. particular instance of a DRAGNN stack, as constructed by the
  28. MasterBuilder below.
  29. Attributes:
  30. activations: Dictionary mapping layer names to StoredActivation objects.
  31. """
  32. def __init__(self):
  33. self.activations = {}
  34. class MasterState(object):
  35. """Simple utility to encapsulate tensors associated with the master state.
  36. Attributes:
  37. handle: string tensor handle to the underlying nlp_saft::dragnn::MasterState
  38. current_batch_size: int tensor containing the batch size following the most
  39. recent MasterState::Reset().
  40. """
  41. def __init__(self, handle, current_batch_size):
  42. self.handle = handle
  43. self.current_batch_size = current_batch_size
  44. @registry.RegisteredClass
  45. class ComponentBuilderBase(object):
  46. """Utility to build a single Component in a DRAGNN stack of models.
  47. This class handles converting a ComponentSpec proto into various TF
  48. sub-graphs. It will stitch together various neural units with dynamic
  49. unrolling inside a tf.while loop.
  50. All variables for parameters are created during the constructor within the
  51. scope of the component's name, e.g. 'tagger/embedding_matrix_0' for a
  52. component named 'tagger'.
  53. As part of the specification, ComponentBuilder will wrap an underlying
  54. NetworkUnit which generates the actual network layout.
  55. """
  56. __metaclass__ = ABCMeta # required for @abstractmethod
  57. def __init__(self, master, component_spec, attr_defaults=None):
  58. """Initializes the ComponentBuilder from specifications.
  59. Args:
  60. master: dragnn.MasterBuilder object.
  61. component_spec: dragnn.ComponentSpec proto to be built.
  62. attr_defaults: Optional dict of component attribute defaults. If not
  63. provided or if empty, attributes are not extracted.
  64. """
  65. self.master = master
  66. self.num_actions = component_spec.num_actions
  67. self.name = component_spec.name
  68. self.spec = component_spec
  69. self.moving_average = None
  70. # Determine if this component should apply self-normalization.
  71. self.eligible_for_self_norm = (
  72. not self.master.hyperparams.self_norm_components_filter or self.name in
  73. self.master.hyperparams.self_norm_components_filter.split(','))
  74. # Extract component attributes before make_network(), so the network unit
  75. # can access them.
  76. self._attrs = {}
  77. if attr_defaults:
  78. self._attrs = network_units.get_attrs_with_defaults(
  79. self.spec.component_builder.parameters, attr_defaults)
  80. with tf.variable_scope(self.name):
  81. self.training_beam_size = tf.constant(
  82. self.spec.training_beam_size, name='TrainingBeamSize')
  83. self.inference_beam_size = tf.constant(
  84. self.spec.inference_beam_size, name='InferenceBeamSize')
  85. self.locally_normalize = tf.constant(False, name='LocallyNormalize')
  86. self._step = tf.get_variable(
  87. 'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
  88. self._total = tf.get_variable(
  89. 'total', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
  90. # Construct network variables.
  91. self.network = self.make_network(self.spec.network_unit)
  92. # Construct moving average.
  93. if self.master.hyperparams.use_moving_average:
  94. self.moving_average = tf.train.ExponentialMovingAverage(
  95. decay=self.master.hyperparams.average_weight, num_updates=self._step)
  96. self.avg_ops = [self.moving_average.apply(self.network.params)]
  97. def make_network(self, network_unit):
  98. """Makes a NetworkUnitInterface object based on the network_unit spec.
  99. Components may override this method to exert control over the
  100. network unit construction, such as which network units are supported.
  101. Args:
  102. network_unit: RegisteredModuleSpec proto defining the network unit.
  103. Returns:
  104. An implementation of NetworkUnitInterface.
  105. Raises:
  106. ValueError: if the requested network unit is not found in the registry.
  107. """
  108. network_type = network_unit.registered_name
  109. with tf.variable_scope(self.name):
  110. # Raises ValueError if not found.
  111. return network_units.NetworkUnitInterface.Create(network_type, self)
  112. @abstractmethod
  113. def build_greedy_training(self, state, network_states):
  114. """Builds a training graph for this component.
  115. Two assumptions are made about the resulting graph:
  116. 1. An oracle will be used to unroll the state and compute the cost.
  117. 2. The graph will be differentiable when the cost is being minimized.
  118. Args:
  119. state: MasterState from the 'AdvanceMaster' op that advances the
  120. underlying master to this component.
  121. network_states: dictionary of component NetworkState objects.
  122. Returns:
  123. (state, cost, correct, total) -- These are TF ops corresponding to
  124. the final state after unrolling, the total cost, the total number of
  125. correctly predicted actions, and the total number of actions.
  126. """
  127. pass
  128. def build_structured_training(self, state, network_states):
  129. """Builds a beam search based training loop for this component.
  130. The default implementation builds a dummy graph and raises a
  131. TensorFlow runtime exception to indicate that structured training
  132. is not implemented.
  133. Args:
  134. state: MasterState from the 'AdvanceMaster' op that advances the
  135. underlying master to this component.
  136. network_states: dictionary of component NetworkState objects.
  137. Returns:
  138. (handle, cost, correct, total) -- These are TF ops corresponding
  139. to the final handle after unrolling, the total cost, and the
  140. total number of actions. Since the number of correctly predicted
  141. actions is not applicable in the structured training setting, a
  142. dummy value should returned.
  143. """
  144. del network_states # Unused.
  145. with tf.control_dependencies([tf.Assert(False, ['Not implemented.'])]):
  146. handle = tf.identity(state.handle)
  147. cost = tf.constant(0.)
  148. correct, total = tf.constant(0), tf.constant(0)
  149. return handle, cost, correct, total
  150. @abstractmethod
  151. def build_greedy_inference(self, state, network_states,
  152. during_training=False):
  153. """Builds an inference graph for this component.
  154. If this graph is being constructed 'during_training', then it needs to be
  155. differentiable even though it doesn't return an explicit cost.
  156. There may be other cases where the distinction between training and eval is
  157. important. The handling of dropout is an example of this.
  158. Args:
  159. state: MasterState from the 'AdvanceMaster' op that advances the
  160. underlying master to this component.
  161. network_states: dictionary of component NetworkState objects.
  162. during_training: whether the graph is being constructed during training
  163. Returns:
  164. Handle to the state once inference is complete for this Component.
  165. """
  166. pass
  167. def get_summaries(self):
  168. """Constructs a set of summaries for this component.
  169. Returns:
  170. List of Summary ops to get parameter norms, progress reports, and
  171. so forth for this component.
  172. """
  173. def combine_norm(matrices):
  174. # Handles None in cases where the optimizer or moving average slot is
  175. # not present.
  176. squares = [tf.reduce_sum(tf.square(m)) for m in matrices if m is not None]
  177. # Some components may not have any parameters, in which case we simply
  178. # return zero.
  179. if squares:
  180. return tf.sqrt(tf.add_n(squares))
  181. else:
  182. return tf.constant(0, tf.float32)
  183. summaries = []
  184. summaries.append(tf.summary.scalar('%s step' % self.name, self._step))
  185. summaries.append(tf.summary.scalar('%s total' % self.name, self._total))
  186. if self.network.params:
  187. summaries.append(
  188. tf.summary.scalar('%s parameter Norm' % self.name,
  189. combine_norm(self.network.params)))
  190. slot_names = self.master.optimizer.get_slot_names()
  191. for name in slot_names:
  192. slot_params = [
  193. self.master.optimizer.get_slot(p, name) for p in self.network.params
  194. ]
  195. summaries.append(
  196. tf.summary.scalar('%s %s Norm' % (self.name, name),
  197. combine_norm(slot_params)))
  198. # Construct moving average.
  199. if self.master.hyperparams.use_moving_average:
  200. summaries.append(
  201. tf.summary.scalar('%s avg Norm' % self.name,
  202. combine_norm([
  203. self.moving_average.average(p)
  204. for p in self.network.params
  205. ])))
  206. return summaries
  207. def get_variable(self, var_name=None, var_params=None):
  208. """Returns either the original or averaged version of a given variable.
  209. If the master.read_from_avg flag is set to True, and the
  210. ExponentialMovingAverage (EMA) object has been attached, then this will ask
  211. the EMA object for the given variable.
  212. This is to allow executing inference from the averaged version of
  213. parameters.
  214. Arguments:
  215. var_name: Name of the variable.
  216. var_params: tf.Variable for which to retrieve an average.
  217. Only one of |var_name| or |var_params| needs to be provided. If both are
  218. provided, |var_params| takes precedence.
  219. Returns:
  220. tf.Variable object corresponding to original or averaged version.
  221. """
  222. if var_params:
  223. var_name = var_params.name
  224. else:
  225. check.NotNone(var_name, 'specify at least one of var_name or var_params')
  226. var_params = tf.get_variable(var_name)
  227. if self.moving_average and self.master.read_from_avg:
  228. logging.info('Retrieving average for: %s', var_name)
  229. var_params = self.moving_average.average(var_params)
  230. assert var_params
  231. logging.info('Returning: %s', var_params.name)
  232. return var_params
  233. def advance_counters(self, total):
  234. """Returns ops to advance the per-component step and total counters.
  235. Args:
  236. total: Total number of actions to increment counters by.
  237. Returns:
  238. tf.Group op incrementing 'step' by 1 and 'total' by total.
  239. """
  240. update_total = tf.assign_add(self._total, total, use_locking=True)
  241. update_step = tf.assign_add(self._step, 1, use_locking=True)
  242. return tf.group(update_total, update_step)
  243. def add_regularizer(self, cost):
  244. """Adds L2 regularization for parameters which have it turned on.
  245. Args:
  246. cost: float cost before regularization.
  247. Returns:
  248. Updated cost optionally including regularization.
  249. """
  250. if self.network is None:
  251. return cost
  252. regularized_weights = self.network.get_l2_regularized_weights()
  253. if not regularized_weights:
  254. return cost
  255. l2_coeff = self.master.hyperparams.l2_regularization_coefficient
  256. if l2_coeff == 0.0:
  257. return cost
  258. tf.logging.info('[%s] Regularizing parameters: %s', self.name,
  259. [w.name for w in regularized_weights])
  260. l2_costs = [tf.nn.l2_loss(p) for p in regularized_weights]
  261. return tf.add(cost, l2_coeff * tf.add_n(l2_costs), name='regularizer')
  262. def build_post_restore_hook(self):
  263. """Builds a post restore graph for this component.
  264. This is a run-once graph that prepares any state necessary for the
  265. inference portion of the component. It is generally a no-op.
  266. Returns:
  267. A no-op state.
  268. """
  269. logging.info('Building default post restore hook for component: %s',
  270. self.spec.name)
  271. return tf.no_op(name='setup_%s' % self.spec.name)
  272. def attr(self, name):
  273. """Returns the value of the component attribute with the |name|."""
  274. return self._attrs[name]
  275. def update_tensor_arrays(network_tensors, arrays):
  276. """Updates a list of tensor arrays from the network's output tensors.
  277. Arguments:
  278. network_tensors: Output tensors from the underlying NN unit.
  279. arrays: TensorArrays to be updated.
  280. Returns:
  281. New list of TensorArrays after writing activations.
  282. """
  283. # TODO(googleuser): Only store activations that will be used later in linked
  284. # feature specifications.
  285. next_arrays = []
  286. for index, network_tensor in enumerate(network_tensors):
  287. array = arrays[index]
  288. size = array.size()
  289. array = array.write(size, network_tensor)
  290. next_arrays.append(array)
  291. return next_arrays
  292. class DynamicComponentBuilder(ComponentBuilderBase):
  293. """Component builder for recurrent DRAGNN networks.
  294. Feature extraction and annotation are done sequentially in a tf.while_loop
  295. so fixed and linked features can be recurrent.
  296. """
  297. def build_greedy_training(self, state, network_states):
  298. """Builds a training loop for this component.
  299. This loop repeatedly evaluates the network and computes the loss, but it
  300. does not advance using the predictions of the network. Instead, it advances
  301. using the oracle defined in the underlying transition system. The final
  302. state will always correspond to the gold annotation.
  303. Args:
  304. state: MasterState from the 'AdvanceMaster' op that advances the
  305. underlying master to this component.
  306. network_states: NetworkState object containing component TensorArrays.
  307. Returns:
  308. (state, cost, correct, total) -- These are TF ops corresponding to
  309. the final state after unrolling, the total cost, the total number of
  310. correctly predicted actions, and the total number of actions.
  311. """
  312. logging.info('Building component: %s', self.spec.name)
  313. with tf.control_dependencies([tf.assert_equal(self.training_beam_size, 1)]):
  314. stride = state.current_batch_size * self.training_beam_size
  315. cost = tf.constant(0.)
  316. correct = tf.constant(0)
  317. total = tf.constant(0)
  318. def cond(handle, *_):
  319. all_final = dragnn_ops.emit_all_final(handle, component=self.name)
  320. return tf.logical_not(tf.reduce_all(all_final))
  321. def body(handle, cost, correct, total, *arrays):
  322. """Runs the network and advances the state by a step."""
  323. with tf.control_dependencies([handle, cost, correct, total] +
  324. [x.flow for x in arrays]):
  325. # Get a copy of the network inside this while loop.
  326. updated_state = MasterState(handle, state.current_batch_size)
  327. network_tensors = self._feedforward_unit(
  328. updated_state, arrays, network_states, stride, during_training=True)
  329. # Every layer is written to a TensorArray, so that it can be backprop'd.
  330. next_arrays = update_tensor_arrays(network_tensors, arrays)
  331. with tf.control_dependencies([x.flow for x in next_arrays]):
  332. with tf.name_scope('compute_loss'):
  333. # A gold label > -1 determines that the sentence is still
  334. # in a valid state. Otherwise, the sentence has ended.
  335. #
  336. # We add only the valid sentences to the loss, in the following way:
  337. # 1. We compute 'valid_ix', the indices in gold that contain
  338. # valid oracle actions.
  339. # 2. We compute the cost function by comparing logits and gold
  340. # only for the valid indices.
  341. gold = dragnn_ops.emit_oracle_labels(handle, component=self.name)
  342. gold.set_shape([None])
  343. valid = tf.greater(gold, -1)
  344. valid_ix = tf.reshape(tf.where(valid), [-1])
  345. gold = tf.gather(gold, valid_ix)
  346. logits = self.network.get_logits(network_tensors)
  347. logits = tf.gather(logits, valid_ix)
  348. cost += tf.reduce_sum(
  349. tf.nn.sparse_softmax_cross_entropy_with_logits(
  350. labels=tf.cast(gold, tf.int64), logits=logits))
  351. if (self.eligible_for_self_norm and
  352. self.master.hyperparams.self_norm_alpha > 0):
  353. log_z = tf.reduce_logsumexp(logits, [1])
  354. cost += (self.master.hyperparams.self_norm_alpha *
  355. tf.nn.l2_loss(log_z))
  356. correct += tf.reduce_sum(
  357. tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  358. total += tf.size(gold)
  359. with tf.control_dependencies([cost, correct, total, gold]):
  360. handle = dragnn_ops.advance_from_oracle(handle, component=self.name)
  361. return [handle, cost, correct, total] + next_arrays
  362. with tf.name_scope(self.name + '/train_state'):
  363. init_arrays = []
  364. for layer in self.network.layers:
  365. init_arrays.append(layer.create_array(state.current_batch_size))
  366. output = tf.while_loop(
  367. cond,
  368. body, [state.handle, cost, correct, total] + init_arrays,
  369. name='train_%s' % self.name)
  370. # Saves completed arrays and return final state and cost.
  371. state.handle = output[0]
  372. correct = output[2]
  373. total = output[3]
  374. arrays = output[4:]
  375. cost = output[1]
  376. # Store handles to the final output for use in subsequent tasks.
  377. network_state = network_states[self.name]
  378. with tf.name_scope(self.name + '/stored_act'):
  379. for index, layer in enumerate(self.network.layers):
  380. network_state.activations[layer.name] = network_units.StoredActivations(
  381. array=arrays[index])
  382. # Normalize the objective by the total # of steps taken.
  383. with tf.control_dependencies([tf.assert_greater(total, 0)]):
  384. cost /= tf.to_float(total)
  385. # Adds regularization for the hidden weights.
  386. cost = self.add_regularizer(cost)
  387. with tf.control_dependencies([x.flow for x in arrays]):
  388. return tf.identity(state.handle), cost, correct, total
  389. def build_greedy_inference(self, state, network_states,
  390. during_training=False):
  391. """Builds an inference loop for this component.
  392. Repeatedly evaluates the network and advances the underlying state according
  393. to the predicted scores.
  394. Args:
  395. state: MasterState from the 'AdvanceMaster' op that advances the
  396. underlying master to this component.
  397. network_states: NetworkState object containing component TensorArrays.
  398. during_training: whether the graph is being constructed during training
  399. Returns:
  400. Handle to the state once inference is complete for this Component.
  401. """
  402. logging.info('Building component: %s', self.spec.name)
  403. if during_training:
  404. stride = state.current_batch_size * self.training_beam_size
  405. else:
  406. stride = state.current_batch_size * self.inference_beam_size
  407. def cond(handle, *_):
  408. all_final = dragnn_ops.emit_all_final(handle, component=self.name)
  409. return tf.logical_not(tf.reduce_all(all_final))
  410. def body(handle, *arrays):
  411. """Runs the network and advances the state by a step."""
  412. with tf.control_dependencies([handle] + [x.flow for x in arrays]):
  413. # Get a copy of the network inside this while loop.
  414. updated_state = MasterState(handle, state.current_batch_size)
  415. network_tensors = self._feedforward_unit(
  416. updated_state,
  417. arrays,
  418. network_states,
  419. stride,
  420. during_training=during_training)
  421. next_arrays = update_tensor_arrays(network_tensors, arrays)
  422. with tf.control_dependencies([x.flow for x in next_arrays]):
  423. logits = self.network.get_logits(network_tensors)
  424. logits = tf.cond(self.locally_normalize,
  425. lambda: tf.nn.log_softmax(logits), lambda: logits)
  426. handle = dragnn_ops.advance_from_prediction(
  427. handle, logits, component=self.name)
  428. return [handle] + next_arrays
  429. # Create the TensorArray's to store activations for downstream/recurrent
  430. # connections.
  431. with tf.name_scope(self.name + '/inference_state'):
  432. init_arrays = []
  433. for layer in self.network.layers:
  434. init_arrays.append(layer.create_array(stride))
  435. output = tf.while_loop(
  436. cond,
  437. body, [state.handle] + init_arrays,
  438. name='inference_%s' % self.name)
  439. # Saves completed arrays and returns final state.
  440. state.handle = output[0]
  441. arrays = output[1:]
  442. network_state = network_states[self.name]
  443. with tf.name_scope(self.name + '/stored_act'):
  444. for index, layer in enumerate(self.network.layers):
  445. network_state.activations[layer.name] = network_units.StoredActivations(
  446. array=arrays[index])
  447. with tf.control_dependencies([x.flow for x in arrays]):
  448. return tf.identity(state.handle)
  449. def _feedforward_unit(self, state, arrays, network_states, stride,
  450. during_training):
  451. """Constructs a single instance of a feed-forward cell.
  452. Given an input state and access to the arrays storing activations, this
  453. function encapsulates creation of a single network unit. This will *not*
  454. create new variables.
  455. Args:
  456. state: MasterState for the state that will be used to extract features.
  457. arrays: List of TensorArrays corresponding to network outputs from this
  458. component. These are used for recurrent link features; the arrays from
  459. other components are used for stack-prop style connections.
  460. network_states: NetworkState object containing the TensorArrays from
  461. *all* components.
  462. stride: int Tensor with the current beam * batch size.
  463. during_training: Whether to build a unit for training (vs inference).
  464. Returns:
  465. List of tensors generated by the underlying network implementation.
  466. """
  467. with tf.variable_scope(self.name, reuse=True):
  468. fixed_embeddings = []
  469. for channel_id, feature_spec in enumerate(self.spec.fixed_feature):
  470. fixed_embedding = network_units.fixed_feature_lookup(
  471. self, state, channel_id, stride)
  472. if feature_spec.is_constant:
  473. fixed_embedding.tensor = tf.stop_gradient(fixed_embedding.tensor)
  474. fixed_embeddings.append(fixed_embedding)
  475. linked_embeddings = []
  476. for channel_id, feature_spec in enumerate(self.spec.linked_feature):
  477. if feature_spec.source_component == self.name:
  478. # Recurrent feature: pull from the local arrays.
  479. index = self.network.get_layer_index(feature_spec.source_layer)
  480. source_array = arrays[index]
  481. source_layer_size = self.network.layers[index].dim
  482. linked_embeddings.append(
  483. network_units.activation_lookup_recurrent(
  484. self, state, channel_id, source_array, source_layer_size,
  485. stride))
  486. else:
  487. # Stackprop style feature: pull from another component's arrays.
  488. source = self.master.lookup_component[feature_spec.source_component]
  489. source_tensor = network_states[source.name].activations[
  490. feature_spec.source_layer]
  491. source_layer_size = source.network.get_layer_size(
  492. feature_spec.source_layer)
  493. linked_embeddings.append(
  494. network_units.activation_lookup_other(
  495. self, state, channel_id, source_tensor.dynamic_tensor,
  496. source_layer_size))
  497. context_tensor_arrays = []
  498. for context_layer in self.network.context_layers:
  499. index = self.network.get_layer_index(context_layer.name)
  500. context_tensor_arrays.append(arrays[index])
  501. if self.spec.attention_component:
  502. logging.info('%s component has attention over %s', self.name,
  503. self.spec.attention_component)
  504. source = self.master.lookup_component[self.spec.attention_component]
  505. network_state = network_states[self.spec.attention_component]
  506. with tf.control_dependencies(
  507. [tf.assert_equal(state.current_batch_size, 1)]):
  508. attention_tensor = tf.identity(
  509. network_state.activations['layer_0'].bulk_tensor)
  510. else:
  511. attention_tensor = None
  512. return self.network.create(fixed_embeddings, linked_embeddings,
  513. context_tensor_arrays, attention_tensor,
  514. during_training)