graph_builder.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builds a DRAGNN graph for local training."""
  16. import tensorflow as tf
  17. from tensorflow.core.protobuf import saver_pb2
  18. from tensorflow.python.platform import tf_logging as logging
  19. from dragnn.protos import spec_pb2
  20. from dragnn.python import component
  21. from dragnn.python import composite_optimizer
  22. from dragnn.python import dragnn_ops
  23. from syntaxnet.util import check
  24. try:
  25. tf.NotDifferentiable('ExtractFixedFeatures')
  26. except KeyError, e:
  27. logging.info(str(e))
  28. def _create_learning_rate(hyperparams, step_var):
  29. """Creates learning rate var, with decay and switching for CompositeOptimizer.
  30. Args:
  31. hyperparams: a GridPoint proto containing optimizer spec, particularly
  32. learning_method to determine optimizer class to use.
  33. step_var: tf.Variable, global training step.
  34. Returns:
  35. a scalar `Tensor`, the learning rate based on current step and hyperparams.
  36. """
  37. if hyperparams.learning_method != 'composite':
  38. base_rate = hyperparams.learning_rate
  39. else:
  40. spec = hyperparams.composite_optimizer_spec
  41. switch = tf.less(step_var, spec.switch_after_steps)
  42. base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
  43. lambda: tf.constant(spec.method2.learning_rate))
  44. return tf.train.exponential_decay(
  45. base_rate,
  46. step_var,
  47. hyperparams.decay_steps,
  48. hyperparams.decay_base,
  49. staircase=hyperparams.decay_staircase)
  50. def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
  51. """Creates an optimizer object for a given spec, learning rate and step var.
  52. Args:
  53. hyperparams: a GridPoint proto containing optimizer spec, particularly
  54. learning_method to determine optimizer class to use.
  55. learning_rate_var: a `tf.Tensor`, the learning rate.
  56. step_var: a `tf.Variable`, global training step.
  57. Returns:
  58. a `tf.train.Optimizer` object that was built.
  59. """
  60. if hyperparams.learning_method == 'gradient_descent':
  61. return tf.train.GradientDescentOptimizer(
  62. learning_rate_var, use_locking=True)
  63. elif hyperparams.learning_method == 'adam':
  64. return tf.train.AdamOptimizer(
  65. learning_rate_var,
  66. beta1=hyperparams.adam_beta1,
  67. beta2=hyperparams.adam_beta2,
  68. epsilon=hyperparams.adam_eps,
  69. use_locking=True)
  70. elif hyperparams.learning_method == 'lazyadam':
  71. return tf.contrib.opt.LazyAdamOptimizer(
  72. learning_rate_var,
  73. beta1=hyperparams.adam_beta1,
  74. beta2=hyperparams.adam_beta2,
  75. epsilon=hyperparams.adam_eps,
  76. use_locking=True)
  77. elif hyperparams.learning_method == 'momentum':
  78. return tf.train.MomentumOptimizer(
  79. learning_rate_var, hyperparams.momentum, use_locking=True)
  80. elif hyperparams.learning_method == 'composite':
  81. spec = hyperparams.composite_optimizer_spec
  82. optimizer1 = _create_optimizer(spec.method1, learning_rate_var, step_var)
  83. optimizer2 = _create_optimizer(spec.method2, learning_rate_var, step_var)
  84. if step_var is None:
  85. logging.fatal('step_var is required for CompositeOptimizer')
  86. switch = tf.less(step_var, spec.switch_after_steps)
  87. return composite_optimizer.CompositeOptimizer(
  88. optimizer1, optimizer2, switch, use_locking=True)
  89. else:
  90. logging.fatal('Unknown learning method (optimizer)')
  91. class MasterBuilder(object):
  92. """A builder for a DRAGNN stack of models.
  93. This class is the major factory for all DRAGNN models. It provides
  94. common hooks to build training and evaluation targets from a single
  95. MasterSpec and hyperparameter configuration.
  96. The key concept is as follows: to execute a DRAGNN graph, one needs
  97. two stateful pieces:
  98. 1. A handle to a C++ dragnn state, managed outside of TensorFlow and
  99. accesssed via the custom dragnn ops.
  100. 2. A set of StoredActivations, one for each component, that contain network
  101. activations that can be used across components.
  102. TODO(googleuser): Update these comments to be accurate.
  103. Both of these can be handled automatically "under-the-hood" by the
  104. MasterBuilder API. For #1, the key consideration is that each C++
  105. ComputeSession is allocated statically, meaning memory is shared
  106. across different tensorflow::Session invocations. ComputeSessions are
  107. allocated from pools. The `pool_scope` identifies the pool, unique to this
  108. MasterBuilder, from which the ComputeSession is allocated. From there,
  109. GetSession takes care of handing out ComputeSessions with unique handles.
  110. Each ComputeSession can then be run concurrently.
  111. Attributes:
  112. spec: the MasterSpec proto.
  113. hyperparams: the GridPoint proto containing hyperparameters.
  114. pool_scope: string identifier for the ComputeSession pool to use.
  115. components: a list of ComponentBuilders in the order they are defined
  116. in the MasterSpec.
  117. lookup_component: a dictionary to lookup ComponentBuilders by name.
  118. optimizer: handle to the tf.train Optimizer object used to train this model.
  119. master_vars: dictionary of globally shared tf.Variable objects (e.g.
  120. the global training step and learning rate.)
  121. """
  122. def __init__(self, master_spec, hyperparam_config=None, pool_scope='shared'):
  123. """Initializes the MasterBuilder from specifications.
  124. During construction, all components are initialized along with their
  125. parameter tf.Variables.
  126. Args:
  127. master_spec: dragnn.MasterSpec proto.
  128. hyperparam_config: dragnn.GridPoint proto specifying hyperparameters.
  129. Defaults to empty specification.
  130. pool_scope: string identifier for the compute session pool to use.
  131. Raises:
  132. ValueError: if a component is not found in the registry.
  133. """
  134. self.spec = master_spec
  135. self.hyperparams = (spec_pb2.GridPoint()
  136. if hyperparam_config is None else hyperparam_config)
  137. self.pool_scope = pool_scope
  138. # Set the graph-level random seed before creating the Components so the ops
  139. # they create will use this seed.
  140. tf.set_random_seed(hyperparam_config.seed)
  141. # Construct all utility class and variables for each Component.
  142. self.components = []
  143. self.lookup_component = {}
  144. for component_spec in master_spec.component:
  145. component_type = component_spec.component_builder.registered_name
  146. # Raises ValueError if not found.
  147. comp = component.ComponentBuilderBase.Create(component_type, self,
  148. component_spec)
  149. self.lookup_component[comp.name] = comp
  150. self.components.append(comp)
  151. # Add global step variable.
  152. self.master_vars = {}
  153. with tf.variable_scope('master', reuse=False):
  154. self.master_vars['step'] = tf.get_variable(
  155. 'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
  156. self.master_vars['learning_rate'] = _create_learning_rate(
  157. self.hyperparams, self.master_vars['step'])
  158. # Construct optimizer.
  159. self.optimizer = _create_optimizer(self.hyperparams,
  160. self.master_vars['learning_rate'],
  161. self.master_vars['step'])
  162. @property
  163. def component_names(self):
  164. return tuple(c.name for c in self.components)
  165. def _get_compute_session(self):
  166. """Returns a new ComputeSession handle."""
  167. return dragnn_ops.get_session(
  168. self.pool_scope,
  169. master_spec=self.spec.SerializeToString(),
  170. grid_point=self.hyperparams.SerializeToString(),
  171. name='GetSession')
  172. def _get_session_with_reader(self, enable_tracing):
  173. """Utility to create ComputeSession management ops.
  174. Creates a new ComputeSession handle and provides the following
  175. named nodes:
  176. ComputeSession/InputBatch -- a placeholder for attaching a string
  177. specification for AttachReader.
  178. ComputeSession/AttachReader -- the AttachReader op.
  179. Args:
  180. enable_tracing: bool, whether to enable tracing before attaching the data.
  181. Returns:
  182. handle: handle to a new ComputeSession returned by the AttachReader op.
  183. input_batch: InputBatch placeholder.
  184. """
  185. with tf.name_scope('ComputeSession'):
  186. input_batch = tf.placeholder(
  187. dtype=tf.string, shape=[None], name='InputBatch')
  188. # Get the ComputeSession and chain some essential ops.
  189. handle = self._get_compute_session()
  190. if enable_tracing:
  191. handle = dragnn_ops.set_tracing(handle, True)
  192. handle = dragnn_ops.attach_data_reader(
  193. handle, input_batch, name='AttachReader')
  194. return handle, input_batch
  195. def _outputs_with_release(self, handle, inputs, outputs):
  196. """Ensures ComputeSession is released before outputs are returned.
  197. Args:
  198. handle: Handle to ComputeSession on which all computation until now has
  199. depended. It will be released and assigned to the output 'run'.
  200. inputs: list of nodes we want to pass through without any dependencies.
  201. outputs: list of nodes whose access should ensure the ComputeSession is
  202. safely released.
  203. Returns:
  204. A dictionary of both input and output nodes.
  205. """
  206. with tf.control_dependencies(outputs.values()):
  207. with tf.name_scope('ComputeSession'):
  208. release_op = dragnn_ops.release_session(handle)
  209. run_op = tf.group(release_op, name='run')
  210. for output in outputs:
  211. with tf.control_dependencies([release_op]):
  212. outputs[output] = tf.identity(outputs[output], name=output)
  213. all_nodes = inputs.copy()
  214. all_nodes.update(outputs)
  215. # Add an alias for simply running without collecting outputs.
  216. # Common, for instance, with training.
  217. all_nodes['run'] = run_op
  218. return all_nodes
  219. def build_training(self,
  220. handle,
  221. compute_gradients=True,
  222. use_moving_average=False,
  223. advance_counters=True,
  224. component_weights=None,
  225. unroll_using_oracle=None,
  226. max_index=-1):
  227. """Builds a training pipeline.
  228. Args:
  229. handle: Handle tensor for the ComputeSession.
  230. compute_gradients: Whether to generate gradients and an optimizer op.
  231. When False, build_training will return a 'dry run' training op,
  232. used normally only for oracle tracing.
  233. use_moving_average: Whether or not to read from the moving
  234. average variables instead of the true parameters. Note: it is not
  235. possible to make gradient updates when this is True.
  236. advance_counters: Whether or not this loop should increment the
  237. per-component step counters.
  238. component_weights: If set, this is a list of relative weights
  239. each component's cost should get in the pipeline. Defaults to 1.0 for
  240. each component.
  241. unroll_using_oracle: If set, this is a list of booleans indicating
  242. whether or not to use the gold decodings for each component. Defaults
  243. to True for each component.
  244. max_index: Training will use only the first max_index components,
  245. or -1 for all components.
  246. Returns:
  247. handle: to the ComputeSession, conditioned on completing training step.
  248. outputs: a dictionary of useful training tensors.
  249. Raises:
  250. IndexError: if max_index is positive but out of bounds.
  251. """
  252. check.IsFalse(compute_gradients and use_moving_average,
  253. 'It is not possible to make gradient updates when reading '
  254. 'from the moving average variables.')
  255. self.read_from_avg = use_moving_average
  256. if max_index < 0:
  257. max_index = len(self.components)
  258. else:
  259. if not 0 < max_index <= len(self.components):
  260. raise IndexError('Invalid max_index {} for components {}; handle {}'.
  261. format(max_index, self.component_names, handle.name))
  262. # By default, we train every component supervised.
  263. if not component_weights:
  264. component_weights = [1] * max_index
  265. if not unroll_using_oracle:
  266. unroll_using_oracle = [True] * max_index
  267. component_weights = component_weights[:max_index]
  268. total_weight = (float)(sum(component_weights))
  269. component_weights = [w / total_weight for w in component_weights]
  270. unroll_using_oracle = unroll_using_oracle[:max_index]
  271. logging.info('Creating training target:')
  272. logging.info('\tWeights: %s', component_weights)
  273. logging.info('\tOracle: %s', unroll_using_oracle)
  274. metrics_list = []
  275. cost = tf.constant(0.)
  276. effective_batch = tf.constant(0)
  277. avg_ops = []
  278. params_to_train = []
  279. network_states = {}
  280. for component_index in range(0, max_index):
  281. comp = self.components[component_index]
  282. network_states[comp.name] = component.NetworkState()
  283. logging.info('Initializing data for component "%s"', comp.name)
  284. handle = dragnn_ops.init_component_data(
  285. handle, beam_size=comp.training_beam_size, component=comp.name)
  286. # TODO(googleuser): Phase out component.MasterState.
  287. master_state = component.MasterState(handle,
  288. dragnn_ops.batch_size(
  289. handle, component=comp.name))
  290. with tf.control_dependencies([handle, cost]):
  291. args = (master_state, network_states)
  292. if unroll_using_oracle[component_index]:
  293. handle, component_cost, component_correct, component_total = (tf.cond(
  294. comp.training_beam_size > 1,
  295. lambda: comp.build_structured_training(*args),
  296. lambda: comp.build_greedy_training(*args)))
  297. else:
  298. handle = comp.build_greedy_inference(*args, during_training=True)
  299. component_cost = tf.constant(0.)
  300. component_correct, component_total = tf.constant(0), tf.constant(0)
  301. weighted_component_cost = tf.multiply(
  302. component_cost,
  303. tf.constant((float)(component_weights[component_index])),
  304. name='weighted_component_cost')
  305. cost += weighted_component_cost
  306. effective_batch += component_total
  307. metrics_list += [[component_total], [component_correct]]
  308. if advance_counters:
  309. with tf.control_dependencies(
  310. [comp.advance_counters(component_total)]):
  311. cost = tf.identity(cost)
  312. # Keep track of which parameters will be trained, and any moving
  313. # average updates to apply for these parameters.
  314. params_to_train += comp.network.params
  315. if self.hyperparams.use_moving_average:
  316. avg_ops += comp.avg_ops
  317. # Concatenate evaluation results
  318. metrics = tf.concat(metrics_list, 0)
  319. # If gradient computation is requested, then:
  320. # 1. compute the gradients,
  321. # 2. add an optimizer to update the parameters using the gradients,
  322. # 3. make the ComputeSession handle depend on the optimizer.
  323. if compute_gradients:
  324. logging.info('Creating train op with %d variables:\n\t%s',
  325. len(params_to_train),
  326. '\n\t'.join([x.name for x in params_to_train]))
  327. grads_and_vars = self.optimizer.compute_gradients(
  328. cost, var_list=params_to_train)
  329. clipped_gradients = [(self._clip_gradients(g), v)
  330. for g, v in grads_and_vars]
  331. minimize_op = self.optimizer.apply_gradients(
  332. clipped_gradients, global_step=self.master_vars['step'])
  333. if self.hyperparams.use_moving_average:
  334. with tf.control_dependencies([minimize_op]):
  335. minimize_op = tf.group(*avg_ops)
  336. # Make sure all the side-effectful minimizations ops finish before
  337. # proceeding.
  338. with tf.control_dependencies([minimize_op]):
  339. handle = tf.identity(handle)
  340. # Restore that subsequent builds don't use average by default.
  341. self.read_from_avg = False
  342. # Returns named access to common outputs.
  343. outputs = {
  344. 'cost': cost,
  345. 'batch': effective_batch,
  346. 'metrics': metrics,
  347. }
  348. return handle, outputs
  349. def _clip_gradients(self, grad):
  350. """Clips gradients if the hyperparameter `gradient_clip_norm` requires it.
  351. Sparse tensors, in the form of IndexedSlices returned for the
  352. gradients of embeddings, require special handling.
  353. Args:
  354. grad: Gradient Tensor, IndexedSlices, or None.
  355. Returns:
  356. Optionally clipped gradient.
  357. """
  358. if grad is not None and self.hyperparams.gradient_clip_norm > 0:
  359. logging.info('Clipping gradient %s', grad)
  360. if isinstance(grad, tf.IndexedSlices):
  361. tmp = tf.clip_by_norm(grad.values, self.hyperparams.gradient_clip_norm)
  362. return tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
  363. else:
  364. return tf.clip_by_norm(grad, self.hyperparams.gradient_clip_norm)
  365. else:
  366. return grad
  367. def build_post_restore_hook(self):
  368. """Builds a graph that should be executed after the restore op.
  369. This graph is intended to be run once, before the inference pipeline is
  370. run.
  371. Returns:
  372. setup_op - An op that, when run, guarantees all setup ops will run.
  373. """
  374. with tf.control_dependencies(
  375. [comp.build_post_restore_hook() for comp in self.components]):
  376. return tf.no_op(name='post_restore_hook_master')
  377. def build_inference(self, handle, use_moving_average=False):
  378. """Builds an inference pipeline.
  379. This always uses the whole pipeline.
  380. Args:
  381. handle: Handle tensor for the ComputeSession.
  382. use_moving_average: Whether or not to read from the moving
  383. average variables instead of the true parameters. Note: it is not
  384. possible to make gradient updates when this is True.
  385. Returns:
  386. handle: Handle after annotation.
  387. """
  388. self.read_from_avg = use_moving_average
  389. network_states = {}
  390. for comp in self.components:
  391. network_states[comp.name] = component.NetworkState()
  392. handle = dragnn_ops.init_component_data(
  393. handle, beam_size=comp.inference_beam_size, component=comp.name)
  394. master_state = component.MasterState(handle,
  395. dragnn_ops.batch_size(
  396. handle, component=comp.name))
  397. with tf.control_dependencies([handle]):
  398. handle = comp.build_greedy_inference(master_state, network_states)
  399. handle = dragnn_ops.write_annotations(handle, component=comp.name)
  400. self.read_from_avg = False
  401. return handle
  402. def add_training_from_config(self,
  403. target_config,
  404. prefix='train-',
  405. trace_only=False,
  406. **kwargs):
  407. """Constructs a training pipeline from a TrainTarget proto.
  408. This constructs a separately managed pipeline for a given target:
  409. it has its own ComputeSession, InputSpec placeholder, etc. The ops
  410. are given standardized names to allow access from the C++ API. It
  411. passes the values in target_config to build_training() above.
  412. For the default prefix ('train-'), and a target named 'target', this will
  413. construct the following targets in the graph:
  414. train-target/ComputeSession/* (the standard ComputeSession controls)
  415. train-target/run (handle to a completed training step)
  416. train-target/metrics (per-decision metrics from gold oracles)
  417. train-target/cost (total cost across all components)
  418. Enabling `trace_only` effectively creates a graph that is a 'dry run'.
  419. There will be no side affects. In addition, the gradients won't be computed
  420. and the model parameters will not be updated.
  421. Args:
  422. target_config: the TrainTarget proto.
  423. prefix: Preprends target_config.name with this to construct
  424. a unique identifier.
  425. trace_only: Enabling this will result in:
  426. 1. Tracing will be enabled for the ComputeSession..
  427. 2. A 'traces' node will be added to the outputs.
  428. 3. Gradients will not be computed.
  429. **kwargs: Passed on to build_training() above.
  430. Returns:
  431. Dictionary of training targets.
  432. """
  433. logging.info('Creating new training target '
  434. '%s'
  435. ' from config: %s', target_config.name, str(target_config))
  436. scope_id = prefix + target_config.name
  437. with tf.name_scope(scope_id):
  438. # Construct training targets. Disable tracing during training.
  439. handle, input_batch = self._get_session_with_reader(trace_only)
  440. # If `trace_only` is True, the training graph shouldn't have any
  441. # side effects. Otherwise, the standard training scenario should
  442. # generate gradients and update counters.
  443. handle, outputs = self.build_training(
  444. handle,
  445. compute_gradients=not trace_only,
  446. advance_counters=not trace_only,
  447. component_weights=target_config.component_weights,
  448. unroll_using_oracle=target_config.unroll_using_oracle,
  449. max_index=target_config.max_index,
  450. **kwargs)
  451. if trace_only:
  452. outputs['traces'] = dragnn_ops.get_component_trace(
  453. handle, component=self.spec.component[-1].name)
  454. else:
  455. # Standard training keeps track of the number of training steps.
  456. outputs['target_step'] = tf.get_variable(
  457. scope_id + '/TargetStep', [],
  458. initializer=tf.zeros_initializer(),
  459. dtype=tf.int32)
  460. increment_target_step = tf.assign_add(
  461. outputs['target_step'], 1, use_locking=True)
  462. with tf.control_dependencies([increment_target_step]):
  463. handle = tf.identity(handle)
  464. return self._outputs_with_release(handle, {'input_batch': input_batch},
  465. outputs)
  466. def add_annotation(self, name_scope='annotation', enable_tracing=False):
  467. """Adds an annotation pipeline to the graph.
  468. This will create the following additional named targets by default, for use
  469. in C++ annotation code (as well as regular ComputeSession targets):
  470. annotation/ComputeSession/session_id (placeholder for giving unique id)
  471. annotation/EmitAnnotations (get annotated data)
  472. annotation/GetComponentTrace (get trace data)
  473. annotation/SetTracing (sets tracing based on annotation/tracing_on)
  474. Args:
  475. name_scope: Scope for the annotation pipeline.
  476. enable_tracing: Enabling this will result in two things:
  477. 1. Tracing will be enabled during inference.
  478. 2. A 'traces' node will be added to the outputs.
  479. Returns:
  480. A dictionary of input and output nodes.
  481. """
  482. with tf.name_scope(name_scope):
  483. handle, input_batch = self._get_session_with_reader(enable_tracing)
  484. handle = self.build_inference(handle, use_moving_average=True)
  485. annotations = dragnn_ops.emit_annotations(
  486. handle, component=self.spec.component[-1].name)
  487. outputs = {'annotations': annotations}
  488. if enable_tracing:
  489. outputs['traces'] = dragnn_ops.get_component_trace(
  490. handle, component=self.spec.component[-1].name)
  491. return self._outputs_with_release(handle, {'input_batch': input_batch},
  492. outputs)
  493. def add_post_restore_hook(self, name_scope):
  494. """Adds the post restore ops."""
  495. with tf.name_scope(name_scope):
  496. return self.build_post_restore_hook()
  497. def add_saver(self):
  498. """Adds a Saver for all variables in the graph."""
  499. logging.info('Saving non-quantized variables:\n\t%s', '\n\t'.join(
  500. [x.name for x in tf.global_variables() if 'quantized' not in x.name]))
  501. self.saver = tf.train.Saver(
  502. var_list=[
  503. x for x in tf.global_variables() if 'quantized' not in x.name
  504. ],
  505. write_version=saver_pb2.SaverDef.V1)