bulk_component.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. """Component builders for non-recurrent networks in DRAGNN."""
  2. import tensorflow as tf
  3. from tensorflow.python.platform import tf_logging as logging
  4. from dragnn.python import component
  5. from dragnn.python import dragnn_ops
  6. from dragnn.python import network_units
  7. from syntaxnet.util import check
  8. def fetch_linked_embedding(comp, network_states, feature_spec):
  9. """Looks up linked embeddings in other components.
  10. Args:
  11. comp: ComponentBuilder object with respect to which the feature is to be
  12. fetched
  13. network_states: dictionary of NetworkState objects
  14. feature_spec: FeatureSpec proto for the linked feature to be looked up
  15. Returns:
  16. NamedTensor containing the linked feature tensor
  17. Raises:
  18. NotImplementedError: if a linked feature with source translator other than
  19. 'identity' is configured.
  20. RuntimeError: if a recurrent linked feature is configured.
  21. """
  22. if feature_spec.source_translator != 'identity':
  23. raise NotImplementedError(feature_spec.source_translator)
  24. if feature_spec.source_component == comp.name:
  25. raise RuntimeError(
  26. 'Recurrent linked features are not supported in bulk extraction.')
  27. tf.logging.info('[%s] Adding linked feature "%s"', comp.name,
  28. feature_spec.name)
  29. source = comp.master.lookup_component[feature_spec.source_component]
  30. return network_units.NamedTensor(
  31. network_states[source.name].activations[
  32. feature_spec.source_layer].bulk_tensor,
  33. feature_spec.name)
  34. def _validate_embedded_fixed_features(comp):
  35. """Checks that the embedded fixed features of |comp| are set up properly."""
  36. for feature in comp.spec.fixed_feature:
  37. check.Gt(feature.embedding_dim, 0,
  38. 'Embeddings requested for non-embedded feature: %s' % feature)
  39. if feature.is_constant:
  40. check.IsTrue(feature.HasField('pretrained_embedding_matrix'),
  41. 'Constant embeddings must be pretrained: %s' % feature)
  42. def fetch_differentiable_fixed_embeddings(comp, state, stride):
  43. """Looks up fixed features with separate, differentiable, embedding lookup.
  44. Args:
  45. comp: Component whose fixed features we wish to look up.
  46. state: live MasterState object for the component.
  47. stride: Tensor containing current batch * beam size.
  48. Returns:
  49. state handle: updated state handle to be used after this call
  50. fixed_embeddings: list of NamedTensor objects
  51. """
  52. _validate_embedded_fixed_features(comp)
  53. num_channels = len(comp.spec.fixed_feature)
  54. if not num_channels:
  55. return state.handle, []
  56. state.handle, indices, ids, weights, num_steps = (
  57. dragnn_ops.bulk_fixed_features(
  58. state.handle, component=comp.name, num_channels=num_channels))
  59. fixed_embeddings = []
  60. for channel, feature_spec in enumerate(comp.spec.fixed_feature):
  61. differentiable_or_constant = ('constant' if feature_spec.is_constant else
  62. 'differentiable')
  63. tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
  64. differentiable_or_constant, feature_spec.name)
  65. size = stride * num_steps * feature_spec.size
  66. fixed_embedding = network_units.embedding_lookup(
  67. comp.get_variable(network_units.fixed_embeddings_name(channel)),
  68. indices[channel], ids[channel], weights[channel], size)
  69. if feature_spec.is_constant:
  70. fixed_embedding = tf.stop_gradient(fixed_embedding)
  71. fixed_embeddings.append(
  72. network_units.NamedTensor(fixed_embedding, feature_spec.name))
  73. return state.handle, fixed_embeddings
  74. def fetch_fast_fixed_embeddings(comp, state):
  75. """Looks up fixed features with fast, non-differentiable, op.
  76. Since BulkFixedEmbeddings is non-differentiable with respect to the
  77. embeddings, the idea is to call this function only when the graph is
  78. not being used for training.
  79. Args:
  80. comp: Component whose fixed features we wish to look up.
  81. state: live MasterState object for the component.
  82. Returns:
  83. state handle: updated state handle to be used after this call
  84. fixed_embeddings: list of NamedTensor objects
  85. """
  86. _validate_embedded_fixed_features(comp)
  87. num_channels = len(comp.spec.fixed_feature)
  88. if not num_channels:
  89. return state.handle, []
  90. tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
  91. state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
  92. state.handle, [
  93. comp.get_variable(network_units.fixed_embeddings_name(c))
  94. for c in range(num_channels)
  95. ],
  96. component=comp.name)
  97. bulk_embeddings = network_units.NamedTensor(bulk_embeddings,
  98. 'bulk-%s-fixed-features' %
  99. comp.name)
  100. return state.handle, [bulk_embeddings]
  101. def extract_fixed_feature_ids(comp, state, stride):
  102. """Extracts fixed feature IDs.
  103. Args:
  104. comp: Component whose fixed feature IDs we wish to extract.
  105. state: Live MasterState object for the component.
  106. stride: Tensor containing current batch * beam size.
  107. Returns:
  108. state handle: Updated state handle to be used after this call.
  109. ids: List of [stride * num_steps, 1] feature IDs per channel. Missing IDs
  110. (e.g., due to batch padding) are set to -1.
  111. """
  112. num_channels = len(comp.spec.fixed_feature)
  113. if not num_channels:
  114. return state.handle, []
  115. for feature_spec in comp.spec.fixed_feature:
  116. check.Eq(feature_spec.size, 1, 'All features must have size=1')
  117. check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded')
  118. state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
  119. state.handle, component=comp.name, num_channels=num_channels)
  120. size = stride * num_steps
  121. fixed_ids = []
  122. for channel, feature_spec in enumerate(comp.spec.fixed_feature):
  123. tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
  124. feature_spec.name)
  125. # The +1 and -1 increments ensure that missing IDs default to -1.
  126. #
  127. # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
  128. # step. Try using tf.unique() to enforce the unique-IDS precondition.
  129. sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1
  130. sums = tf.expand_dims(sums, axis=1)
  131. fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1))
  132. return state.handle, fixed_ids
  133. def update_network_states(comp, tensors, network_states, stride):
  134. """Stores Tensor objects corresponding to layer outputs.
  135. For use in subsequent tasks.
  136. Args:
  137. comp: Component for which the tensor handles are being stored.
  138. tensors: list of Tensors to store
  139. network_states: dictionary of component NetworkState objects
  140. stride: stride of the stored tensor.
  141. """
  142. network_state = network_states[comp.name]
  143. with tf.name_scope(comp.name + '/stored_act'):
  144. for index, network_tensor in enumerate(tensors):
  145. network_state.activations[comp.network.layers[index].name] = (
  146. network_units.StoredActivations(tensor=network_tensor, stride=stride,
  147. dim=comp.network.layers[index].dim))
  148. def build_cross_entropy_loss(logits, gold):
  149. """Constructs a cross entropy from logits and one-hot encoded gold labels.
  150. Supports skipping rows where the gold label is the magic -1 value.
  151. Args:
  152. logits: float Tensor of scores.
  153. gold: int Tensor of one-hot labels.
  154. Returns:
  155. cost, correct, total: the total cost, the total number of correctly
  156. predicted labels, and the total number of valid labels.
  157. """
  158. valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
  159. gold = tf.gather(gold, valid)
  160. logits = tf.gather(logits, valid)
  161. correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
  162. total = tf.size(gold)
  163. cost = tf.reduce_sum(
  164. tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
  165. logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
  166. return cost, correct, total
  167. class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
  168. """A component builder to bulk extract features.
  169. Both fixed and linked features are supported, with some restrictions:
  170. 1. Fixed features may not be recurrent. Fixed features are extracted along the
  171. gold path, which does not work during inference.
  172. 2. Linked features may not be recurrent and are 'untranslated'. For now,
  173. linked features are extracted without passing them through any transition
  174. system or source translator.
  175. """
  176. def build_greedy_training(self, state, network_states):
  177. """Extracts features and advances a batch using the oracle path.
  178. Args:
  179. state: MasterState from the 'AdvanceMaster' op that advances the
  180. underlying master to this component.
  181. network_states: dictionary of component NetworkState objects
  182. Returns:
  183. state handle: final state after advancing
  184. cost: regularization cost, possibly associated with embedding matrices
  185. correct: since no gold path is available, 0.
  186. total: since no gold path is available, 0.
  187. """
  188. logging.info('Building component: %s', self.spec.name)
  189. stride = state.current_batch_size * self.training_beam_size
  190. with tf.variable_scope(self.name, reuse=True):
  191. state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
  192. self, state, stride)
  193. linked_embeddings = [
  194. fetch_linked_embedding(self, network_states, spec)
  195. for spec in self.spec.linked_feature
  196. ]
  197. with tf.variable_scope(self.name, reuse=True):
  198. tensors = self.network.create(
  199. fixed_embeddings, linked_embeddings, None, None, True, stride=stride)
  200. update_network_states(self, tensors, network_states, stride)
  201. cost = self.add_regularizer(tf.constant(0.))
  202. return state.handle, cost, 0, 0
  203. def build_greedy_inference(self, state, network_states,
  204. during_training=False):
  205. """Extracts features and advances a batch using the oracle path.
  206. NOTE(danielandor) For now this method cannot be called during training.
  207. That is to say, unroll_using_oracle for this component must be set to true.
  208. This will be fixed by separating train_with_oracle and train_with_inference.
  209. Args:
  210. state: MasterState from the 'AdvanceMaster' op that advances the
  211. underlying master to this component.
  212. network_states: dictionary of component NetworkState objects
  213. during_training: whether the graph is being constructed during training
  214. Returns:
  215. state handle: final state after advancing
  216. """
  217. logging.info('Building component: %s', self.spec.name)
  218. if during_training:
  219. stride = state.current_batch_size * self.training_beam_size
  220. else:
  221. stride = state.current_batch_size * self.inference_beam_size
  222. with tf.variable_scope(self.name, reuse=True):
  223. if during_training:
  224. state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
  225. self, state, stride)
  226. else:
  227. state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(self,
  228. state)
  229. linked_embeddings = [
  230. fetch_linked_embedding(self, network_states, spec)
  231. for spec in self.spec.linked_feature
  232. ]
  233. with tf.variable_scope(self.name, reuse=True):
  234. tensors = self.network.create(
  235. fixed_embeddings,
  236. linked_embeddings,
  237. None,
  238. None,
  239. during_training=during_training,
  240. stride=stride)
  241. update_network_states(self, tensors, network_states, stride)
  242. return state.handle
  243. class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
  244. """A component builder to bulk extract feature IDs.
  245. This is a variant of BulkFeatureExtractorComponentBuilder that only supports
  246. fixed features, and extracts raw feature IDs instead of feature embeddings.
  247. Since the extracted feature IDs are integers, the results produced by this
  248. component are in general not differentiable.
  249. """
  250. def __init__(self, master, component_spec):
  251. """Initializes the feature ID extractor component.
  252. Args:
  253. master: dragnn.MasterBuilder object.
  254. component_spec: dragnn.ComponentSpec proto to be built.
  255. """
  256. super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
  257. master, component_spec)
  258. check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
  259. for feature_spec in self.spec.fixed_feature:
  260. check.Lt(feature_spec.embedding_dim, 0,
  261. 'Features must be non-embedded: %s' % feature_spec)
  262. def build_greedy_training(self, state, network_states):
  263. """See base class."""
  264. state.handle = self._extract_feature_ids(state, network_states, True)
  265. cost = self.add_regularizer(tf.constant(0.))
  266. return state.handle, cost, 0, 0
  267. def build_greedy_inference(self, state, network_states,
  268. during_training=False):
  269. """See base class."""
  270. return self._extract_feature_ids(state, network_states, during_training)
  271. def _extract_feature_ids(self, state, network_states, during_training):
  272. """Extracts feature IDs and advances a batch using the oracle path.
  273. Args:
  274. state: MasterState from the 'AdvanceMaster' op that advances the
  275. underlying master to this component.
  276. network_states: Dictionary of component NetworkState objects.
  277. during_training: Whether the graph is being constructed during training.
  278. Returns:
  279. state handle: Final state after advancing.
  280. """
  281. logging.info('Building component: %s', self.spec.name)
  282. if during_training:
  283. stride = state.current_batch_size * self.training_beam_size
  284. else:
  285. stride = state.current_batch_size * self.inference_beam_size
  286. with tf.variable_scope(self.name, reuse=True):
  287. state.handle, ids = extract_fixed_feature_ids(self, state, stride)
  288. with tf.variable_scope(self.name, reuse=True):
  289. tensors = self.network.create(
  290. ids, [], None, None, during_training, stride=stride)
  291. update_network_states(self, tensors, network_states, stride)
  292. return state.handle
  293. class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
  294. """A component builder to bulk annotate or compute the cost of a gold path.
  295. This component can be used with features that don't depend on the
  296. transition system state.
  297. Since no feature extraction is performed, only non-recurrent
  298. 'identity' linked features are supported.
  299. If a FeedForwardNetwork is configured with no hidden units, this component
  300. acts as a 'bulk softmax' component.
  301. """
  302. def build_greedy_training(self, state, network_states):
  303. """Advances a batch using oracle paths, returning the overall CE cost.
  304. Args:
  305. state: MasterState from the 'AdvanceMaster' op that advances the
  306. underlying master to this component.
  307. network_states: dictionary of component NetworkState objects
  308. Returns:
  309. (state handle, cost, correct, total): TF ops corresponding to the final
  310. state after unrolling, the total cost, the total number of correctly
  311. predicted actions, and the total number of actions.
  312. Raises:
  313. RuntimeError: if fixed features are configured.
  314. """
  315. logging.info('Building component: %s', self.spec.name)
  316. if self.spec.fixed_feature:
  317. raise RuntimeError(
  318. 'Fixed features are not compatible with bulk annotation. '
  319. 'Use the "bulk-features" component instead.')
  320. linked_embeddings = [
  321. fetch_linked_embedding(self, network_states, spec)
  322. for spec in self.spec.linked_feature
  323. ]
  324. stride = state.current_batch_size * self.training_beam_size
  325. with tf.variable_scope(self.name, reuse=True):
  326. network_tensors = self.network.create([], linked_embeddings, None, None,
  327. True, stride)
  328. update_network_states(self, network_tensors, network_states, stride)
  329. logits = self.network.get_logits(network_tensors)
  330. state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
  331. state.handle, component=self.name)
  332. cost, correct, total = build_cross_entropy_loss(logits, gold)
  333. cost = self.add_regularizer(cost)
  334. return state.handle, cost, correct, total
  335. def build_greedy_inference(self, state, network_states,
  336. during_training=False):
  337. """Annotates a batch of documents using network scores.
  338. Args:
  339. state: MasterState from the 'AdvanceMaster' op that advances the
  340. underlying master to this component.
  341. network_states: dictionary of component NetworkState objects
  342. during_training: whether the graph is being constructed during training
  343. Returns:
  344. Handle to the state once inference is complete for this Component.
  345. Raises:
  346. RuntimeError: if fixed features are configured
  347. """
  348. logging.info('Building component: %s', self.spec.name)
  349. if self.spec.fixed_feature:
  350. raise RuntimeError(
  351. 'Fixed features are not compatible with bulk annotation. '
  352. 'Use the "bulk-features" component instead.')
  353. linked_embeddings = [
  354. fetch_linked_embedding(self, network_states, spec)
  355. for spec in self.spec.linked_feature
  356. ]
  357. if during_training:
  358. stride = state.current_batch_size * self.training_beam_size
  359. else:
  360. stride = state.current_batch_size * self.inference_beam_size
  361. with tf.variable_scope(self.name, reuse=True):
  362. network_tensors = self.network.create(
  363. [], linked_embeddings, None, None, during_training, stride)
  364. update_network_states(self, network_tensors, network_states, stride)
  365. logits = self.network.get_logits(network_tensors)
  366. return dragnn_ops.bulk_advance_from_prediction(
  367. state.handle, logits, component=self.name)