graph_builder.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570
  1. # Copyright 2016 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Builds parser models."""
  16. import tensorflow as tf
  17. import syntaxnet.load_parser_ops
  18. from tensorflow.python.ops import control_flow_ops as cf
  19. from tensorflow.python.ops import state_ops
  20. from tensorflow.python.platform import logging
  21. from syntaxnet.ops import gen_parser_ops
  22. def BatchedSparseToDense(sparse_indices, output_size):
  23. """Batch compatible sparse to dense conversion.
  24. This is useful for one-hot coded target labels.
  25. Args:
  26. sparse_indices: [batch_size] tensor containing one index per batch
  27. output_size: needed in order to generate the correct dense output
  28. Returns:
  29. A [batch_size, output_size] dense tensor.
  30. """
  31. eye = tf.diag(tf.fill([output_size], tf.constant(1, tf.float32)))
  32. return tf.nn.embedding_lookup(eye, sparse_indices)
  33. def EmbeddingLookupFeatures(params, sparse_features, allow_weights):
  34. """Computes embeddings for each entry of sparse features sparse_features.
  35. Args:
  36. params: list of 2D tensors containing vector embeddings
  37. sparse_features: 1D tensor of strings. Each entry is a string encoding of
  38. dist_belief.SparseFeatures, and represents a variable length list of
  39. feature ids, and optionally, corresponding weights values.
  40. allow_weights: boolean to control whether the weights returned from the
  41. SparseFeatures are used to multiply the embeddings.
  42. Returns:
  43. A tensor representing the combined embeddings for the sparse features.
  44. For each entry s in sparse_features, the function looks up the embeddings
  45. for each id and sums them into a single tensor weighing them by the
  46. weight of each id. It returns a tensor with each entry of sparse_features
  47. replaced by this combined embedding.
  48. """
  49. if not isinstance(params, list):
  50. params = [params]
  51. # Lookup embeddings.
  52. sparse_features = tf.convert_to_tensor(sparse_features)
  53. indices, ids, weights = gen_parser_ops.unpack_sparse_features(sparse_features)
  54. embeddings = tf.nn.embedding_lookup(params, ids)
  55. if allow_weights:
  56. # Multiply by weights, reshaping to allow broadcast.
  57. broadcast_weights_shape = tf.concat(0, [tf.shape(weights), [1]])
  58. embeddings *= tf.reshape(weights, broadcast_weights_shape)
  59. # Sum embeddings by index.
  60. return tf.unsorted_segment_sum(embeddings, indices, tf.size(sparse_features))
  61. class GreedyParser(object):
  62. """Builds a Chen & Manning style greedy neural net parser.
  63. Builds a graph with an optional reader op connected at one end and
  64. operations needed to train the network on the other. Supports multiple
  65. network instantiations sharing the same parameters and network topology.
  66. The following named nodes are added to the training and eval networks:
  67. epochs: a tensor containing the current epoch number
  68. cost: a tensor containing the current training step cost
  69. gold_actions: a tensor containing actions from gold decoding
  70. feature_endpoints: a list of sparse feature vectors
  71. logits: output of the final layer before computing softmax
  72. The training network also contains:
  73. train_op: an op that executes a single training step
  74. Typical usage:
  75. parser = graph_builder.GreedyParser(num_actions, num_features,
  76. num_feature_ids, embedding_sizes,
  77. hidden_layer_sizes)
  78. parser.AddTraining(task_context, batch_size=5)
  79. with tf.Session('local') as sess:
  80. # This works because the session uses the same default graph as the
  81. # GraphBuilder did.
  82. sess.run(parser.inits.values())
  83. while True:
  84. tf_epoch, _ = sess.run([parser.training['epoch'],
  85. parser.training['train_op']])
  86. if tf_epoch[0] > 0:
  87. break
  88. """
  89. def __init__(self,
  90. num_actions,
  91. num_features,
  92. num_feature_ids,
  93. embedding_sizes,
  94. hidden_layer_sizes,
  95. seed=None,
  96. gate_gradients=False,
  97. use_locking=False,
  98. embedding_init=1.0,
  99. relu_init=1e-4,
  100. bias_init=0.2,
  101. softmax_init=1e-4,
  102. averaging_decay=0.9999,
  103. use_averaging=True,
  104. check_parameters=True,
  105. check_every=1,
  106. allow_feature_weights=False,
  107. only_train='',
  108. arg_prefix=None,
  109. **unused_kwargs):
  110. """Initialize the graph builder with parameters defining the network.
  111. Args:
  112. num_actions: int size of the set of parser actions
  113. num_features: int list of dimensions of the feature vectors
  114. num_feature_ids: int list of same length as num_features corresponding to
  115. the sizes of the input feature spaces
  116. embedding_sizes: int list of same length as num_features of the desired
  117. embedding layer sizes
  118. hidden_layer_sizes: int list of desired relu layer sizes; may be empty
  119. seed: optional random initializer seed to enable reproducibility
  120. gate_gradients: if True, gradient updates are computed synchronously,
  121. ensuring consistency and reproducibility
  122. use_locking: if True, use locking to avoid read-write contention when
  123. updating Variables
  124. embedding_init: sets the std dev of normal initializer of embeddings to
  125. embedding_init / embedding_size ** .5
  126. relu_init: sets the std dev of normal initializer of relu weights
  127. to relu_init
  128. bias_init: sets constant initializer of relu bias to bias_init
  129. softmax_init: sets the std dev of normal initializer of softmax init
  130. to softmax_init
  131. averaging_decay: decay for exponential moving average when computing
  132. averaged parameters, set to 1 to do vanilla averaging
  133. use_averaging: whether to use moving averages of parameters during evals
  134. check_parameters: whether to check for NaN/Inf parameters during
  135. training
  136. check_every: checks numerics every check_every steps.
  137. allow_feature_weights: whether feature weights are allowed.
  138. only_train: the comma separated set of parameter names to train. If empty,
  139. all model parameters will be trained.
  140. arg_prefix: prefix for context parameters.
  141. """
  142. self._num_actions = num_actions
  143. self._num_features = num_features
  144. self._num_feature_ids = num_feature_ids
  145. self._embedding_sizes = embedding_sizes
  146. self._hidden_layer_sizes = hidden_layer_sizes
  147. self._seed = seed
  148. self._gate_gradients = gate_gradients
  149. self._use_locking = use_locking
  150. self._use_averaging = use_averaging
  151. self._check_parameters = check_parameters
  152. self._check_every = check_every
  153. self._allow_feature_weights = allow_feature_weights
  154. self._only_train = set(only_train.split(',')) if only_train else None
  155. self._feature_size = len(embedding_sizes)
  156. self._embedding_init = embedding_init
  157. self._relu_init = relu_init
  158. self._softmax_init = softmax_init
  159. self._arg_prefix = arg_prefix
  160. # Parameters of the network with respect to which training is done.
  161. self.params = {}
  162. # Other variables, with respect to which no training is done, but which we
  163. # nonetheless need to save in order to capture the state of the graph.
  164. self.variables = {}
  165. # Operations to initialize any nodes that require initialization.
  166. self.inits = {}
  167. # Training- and eval-related nodes.
  168. self.training = {}
  169. self.evaluation = {}
  170. self.saver = None
  171. # Nodes to compute moving averages of parameters, called every train step.
  172. self._averaging = {}
  173. self._averaging_decay = averaging_decay
  174. # Pretrained embeddings that can be used instead of constant initializers.
  175. self._pretrained_embeddings = {}
  176. # After the following 'with' statement, we'll be able to re-enter the
  177. # 'params' scope by re-using the self._param_scope member variable. See for
  178. # instance _AddParam.
  179. with tf.name_scope('params') as self._param_scope:
  180. self._relu_bias_init = tf.constant_initializer(bias_init)
  181. @property
  182. def embedding_size(self):
  183. size = 0
  184. for i in range(self._feature_size):
  185. size += self._num_features[i] * self._embedding_sizes[i]
  186. return size
  187. def _AddParam(self,
  188. shape,
  189. dtype,
  190. name,
  191. initializer=None,
  192. return_average=False):
  193. """Add a model parameter w.r.t. we expect to compute gradients.
  194. _AddParam creates both regular parameters (usually for training) and
  195. averaged nodes (usually for inference). It returns one or the other based
  196. on the 'return_average' arg.
  197. Args:
  198. shape: int list, tensor shape of the parameter to create
  199. dtype: tf.DataType, data type of the parameter
  200. name: string, name of the parameter in the TF graph
  201. initializer: optional initializer for the paramter
  202. return_average: if False, return parameter otherwise return moving average
  203. Returns:
  204. parameter or averaged parameter
  205. """
  206. if name not in self.params:
  207. step = tf.cast(self.GetStep(), tf.float32)
  208. # Put all parameters and their initializing ops in their own scope
  209. # irrespective of the current scope (training or eval).
  210. with tf.name_scope(self._param_scope):
  211. self.params[name] = tf.get_variable(name, shape, dtype, initializer)
  212. param = self.params[name]
  213. if initializer is not None:
  214. self.inits[name] = state_ops.init_variable(param, initializer)
  215. if self._averaging_decay == 1:
  216. logging.info('Using vanilla averaging of parameters.')
  217. ema = tf.train.ExponentialMovingAverage(decay=(step / (step + 1.0)),
  218. num_updates=None)
  219. else:
  220. ema = tf.train.ExponentialMovingAverage(decay=self._averaging_decay,
  221. num_updates=step)
  222. self._averaging[name + '_avg_update'] = ema.apply([param])
  223. self.variables[name + '_avg_var'] = ema.average(param)
  224. self.inits[name + '_avg_init'] = state_ops.init_variable(
  225. ema.average(param), tf.zeros_initializer)
  226. return (self.variables[name + '_avg_var'] if return_average else
  227. self.params[name])
  228. def GetStep(self):
  229. def OnesInitializer(shape, dtype=tf.float32):
  230. return tf.ones(shape, dtype)
  231. return self._AddVariable([], tf.int32, 'step', OnesInitializer)
  232. def _AddVariable(self, shape, dtype, name, initializer=None):
  233. if name in self.variables:
  234. return self.variables[name]
  235. self.variables[name] = tf.get_variable(name, shape, dtype, initializer)
  236. if initializer is not None:
  237. self.inits[name] = state_ops.init_variable(self.variables[name],
  238. initializer)
  239. return self.variables[name]
  240. def _ReluWeightInitializer(self):
  241. with tf.name_scope(self._param_scope):
  242. return tf.random_normal_initializer(stddev=self._relu_init,
  243. seed=self._seed)
  244. def _EmbeddingMatrixInitializer(self, index, embedding_size):
  245. if index in self._pretrained_embeddings:
  246. return self._pretrained_embeddings[index]
  247. else:
  248. return tf.random_normal_initializer(
  249. stddev=self._embedding_init / embedding_size**.5,
  250. seed=self._seed)
  251. def _AddEmbedding(self,
  252. features,
  253. num_features,
  254. num_ids,
  255. embedding_size,
  256. index,
  257. return_average=False):
  258. """Adds an embedding matrix and passes the `features` vector through it."""
  259. embedding_matrix = self._AddParam(
  260. [num_ids, embedding_size],
  261. tf.float32,
  262. 'embedding_matrix_%d' % index,
  263. self._EmbeddingMatrixInitializer(index, embedding_size),
  264. return_average=return_average)
  265. embedding = EmbeddingLookupFeatures(embedding_matrix,
  266. tf.reshape(features,
  267. [-1],
  268. name='feature_%d' % index),
  269. self._allow_feature_weights)
  270. return tf.reshape(embedding, [-1, num_features * embedding_size])
  271. def _BuildNetwork(self, feature_endpoints, return_average=False):
  272. """Builds a feed-forward part of the net given features as input.
  273. The network topology is already defined in the constructor, so multiple
  274. calls to BuildForward build multiple networks whose parameters are all
  275. shared. It is the source of the input features and the use of the output
  276. that distinguishes each network.
  277. Args:
  278. feature_endpoints: tensors with input features to the network
  279. return_average: whether to use moving averages as model parameters
  280. Returns:
  281. logits: output of the final layer before computing softmax
  282. """
  283. assert len(feature_endpoints) == self._feature_size
  284. # Create embedding layer.
  285. embeddings = []
  286. for i in range(self._feature_size):
  287. embeddings.append(self._AddEmbedding(feature_endpoints[i],
  288. self._num_features[i],
  289. self._num_feature_ids[i],
  290. self._embedding_sizes[i],
  291. i,
  292. return_average=return_average))
  293. last_layer = tf.concat(1, embeddings)
  294. last_layer_size = self.embedding_size
  295. # Create ReLU layers.
  296. for i, hidden_layer_size in enumerate(self._hidden_layer_sizes):
  297. weights = self._AddParam(
  298. [last_layer_size, hidden_layer_size],
  299. tf.float32,
  300. 'weights_%d' % i,
  301. self._ReluWeightInitializer(),
  302. return_average=return_average)
  303. bias = self._AddParam([hidden_layer_size],
  304. tf.float32,
  305. 'bias_%d' % i,
  306. self._relu_bias_init,
  307. return_average=return_average)
  308. last_layer = tf.nn.relu_layer(last_layer,
  309. weights,
  310. bias,
  311. name='layer_%d' % i)
  312. last_layer_size = hidden_layer_size
  313. # Create softmax layer.
  314. softmax_weight = self._AddParam(
  315. [last_layer_size, self._num_actions],
  316. tf.float32,
  317. 'softmax_weight',
  318. tf.random_normal_initializer(stddev=self._softmax_init,
  319. seed=self._seed),
  320. return_average=return_average)
  321. softmax_bias = self._AddParam(
  322. [self._num_actions],
  323. tf.float32,
  324. 'softmax_bias',
  325. tf.zeros_initializer,
  326. return_average=return_average)
  327. logits = tf.nn.xw_plus_b(last_layer,
  328. softmax_weight,
  329. softmax_bias,
  330. name='logits')
  331. return {'logits': logits}
  332. def _AddGoldReader(self, task_context, batch_size, corpus_name):
  333. features, epochs, gold_actions = (
  334. gen_parser_ops.gold_parse_reader(task_context,
  335. self._feature_size,
  336. batch_size,
  337. corpus_name=corpus_name,
  338. arg_prefix=self._arg_prefix))
  339. return {'gold_actions': tf.identity(gold_actions,
  340. name='gold_actions'),
  341. 'epochs': tf.identity(epochs,
  342. name='epochs'),
  343. 'feature_endpoints': features}
  344. def _AddDecodedReader(self, task_context, batch_size, transition_scores,
  345. corpus_name):
  346. features, epochs, eval_metrics, documents = (
  347. gen_parser_ops.decoded_parse_reader(transition_scores,
  348. task_context,
  349. self._feature_size,
  350. batch_size,
  351. corpus_name=corpus_name,
  352. arg_prefix=self._arg_prefix))
  353. return {'eval_metrics': eval_metrics,
  354. 'epochs': tf.identity(epochs,
  355. name='epochs'),
  356. 'feature_endpoints': features,
  357. 'documents': documents}
  358. def _AddCostFunction(self, batch_size, gold_actions, logits):
  359. """Cross entropy plus L2 loss on weights and biases of the hidden layers."""
  360. dense_golden = BatchedSparseToDense(gold_actions, self._num_actions)
  361. cross_entropy = tf.div(
  362. tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(
  363. logits, dense_golden)), batch_size)
  364. regularized_params = [tf.nn.l2_loss(p)
  365. for k, p in self.params.items()
  366. if k.startswith('weights') or k.startswith('bias')]
  367. l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
  368. return {'cost': tf.add(cross_entropy, l2_loss, name='cost')}
  369. def AddEvaluation(self,
  370. task_context,
  371. batch_size,
  372. evaluation_max_steps=300,
  373. corpus_name='documents'):
  374. """Builds the forward network only without the training operation.
  375. Args:
  376. task_context: file path from which to read the task context.
  377. batch_size: batch size to request from reader op.
  378. evaluation_max_steps: max number of parsing actions during evaluation,
  379. only used in beam parsing.
  380. corpus_name: name of the task input to read parses from.
  381. Returns:
  382. Dictionary of named eval nodes.
  383. """
  384. def _AssignTransitionScores():
  385. return tf.assign(nodes['transition_scores'],
  386. nodes['logits'], validate_shape=False)
  387. def _Pass():
  388. return tf.constant(-1.0)
  389. unused_evaluation_max_steps = evaluation_max_steps
  390. with tf.name_scope('evaluation'):
  391. nodes = self.evaluation
  392. nodes['transition_scores'] = self._AddVariable(
  393. [batch_size, self._num_actions], tf.float32, 'transition_scores',
  394. tf.constant_initializer(-1.0))
  395. nodes.update(self._AddDecodedReader(task_context, batch_size, nodes[
  396. 'transition_scores'], corpus_name))
  397. nodes.update(self._BuildNetwork(nodes['feature_endpoints'],
  398. return_average=self._use_averaging))
  399. nodes['eval_metrics'] = cf.with_dependencies(
  400. [tf.cond(tf.greater(tf.size(nodes['logits']), 0),
  401. _AssignTransitionScores, _Pass)],
  402. nodes['eval_metrics'], name='eval_metrics')
  403. return nodes
  404. def _IncrementCounter(self, counter):
  405. return state_ops.assign_add(counter, 1, use_locking=True)
  406. def _AddLearningRate(self, initial_learning_rate, decay_steps):
  407. """Returns a learning rate that decays by 0.96 every decay_steps.
  408. Args:
  409. initial_learning_rate: initial value of the learning rate
  410. decay_steps: decay by 0.96 every this many steps
  411. Returns:
  412. learning rate variable.
  413. """
  414. step = self.GetStep()
  415. return cf.with_dependencies(
  416. [self._IncrementCounter(step)],
  417. tf.train.exponential_decay(initial_learning_rate,
  418. step,
  419. decay_steps,
  420. 0.96,
  421. staircase=True))
  422. def AddPretrainedEmbeddings(self, index, embeddings_path, task_context):
  423. """Embeddings at the given index will be set to pretrained values."""
  424. def _Initializer(shape, dtype=tf.float32):
  425. unused_dtype = dtype
  426. t = gen_parser_ops.word_embedding_initializer(
  427. vectors=embeddings_path,
  428. task_context=task_context,
  429. embedding_init=self._embedding_init)
  430. t.set_shape(shape)
  431. return t
  432. self._pretrained_embeddings[index] = _Initializer
  433. def AddTraining(self,
  434. task_context,
  435. batch_size,
  436. learning_rate=0.1,
  437. decay_steps=4000,
  438. momentum=0.9,
  439. corpus_name='documents'):
  440. """Builds a trainer to minimize the cross entropy cost function.
  441. Args:
  442. task_context: file path from which to read the task context
  443. batch_size: batch size to request from reader op
  444. learning_rate: initial value of the learning rate
  445. decay_steps: decay learning rate by 0.96 every this many steps
  446. momentum: momentum parameter used when training with momentum
  447. corpus_name: name of the task input to read parses from
  448. Returns:
  449. Dictionary of named training nodes.
  450. """
  451. with tf.name_scope('training'):
  452. nodes = self.training
  453. nodes.update(self._AddGoldReader(task_context, batch_size, corpus_name))
  454. nodes.update(self._BuildNetwork(nodes['feature_endpoints'],
  455. return_average=False))
  456. nodes.update(self._AddCostFunction(batch_size, nodes['gold_actions'],
  457. nodes['logits']))
  458. # Add the optimizer
  459. if self._only_train:
  460. trainable_params = [v
  461. for k, v in self.params.iteritems()
  462. if k in self._only_train]
  463. else:
  464. trainable_params = self.params.values()
  465. lr = self._AddLearningRate(learning_rate, decay_steps)
  466. optimizer = tf.train.MomentumOptimizer(lr,
  467. momentum,
  468. use_locking=self._use_locking)
  469. train_op = optimizer.minimize(nodes['cost'], var_list=trainable_params)
  470. for param in trainable_params:
  471. slot = optimizer.get_slot(param, 'momentum')
  472. self.inits[slot.name] = state_ops.init_variable(slot,
  473. tf.zeros_initializer)
  474. self.variables[slot.name] = slot
  475. numerical_checks = [
  476. tf.check_numerics(param,
  477. message='Parameter is not finite.')
  478. for param in trainable_params
  479. if param.dtype.base_dtype in [tf.float32, tf.float64]
  480. ]
  481. check_op = tf.group(*numerical_checks)
  482. avg_update_op = tf.group(*self._averaging.values())
  483. train_ops = [train_op]
  484. if self._check_parameters:
  485. train_ops.append(check_op)
  486. if self._use_averaging:
  487. train_ops.append(avg_update_op)
  488. nodes['train_op'] = tf.group(*train_ops, name='train_op')
  489. return nodes
  490. def AddSaver(self, slim_model=False):
  491. """Adds ops to save and restore model parameters.
  492. Args:
  493. slim_model: whether only averaged variables are saved.
  494. Returns:
  495. the saver object.
  496. """
  497. # We have to put the save op in the root scope otherwise running
  498. # "save/restore_all" won't find the "save/Const" node it expects.
  499. with tf.name_scope(None):
  500. variables_to_save = self.params.copy()
  501. variables_to_save.update(self.variables)
  502. if slim_model:
  503. for key in variables_to_save.keys():
  504. if not key.endswith('avg_var'):
  505. del variables_to_save[key]
  506. self.saver = tf.train.Saver(variables_to_save)
  507. return self.saver