bulk_component.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Component builders for non-recurrent networks in DRAGNN."""
  16. import tensorflow as tf
  17. from tensorflow.python.platform import tf_logging as logging
  18. from dragnn.python import component
  19. from dragnn.python import dragnn_ops
  20. from dragnn.python import network_units
  21. from syntaxnet.util import check
  22. def fetch_linked_embedding(comp, network_states, feature_spec):
  23. """Looks up linked embeddings in other components.
  24. Args:
  25. comp: ComponentBuilder object with respect to which the feature is to be
  26. fetched
  27. network_states: dictionary of NetworkState objects
  28. feature_spec: FeatureSpec proto for the linked feature to be looked up
  29. Returns:
  30. NamedTensor containing the linked feature tensor
  31. Raises:
  32. NotImplementedError: if a linked feature with source translator other than
  33. 'identity' is configured.
  34. RuntimeError: if a recurrent linked feature is configured.
  35. """
  36. if feature_spec.source_translator != 'identity':
  37. raise NotImplementedError(feature_spec.source_translator)
  38. if feature_spec.source_component == comp.name:
  39. raise RuntimeError(
  40. 'Recurrent linked features are not supported in bulk extraction.')
  41. tf.logging.info('[%s] Adding linked feature "%s"', comp.name,
  42. feature_spec.name)
  43. source = comp.master.lookup_component[feature_spec.source_component]
  44. return network_units.NamedTensor(
  45. network_states[source.name].activations[
  46. feature_spec.source_layer].bulk_tensor,
  47. feature_spec.name)
  48. def _validate_embedded_fixed_features(comp):
  49. """Checks that the embedded fixed features of |comp| are set up properly."""
  50. for feature in comp.spec.fixed_feature:
  51. check.Gt(feature.embedding_dim, 0,
  52. 'Embeddings requested for non-embedded feature: %s' % feature)
  53. if feature.is_constant:
  54. check.IsTrue(feature.HasField('pretrained_embedding_matrix'),
  55. 'Constant embeddings must be pretrained: %s' % feature)
  56. def fetch_differentiable_fixed_embeddings(comp, state, stride):
  57. """Looks up fixed features with separate, differentiable, embedding lookup.
  58. Args:
  59. comp: Component whose fixed features we wish to look up.
  60. state: live MasterState object for the component.
  61. stride: Tensor containing current batch * beam size.
  62. Returns:
  63. state handle: updated state handle to be used after this call
  64. fixed_embeddings: list of NamedTensor objects
  65. """
  66. _validate_embedded_fixed_features(comp)
  67. num_channels = len(comp.spec.fixed_feature)
  68. if not num_channels:
  69. return state.handle, []
  70. state.handle, indices, ids, weights, num_steps = (
  71. dragnn_ops.bulk_fixed_features(
  72. state.handle, component=comp.name, num_channels=num_channels))
  73. fixed_embeddings = []
  74. for channel, feature_spec in enumerate(comp.spec.fixed_feature):
  75. differentiable_or_constant = ('constant' if feature_spec.is_constant else
  76. 'differentiable')
  77. tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
  78. differentiable_or_constant, feature_spec.name)
  79. size = stride * num_steps * feature_spec.size
  80. fixed_embedding = network_units.embedding_lookup(
  81. comp.get_variable(network_units.fixed_embeddings_name(channel)),
  82. indices[channel], ids[channel], weights[channel], size)
  83. if feature_spec.is_constant:
  84. fixed_embedding = tf.stop_gradient(fixed_embedding)
  85. fixed_embeddings.append(
  86. network_units.NamedTensor(fixed_embedding, feature_spec.name))
  87. return state.handle, fixed_embeddings
  88. def fetch_fast_fixed_embeddings(comp, state):
  89. """Looks up fixed features with fast, non-differentiable, op.
  90. Since BulkFixedEmbeddings is non-differentiable with respect to the
  91. embeddings, the idea is to call this function only when the graph is
  92. not being used for training.
  93. Args:
  94. comp: Component whose fixed features we wish to look up.
  95. state: live MasterState object for the component.
  96. Returns:
  97. state handle: updated state handle to be used after this call
  98. fixed_embeddings: list of NamedTensor objects
  99. """
  100. _validate_embedded_fixed_features(comp)
  101. num_channels = len(comp.spec.fixed_feature)
  102. if not num_channels:
  103. return state.handle, []
  104. tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
  105. state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
  106. state.handle, [
  107. comp.get_variable(network_units.fixed_embeddings_name(c))
  108. for c in range(num_channels)
  109. ],
  110. component=comp.name)
  111. bulk_embeddings = network_units.NamedTensor(bulk_embeddings,
  112. 'bulk-%s-fixed-features' %
  113. comp.name)
  114. return state.handle, [bulk_embeddings]
  115. def extract_fixed_feature_ids(comp, state, stride):
  116. """Extracts fixed feature IDs.
  117. Args:
  118. comp: Component whose fixed feature IDs we wish to extract.
  119. state: Live MasterState object for the component.
  120. stride: Tensor containing current batch * beam size.
  121. Returns:
  122. state handle: Updated state handle to be used after this call.
  123. ids: List of [stride * num_steps, 1] feature IDs per channel. Missing IDs
  124. (e.g., due to batch padding) are set to -1.
  125. """
  126. num_channels = len(comp.spec.fixed_feature)
  127. if not num_channels:
  128. return state.handle, []
  129. for feature_spec in comp.spec.fixed_feature:
  130. check.Eq(feature_spec.size, 1, 'All features must have size=1')
  131. check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded')
  132. state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
  133. state.handle, component=comp.name, num_channels=num_channels)
  134. size = stride * num_steps
  135. fixed_ids = []
  136. for channel, feature_spec in enumerate(comp.spec.fixed_feature):
  137. tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
  138. feature_spec.name)
  139. # The +1 and -1 increments ensure that missing IDs default to -1.
  140. #
  141. # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
  142. # step. Try using tf.unique() to enforce the unique-IDS precondition.
  143. sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1
  144. sums = tf.expand_dims(sums, axis=1)
  145. fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1))
  146. return state.handle, fixed_ids
  147. def update_network_states(comp, tensors, network_states, stride):
  148. """Stores Tensor objects corresponding to layer outputs.
  149. For use in subsequent tasks.
  150. Args:
  151. comp: Component for which the tensor handles are being stored.
  152. tensors: list of Tensors to store
  153. network_states: dictionary of component NetworkState objects
  154. stride: stride of the stored tensor.
  155. """
  156. network_state = network_states[comp.name]
  157. with tf.name_scope(comp.name + '/stored_act'):
  158. for index, network_tensor in enumerate(tensors):
  159. network_state.activations[comp.network.layers[index].name] = (
  160. network_units.StoredActivations(tensor=network_tensor, stride=stride,
  161. dim=comp.network.layers[index].dim))
  162. def build_cross_entropy_loss(logits, gold):
  163. """Constructs a cross entropy from logits and one-hot encoded gold labels.
  164. Supports skipping rows where the gold label is the magic -1 value.
  165. Args:
  166. logits: float Tensor of scores.
  167. gold: int Tensor of one-hot labels.
  168. Returns:
  169. cost, correct, total: the total cost, the total number of correctly
  170. predicted labels, and the total number of valid labels.
  171. """
  172. valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  173. gold = tf.gather(gold, valid)
  174. logits = tf.gather(logits, valid)
  175. correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  176. total = tf.size(gold)
  177. cost = tf.reduce_sum(
  178. tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
  179. logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
  180. return cost, correct, total
  181. class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
  182. """A component builder to bulk extract features.
  183. Both fixed and linked features are supported, with some restrictions:
  184. 1. Fixed features may not be recurrent. Fixed features are extracted along the
  185. gold path, which does not work during inference.
  186. 2. Linked features may not be recurrent and are 'untranslated'. For now,
  187. linked features are extracted without passing them through any transition
  188. system or source translator.
  189. """
  190. def build_greedy_training(self, state, network_states):
  191. """Extracts features and advances a batch using the oracle path.
  192. Args:
  193. state: MasterState from the 'AdvanceMaster' op that advances the
  194. underlying master to this component.
  195. network_states: dictionary of component NetworkState objects
  196. Returns:
  197. state handle: final state after advancing
  198. cost: regularization cost, possibly associated with embedding matrices
  199. correct: since no gold path is available, 0.
  200. total: since no gold path is available, 0.
  201. """
  202. logging.info('Building component: %s', self.spec.name)
  203. stride = state.current_batch_size * self.training_beam_size
  204. with tf.variable_scope(self.name, reuse=True):
  205. state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
  206. self, state, stride)
  207. linked_embeddings = [
  208. fetch_linked_embedding(self, network_states, spec)
  209. for spec in self.spec.linked_feature
  210. ]
  211. with tf.variable_scope(self.name, reuse=True):
  212. tensors = self.network.create(
  213. fixed_embeddings, linked_embeddings, None, None, True, stride=stride)
  214. update_network_states(self, tensors, network_states, stride)
  215. cost = self.add_regularizer(tf.constant(0.))
  216. correct, total = tf.constant(0), tf.constant(0)
  217. return state.handle, cost, correct, total
  218. def build_greedy_inference(self, state, network_states,
  219. during_training=False):
  220. """Extracts features and advances a batch using the oracle path.
  221. NOTE(danielandor) For now this method cannot be called during training.
  222. That is to say, unroll_using_oracle for this component must be set to true.
  223. This will be fixed by separating train_with_oracle and train_with_inference.
  224. Args:
  225. state: MasterState from the 'AdvanceMaster' op that advances the
  226. underlying master to this component.
  227. network_states: dictionary of component NetworkState objects
  228. during_training: whether the graph is being constructed during training
  229. Returns:
  230. state handle: final state after advancing
  231. """
  232. logging.info('Building component: %s', self.spec.name)
  233. if during_training:
  234. stride = state.current_batch_size * self.training_beam_size
  235. else:
  236. stride = state.current_batch_size * self.inference_beam_size
  237. with tf.variable_scope(self.name, reuse=True):
  238. if during_training:
  239. state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
  240. self, state, stride)
  241. else:
  242. state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(self,
  243. state)
  244. linked_embeddings = [
  245. fetch_linked_embedding(self, network_states, spec)
  246. for spec in self.spec.linked_feature
  247. ]
  248. with tf.variable_scope(self.name, reuse=True):
  249. tensors = self.network.create(
  250. fixed_embeddings,
  251. linked_embeddings,
  252. None,
  253. None,
  254. during_training=during_training,
  255. stride=stride)
  256. update_network_states(self, tensors, network_states, stride)
  257. return state.handle
  258. class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
  259. """A component builder to bulk extract feature IDs.
  260. This is a variant of BulkFeatureExtractorComponentBuilder that only supports
  261. fixed features, and extracts raw feature IDs instead of feature embeddings.
  262. Since the extracted feature IDs are integers, the results produced by this
  263. component are in general not differentiable.
  264. """
  265. def __init__(self, master, component_spec):
  266. """Initializes the feature ID extractor component.
  267. Args:
  268. master: dragnn.MasterBuilder object.
  269. component_spec: dragnn.ComponentSpec proto to be built.
  270. """
  271. super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
  272. master, component_spec)
  273. check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
  274. for feature_spec in self.spec.fixed_feature:
  275. check.Lt(feature_spec.embedding_dim, 0,
  276. 'Features must be non-embedded: %s' % feature_spec)
  277. def build_greedy_training(self, state, network_states):
  278. """See base class."""
  279. state.handle = self._extract_feature_ids(state, network_states, True)
  280. cost = self.add_regularizer(tf.constant(0.))
  281. correct, total = tf.constant(0), tf.constant(0)
  282. return state.handle, cost, correct, total
  283. def build_greedy_inference(self, state, network_states,
  284. during_training=False):
  285. """See base class."""
  286. return self._extract_feature_ids(state, network_states, during_training)
  287. def _extract_feature_ids(self, state, network_states, during_training):
  288. """Extracts feature IDs and advances a batch using the oracle path.
  289. Args:
  290. state: MasterState from the 'AdvanceMaster' op that advances the
  291. underlying master to this component.
  292. network_states: Dictionary of component NetworkState objects.
  293. during_training: Whether the graph is being constructed during training.
  294. Returns:
  295. state handle: Final state after advancing.
  296. """
  297. logging.info('Building component: %s', self.spec.name)
  298. if during_training:
  299. stride = state.current_batch_size * self.training_beam_size
  300. else:
  301. stride = state.current_batch_size * self.inference_beam_size
  302. with tf.variable_scope(self.name, reuse=True):
  303. state.handle, ids = extract_fixed_feature_ids(self, state, stride)
  304. with tf.variable_scope(self.name, reuse=True):
  305. tensors = self.network.create(
  306. ids, [], None, None, during_training, stride=stride)
  307. update_network_states(self, tensors, network_states, stride)
  308. return state.handle
  309. class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
  310. """A component builder to bulk annotate or compute the cost of a gold path.
  311. This component can be used with features that don't depend on the
  312. transition system state.
  313. Since no feature extraction is performed, only non-recurrent
  314. 'identity' linked features are supported.
  315. If a FeedForwardNetwork is configured with no hidden units, this component
  316. acts as a 'bulk softmax' component.
  317. """
  318. def build_greedy_training(self, state, network_states):
  319. """Advances a batch using oracle paths, returning the overall CE cost.
  320. Args:
  321. state: MasterState from the 'AdvanceMaster' op that advances the
  322. underlying master to this component.
  323. network_states: dictionary of component NetworkState objects
  324. Returns:
  325. (state handle, cost, correct, total): TF ops corresponding to the final
  326. state after unrolling, the total cost, the total number of correctly
  327. predicted actions, and the total number of actions.
  328. Raises:
  329. RuntimeError: if fixed features are configured.
  330. """
  331. logging.info('Building component: %s', self.spec.name)
  332. if self.spec.fixed_feature:
  333. raise RuntimeError(
  334. 'Fixed features are not compatible with bulk annotation. '
  335. 'Use the "bulk-features" component instead.')
  336. linked_embeddings = [
  337. fetch_linked_embedding(self, network_states, spec)
  338. for spec in self.spec.linked_feature
  339. ]
  340. stride = state.current_batch_size * self.training_beam_size
  341. with tf.variable_scope(self.name, reuse=True):
  342. network_tensors = self.network.create([], linked_embeddings, None, None,
  343. True, stride)
  344. update_network_states(self, network_tensors, network_states, stride)
  345. logits = self.network.get_logits(network_tensors)
  346. state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
  347. state.handle, component=self.name)
  348. cost, correct, total = build_cross_entropy_loss(logits, gold)
  349. cost = self.add_regularizer(cost)
  350. return state.handle, cost, correct, total
  351. def build_greedy_inference(self, state, network_states,
  352. during_training=False):
  353. """Annotates a batch of documents using network scores.
  354. Args:
  355. state: MasterState from the 'AdvanceMaster' op that advances the
  356. underlying master to this component.
  357. network_states: dictionary of component NetworkState objects
  358. during_training: whether the graph is being constructed during training
  359. Returns:
  360. Handle to the state once inference is complete for this Component.
  361. Raises:
  362. RuntimeError: if fixed features are configured
  363. """
  364. logging.info('Building component: %s', self.spec.name)
  365. if self.spec.fixed_feature:
  366. raise RuntimeError(
  367. 'Fixed features are not compatible with bulk annotation. '
  368. 'Use the "bulk-features" component instead.')
  369. linked_embeddings = [
  370. fetch_linked_embedding(self, network_states, spec)
  371. for spec in self.spec.linked_feature
  372. ]
  373. if during_training:
  374. stride = state.current_batch_size * self.training_beam_size
  375. else:
  376. stride = state.current_batch_size * self.inference_beam_size
  377. with tf.variable_scope(self.name, reuse=True):
  378. network_tensors = self.network.create(
  379. [], linked_embeddings, None, None, during_training, stride)
  380. update_network_states(self, network_tensors, network_states, stride)
  381. logits = self.network.get_logits(network_tensors)
  382. return dragnn_ops.bulk_advance_from_prediction(
  383. state.handle, logits, component=self.name)