graph_builder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585
  1. """Builds a DRAGNN graph for local training."""
  2. import tensorflow as tf
  3. from tensorflow.core.protobuf import saver_pb2
  4. from tensorflow.python.platform import tf_logging as logging
  5. from dragnn.protos import spec_pb2
  6. from dragnn.python import component
  7. from dragnn.python import composite_optimizer
  8. from dragnn.python import dragnn_ops
  9. from syntaxnet.util import check
  10. try:
  11. tf.NotDifferentiable('ExtractFixedFeatures')
  12. except KeyError, e:
  13. logging.info(str(e))
  14. def _create_learning_rate(hyperparams, step_var):
  15. """Creates learning rate var, with decay and switching for CompositeOptimizer.
  16. Args:
  17. hyperparams: a GridPoint proto containing optimizer spec, particularly
  18. learning_method to determine optimizer class to use.
  19. step_var: tf.Variable, global training step.
  20. Returns:
  21. a scalar `Tensor`, the learning rate based on current step and hyperparams.
  22. """
  23. if hyperparams.learning_method != 'composite':
  24. base_rate = hyperparams.learning_rate
  25. else:
  26. spec = hyperparams.composite_optimizer_spec
  27. switch = tf.less(step_var, spec.switch_after_steps)
  28. base_rate = tf.cond(switch, lambda: tf.constant(spec.method1.learning_rate),
  29. lambda: tf.constant(spec.method2.learning_rate))
  30. return tf.train.exponential_decay(
  31. base_rate,
  32. step_var,
  33. hyperparams.decay_steps,
  34. hyperparams.decay_base,
  35. staircase=hyperparams.decay_staircase)
  36. def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
  37. """Creates an optimizer object for a given spec, learning rate and step var.
  38. Args:
  39. hyperparams: a GridPoint proto containing optimizer spec, particularly
  40. learning_method to determine optimizer class to use.
  41. learning_rate_var: a `tf.Tensor`, the learning rate.
  42. step_var: a `tf.Variable`, global training step.
  43. Returns:
  44. a `tf.train.Optimizer` object that was built.
  45. """
  46. if hyperparams.learning_method == 'gradient_descent':
  47. return tf.train.GradientDescentOptimizer(
  48. learning_rate_var, use_locking=True)
  49. elif hyperparams.learning_method == 'adam':
  50. return tf.train.AdamOptimizer(
  51. learning_rate_var,
  52. beta1=hyperparams.adam_beta1,
  53. beta2=hyperparams.adam_beta2,
  54. epsilon=hyperparams.adam_eps,
  55. use_locking=True)
  56. elif hyperparams.learning_method == 'momentum':
  57. return tf.train.MomentumOptimizer(
  58. learning_rate_var, hyperparams.momentum, use_locking=True)
  59. elif hyperparams.learning_method == 'composite':
  60. spec = hyperparams.composite_optimizer_spec
  61. optimizer1 = _create_optimizer(spec.method1, learning_rate_var, step_var)
  62. optimizer2 = _create_optimizer(spec.method2, learning_rate_var, step_var)
  63. if step_var is None:
  64. logging.fatal('step_var is required for CompositeOptimizer')
  65. switch = tf.less(step_var, spec.switch_after_steps)
  66. return composite_optimizer.CompositeOptimizer(
  67. optimizer1, optimizer2, switch, use_locking=True)
  68. else:
  69. logging.fatal('Unknown learning method (optimizer)')
  70. class MasterBuilder(object):
  71. """A builder for a DRAGNN stack of models.
  72. This class is the major factory for all DRAGNN models. It provides
  73. common hooks to build training and evaluation targets from a single
  74. MasterSpec and hyperparameter configuration.
  75. The key concept is as follows: to execute a DRAGNN graph, one needs
  76. two stateful pieces:
  77. 1. A handle to a C++ dragnn state, managed outside of TensorFlow and
  78. accesssed via the custom dragnn ops.
  79. 2. A set of StoredActivations, one for each component, that contain network
  80. activations that can be used across components.
  81. TODO(googleuser): Update these comments to be accurate.
  82. Both of these can be handled automatically "under-the-hood" by the
  83. MasterBuilder API. For #1, the key consideration is that each C++
  84. ComputeSession is allocated statically, meaning memory is shared
  85. across different tensorflow::Session invocations. ComputeSessions are
  86. allocated from pools. The `pool_scope` identifies the pool, unique to this
  87. MasterBuilder, from which the ComputeSession is allocated. From there,
  88. GetSession takes care of handing out ComputeSessions with unique handles.
  89. Each ComputeSession can then be run concurrently.
  90. Attributes:
  91. spec: the MasterSpec proto.
  92. hyperparams: the GridPoint proto containing hyperparameters.
  93. pool_scope: string identifier for the ComputeSession pool to use.
  94. components: a list of ComponentBuilders in the order they are defined
  95. in the MasterSpec.
  96. lookup_component: a dictionary to lookup ComponentBuilders by name.
  97. optimizer: handle to the tf.train Optimizer object used to train this model.
  98. master_vars: dictionary of globally shared tf.Variable objects (e.g.
  99. the global training step and learning rate.)
  100. """
  101. def __init__(self, master_spec, hyperparam_config=None, pool_scope='shared'):
  102. """Initializes the MasterBuilder from specifications.
  103. During construction, all components are initialized along with their
  104. parameter tf.Variables.
  105. Args:
  106. master_spec: dragnn.MasterSpec proto.
  107. hyperparam_config: dragnn.GridPoint proto specifying hyperparameters.
  108. Defaults to empty specification.
  109. pool_scope: string identifier for the compute session pool to use.
  110. Raises:
  111. ValueError: if a component is not found in the registry.
  112. """
  113. self.spec = master_spec
  114. self.hyperparams = (spec_pb2.GridPoint()
  115. if hyperparam_config is None else hyperparam_config)
  116. self.pool_scope = pool_scope
  117. # Construct all utility class and variables for each Component.
  118. self.components = []
  119. self.lookup_component = {}
  120. for component_spec in master_spec.component:
  121. component_type = component_spec.component_builder.registered_name
  122. # Raises ValueError if not found.
  123. comp = component.ComponentBuilderBase.Create(component_type, self,
  124. component_spec)
  125. self.lookup_component[comp.name] = comp
  126. self.components.append(comp)
  127. # Add global step variable.
  128. self.master_vars = {}
  129. with tf.variable_scope('master', reuse=False):
  130. self.master_vars['step'] = tf.get_variable(
  131. 'step', [], initializer=tf.zeros_initializer(), dtype=tf.int32)
  132. self.master_vars['learning_rate'] = _create_learning_rate(
  133. self.hyperparams, self.master_vars['step'])
  134. # Construct optimizer.
  135. self.optimizer = _create_optimizer(self.hyperparams,
  136. self.master_vars['learning_rate'],
  137. self.master_vars['step'])
  138. @property
  139. def component_names(self):
  140. return tuple(c.name for c in self.components)
  141. def _get_compute_session(self):
  142. """Returns a new ComputeSession handle."""
  143. return dragnn_ops.get_session(
  144. self.pool_scope,
  145. master_spec=self.spec.SerializeToString(),
  146. grid_point=self.hyperparams.SerializeToString(),
  147. name='GetSession')
  148. def _get_session_with_reader(self, enable_tracing):
  149. """Utility to create ComputeSession management ops.
  150. Creates a new ComputeSession handle and provides the following
  151. named nodes:
  152. ComputeSession/InputBatch -- a placeholder for attaching a string
  153. specification for AttachReader.
  154. ComputeSession/AttachReader -- the AttachReader op.
  155. Args:
  156. enable_tracing: bool, whether to enable tracing before attaching the data.
  157. Returns:
  158. handle: handle to a new ComputeSession returned by the AttachReader op.
  159. input_batch: InputBatch placeholder.
  160. """
  161. with tf.name_scope('ComputeSession'):
  162. input_batch = tf.placeholder(
  163. dtype=tf.string, shape=[None], name='InputBatch')
  164. # Get the ComputeSession and chain some essential ops.
  165. handle = self._get_compute_session()
  166. if enable_tracing:
  167. handle = dragnn_ops.set_tracing(handle, True)
  168. handle = dragnn_ops.attach_data_reader(
  169. handle, input_batch, name='AttachReader')
  170. return handle, input_batch
  171. def _outputs_with_release(self, handle, inputs, outputs):
  172. """Ensures ComputeSession is released before outputs are returned.
  173. Args:
  174. handle: Handle to ComputeSession on which all computation until now has
  175. depended. It will be released and assigned to the output 'run'.
  176. inputs: list of nodes we want to pass through without any dependencies.
  177. outputs: list of nodes whose access should ensure the ComputeSession is
  178. safely released.
  179. Returns:
  180. A dictionary of both input and output nodes.
  181. """
  182. with tf.control_dependencies(outputs.values()):
  183. with tf.name_scope('ComputeSession'):
  184. release_op = dragnn_ops.release_session(handle)
  185. run_op = tf.group(release_op, name='run')
  186. for output in outputs:
  187. with tf.control_dependencies([release_op]):
  188. outputs[output] = tf.identity(outputs[output], name=output)
  189. all_nodes = inputs.copy()
  190. all_nodes.update(outputs)
  191. # Add an alias for simply running without collecting outputs.
  192. # Common, for instance, with training.
  193. all_nodes['run'] = run_op
  194. return all_nodes
  195. def build_training(self,
  196. handle,
  197. compute_gradients=True,
  198. use_moving_average=False,
  199. advance_counters=True,
  200. component_weights=None,
  201. unroll_using_oracle=None,
  202. max_index=-1):
  203. """Builds a training pipeline.
  204. Args:
  205. handle: Handle tensor for the ComputeSession.
  206. compute_gradients: Whether to generate gradients and an optimizer op.
  207. When False, build_training will return a 'dry run' training op,
  208. used normally only for oracle tracing.
  209. use_moving_average: Whether or not to read from the moving
  210. average variables instead of the true parameters. Note: it is not
  211. possible to make gradient updates when this is True.
  212. advance_counters: Whether or not this loop should increment the
  213. per-component step counters.
  214. component_weights: If set, this is a list of relative weights
  215. each component's cost should get in the pipeline. Defaults to 1.0 for
  216. each component.
  217. unroll_using_oracle: If set, this is a list of booleans indicating
  218. whether or not to use the gold decodings for each component. Defaults
  219. to True for each component.
  220. max_index: Training will use only the first max_index components,
  221. or -1 for all components.
  222. Returns:
  223. handle: to the ComputeSession, conditioned on completing training step.
  224. outputs: a dictionary of useful training tensors.
  225. Raises:
  226. IndexError: if max_index is positive but out of bounds.
  227. """
  228. check.IsFalse(compute_gradients and use_moving_average,
  229. 'It is not possible to make gradient updates when reading '
  230. 'from the moving average variables.')
  231. self.read_from_avg = use_moving_average
  232. if max_index < 0:
  233. max_index = len(self.components)
  234. else:
  235. if not 0 < max_index <= len(self.components):
  236. raise IndexError('Invalid max_index {} for components {}; handle {}'.
  237. format(max_index, self.component_names, handle.name))
  238. # By default, we train every component supervised.
  239. if not component_weights:
  240. component_weights = [1] * max_index
  241. if not unroll_using_oracle:
  242. unroll_using_oracle = [True] * max_index
  243. component_weights = component_weights[:max_index]
  244. total_weight = (float)(sum(component_weights))
  245. component_weights = [w / total_weight for w in component_weights]
  246. unroll_using_oracle = unroll_using_oracle[:max_index]
  247. logging.info('Creating training target:')
  248. logging.info('\tWeights: %s', component_weights)
  249. logging.info('\tOracle: %s', unroll_using_oracle)
  250. metrics_list = []
  251. cost = tf.constant(0.)
  252. effective_batch = tf.constant(0)
  253. avg_ops = []
  254. params_to_train = []
  255. network_states = {}
  256. for component_index in range(0, max_index):
  257. comp = self.components[component_index]
  258. network_states[comp.name] = component.NetworkState()
  259. logging.info('Initializing data for component "%s"', comp.name)
  260. handle = dragnn_ops.init_component_data(
  261. handle, beam_size=comp.training_beam_size, component=comp.name)
  262. # TODO(googleuser): Phase out component.MasterState.
  263. master_state = component.MasterState(handle,
  264. dragnn_ops.batch_size(
  265. handle, component=comp.name))
  266. with tf.control_dependencies([handle, cost]):
  267. component_cost = tf.constant(0.)
  268. component_correct = tf.constant(0)
  269. component_total = tf.constant(0)
  270. if unroll_using_oracle[component_index]:
  271. handle, component_cost, component_correct, component_total = (
  272. comp.build_greedy_training(master_state, network_states))
  273. else:
  274. handle = comp.build_greedy_inference(
  275. master_state, network_states, during_training=True)
  276. weighted_component_cost = tf.multiply(
  277. component_cost,
  278. tf.constant((float)(component_weights[component_index])),
  279. name='weighted_component_cost')
  280. cost += weighted_component_cost
  281. effective_batch += component_total
  282. metrics_list += [[component_total], [component_correct]]
  283. if advance_counters:
  284. with tf.control_dependencies(
  285. [comp.advance_counters(component_total)]):
  286. cost = tf.identity(cost)
  287. # Keep track of which parameters will be trained, and any moving
  288. # average updates to apply for these parameters.
  289. params_to_train += comp.network.params
  290. if self.hyperparams.use_moving_average:
  291. avg_ops += comp.avg_ops
  292. # Concatenate evaluation results
  293. metrics = tf.concat(metrics_list, 0)
  294. # If gradient computation is requested, then:
  295. # 1. compute the gradients,
  296. # 2. add an optimizer to update the parameters using the gradients,
  297. # 3. make the ComputeSession handle depend on the optimizer.
  298. if compute_gradients:
  299. logging.info('Creating train op with %d variables:\n\t%s',
  300. len(params_to_train),
  301. '\n\t'.join([x.name for x in params_to_train]))
  302. grads_and_vars = self.optimizer.compute_gradients(
  303. cost, var_list=params_to_train)
  304. clipped_gradients = [(self._clip_gradients(g), v)
  305. for g, v in grads_and_vars]
  306. minimize_op = self.optimizer.apply_gradients(
  307. clipped_gradients, global_step=self.master_vars['step'])
  308. if self.hyperparams.use_moving_average:
  309. with tf.control_dependencies([minimize_op]):
  310. minimize_op = tf.group(*avg_ops)
  311. # Make sure all the side-effectful minimizations ops finish before
  312. # proceeding.
  313. with tf.control_dependencies([minimize_op]):
  314. handle = tf.identity(handle)
  315. # Restore that subsequent builds don't use average by default.
  316. self.read_from_avg = False
  317. # Returns named access to common outputs.
  318. outputs = {
  319. 'cost': cost,
  320. 'batch': effective_batch,
  321. 'metrics': metrics,
  322. }
  323. return handle, outputs
  324. def _clip_gradients(self, grad):
  325. """Clips gradients if the hyperparameter `gradient_clip_norm` requires it.
  326. Sparse tensors, in the form of IndexedSlices returned for the
  327. gradients of embeddings, require special handling.
  328. Args:
  329. grad: Gradient Tensor, IndexedSlices, or None.
  330. Returns:
  331. Optionally clipped gradient.
  332. """
  333. if grad is not None and self.hyperparams.gradient_clip_norm > 0:
  334. logging.info('Clipping gradient %s', grad)
  335. if isinstance(grad, tf.IndexedSlices):
  336. tmp = tf.clip_by_norm(grad.values, self.hyperparams.gradient_clip_norm)
  337. return tf.IndexedSlices(tmp, grad.indices, grad.dense_shape)
  338. else:
  339. return tf.clip_by_norm(grad, self.hyperparams.gradient_clip_norm)
  340. else:
  341. return grad
  342. def build_post_restore_hook(self):
  343. """Builds a graph that should be executed after the restore op.
  344. This graph is intended to be run once, before the inference pipeline is
  345. run.
  346. Returns:
  347. setup_op - An op that, when run, guarantees all setup ops will run.
  348. """
  349. with tf.control_dependencies(
  350. [comp.build_post_restore_hook() for comp in self.components]):
  351. return tf.no_op(name='post_restore_hook_master')
  352. def build_inference(self, handle, use_moving_average=False):
  353. """Builds an inference pipeline.
  354. This always uses the whole pipeline.
  355. Args:
  356. handle: Handle tensor for the ComputeSession.
  357. use_moving_average: Whether or not to read from the moving
  358. average variables instead of the true parameters. Note: it is not
  359. possible to make gradient updates when this is True.
  360. Returns:
  361. handle: Handle after annotation.
  362. """
  363. self.read_from_avg = use_moving_average
  364. network_states = {}
  365. for comp in self.components:
  366. network_states[comp.name] = component.NetworkState()
  367. handle = dragnn_ops.init_component_data(
  368. handle, beam_size=comp.inference_beam_size, component=comp.name)
  369. master_state = component.MasterState(handle,
  370. dragnn_ops.batch_size(
  371. handle, component=comp.name))
  372. with tf.control_dependencies([handle]):
  373. handle = comp.build_greedy_inference(master_state, network_states)
  374. handle = dragnn_ops.write_annotations(handle, component=comp.name)
  375. self.read_from_avg = False
  376. return handle
  377. def add_training_from_config(self,
  378. target_config,
  379. prefix='train-',
  380. trace_only=False,
  381. **kwargs):
  382. """Constructs a training pipeline from a TrainTarget proto.
  383. This constructs a separately managed pipeline for a given target:
  384. it has its own ComputeSession, InputSpec placeholder, etc. The ops
  385. are given standardized names to allow access from the C++ API. It
  386. passes the values in target_config to build_training() above.
  387. For the default prefix ('train-'), and a target named 'target', this will
  388. construct the following targets in the graph:
  389. train-target/ComputeSession/* (the standard ComputeSession controls)
  390. train-target/run (handle to a completed training step)
  391. train-target/metrics (per-decision metrics from gold oracles)
  392. train-target/cost (total cost across all components)
  393. Enabling `trace_only` effectively creates a graph that is a 'dry run'.
  394. There will be no side affects. In addition, the gradients won't be computed
  395. and the model parameters will not be updated.
  396. Args:
  397. target_config: the TrainTarget proto.
  398. prefix: Preprends target_config.name with this to construct
  399. a unique identifier.
  400. trace_only: Enabling this will result in:
  401. 1. Tracing will be enabled for the ComputeSession..
  402. 2. A 'traces' node will be added to the outputs.
  403. 3. Gradients will not be computed.
  404. **kwargs: Passed on to build_training() above.
  405. Returns:
  406. Dictionary of training targets.
  407. """
  408. logging.info('Creating new training target '
  409. '%s'
  410. ' from config: %s', target_config.name, str(target_config))
  411. scope_id = prefix + target_config.name
  412. with tf.name_scope(scope_id):
  413. # Construct training targets. Disable tracing during training.
  414. handle, input_batch = self._get_session_with_reader(trace_only)
  415. if trace_only:
  416. # Build a training graph that doesn't have any side effects.
  417. handle, outputs = self.build_training(
  418. handle,
  419. compute_gradients=False,
  420. advance_counters=False,
  421. component_weights=target_config.component_weights,
  422. unroll_using_oracle=target_config.unroll_using_oracle,
  423. max_index=target_config.max_index,
  424. **kwargs)
  425. outputs['traces'] = dragnn_ops.get_component_trace(
  426. handle, component=self.spec.component[-1].name)
  427. else:
  428. # The standard training scenario has gradients and updates counters.
  429. handle, outputs = self.build_training(
  430. handle,
  431. compute_gradients=True,
  432. advance_counters=True,
  433. component_weights=target_config.component_weights,
  434. unroll_using_oracle=target_config.unroll_using_oracle,
  435. max_index=target_config.max_index,
  436. **kwargs)
  437. # In addition, it keeps track of the number of training steps.
  438. outputs['target_step'] = tf.get_variable(
  439. scope_id + '/TargetStep', [],
  440. initializer=tf.zeros_initializer(),
  441. dtype=tf.int32)
  442. increment_target_step = tf.assign_add(
  443. outputs['target_step'], 1, use_locking=True)
  444. with tf.control_dependencies([increment_target_step]):
  445. handle = tf.identity(handle)
  446. return self._outputs_with_release(handle, {'input_batch': input_batch},
  447. outputs)
  448. def add_annotation(self, name_scope='annotation', enable_tracing=False):
  449. """Adds an annotation pipeline to the graph.
  450. This will create the following additional named targets by default, for use
  451. in C++ annotation code (as well as regular ComputeSession targets):
  452. annotation/ComputeSession/session_id (placeholder for giving unique id)
  453. annotation/EmitAnnotations (get annotated data)
  454. annotation/GetComponentTrace (get trace data)
  455. annotation/SetTracing (sets tracing based on annotation/tracing_on)
  456. Args:
  457. name_scope: Scope for the annotation pipeline.
  458. enable_tracing: Enabling this will result in two things:
  459. 1. Tracing will be enabled during inference.
  460. 2. A 'traces' node will be added to the outputs.
  461. Returns:
  462. A dictionary of input and output nodes.
  463. """
  464. with tf.name_scope(name_scope):
  465. handle, input_batch = self._get_session_with_reader(enable_tracing)
  466. handle = self.build_inference(handle, use_moving_average=True)
  467. annotations = dragnn_ops.emit_annotations(
  468. handle, component=self.spec.component[-1].name)
  469. outputs = {'annotations': annotations}
  470. if enable_tracing:
  471. outputs['traces'] = dragnn_ops.get_component_trace(
  472. handle, component=self.spec.component[-1].name)
  473. return self._outputs_with_release(handle, {'input_batch': input_batch},
  474. outputs)
  475. def add_post_restore_hook(self, name_scope):
  476. """Adds the post restore ops."""
  477. with tf.name_scope(name_scope):
  478. return self.build_post_restore_hook()
  479. def add_saver(self):
  480. """Adds a Saver for all variables in the graph."""
  481. logging.info('Saving non-quantized variables:\n\t%s', '\n\t'.join(
  482. [x.name for x in tf.global_variables() if 'quantized' not in x.name]))
  483. self.saver = tf.train.Saver(
  484. var_list=[
  485. x for x in tf.global_variables() if 'quantized' not in x.name
  486. ],
  487. write_version=saver_pb2.SaverDef.V1)