123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478 |
- # Copyright 2017 Google Inc. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- """Component builders for non-recurrent networks in DRAGNN."""
- import tensorflow as tf
- from tensorflow.python.platform import tf_logging as logging
- from dragnn.python import component
- from dragnn.python import dragnn_ops
- from dragnn.python import network_units
- from syntaxnet.util import check
- def fetch_linked_embedding(comp, network_states, feature_spec):
- """Looks up linked embeddings in other components.
- Args:
- comp: ComponentBuilder object with respect to which the feature is to be
- fetched
- network_states: dictionary of NetworkState objects
- feature_spec: FeatureSpec proto for the linked feature to be looked up
- Returns:
- NamedTensor containing the linked feature tensor
- Raises:
- NotImplementedError: if a linked feature with source translator other than
- 'identity' is configured.
- RuntimeError: if a recurrent linked feature is configured.
- """
- if feature_spec.source_translator != 'identity':
- raise NotImplementedError(feature_spec.source_translator)
- if feature_spec.source_component == comp.name:
- raise RuntimeError(
- 'Recurrent linked features are not supported in bulk extraction.')
- tf.logging.info('[%s] Adding linked feature "%s"', comp.name,
- feature_spec.name)
- source = comp.master.lookup_component[feature_spec.source_component]
- return network_units.NamedTensor(
- network_states[source.name].activations[
- feature_spec.source_layer].bulk_tensor,
- feature_spec.name)
- def _validate_embedded_fixed_features(comp):
- """Checks that the embedded fixed features of |comp| are set up properly."""
- for feature in comp.spec.fixed_feature:
- check.Gt(feature.embedding_dim, 0,
- 'Embeddings requested for non-embedded feature: %s' % feature)
- if feature.is_constant:
- check.IsTrue(feature.HasField('pretrained_embedding_matrix'),
- 'Constant embeddings must be pretrained: %s' % feature)
- def fetch_differentiable_fixed_embeddings(comp, state, stride):
- """Looks up fixed features with separate, differentiable, embedding lookup.
- Args:
- comp: Component whose fixed features we wish to look up.
- state: live MasterState object for the component.
- stride: Tensor containing current batch * beam size.
- Returns:
- state handle: updated state handle to be used after this call
- fixed_embeddings: list of NamedTensor objects
- """
- _validate_embedded_fixed_features(comp)
- num_channels = len(comp.spec.fixed_feature)
- if not num_channels:
- return state.handle, []
- state.handle, indices, ids, weights, num_steps = (
- dragnn_ops.bulk_fixed_features(
- state.handle, component=comp.name, num_channels=num_channels))
- fixed_embeddings = []
- for channel, feature_spec in enumerate(comp.spec.fixed_feature):
- differentiable_or_constant = ('constant' if feature_spec.is_constant else
- 'differentiable')
- tf.logging.info('[%s] Adding %s fixed feature "%s"', comp.name,
- differentiable_or_constant, feature_spec.name)
- size = stride * num_steps * feature_spec.size
- fixed_embedding = network_units.embedding_lookup(
- comp.get_variable(network_units.fixed_embeddings_name(channel)),
- indices[channel], ids[channel], weights[channel], size)
- if feature_spec.is_constant:
- fixed_embedding = tf.stop_gradient(fixed_embedding)
- fixed_embeddings.append(
- network_units.NamedTensor(fixed_embedding, feature_spec.name))
- return state.handle, fixed_embeddings
- def fetch_fast_fixed_embeddings(comp, state):
- """Looks up fixed features with fast, non-differentiable, op.
- Since BulkFixedEmbeddings is non-differentiable with respect to the
- embeddings, the idea is to call this function only when the graph is
- not being used for training.
- Args:
- comp: Component whose fixed features we wish to look up.
- state: live MasterState object for the component.
- Returns:
- state handle: updated state handle to be used after this call
- fixed_embeddings: list of NamedTensor objects
- """
- _validate_embedded_fixed_features(comp)
- num_channels = len(comp.spec.fixed_feature)
- if not num_channels:
- return state.handle, []
- tf.logging.info('[%s] Adding %d fast fixed features', comp.name, num_channels)
- state.handle, bulk_embeddings, _ = dragnn_ops.bulk_fixed_embeddings(
- state.handle, [
- comp.get_variable(network_units.fixed_embeddings_name(c))
- for c in range(num_channels)
- ],
- component=comp.name)
- bulk_embeddings = network_units.NamedTensor(bulk_embeddings,
- 'bulk-%s-fixed-features' %
- comp.name)
- return state.handle, [bulk_embeddings]
- def extract_fixed_feature_ids(comp, state, stride):
- """Extracts fixed feature IDs.
- Args:
- comp: Component whose fixed feature IDs we wish to extract.
- state: Live MasterState object for the component.
- stride: Tensor containing current batch * beam size.
- Returns:
- state handle: Updated state handle to be used after this call.
- ids: List of [stride * num_steps, 1] feature IDs per channel. Missing IDs
- (e.g., due to batch padding) are set to -1.
- """
- num_channels = len(comp.spec.fixed_feature)
- if not num_channels:
- return state.handle, []
- for feature_spec in comp.spec.fixed_feature:
- check.Eq(feature_spec.size, 1, 'All features must have size=1')
- check.Lt(feature_spec.embedding_dim, 0, 'All features must be non-embedded')
- state.handle, indices, ids, _, num_steps = dragnn_ops.bulk_fixed_features(
- state.handle, component=comp.name, num_channels=num_channels)
- size = stride * num_steps
- fixed_ids = []
- for channel, feature_spec in enumerate(comp.spec.fixed_feature):
- tf.logging.info('[%s] Adding fixed feature IDs "%s"', comp.name,
- feature_spec.name)
- # The +1 and -1 increments ensure that missing IDs default to -1.
- #
- # TODO(googleuser): This formula breaks if multiple IDs are extracted at some
- # step. Try using tf.unique() to enforce the unique-IDS precondition.
- sums = tf.unsorted_segment_sum(ids[channel] + 1, indices[channel], size) - 1
- sums = tf.expand_dims(sums, axis=1)
- fixed_ids.append(network_units.NamedTensor(sums, feature_spec.name, dim=1))
- return state.handle, fixed_ids
- def update_network_states(comp, tensors, network_states, stride):
- """Stores Tensor objects corresponding to layer outputs.
- For use in subsequent tasks.
- Args:
- comp: Component for which the tensor handles are being stored.
- tensors: list of Tensors to store
- network_states: dictionary of component NetworkState objects
- stride: stride of the stored tensor.
- """
- network_state = network_states[comp.name]
- with tf.name_scope(comp.name + '/stored_act'):
- for index, network_tensor in enumerate(tensors):
- network_state.activations[comp.network.layers[index].name] = (
- network_units.StoredActivations(tensor=network_tensor, stride=stride,
- dim=comp.network.layers[index].dim))
- def build_cross_entropy_loss(logits, gold):
- """Constructs a cross entropy from logits and one-hot encoded gold labels.
- Supports skipping rows where the gold label is the magic -1 value.
- Args:
- logits: float Tensor of scores.
- gold: int Tensor of one-hot labels.
- Returns:
- cost, correct, total: the total cost, the total number of correctly
- predicted labels, and the total number of valid labels.
- """
- valid = tf.reshape(tf.where(tf.greater(gold, -1)), [-1])
- gold = tf.gather(gold, valid)
- logits = tf.gather(logits, valid)
- correct = tf.reduce_sum(tf.to_int32(tf.nn.in_top_k(logits, gold, 1)))
- total = tf.size(gold)
- cost = tf.reduce_sum(
- tf.contrib.nn.deprecated_flipped_sparse_softmax_cross_entropy_with_logits(
- logits, tf.cast(gold, tf.int64))) / tf.cast(total, tf.float32)
- return cost, correct, total
- class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
- """A component builder to bulk extract features.
- Both fixed and linked features are supported, with some restrictions:
- 1. Fixed features may not be recurrent. Fixed features are extracted along the
- gold path, which does not work during inference.
- 2. Linked features may not be recurrent and are 'untranslated'. For now,
- linked features are extracted without passing them through any transition
- system or source translator.
- """
- def build_greedy_training(self, state, network_states):
- """Extracts features and advances a batch using the oracle path.
- Args:
- state: MasterState from the 'AdvanceMaster' op that advances the
- underlying master to this component.
- network_states: dictionary of component NetworkState objects
- Returns:
- state handle: final state after advancing
- cost: regularization cost, possibly associated with embedding matrices
- correct: since no gold path is available, 0.
- total: since no gold path is available, 0.
- """
- logging.info('Building component: %s', self.spec.name)
- stride = state.current_batch_size * self.training_beam_size
- with tf.variable_scope(self.name, reuse=True):
- state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
- self, state, stride)
- linked_embeddings = [
- fetch_linked_embedding(self, network_states, spec)
- for spec in self.spec.linked_feature
- ]
- with tf.variable_scope(self.name, reuse=True):
- tensors = self.network.create(
- fixed_embeddings, linked_embeddings, None, None, True, stride=stride)
- update_network_states(self, tensors, network_states, stride)
- cost = self.add_regularizer(tf.constant(0.))
- correct, total = tf.constant(0), tf.constant(0)
- return state.handle, cost, correct, total
- def build_greedy_inference(self, state, network_states,
- during_training=False):
- """Extracts features and advances a batch using the oracle path.
- NOTE(danielandor) For now this method cannot be called during training.
- That is to say, unroll_using_oracle for this component must be set to true.
- This will be fixed by separating train_with_oracle and train_with_inference.
- Args:
- state: MasterState from the 'AdvanceMaster' op that advances the
- underlying master to this component.
- network_states: dictionary of component NetworkState objects
- during_training: whether the graph is being constructed during training
- Returns:
- state handle: final state after advancing
- """
- logging.info('Building component: %s', self.spec.name)
- if during_training:
- stride = state.current_batch_size * self.training_beam_size
- else:
- stride = state.current_batch_size * self.inference_beam_size
- with tf.variable_scope(self.name, reuse=True):
- if during_training:
- state.handle, fixed_embeddings = fetch_differentiable_fixed_embeddings(
- self, state, stride)
- else:
- state.handle, fixed_embeddings = fetch_fast_fixed_embeddings(self,
- state)
- linked_embeddings = [
- fetch_linked_embedding(self, network_states, spec)
- for spec in self.spec.linked_feature
- ]
- with tf.variable_scope(self.name, reuse=True):
- tensors = self.network.create(
- fixed_embeddings,
- linked_embeddings,
- None,
- None,
- during_training=during_training,
- stride=stride)
- update_network_states(self, tensors, network_states, stride)
- return state.handle
- class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
- """A component builder to bulk extract feature IDs.
- This is a variant of BulkFeatureExtractorComponentBuilder that only supports
- fixed features, and extracts raw feature IDs instead of feature embeddings.
- Since the extracted feature IDs are integers, the results produced by this
- component are in general not differentiable.
- """
- def __init__(self, master, component_spec):
- """Initializes the feature ID extractor component.
- Args:
- master: dragnn.MasterBuilder object.
- component_spec: dragnn.ComponentSpec proto to be built.
- """
- super(BulkFeatureIdExtractorComponentBuilder, self).__init__(
- master, component_spec)
- check.Eq(len(self.spec.linked_feature), 0, 'Linked features are forbidden')
- for feature_spec in self.spec.fixed_feature:
- check.Lt(feature_spec.embedding_dim, 0,
- 'Features must be non-embedded: %s' % feature_spec)
- def build_greedy_training(self, state, network_states):
- """See base class."""
- state.handle = self._extract_feature_ids(state, network_states, True)
- cost = self.add_regularizer(tf.constant(0.))
- correct, total = tf.constant(0), tf.constant(0)
- return state.handle, cost, correct, total
- def build_greedy_inference(self, state, network_states,
- during_training=False):
- """See base class."""
- return self._extract_feature_ids(state, network_states, during_training)
- def _extract_feature_ids(self, state, network_states, during_training):
- """Extracts feature IDs and advances a batch using the oracle path.
- Args:
- state: MasterState from the 'AdvanceMaster' op that advances the
- underlying master to this component.
- network_states: Dictionary of component NetworkState objects.
- during_training: Whether the graph is being constructed during training.
- Returns:
- state handle: Final state after advancing.
- """
- logging.info('Building component: %s', self.spec.name)
- if during_training:
- stride = state.current_batch_size * self.training_beam_size
- else:
- stride = state.current_batch_size * self.inference_beam_size
- with tf.variable_scope(self.name, reuse=True):
- state.handle, ids = extract_fixed_feature_ids(self, state, stride)
- with tf.variable_scope(self.name, reuse=True):
- tensors = self.network.create(
- ids, [], None, None, during_training, stride=stride)
- update_network_states(self, tensors, network_states, stride)
- return state.handle
- class BulkAnnotatorComponentBuilder(component.ComponentBuilderBase):
- """A component builder to bulk annotate or compute the cost of a gold path.
- This component can be used with features that don't depend on the
- transition system state.
- Since no feature extraction is performed, only non-recurrent
- 'identity' linked features are supported.
- If a FeedForwardNetwork is configured with no hidden units, this component
- acts as a 'bulk softmax' component.
- """
- def build_greedy_training(self, state, network_states):
- """Advances a batch using oracle paths, returning the overall CE cost.
- Args:
- state: MasterState from the 'AdvanceMaster' op that advances the
- underlying master to this component.
- network_states: dictionary of component NetworkState objects
- Returns:
- (state handle, cost, correct, total): TF ops corresponding to the final
- state after unrolling, the total cost, the total number of correctly
- predicted actions, and the total number of actions.
- Raises:
- RuntimeError: if fixed features are configured.
- """
- logging.info('Building component: %s', self.spec.name)
- if self.spec.fixed_feature:
- raise RuntimeError(
- 'Fixed features are not compatible with bulk annotation. '
- 'Use the "bulk-features" component instead.')
- linked_embeddings = [
- fetch_linked_embedding(self, network_states, spec)
- for spec in self.spec.linked_feature
- ]
- stride = state.current_batch_size * self.training_beam_size
- with tf.variable_scope(self.name, reuse=True):
- network_tensors = self.network.create([], linked_embeddings, None, None,
- True, stride)
- update_network_states(self, network_tensors, network_states, stride)
- logits = self.network.get_logits(network_tensors)
- state.handle, gold = dragnn_ops.bulk_advance_from_oracle(
- state.handle, component=self.name)
- cost, correct, total = build_cross_entropy_loss(logits, gold)
- cost = self.add_regularizer(cost)
- return state.handle, cost, correct, total
- def build_greedy_inference(self, state, network_states,
- during_training=False):
- """Annotates a batch of documents using network scores.
- Args:
- state: MasterState from the 'AdvanceMaster' op that advances the
- underlying master to this component.
- network_states: dictionary of component NetworkState objects
- during_training: whether the graph is being constructed during training
- Returns:
- Handle to the state once inference is complete for this Component.
- Raises:
- RuntimeError: if fixed features are configured
- """
- logging.info('Building component: %s', self.spec.name)
- if self.spec.fixed_feature:
- raise RuntimeError(
- 'Fixed features are not compatible with bulk annotation. '
- 'Use the "bulk-features" component instead.')
- linked_embeddings = [
- fetch_linked_embedding(self, network_states, spec)
- for spec in self.spec.linked_feature
- ]
- if during_training:
- stride = state.current_batch_size * self.training_beam_size
- else:
- stride = state.current_batch_size * self.inference_beam_size
- with tf.variable_scope(self.name, reuse=True):
- network_tensors = self.network.create(
- [], linked_embeddings, None, None, during_training, stride)
- update_network_states(self, network_tensors, network_states, stride)
- logits = self.network.get_logits(network_tensors)
- return dragnn_ops.bulk_advance_from_prediction(
- state.handle, logits, component=self.name)
|