Browse Source

Adding SyntaxNet to tensorflow/models (#63)

calberti 9 năm trước cách đây
mục cha
commit
32ab5a58dd
100 tập tin đã thay đổi với 81002 bổ sung2 xóa
  1. 0 2
      .gitignore
  2. 3 0
      .gitmodules
  3. 7 0
      syntaxnet/.gitignore
  4. 610 0
      syntaxnet/README.md
  5. 38 0
      syntaxnet/WORKSPACE
  6. BIN
      syntaxnet/beam_search_training.png
  7. BIN
      syntaxnet/ff_nn_schematic.png
  8. BIN
      syntaxnet/looping-parser.gif
  9. BIN
      syntaxnet/sawman.png
  10. 637 0
      syntaxnet/syntaxnet/BUILD
  11. 263 0
      syntaxnet/syntaxnet/affix.cc
  12. 155 0
      syntaxnet/syntaxnet/affix.h
  13. 305 0
      syntaxnet/syntaxnet/arc_standard_transitions.cc
  14. 117 0
      syntaxnet/syntaxnet/arc_standard_transitions_test.cc
  15. 53 0
      syntaxnet/syntaxnet/base.h
  16. 893 0
      syntaxnet/syntaxnet/beam_reader_ops.cc
  17. 230 0
      syntaxnet/syntaxnet/beam_reader_ops_test.py
  18. 93 0
      syntaxnet/syntaxnet/conll2tree.py
  19. 156 0
      syntaxnet/syntaxnet/context.pbtxt
  20. 56 0
      syntaxnet/syntaxnet/demo.sh
  21. 57 0
      syntaxnet/syntaxnet/dictionary.proto
  22. 335 0
      syntaxnet/syntaxnet/document_filters.cc
  23. 23 0
      syntaxnet/syntaxnet/document_format.cc
  24. 63 0
      syntaxnet/syntaxnet/document_format.h
  25. 80 0
      syntaxnet/syntaxnet/embedding_feature_extractor.cc
  26. 222 0
      syntaxnet/syntaxnet/embedding_feature_extractor.h
  27. 122 0
      syntaxnet/syntaxnet/feature_extractor.cc
  28. 624 0
      syntaxnet/syntaxnet/feature_extractor.h
  29. 34 0
      syntaxnet/syntaxnet/feature_extractor.proto
  30. 176 0
      syntaxnet/syntaxnet/feature_types.h
  31. 291 0
      syntaxnet/syntaxnet/fml_parser.cc
  32. 113 0
      syntaxnet/syntaxnet/fml_parser.h
  33. 569 0
      syntaxnet/syntaxnet/graph_builder.py
  34. 325 0
      syntaxnet/syntaxnet/graph_builder_test.py
  35. 82 0
      syntaxnet/syntaxnet/kbest_syntax.proto
  36. 248 0
      syntaxnet/syntaxnet/lexicon_builder.cc
  37. 174 0
      syntaxnet/syntaxnet/lexicon_builder_test.py
  38. 23 0
      syntaxnet/syntaxnet/load_parser_ops.py
  39. 189 0
      syntaxnet/syntaxnet/models/parsey_mcparseface/context.pbtxt
  40. 52 0
      syntaxnet/syntaxnet/models/parsey_mcparseface/fine-to-universal.map
  41. 47 0
      syntaxnet/syntaxnet/models/parsey_mcparseface/label-map
  42. BIN
      syntaxnet/syntaxnet/models/parsey_mcparseface/parser-params
  43. BIN
      syntaxnet/syntaxnet/models/parsey_mcparseface/prefix-table
  44. BIN
      syntaxnet/syntaxnet/models/parsey_mcparseface/suffix-table
  45. 50 0
      syntaxnet/syntaxnet/models/parsey_mcparseface/tag-map
  46. BIN
      syntaxnet/syntaxnet/models/parsey_mcparseface/tagger-params
  47. 64037 0
      syntaxnet/syntaxnet/models/parsey_mcparseface/word-map
  48. 274 0
      syntaxnet/syntaxnet/ops/parser_ops.cc
  49. 149 0
      syntaxnet/syntaxnet/parser_eval.py
  50. 213 0
      syntaxnet/syntaxnet/parser_features.cc
  51. 150 0
      syntaxnet/syntaxnet/parser_features.h
  52. 144 0
      syntaxnet/syntaxnet/parser_features_test.cc
  53. 248 0
      syntaxnet/syntaxnet/parser_state.cc
  54. 233 0
      syntaxnet/syntaxnet/parser_state.h
  55. 303 0
      syntaxnet/syntaxnet/parser_trainer.py
  56. 115 0
      syntaxnet/syntaxnet/parser_trainer_test.sh
  57. 30 0
      syntaxnet/syntaxnet/parser_transitions.cc
  58. 208 0
      syntaxnet/syntaxnet/parser_transitions.h
  59. 151 0
      syntaxnet/syntaxnet/populate_test_inputs.cc
  60. 153 0
      syntaxnet/syntaxnet/populate_test_inputs.h
  61. 242 0
      syntaxnet/syntaxnet/proto_io.h
  62. 563 0
      syntaxnet/syntaxnet/reader_ops.cc
  63. 198 0
      syntaxnet/syntaxnet/reader_ops_test.py
  64. 28 0
      syntaxnet/syntaxnet/registry.cc
  65. 243 0
      syntaxnet/syntaxnet/registry.h
  66. 61 0
      syntaxnet/syntaxnet/sentence.proto
  67. 45 0
      syntaxnet/syntaxnet/sentence_batch.cc
  68. 78 0
      syntaxnet/syntaxnet/sentence_batch.h
  69. 192 0
      syntaxnet/syntaxnet/sentence_features.cc
  70. 317 0
      syntaxnet/syntaxnet/sentence_features.h
  71. 155 0
      syntaxnet/syntaxnet/sentence_features_test.cc
  72. 91 0
      syntaxnet/syntaxnet/shared_store.cc
  73. 234 0
      syntaxnet/syntaxnet/shared_store.h
  74. 242 0
      syntaxnet/syntaxnet/shared_store_test.cc
  75. 19 0
      syntaxnet/syntaxnet/sparse.proto
  76. 240 0
      syntaxnet/syntaxnet/structured_graph_builder.py
  77. 107 0
      syntaxnet/syntaxnet/syntaxnet.bzl
  78. 258 0
      syntaxnet/syntaxnet/tagger_transitions.cc
  79. 113 0
      syntaxnet/syntaxnet/tagger_transitions_test.cc
  80. 173 0
      syntaxnet/syntaxnet/task_context.cc
  81. 80 0
      syntaxnet/syntaxnet/task_context.h
  82. 82 0
      syntaxnet/syntaxnet/task_spec.proto
  83. 188 0
      syntaxnet/syntaxnet/term_frequency_map.cc
  84. 117 0
      syntaxnet/syntaxnet/term_frequency_map.h
  85. 45 0
      syntaxnet/syntaxnet/test_main.cc
  86. 87 0
      syntaxnet/syntaxnet/testdata/context.pbtxt
  87. 145 0
      syntaxnet/syntaxnet/testdata/document
  88. 1017 0
      syntaxnet/syntaxnet/testdata/mini-training-set
  89. 399 0
      syntaxnet/syntaxnet/text_formats.cc
  90. 108 0
      syntaxnet/syntaxnet/text_formats_test.py
  91. 111 0
      syntaxnet/syntaxnet/unpack_sparse_features.cc
  92. 260 0
      syntaxnet/syntaxnet/utils.cc
  93. 171 0
      syntaxnet/syntaxnet/utils.h
  94. 50 0
      syntaxnet/syntaxnet/workspace.cc
  95. 215 0
      syntaxnet/syntaxnet/workspace.h
  96. 1 0
      syntaxnet/tensorflow
  97. 34 0
      syntaxnet/third_party/utf/BUILD
  98. 13 0
      syntaxnet/third_party/utf/README
  99. 357 0
      syntaxnet/third_party/utf/rune.c
  100. 0 0
      syntaxnet/third_party/utf/runestrcat.c

+ 0 - 2
.gitignore

@@ -1,2 +0,0 @@
-autoencoder/MNIST_data/*
-*.pyc

+ 3 - 0
.gitmodules

@@ -0,0 +1,3 @@
+[submodule "tensorflow"]
+	path = syntaxnet/tensorflow
+	url = https://github.com/tensorflow/tensorflow.git

+ 7 - 0
syntaxnet/.gitignore

@@ -0,0 +1,7 @@
+/bazel-bin
+/bazel-genfiles
+/bazel-out
+/bazel-tensorflow
+/bazel-testlogs
+/bazel-tf
+/bazel-syntaxnet

+ 610 - 0
syntaxnet/README.md

@@ -0,0 +1,610 @@
+# SyntaxNet: Neural Models of Syntax.
+
+*A TensorFlow implementation of the models described in [Andor et al. (2016)]
+(http://arxiv.org/pdf/1603.06042v1.pdf).*
+
+At Google, we spend a lot of time thinking about how computer systems can read
+and understand human language in order to process it in intelligent ways. We are
+excited to share the fruits of our research with the broader community by
+releasing SyntaxNet, an open-source neural network framework for [TensorFlow]
+(http://www.tensorflow.org) that provides a foundation for Natural Language
+Understanding (NLU) systems. Our release includes all the code needed to train
+new SyntaxNet models on your own data, as well as *Parsey McParseface*, an
+English parser that we have trained for you, and that you can use to analyze
+English text.
+
+So, how accurate is Parsey McParseface? For this release, we tried to balance a
+model that runs fast enough to be useful on a single machine (e.g. ~600
+words/second on a modern desktop) and that is also the most accurate parser
+available. Here's how Parsey McParseface compares to the academic literature on
+several different English domains: (all numbers are % correct head assignments
+in the tree, or unlabelled attachment score)
+
+Model                                                                                                           | News  | Web   | Questions
+--------------------------------------------------------------------------------------------------------------- | :---: | :---: | :-------:
+[Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/)                                                | 93.10 | 88.23 | 94.21
+[Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf)                                  | 93.32 | 88.65 | 93.37
+[Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17
+[Andor et al. (2016)](http://arxiv.org/pdf/1603.06042v1.pdf)*                                                   | 94.44 | 90.17 | 95.40
+Parsey McParseface                                                                                              | 94.15 | 89.08 | 94.77
+
+We see that Parsey McParseface is state-of-the-art; more importantly, with
+SyntaxNet you can train larger networks with more hidden units and bigger beam
+sizes if you want to push the accuracy even further: [Andor et al. (2016)]
+(http://arxiv.org/pdf/1603.06042v1.pdf)* is simply a SyntaxNet model with a
+larger beam and network. For futher information on the datasets, see that paper
+under the section "Treebank Union".
+
+Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging
+(numbers below are per-token accuracy):
+
+Model                                                                      | News  | Web   | Questions
+-------------------------------------------------------------------------- | :---: | :---: | :-------:
+[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.78 | 94.03 | 96.18
+[Andor et al. (2016)](http://arxiv.org/pdf/1603.06042v1.pdf)*              | 97.77 | 94.80 | 96.86
+Parsey McParseface                                                         | 97.52 | 94.24 | 96.45
+
+The first part of this tutorial describes how to install the necessary tools and
+use the already trained models provided in this release. In the second part of
+the tutorial we provide more background about the models, as well as
+instructions for training models on other datasets.
+
+## Contents
+* [Installation](#installation)
+* [Getting Started](#getting-started)
+    * [Parsing from Standard Input](#parsing-from-standard-input)
+    * [Annotating a Corpus](#annotating-a-corpus)
+    * [Configuring the Python Scripts](#configuring-the-python-scripts)
+    * [Next Steps](#next-steps)
+* [Detailed Tutorial: Building an NLP Pipeline with SyntaxNet](#detailed-tutorial-building-an-nlp-pipeline-with-syntaxnet)
+    * [Obtaining Data](#obtaining-data)
+    * [Part-of-Speech Tagging](#part-of-speech-tagging)
+    * [Training the SyntaxNet POS Tagger](#training-the-syntaxnet-pos-tagger)
+    * [Preprocessing with the Tagger](#preprocessing-with-the-tagger)
+    * [Dependency Parsing: Transition-Based Parsing](#dependency-parsing-transition-based-parsing)
+    * [Training a Parser Step 1: Local Pretraining](#training-a-parser-step-1-local-pretraining)
+    * [Training a Parser Step 2: Global Training](#training-a-parser-step-2-global-training)
+* [Contact](#contact)
+* [Credits](#credits)
+
+## Installation
+
+Running and training SyntaxNet models requires building this package from
+source. You'll need to install:
+
+*   bazel:
+    *   follow the instructions [here](http://bazel.io/docs/install.html)
+    *   **Note: You must use bazel version 0.2.2, NOT 0.2.2b, due to a WORKSPACE
+        issue**
+*   swig:
+    *   `apt-get install swig` on Ubuntu
+    *   `brew install swig` on OSX
+*   protocol buffers, with a version supported by TensorFlow:
+    *   check your protobuf version with `pip freeze | grep protobuf1`
+    *   upgrade to a supported version with `pip install -U protobuf==3.0.0b2`
+*   asciitree, to draw parse trees on the console for the demo:
+    *   `pip install asciitree`
+
+Once you completed the above steps, you can build and test SyntaxNet with the
+following commands:
+
+```shell
+  git clone --recursive https://github.com/tensorflow/models.git
+  cd models/syntaxnet/tensorflow
+  ./configure
+  cd ..
+  bazel test syntaxnet/... util/utf8/...
+  # On Mac, run the following:
+  bazel test --linkopt=-headerpad_max_install_names \
+    syntaxnet/... util/utf8/...
+```
+
+Bazel should complete reporting all tests passed.
+
+## Getting Started
+
+Once you have successfully built SyntaxNet, you can start parsing text right
+away with Parsey McParseface, located under `syntaxnet/models`. The easiest
+thing is to use or modify the included script `syntaxnet/demo.sh`, which shows a
+basic setup to parse English taking plain text as input.
+
+### Parsing from Standard Input
+
+Simply pass one sentence per line of text into the script at
+`syntaxnet/demo.sh`. The script will break the text into words, run the POS
+tagger, run the parser, and then generate an ASCII version of the parse tree:
+
+```shell
+echo 'Bob brought the pizza to Alice.' | syntaxnet/demo.sh
+
+Input: Bob brought the pizza to Alice .
+Parse:
+brought VBD ROOT
+ +-- Bob NNP nsubj
+ +-- pizza NN dobj
+ |   +-- the DT det
+ +-- to IN prep
+ |   +-- Alice NNP pobj
+ +-- . . punct
+```
+
+The ASCII tree shows the text organized as in the parse, not left-to-right as
+visualized in our tutorial graphs. In this example, we see that the verb
+"brought" is the root of the sentence, with the subject "Bob", the object
+"pizza", and the prepositional phrase "to Alice".
+
+If you want to feed in tokenized, CONLL-formatted text, you can run `demo.sh
+--conll`.
+
+### Annotating a Corpus
+
+To change the pipeline to read and write to specific files (as opposed to piping
+through stdin and stdout), we have to modify the `demo.sh` to point to the files
+we want. The SyntaxNet models are configured via a combination of run-time flags
+(which are easy to change) and a text format `TaskSpec` protocol buffer. The
+spec file used in the demo is in `syntaxnet/models/treebank_union/context`.
+
+To use corpora instead of stdin/stdout, we have to:
+
+1.  Create or modify a `input` field inside the `TaskSpec`, with the
+    `file_pattern` specifying the location we want. If the input corpus is in
+    CONLL format, make sure to put `record_format: 'conll-sentence'`.
+1.  Change the `--input` and/or `--output` flag to use the name of the resource
+    as the output, instead of `stdin` and `stdout`.
+
+E.g., if we wanted to POS tag the CONLL corpus `./wsj.conll`, we would create
+two entries, one for the input and one for the output:
+
+```protosame
+input {
+  name: 'wsj-data'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: './wsj.conll'
+  }
+}
+input {
+  name: 'wsj-data-tagged'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: './wsj-tagged.conll'
+  }
+}
+```
+
+Then we can use `--input=wsj-data --output=wsj-data-tagged` on the command line
+to specify reading and writing to these files.
+
+### Configuring the Python Scripts
+
+As mentioned above, the python scripts are configured in two ways:
+
+1.  **Run-time flags** are used to point to the `TaskSpec` file, switch between
+    inputs for reading and writing, and set various run-time model parameters.
+    At training time, these flags are used to set the learning rate, hidden
+    layer sizes, and other key parameters.
+1.  The **`TaskSpec` proto** stores configuration about the transition system,
+    the features, and a set of named static resources required by the parser. It
+    is specified via the `--task_context` flag. A few key notes to remember:
+
+    -   The `Parameter` settings in the `TaskSpec` have a prefix: either
+        `brain_pos` (they apply to the tagger) or `brain_parser` (they apply to
+        the parser). The `--prefix` run-time flag switches between reading from
+        the two configurations.
+    -   The resources will be created and/or modified during multiple stages of
+        training. As described above, the resources can also be used at
+        evaluation time to read or write to specific files. These resources are
+        also separate from the model parameters, which are saved separately via
+        calls to TensorFlow ops, and loaded via the `--model_path` flag.
+    -   Because the `TaskSpec` contains file path, remember that copying around
+        this file is not enough to relocate a trained model: you need up move
+        and update all the paths as well.
+
+Note that some run-time flags need to be consistent between training and testing
+(e.g. the number of hidden units).
+
+### Next Steps
+
+There are many ways to extend this framework, e.g. adding new features, changing
+the model structure, training on other languages, etc. We suggest reading the
+detailed tutorial below to get a handle on the rest of the framework.
+
+## Detailed Tutorial: Building an NLP Pipeline with SyntaxNet
+
+In this tutorial, we'll go over how to train new models, and explain in a bit
+more technical detail the NLP side of the models. Our goal here is to explain
+the NLP pipeline produced by this package.
+
+### Obtaining Data
+
+The included English parser, Parsey McParseface, was trained on the the standard
+corpora of the [Penn Treebank](https://catalog.ldc.upenn.edu/LDC99T42) and
+[OntoNotes](https://catalog.ldc.upenn.edu/LDC2013T19), as well as the [English
+Web Treebank](https://catalog.ldc.upenn.edu/LDC2012T13), but these are
+unfortunately not freely available.
+
+However, the [Universal Dependencies](http://universaldependencies.org/) project
+provides freely available treebank data in a number of languages. SyntaxNet can
+be trained and evaluated on any of these corpora.
+
+### Part-of-Speech Tagging
+
+Consider the following sentence, which exhibits several ambiguities that affect
+its interpretation:
+
+> I saw the man with glasses.
+
+This sentence is composed of words: strings of characters that are segmented
+into groups (e.g. "I", "saw", etc.) Each word in the sentence has a *grammatical
+function* that can be useful for understanding the meaning of language. For
+example, "saw" in this example is a past tense of the verb "to see". But any
+given word might have different meanings in different contexts: "saw" could just
+as well be a noun (e.g., a saw used for cutting) or a present tense verb (using
+a saw to cut something).
+
+A logical first step in understanding language is figuring out these roles for
+each word in the sentence. This process is called *Part-of-Speech (POS)
+Tagging*. The roles are called POS tags. Although a given word might have
+multiple possible tags depending on the context, given any one interpretation of
+a sentence each word will generally only have one tag.
+
+One interesting challenge of POS tagging is that the problem of defining a
+vocabulary of POS tags for a given language is quite involved. While the concept
+of nouns and verbs is pretty common, it has been traditionally difficult to
+agree on a standard set of roles across all languages. The [Universal
+Dependencies](http://www.universaldependencies.org) project aims to solve this
+problem.
+
+### Training the SyntaxNet POS Tagger
+
+In general, determining the correct POS tag requires understanding the entire
+sentence and the context in which it is uttered. In practice, we can do very
+well just by considering a small window of words around the word of interest.
+For example, words that follow the word ‘the’ tend to be adjectives or nouns,
+rather than verbs.
+
+To predict POS tags, we use a simple setup. We processes the sentences
+left-to-right. For any given word, we extract features of that word and a window
+around it, and use these as inputs to a feed-forward neural network classifier,
+which predicts a probability distribution over POS tags. Because we make
+decisions in left-to-right order, we also use prior decisions as features in
+subsequent ones (e.g. "the previous predicted tag was a noun.").
+
+All the models in this package use a flexible markup language to define
+features. For example, the features in the POS tagger are found in the
+`brain_pos_features` parameter in the `TaskSpec`, and look like this (modulo
+spacing):
+
+```
+stack(3).word stack(2).word stack(1).word stack.word input.word input(1).word input(2).word input(3).word;
+input.digit input.hyphen;
+stack.suffix(length=2) input.suffix(length=2) input(1).suffix(length=2);
+stack.prefix(length=2) input.prefix(length=2) input(1).prefix(length=2)
+```
+
+Note that `stack` here means "words we have already tagged." Thus, this feature
+spec uses three types of features: words, suffixes, and prefixes. The features
+are grouped into blocks that share an embedding matrix, concatenated together,
+and fed into a chain of hidden layers. This structure is based upon the model
+proposed by [Chen and Manning (2014)]
+(http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf).
+
+We show this layout in the schematic below: the state of the system (a stack and
+a buffer, visualized below for both the POS and the dependency parsing task) is
+used to extract sparse features, which are fed into the network in groups. We
+show only a small subset of the features to simplify the presentation in the
+schematic:
+
+![Schematic](ff_nn_schematic.png "Feed-forward Network Structure")
+
+In the configuration above, each block gets its own embedding matrix and the
+blocks in the configuration above are delineated with a semi-colon. The
+dimensions of each block are controlled in the `brain_pos_embedding_dims`
+parameter. **Important note:** unlike many simple NLP models, this is *not* a
+bag of words model. Remember that although certain features share embedding
+matrices, the above features will be concatenated, so the interpretation of
+`input.word` will be quite different from `input(1).word`. This also means that
+adding features increases the dimension of the `concat` layer of the model as
+well as the number of parameters for the first hidden layer.
+
+To train the model, first edit `syntaxnet/context.pbtxt` so that the inputs
+`training-corpus`, `tuning-corpus`, and `dev-corpus` point to the location of
+your training data. You can then train a part-of-speech tagger with:
+
+```shell
+bazel-bin/syntaxnet/parser_trainer \
+  --task_context=syntaxnet/context.pbtxt \
+  --arg_prefix=brain_pos \  # read from POS configuration
+  --compute_lexicon \       # required for first stage of pipeline
+  --graph_builder=greedy \  # no beam search
+  --training_corpus=training-corpus \  # names of training/tuning set
+  --tuning_corpus=tuning-corpus \
+  --output_path=models \  # where to save new resources
+  --batch_size=32 \       # Hyper-parameters
+  --decay_steps=3600 \
+  --hidden_layer_sizes=128 \
+  --learning_rate=0.08 \
+  --momentum=0.9 \
+  --seed=0 \
+  --params=128-0.08-3600-0.9-0  # name for these parameters
+```
+
+This will read in the data, construct a lexicon, build a tensorflow graph for
+the model with the specific hyperparameters, and train the model. Every so often
+the model will be evaluated on the tuning set, and only the checkpoint with the
+highest accuracy on this set will be saved. **Note that you should never use a
+corpus you intend to test your model on as your tuning set, as you will inflate
+your test set results.**
+
+For best results, you should repeat this command with at least 3 different
+seeds, and possibly with a few different values for `--learning_rate` and
+`--decay_steps`. Good values for `--learning_rate` are usually close to 0.1, and
+you usually want `--decay_steps` to correspond to about one tenth of your
+corpus. The `--params` flag is only a human readable identifier for the model
+being trained, used to construct the full output path, so that you don't need to
+worry about clobbering old models by accident.
+
+The `--arg_prefix` flag controls which parameters should be read from the task
+context file `context.pbtxt`. In this case `arg_prefix` is set to `brain_pos`,
+so the paramters being used in this training run are
+`brain_pos_transition_system`, `brain_pos_embedding_dims`, `brain_pos_features`
+and, `brain_pos_embedding_names`. To train the dependency parser later
+`arg_prefix` will be set to `brain_parser`.
+
+### Preprocessing with the Tagger
+
+Now that we have a trained POS tagging model, we want to use the output of this
+model as features in the parser. Thus the next step is to run the trained model
+over our training, tuning, and dev (evaluation) sets. We can use the
+parser_eval.py` script for this.
+
+For example, the model `128-0.08-3600-0.9-0` trained above can be run over the
+training, tuning, and dev sets with the following command:
+
+```shell
+PARAMS=128-0.08-3600-0.9-0
+for SET in training tuning dev; do
+  bazel-bin/syntaxnet/parser_eval \
+    --task_context=models/brain_pos/greedy/$PARAMS/context \
+    --hidden_layer_sizes=128 \
+    --input=$SET-corpus \
+    --output=tagged-$SET-corpus \
+    --arg_prefix=brain_pos \
+    --graph_builder=greedy \
+    --model_path=models/brain_pos/greedy/$PARAMS/model
+done
+```
+
+**Important note:** This command only works because we have created entries for
+you in `context.pbtxt` that correspond to `tagged-training-corpus`,
+`tagged-dev-corpus`, and `tagged-tuning-corpus`. From these default settings,
+the above will write tagged versions of the training, tuning, and dev set to the
+directory `models/brain_pos/greedy/$PARAMS/`. This location is chosen because
+the `input` entries do not have `file_pattern` set: instead, they have `creator:
+brain_pos/greedy`, which means that `parser_trainer.py` will construct *new*
+files when called with `--arg_prefix=brain_pos --graph_builder=greedy` using the
+`--model_path` flag to determine the location.
+
+For convenience, `parser_eval.py` also logs POS tagging accuracy after the
+output tagged datasets have been written.
+
+### Dependency Parsing: Transition-Based Parsing
+
+Now that we have a prediction for the grammatical role of the words, we want to
+understand how the words in the sentence relate to each other. This parser is
+built around the *head-modifier* construction: for each word, we choose a
+*syntactic head* that it modifies according to some grammatical role.
+
+An example for the above sentence is as follows:
+
+![Figure](sawman.png)
+
+Below each word in the sentence we see both a fine-grained part-of-speech
+(*PRP*, *VBD*, *DT*, *NN* etc.), and a coarse-grained part-of-speech (*PRON*,
+*VERB*, *DET*, *NOUN*, etc.). Coarse-grained POS tags encode basic grammatical
+categories, while the fine-grained POS tags make further distinctions: for
+example *NN* is a singular noun (as opposed, for example, to *NNS*, which is a
+plural noun), and *VBD* is a past-tense verb. For more discussion see [Petrov et
+al. (2012)](http://www.lrec-conf.org/proceedings/lrec2012/pdf/274_Paper.pdf).
+
+Crucially, we also see directed arcs signifying grammatical relationships
+between different words in the sentence. For example *I* is the subject of
+*saw*, as signified by the directed arc labeled *nsubj* between these words;
+*man* is the direct object (dobj) of *saw*; the preposition *with* modifies
+*man* with a prep relation, signifiying modification by a prepositional phrase;
+and so on. In addition the verb *saw* is identified as the *root* of the entire
+sentence.
+
+Whenever we have a directed arc between two words, we refer to the word at the
+start of the arc as the *head*, and the word at the end of the arc as the
+*modifier*. For example we have one arc where the head is *saw* and the modifier
+is *I*, another where the head is *saw* and the modifier is *man*, and so on.
+
+The grammatical relationships encoded in dependency structures are directly
+related to the underlying meaning of the sentence in question. They allow us to
+easily recover the answers to various questions, for example *whom did I see?*,
+*who saw the man with glasses?*, and so on.
+
+SyntaxNet is a **transition-based** dependency parser [Nivre (2007)]
+(http://www.mitpressjournals.org/doi/pdfplus/10.1162/coli.07-056-R1-07-027) that
+constructs a parse incrementally. Like the tagger, it processes words
+left-to-right. The words all start as unprocessed input, called the *buffer*. As
+words are encountered they are put onto a *stack*. At each step, the parser can
+do one of three things:
+
+1.  **SHIFT:** Push another word onto the top of the stack, i.e. shifting one
+    token from the buffer to the stack.
+1.  **LEFT_ARC:** Pop the top two words from the stack. Attach the second to the
+    first, creating an arc pointing to the **left**. Push the **first** word
+    back on the stack.
+1.  **RIGHT_ARC:** Pop the top two words from the stack. Attach the second to
+    the first, creating an arc point to the **right**. Push the **second** word
+    back on the stack.
+
+At each step, we call the combination of the stack and the buffer the
+*configuration* of the parser. For the left and right actions, we also assign a
+dependency relation label to that arc. This process is visualized in the
+following animation for a short sentence:
+
+![Animation](looping-parser.gif "Parsing in Action")
+
+Note that this parser is following a sequence of actions, called a
+**derivation**, to produce a "gold" tree labeled by a linguist. We can use this
+sequence of decisions to learn a classifier that takes a configuration and
+predicts the next action to take.
+
+### Training a Parser Step 1: Local Pretraining
+
+As described in our [paper](http://arxiv.org/pdf/1603.06042v1.pdf), the first
+step in training the model is to *pre-train* using *local* decisions. In this
+phase, we use the gold dependency to guide the parser, and train a softmax layer
+to predict the correct action given these gold dependencies. This can be
+performed very efficiently, since the parser's decisions are all independent in
+this setting.
+
+Once the tagged datasets are available, a locally normalized dependency parsing
+model can be trained with the following command:
+
+```shell
+bazel-bin/syntaxnet/parser_trainer \
+  --arg_prefix=brain_parser \
+  --batch_size=32 \
+  --projectivize_training_set \
+  --decay_steps=4400 \
+  --graph_builder=greedy \
+  --hidden_layer_sizes=200,200 \
+  --learning_rate=0.08 \
+  --momentum=0.85 \
+  --output_path=models \
+  --task_context=models/brain_pos/greedy/$PARAMS/context \
+  --seed=4 \
+  --training_corpus=tagged-training-corpus \
+  --tuning_corpus=tagged-tuning-corpus \
+  --params=200x200-0.08-4400-0.85-4
+```
+
+Note that we point the trainer to the context corresponding to the POS tagger
+that we picked previously. This allows the parser to reuse the lexicons and the
+tagged datasets that were created in the previous steps. Processing data can be
+done similarly to how tagging was done above. For example if in this case we
+picked parameters `200x200-0.08-4400-0.85-4`, the training, tuning and dev sets
+can be parsed with the following command:
+
+```shell
+PARAMS=200x200-0.08-4400-0.85-4
+for SET in training tuning dev; do
+  bazel-bin/syntaxnet/parser_eval \
+    --task_context=models/brain_parser/greedy/$PARAMS/context \
+    --hidden_layer_sizes=200,200 \
+    --input=tagged-$SET-corpus \
+    --output=parsed-$SET-corpus \
+    --arg_prefix=brain_parser \
+    --graph_builder=greedy \
+    --model_path=models/brain_parser/greedy/$PARAMS/model
+done
+```
+
+### Training a Parser Step 2: Global Training
+
+As we describe in the paper, there are several problems with the locally
+normalized models we just trained. The most important is the *label-bias*
+problem: the model doesn't learn what a good parse looks like, only what action
+to take given a history of gold decisions. This is because the scores are
+normalized *locally* using a softmax for each decision.
+
+In the paper, we show how we can achieve much better results using a *globally*
+normalized model: in this model, the softmax scores are summed in log space, and
+the scores are not normalized until we reach a final decision. When the parser
+stops, the scores of each hypothesis are normalized against a small set of
+possible parses (in the case of this model, a beam size of 8). When training, we
+force the parser to stop during parsing when the gold derivation falls off the
+beam (a strategy known as early-updates).
+
+We give a simplified view of how this training works for a [garden path
+sentence](https://en.wikipedia.org/wiki/Garden_path_sentence), where it is
+important to maintain multiple hypotheses. A single mistake early on in parsing
+leads to a completely incorrect parse; after training, the model learns to
+prefer the second (correct) parse.
+
+![Beam search training](beam_search_training.png)
+
+Parsey McParseface correctly parses this sentence. Even though the correct parse
+is initially ranked 4th out of multiple hypotheses, when the end of the garden
+path is reached, Parsey McParseface can recover due to the beam; using a larger
+beam will get a more accurate model, but it will be slower (we used beam 32 for
+the models in the paper).
+
+Once you have the pre-trained locally normalized model, a globally normalized
+parsing model can now be trained with the following command:
+
+```shell
+bazel-bin/syntaxnet/parser_trainer \
+  --arg_prefix=brain_parser \
+  --batch_size=8 \
+  --decay_steps=100 \
+  --graph_builder=structured \
+  --hidden_layer_sizes=200,200 \
+  --learning_rate=0.02 \
+  --momentum=0.9 \
+  --output_path=models \
+  --task_context=models/brain_parser/greedy/$PARAMS/context \
+  --seed=0 \
+  --training_corpus=projectivized-training-corpus \
+  --tuning_corpus=tagged-tuning-corpus \
+  --params=200x200-0.02-100-0.9-0 \
+  --pretrained_params=models/brain_parser/greedy/$PARAMS/model \
+  --pretrained_params_names=\
+embedding_matrix_0,embedding_matrix_1,embedding_matrix_2,\
+bias_0,weights_0,bias_1,weights_1
+```
+
+Training a beam model with the structured builder will take a lot longer than
+the greedy training runs above, perhaps 3 or 4 times longer. Note once again
+that multiple restarts of training will yield the most reliable results.
+Evaluation can again be done with `parser_eval.py`. In this case we use
+parameters `200x200-0.02-100-0.9-0` to evaluate on the training, tuning and dev
+sets with the following command:
+
+```shell
+PARAMS=200x200-0.02-100-0.9-0
+for SET in training tuning dev; do
+  bazel-bin/syntaxnet/parser_eval \
+    --task_context=models/brain_parser/structured/$PARAMS/context \
+    --hidden_layer_sizes=200,200 \
+    --input=tagged-$SET-corpus \
+    --output=beam-parsed-$SET-corpus \
+    --arg_prefix=brain_parser \
+    --graph_builder=structured \
+    --model_path=models/brain_parser/structured/$PARAMS/model
+done
+```
+
+Hooray! You now have your very own cousin of Parsey McParseface, ready to go out
+and parse text in the wild.
+
+## Contact
+
+To ask questions or report issues please contact syntaxnet-users@google.com.
+
+## Credits
+
+Original authors of the code in this package include (in alphabetical order):
+
+*   apresta@google.com (Alessandro Presta)
+*   bohnetbd@google.com (Bernd Bohnet)
+*   chrisalberti@google.com (Chris Alberti)
+*   credo@google.com (Tim Credo)
+*   danielandor@google.com (Daniel Andor)
+*   djweiss@google.com (David Weiss)
+*   epitler@google.com (Emily Pitler)
+*   gcoppola@google.com (Greg Coppola)
+*   golding@google.com (Andy Golding)
+*   istefan@google.com (Stefan Istrate)
+*   kbhall@google.com (Keith Hall)
+*   kuzman@google.com (Kuzman Ganchev)
+*   mjcollins@google.com (Michael Collins)
+*   ringgaard@google.com (Michael Ringgaard)
+*   ryanmcd@google.com (Ryan McDonald)
+*   severyn@google.com (Aliaksei Severyn)
+*   slav@google.com (Slav Petrov)
+*   terrykoo@google.com (Terry Koo)

+ 38 - 0
syntaxnet/WORKSPACE

@@ -0,0 +1,38 @@
+local_repository(
+  name = "tf",
+  path = __workspace_dir__ + "/tensorflow",
+)
+
+load('//tensorflow/tensorflow:workspace.bzl', 'tf_workspace')
+tf_workspace("tensorflow/", "@tf")
+
+# Specify the minimum required Bazel version.
+load("@tf//tensorflow:tensorflow.bzl", "check_version")
+check_version("0.2.0")
+
+# ===== gRPC dependencies =====
+
+bind(
+    name = "libssl",
+    actual = "@boringssl_git//:ssl",
+)
+
+git_repository(
+    name = "boringssl_git",
+    commit = "436432d849b83ab90f18773e4ae1c7a8f148f48d",
+    init_submodules = True,
+    remote = "https://github.com/mdsteele/boringssl-bazel.git",
+)
+
+bind(
+    name = "zlib",
+    actual = "@zlib_archive//:zlib",
+)
+
+new_http_archive(
+    name = "zlib_archive",
+    build_file = "zlib.BUILD",
+    sha256 = "879d73d8cd4d155f31c1f04838ecd567d34bebda780156f0e82a20721b3973d5",
+    strip_prefix = "zlib-1.2.8",
+    url = "http://zlib.net/zlib128.zip",
+)

BIN
syntaxnet/beam_search_training.png


BIN
syntaxnet/ff_nn_schematic.png


BIN
syntaxnet/looping-parser.gif


BIN
syntaxnet/sawman.png


+ 637 - 0
syntaxnet/syntaxnet/BUILD

@@ -0,0 +1,637 @@
+# Description:
+# A syntactic parser and part-of-speech tagger in TensorFlow.
+
+package(
+    default_visibility = ["//visibility:private"],
+    features = ["-layering_check"],
+)
+
+licenses(["notice"])  # Apache 2.0
+
+load(
+    "syntaxnet",
+    "tf_proto_library",
+    "tf_proto_library_py",
+    "tf_gen_op_libs",
+    "tf_gen_op_wrapper_py",
+)
+
+# proto libraries
+
+tf_proto_library(
+    name = "feature_extractor_proto",
+    srcs = ["feature_extractor.proto"],
+)
+
+tf_proto_library(
+    name = "sentence_proto",
+    srcs = ["sentence.proto"],
+)
+
+tf_proto_library_py(
+    name = "sentence_py_pb2",
+    srcs = ["sentence.proto"],
+)
+
+tf_proto_library(
+    name = "dictionary_proto",
+    srcs = ["dictionary.proto"],
+)
+
+tf_proto_library_py(
+    name = "dictionary_py_pb2",
+    srcs = ["dictionary.proto"],
+)
+
+tf_proto_library(
+    name = "kbest_syntax_proto",
+    srcs = ["kbest_syntax.proto"],
+    deps = [":sentence_proto"],
+)
+
+tf_proto_library(
+    name = "task_spec_proto",
+    srcs = ["task_spec.proto"],
+)
+
+tf_proto_library_py(
+    name = "task_spec_py_pb2",
+    srcs = ["task_spec.proto"],
+)
+
+tf_proto_library(
+    name = "sparse_proto",
+    srcs = ["sparse.proto"],
+)
+
+tf_proto_library_py(
+    name = "sparse_py_pb2",
+    srcs = ["sparse.proto"],
+)
+
+# cc libraries for feature extraction and parsing
+
+cc_library(
+    name = "base",
+    hdrs = ["base.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        "@re2//:re2",
+        "@tf//google/protobuf",
+        "@tf//third_party/eigen3",
+    ] + select({
+        "//conditions:default": [
+            "@tf//tensorflow/core:framework",
+            "@tf//tensorflow/core:lib",
+        ],
+        "@tf//tensorflow:darwin": [
+            "@tf//tensorflow/core:framework_headers_lib",
+        ],
+    }),
+)
+
+cc_library(
+    name = "utils",
+    srcs = ["utils.cc"],
+    hdrs = [
+        "utils.h",
+    ],
+    deps = [
+        ":base",
+        "//util/utf8:unicodetext",
+    ],
+)
+
+cc_library(
+    name = "test_main",
+    testonly = 1,
+    srcs = ["test_main.cc"],
+    linkopts = ["-lm"],
+    deps = [
+        "@tf//tensorflow/core:lib",
+        "@tf//tensorflow/core:testlib",
+        "//external:gtest",
+    ],
+)
+
+cc_library(
+    name = "document_format",
+    srcs = ["document_format.cc"],
+    hdrs = ["document_format.h"],
+    deps = [
+        ":registry",
+        ":sentence_proto",
+        ":task_context",
+    ],
+)
+
+cc_library(
+    name = "text_formats",
+    srcs = ["text_formats.cc"],
+    deps = [
+        ":document_format",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "fml_parser",
+    srcs = ["fml_parser.cc"],
+    hdrs = ["fml_parser.h"],
+    deps = [
+        ":feature_extractor_proto",
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "proto_io",
+    hdrs = ["proto_io.h"],
+    deps = [
+        ":feature_extractor_proto",
+        ":fml_parser",
+        ":kbest_syntax_proto",
+        ":sentence_proto",
+        ":task_context",
+    ],
+)
+
+cc_library(
+    name = "feature_extractor",
+    srcs = ["feature_extractor.cc"],
+    hdrs = [
+        "feature_extractor.h",
+        "feature_types.h",
+    ],
+    deps = [
+        ":document_format",
+        ":feature_extractor_proto",
+        ":kbest_syntax_proto",
+        ":proto_io",
+        ":sentence_proto",
+        ":task_context",
+        ":utils",
+        ":workspace",
+    ],
+)
+
+cc_library(
+    name = "affix",
+    srcs = ["affix.cc"],
+    hdrs = ["affix.h"],
+    deps = [
+        ":dictionary_proto",
+        ":feature_extractor",
+        ":shared_store",
+        ":term_frequency_map",
+        ":utils",
+        ":workspace",
+    ],
+)
+
+cc_library(
+    name = "sentence_features",
+    srcs = ["sentence_features.cc"],
+    hdrs = ["sentence_features.h"],
+    deps = [
+        ":affix",
+        ":feature_extractor",
+        ":registry",
+    ],
+)
+
+cc_library(
+    name = "shared_store",
+    srcs = ["shared_store.cc"],
+    hdrs = ["shared_store.h"],
+    deps = [
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "registry",
+    srcs = ["registry.cc"],
+    hdrs = ["registry.h"],
+    deps = [
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "workspace",
+    srcs = ["workspace.cc"],
+    hdrs = ["workspace.h"],
+    deps = [
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "task_context",
+    srcs = ["task_context.cc"],
+    hdrs = ["task_context.h"],
+    deps = [
+        ":task_spec_proto",
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "term_frequency_map",
+    srcs = ["term_frequency_map.cc"],
+    hdrs = ["term_frequency_map.h"],
+    visibility = ["//visibility:public"],
+    deps = [
+        ":utils",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "parser_transitions",
+    srcs = [
+        "arc_standard_transitions.cc",
+        "parser_state.cc",
+        "parser_transitions.cc",
+        "tagger_transitions.cc",
+    ],
+    hdrs = [
+        "parser_state.h",
+        "parser_transitions.h",
+    ],
+    deps = [
+        ":kbest_syntax_proto",
+        ":registry",
+        ":shared_store",
+        ":task_context",
+        ":term_frequency_map",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "populate_test_inputs",
+    testonly = 1,
+    srcs = ["populate_test_inputs.cc"],
+    hdrs = ["populate_test_inputs.h"],
+    deps = [
+        ":dictionary_proto",
+        ":sentence_proto",
+        ":task_context",
+        ":term_frequency_map",
+        ":test_main",
+    ],
+)
+
+cc_library(
+    name = "parser_features",
+    srcs = ["parser_features.cc"],
+    hdrs = ["parser_features.h"],
+    deps = [
+        ":affix",
+        ":feature_extractor",
+        ":parser_transitions",
+        ":registry",
+        ":sentence_features",
+        ":sentence_proto",
+        ":task_context",
+        ":term_frequency_map",
+        ":workspace",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "embedding_feature_extractor",
+    srcs = ["embedding_feature_extractor.cc"],
+    hdrs = ["embedding_feature_extractor.h"],
+    deps = [
+        ":feature_extractor",
+        ":parser_features",
+        ":parser_transitions",
+        ":sparse_proto",
+        ":task_context",
+        ":workspace",
+    ],
+)
+
+cc_library(
+    name = "sentence_batch",
+    srcs = ["sentence_batch.cc"],
+    hdrs = ["sentence_batch.h"],
+    deps = [
+        ":embedding_feature_extractor",
+        ":feature_extractor",
+        ":parser_features",
+        ":parser_transitions",
+        ":sparse_proto",
+        ":task_context",
+        ":task_spec_proto",
+        ":term_frequency_map",
+        ":workspace",
+    ],
+)
+
+cc_library(
+    name = "reader_ops",
+    srcs = [
+        "beam_reader_ops.cc",
+        "reader_ops.cc",
+    ],
+    deps = [
+        ":parser_features",
+        ":parser_transitions",
+        ":sentence_batch",
+        ":sentence_proto",
+        ":task_context",
+        ":task_spec_proto",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "document_filters",
+    srcs = ["document_filters.cc"],
+    deps = [
+        ":document_format",
+        ":parser_features",
+        ":parser_transitions",
+        ":sentence_batch",
+        ":sentence_proto",
+        ":task_context",
+        ":task_spec_proto",
+        ":text_formats",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "lexicon_builder",
+    srcs = ["lexicon_builder.cc"],
+    deps = [
+        ":document_format",
+        ":parser_features",
+        ":parser_transitions",
+        ":sentence_batch",
+        ":sentence_proto",
+        ":task_context",
+        ":task_spec_proto",
+        ":text_formats",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "unpack_sparse_features",
+    srcs = ["unpack_sparse_features.cc"],
+    deps = [
+        ":sparse_proto",
+        ":utils",
+    ],
+)
+
+cc_library(
+    name = "parser_ops_cc",
+    srcs = ["ops/parser_ops.cc"],
+    deps = [
+        ":base",
+        ":document_filters",
+        ":lexicon_builder",
+        ":reader_ops",
+        ":unpack_sparse_features",
+    ],
+    alwayslink = 1,
+)
+
+cc_binary(
+    name = "parser_ops.so",
+    linkopts = select({
+        "//conditions:default": ["-lm"],
+        "@tf//tensorflow:darwin": [],
+    }),
+    linkshared = 1,
+    linkstatic = 1,
+    deps = [
+        ":parser_ops_cc",
+    ],
+)
+
+# cc tests
+
+filegroup(
+    name = "testdata",
+    srcs = [
+        "testdata/context.pbtxt",
+        "testdata/document",
+        "testdata/mini-training-set",
+    ],
+)
+
+cc_test(
+    name = "shared_store_test",
+    size = "small",
+    srcs = ["shared_store_test.cc"],
+    deps = [
+        ":shared_store",
+        ":test_main",
+    ],
+)
+
+cc_test(
+    name = "sentence_features_test",
+    size = "medium",
+    srcs = ["sentence_features_test.cc"],
+    deps = [
+        ":feature_extractor",
+        ":populate_test_inputs",
+        ":sentence_features",
+        ":sentence_proto",
+        ":task_context",
+        ":task_spec_proto",
+        ":term_frequency_map",
+        ":test_main",
+        ":workspace",
+    ],
+)
+
+cc_test(
+    name = "arc_standard_transitions_test",
+    size = "small",
+    srcs = ["arc_standard_transitions_test.cc"],
+    data = [":testdata"],
+    deps = [
+        ":parser_transitions",
+        ":populate_test_inputs",
+        ":test_main",
+    ],
+)
+
+cc_test(
+    name = "tagger_transitions_test",
+    size = "small",
+    srcs = ["tagger_transitions_test.cc"],
+    data = [":testdata"],
+    deps = [
+        ":parser_transitions",
+        ":populate_test_inputs",
+        ":test_main",
+    ],
+)
+
+cc_test(
+    name = "parser_features_test",
+    size = "small",
+    srcs = ["parser_features_test.cc"],
+    deps = [
+        ":feature_extractor",
+        ":parser_features",
+        ":parser_transitions",
+        ":populate_test_inputs",
+        ":sentence_proto",
+        ":task_context",
+        ":task_spec_proto",
+        ":term_frequency_map",
+        ":test_main",
+        ":workspace",
+    ],
+)
+
+# py graph builder and trainer
+
+tf_gen_op_libs(
+    op_lib_names = ["parser_ops"],
+)
+
+tf_gen_op_wrapper_py(
+    name = "parser_ops",
+    deps = [":parser_ops_op_lib"],
+)
+
+py_library(
+    name = "load_parser_ops_py",
+    srcs = ["load_parser_ops.py"],
+    data = [":parser_ops.so"],
+)
+
+py_library(
+    name = "graph_builder",
+    srcs = ["graph_builder.py"],
+    deps = [
+        "@tf//tensorflow:tensorflow_py",
+        "@tf//tensorflow/core:protos_all_py",
+        ":load_parser_ops_py",
+        ":parser_ops",
+    ],
+)
+
+py_library(
+    name = "structured_graph_builder",
+    srcs = ["structured_graph_builder.py"],
+    deps = [
+        ":graph_builder",
+    ],
+)
+
+py_binary(
+    name = "parser_trainer",
+    srcs = ["parser_trainer.py"],
+    deps = [
+        ":graph_builder",
+        ":structured_graph_builder",
+        ":task_spec_py_pb2",
+    ],
+)
+
+py_binary(
+    name = "parser_eval",
+    srcs = ["parser_eval.py"],
+    deps = [
+        ":graph_builder",
+        ":sentence_py_pb2",
+        ":structured_graph_builder",
+    ],
+)
+
+py_binary(
+    name = "conll2tree",
+    srcs = ["conll2tree.py"],
+    deps = [
+        ":graph_builder",
+        ":sentence_py_pb2",
+    ],
+)
+
+# py tests
+
+py_test(
+    name = "lexicon_builder_test",
+    size = "small",
+    srcs = ["lexicon_builder_test.py"],
+    deps = [
+        ":graph_builder",
+        ":sentence_py_pb2",
+        ":task_spec_py_pb2",
+    ],
+)
+
+py_test(
+    name = "text_formats_test",
+    size = "small",
+    srcs = ["text_formats_test.py"],
+    deps = [
+        ":graph_builder",
+        ":sentence_py_pb2",
+        ":task_spec_py_pb2",
+    ],
+)
+
+py_test(
+    name = "reader_ops_test",
+    size = "medium",
+    srcs = ["reader_ops_test.py"],
+    data = [":testdata"],
+    tags = ["notsan"],
+    deps = [
+        ":dictionary_py_pb2",
+        ":graph_builder",
+        ":sparse_py_pb2",
+    ],
+)
+
+py_test(
+    name = "beam_reader_ops_test",
+    size = "medium",
+    srcs = ["beam_reader_ops_test.py"],
+    data = [":testdata"],
+    tags = ["notsan"],
+    deps = [
+        ":structured_graph_builder",
+    ],
+)
+
+py_test(
+    name = "graph_builder_test",
+    size = "medium",
+    srcs = ["graph_builder_test.py"],
+    data = [
+        ":testdata",
+    ],
+    tags = ["notsan"],
+    deps = [
+        ":graph_builder",
+        ":sparse_py_pb2",
+    ],
+)
+
+sh_test(
+    name = "parser_trainer_test",
+    size = "medium",
+    srcs = ["parser_trainer_test.sh"],
+    data = [
+        ":parser_eval",
+        ":parser_trainer",
+        ":testdata",
+    ],
+    tags = ["notsan"],
+)

+ 263 - 0
syntaxnet/syntaxnet/affix.cc

@@ -0,0 +1,263 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/affix.h"
+
+#include <ctype.h>
+#include <string.h>
+#include <functional>
+#include <string>
+
+#include "syntaxnet/shared_store.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "syntaxnet/utils.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/regexp.h"
+#include "util/utf8/unicodetext.h"
+
+namespace syntaxnet {
+
+// Initial number of buckets in term and affix hash maps. This must be a power
+// of two.
+static const int kInitialBuckets = 1024;
+
+// Fill factor for term and affix hash maps.
+static const int kFillFactor = 2;
+
+int TermHash(string term) {
+  return utils::Hash32(term.data(), term.size(), 0xDECAF);
+}
+
+// Copies a substring of a Unicode text to a string.
+static void UnicodeSubstring(UnicodeText::const_iterator start,
+                             UnicodeText::const_iterator end, string *result) {
+  result->clear();
+  result->append(start.utf8_data(), end.utf8_data() - start.utf8_data());
+}
+
+AffixTable::AffixTable(Type type, int max_length) {
+  type_ = type;
+  max_length_ = max_length;
+  Resize(0);
+}
+
+AffixTable::~AffixTable() { Reset(0); }
+
+void AffixTable::Reset(int max_length) {
+  // Save new maximum affix length.
+  max_length_ = max_length;
+
+  // Delete all data.
+  for (size_t i = 0; i < affixes_.size(); ++i) delete affixes_[i];
+  affixes_.clear();
+  buckets_.clear();
+  Resize(0);
+}
+
+void AffixTable::Read(const AffixTableEntry &table_entry) {
+  CHECK_EQ(table_entry.type(), type_ == PREFIX ? "PREFIX" : "SUFFIX");
+  CHECK_GE(table_entry.max_length(), 0);
+  Reset(table_entry.max_length());
+
+  // First, create all affixes.
+  for (int affix_id = 0; affix_id < table_entry.affix_size(); ++affix_id) {
+    const auto &affix_entry = table_entry.affix(affix_id);
+    CHECK_GE(affix_entry.length(), 0);
+    CHECK_LE(affix_entry.length(), max_length_);
+    CHECK(FindAffix(affix_entry.form()) == NULL);  // forbid duplicates
+    Affix *affix = AddNewAffix(affix_entry.form(), affix_entry.length());
+    CHECK_EQ(affix->id(), affix_id);
+  }
+  CHECK_EQ(affixes_.size(), table_entry.affix_size());
+
+  // Next, link the shorter affixes.
+  for (int affix_id = 0; affix_id < table_entry.affix_size(); ++affix_id) {
+    const auto &affix_entry = table_entry.affix(affix_id);
+    if (affix_entry.shorter_id() == -1) {
+      CHECK_EQ(affix_entry.length(), 1);
+      continue;
+    }
+    CHECK_GT(affix_entry.length(), 1);
+    CHECK_GE(affix_entry.shorter_id(), 0);
+    CHECK_LT(affix_entry.shorter_id(), affixes_.size());
+    Affix *affix = affixes_[affix_id];
+    Affix *shorter = affixes_[affix_entry.shorter_id()];
+    CHECK_EQ(affix->length(), shorter->length() + 1);
+    affix->set_shorter(shorter);
+  }
+}
+
+void AffixTable::Read(ProtoRecordReader *reader) {
+  AffixTableEntry table_entry;
+  TF_CHECK_OK(reader->Read(&table_entry));
+  Read(table_entry);
+}
+
+void AffixTable::Write(AffixTableEntry *table_entry) const {
+  table_entry->Clear();
+  table_entry->set_type(type_ == PREFIX ? "PREFIX" : "SUFFIX");
+  table_entry->set_max_length(max_length_);
+  for (const Affix *affix : affixes_) {
+    auto *affix_entry = table_entry->add_affix();
+    affix_entry->set_form(affix->form());
+    affix_entry->set_length(affix->length());
+    affix_entry->set_shorter_id(
+        affix->shorter() == NULL ? -1 : affix->shorter()->id());
+  }
+}
+
+void AffixTable::Write(ProtoRecordWriter *writer) const {
+  AffixTableEntry table_entry;
+  Write(&table_entry);
+  writer->Write(table_entry);
+}
+
+Affix *AffixTable::AddAffixesForWord(const char *word, size_t size) {
+  // The affix length is measured in characters and not bytes so we need to
+  // determine the length in characters.
+  UnicodeText text;
+  text.PointToUTF8(word, size);
+  int length = text.size();
+
+  // Determine longest affix.
+  int affix_len = length;
+  if (affix_len > max_length_) affix_len = max_length_;
+  if (affix_len == 0) return NULL;
+
+  // Find start and end of longest affix.
+  UnicodeText::const_iterator start, end;
+  if (type_ == PREFIX) {
+    start = end = text.begin();
+    for (int i = 0; i < affix_len; ++i) ++end;
+  } else {
+    start = end = text.end();
+    for (int i = 0; i < affix_len; ++i) --start;
+  }
+
+  // Try to find successively shorter affixes.
+  Affix *top = NULL;
+  Affix *ancestor = NULL;
+  string s;
+  while (affix_len > 0) {
+    // Try to find affix in table.
+    UnicodeSubstring(start, end, &s);
+    Affix *affix = FindAffix(s);
+    if (affix == NULL) {
+      // Affix not found, add new one to table.
+      affix = AddNewAffix(s, affix_len);
+
+      // Update ancestor chain.
+      if (ancestor != NULL) ancestor->set_shorter(affix);
+      ancestor = affix;
+      if (top == NULL) top = affix;
+    } else {
+      // Affix found. Update ancestor if needed and return match.
+      if (ancestor != NULL) ancestor->set_shorter(affix);
+      if (top == NULL) top = affix;
+      break;
+    }
+
+    // Next affix.
+    if (type_ == PREFIX) {
+      --end;
+    } else {
+      ++start;
+    }
+
+    affix_len--;
+  }
+
+  return top;
+}
+
+Affix *AffixTable::GetAffix(int id) const {
+  if (id < 0 || id >= static_cast<int>(affixes_.size())) {
+    return NULL;
+  } else {
+    return affixes_[id];
+  }
+}
+
+string AffixTable::AffixForm(int id) const {
+  Affix *affix = GetAffix(id);
+  if (affix == NULL) {
+    return "";
+  } else {
+    return affix->form();
+  }
+}
+
+int AffixTable::AffixId(const string &form) const {
+  Affix *affix = FindAffix(form);
+  if (affix == NULL) {
+    return -1;
+  } else {
+    return affix->id();
+  }
+}
+
+Affix *AffixTable::AddNewAffix(const string &form, int length) {
+  int hash = TermHash(form);
+  int id = affixes_.size();
+  if (id > static_cast<int>(buckets_.size()) * kFillFactor) Resize(id);
+  int b = hash & (buckets_.size() - 1);
+
+  // Create new affix object.
+  Affix *affix = new Affix(id, form.c_str(), length);
+  affixes_.push_back(affix);
+
+  // Insert affix in bucket chain.
+  affix->next_ = buckets_[b];
+  buckets_[b] = affix;
+
+  return affix;
+}
+
+Affix *AffixTable::FindAffix(const string &form) const {
+  // Compute hash value for word.
+  int hash = TermHash(form);
+
+  // Try to find affix in hash table.
+  Affix *affix = buckets_[hash & (buckets_.size() - 1)];
+  while (affix != NULL) {
+    if (strcmp(affix->form_.c_str(), form.c_str()) == 0) return affix;
+    affix = affix->next_;
+  }
+  return NULL;
+}
+
+void AffixTable::Resize(int size_hint) {
+  // Compute new size for bucket array.
+  int new_size = kInitialBuckets;
+  while (new_size < size_hint) new_size *= 2;
+  int mask = new_size - 1;
+
+  // Distribute affixes in new buckets.
+  buckets_.resize(new_size);
+  for (size_t i = 0; i < buckets_.size(); ++i) {
+    buckets_[i] = NULL;
+  }
+  for (size_t i = 0; i < affixes_.size(); ++i) {
+    Affix *affix = affixes_[i];
+    int b = TermHash(affix->form_) & mask;
+    affix->next_ = buckets_[b];
+    buckets_[b] = affix;
+  }
+}
+
+}  // namespace syntaxnet

+ 155 - 0
syntaxnet/syntaxnet/affix.h

@@ -0,0 +1,155 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_AFFIX_H_
+#define $TARGETDIR_AFFIX_H_
+
+#include <stddef.h>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/dictionary.pb.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/proto_io.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+// An affix represents a prefix or suffix of a word of a certain length. Each
+// affix has a unique id and a textual form. An affix also has a pointer to the
+// affix that is one character shorter. This creates a chain of affixes that are
+// successively shorter.
+class Affix {
+ private:
+  friend class AffixTable;
+  Affix(int id, const char *form, int length)
+      : id_(id), length_(length), form_(form), shorter_(NULL), next_(NULL) {}
+
+ public:
+  // Returns unique id of affix.
+  int id() const { return id_; }
+
+  // Returns the textual representation of the affix.
+  string form() const { return form_; }
+
+  // Returns the length of the affix.
+  int length() const { return length_; }
+
+  // Gets/sets the affix that is one character shorter.
+  Affix *shorter() const { return shorter_; }
+  void set_shorter(Affix *next) { shorter_ = next; }
+
+ private:
+  // Affix id.
+  int id_;
+
+  // Length (in characters) of affix.
+  int length_;
+
+  // Text form of affix.
+  string form_;
+
+  // Pointer to affix that is one character shorter.
+  Affix *shorter_;
+
+  // Next affix in bucket chain.
+  Affix *next_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(Affix);
+};
+
+// An affix table holds all prefixes/suffixes of all the words added to the
+// table up to a maximum length. The affixes are chained together to enable
+// fast lookup of all affixes for a word.
+class AffixTable {
+ public:
+  // Affix table type.
+  enum Type { PREFIX, SUFFIX };
+
+  AffixTable(Type type, int max_length);
+  ~AffixTable();
+
+  // Resets the affix table and initialize the table for affixes of up to the
+  // maximum length specified.
+  void Reset(int max_length);
+
+  // De-serializes this from the given proto.
+  void Read(const AffixTableEntry &table_entry);
+
+  // De-serializes this from the given records.
+  void Read(ProtoRecordReader *reader);
+
+  // Serializes this to the given proto.
+  void Write(AffixTableEntry *table_entry) const;
+
+  // Serializes this to the given records.
+  void Write(ProtoRecordWriter *writer) const;
+
+  // Adds all prefixes/suffixes of the word up to the maximum length to the
+  // table. The longest affix is returned. The pointers in the affix can be
+  // used for getting shorter affixes.
+  Affix *AddAffixesForWord(const char *word, size_t size);
+
+  // Gets the affix information for the affix with a certain id. Returns NULL if
+  // there is no affix in the table with this id.
+  Affix *GetAffix(int id) const;
+
+  // Gets affix form from id. If the affix does not exist in the table, an empty
+  // string is returned.
+  string AffixForm(int id) const;
+
+  // Gets affix id for affix. If the affix does not exist in the table, -1 is
+  // returned.
+  int AffixId(const string &form) const;
+
+  // Returns size of the affix table.
+  int size() const { return affixes_.size(); }
+
+  // Returns the maximum affix length.
+  int max_length() const { return max_length_; }
+
+ private:
+  // Adds a new affix to table.
+  Affix *AddNewAffix(const string &form, int length);
+
+  // Finds existing affix in table.
+  Affix *FindAffix(const string &form) const;
+
+  // Resizes bucket array.
+  void Resize(int size_hint);
+
+  // Affix type (prefix or suffix).
+  Type type_;
+
+  // Maximum length of affix.
+  int max_length_;
+
+  // Index from affix ids to affix items.
+  vector<Affix *> affixes_;
+
+  // Buckets for word-to-affix hash map.
+  vector<Affix *> buckets_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(AffixTable);
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_AFFIX_H_

+ 305 - 0
syntaxnet/syntaxnet/arc_standard_transitions.cc

@@ -0,0 +1,305 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Arc-standard transition system.
+//
+// This transition system has three types of actions:
+//  - The SHIFT action pushes the next input token to the stack and
+//    advances to the next input token.
+//  - The LEFT_ARC action adds a dependency relation from first to second token
+//    on the stack and removes second one.
+//  - The RIGHT_ARC action adds a dependency relation from second to first token
+//    on the stack and removes the first one.
+//
+// The transition system operates with parser actions encoded as integers:
+//  - A SHIFT action is encoded as 0.
+//  - A LEFT_ARC action is encoded as an odd number starting from 1.
+//  - A RIGHT_ARC action is encoded as an even number starting from 2.
+
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+class ArcStandardTransitionState : public ParserTransitionState {
+ public:
+  // Clones the transition state by returning a new object.
+  ParserTransitionState *Clone() const override {
+    return new ArcStandardTransitionState();
+  }
+
+  // Pushes the root on the stack before using the parser state in parsing.
+  void Init(ParserState *state) override { state->Push(-1); }
+
+  // Adds transition state specific annotations to the document.
+  void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
+                          Sentence *sentence) const override {
+    for (int i = 0; i < state.NumTokens(); ++i) {
+      Token *token = sentence->mutable_token(i);
+      token->set_label(state.LabelAsString(state.Label(i)));
+      if (state.Head(i) != -1) {
+        token->set_head(state.Head(i));
+      } else {
+        token->clear_head();
+        if (rewrite_root_labels) {
+          token->set_label(state.LabelAsString(state.RootLabel()));
+        }
+      }
+    }
+  }
+
+  // Whether a parsed token should be considered correct for evaluation.
+  bool IsTokenCorrect(const ParserState &state, int index) const override {
+    return state.GoldHead(index) == state.Head(index);
+  }
+
+  // Returns a human readable string representation of this state.
+  string ToString(const ParserState &state) const override {
+    string str;
+    str.append("[");
+    for (int i = state.StackSize() - 1; i >= 0; --i) {
+      const string &word = state.GetToken(state.Stack(i)).word();
+      if (i != state.StackSize() - 1) str.append(" ");
+      if (word == "") {
+        str.append(ParserState::kRootLabel);
+      } else {
+        str.append(word);
+      }
+    }
+    str.append("]");
+    for (int i = state.Next(); i < state.NumTokens(); ++i) {
+      tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
+    }
+    return str;
+  }
+};
+
+class ArcStandardTransitionSystem : public ParserTransitionSystem {
+ public:
+  // Action types for the arc-standard transition system.
+  enum ParserActionType {
+    SHIFT = 0,
+    LEFT_ARC = 1,
+    RIGHT_ARC = 2,
+  };
+
+  // The SHIFT action uses the same value as the corresponding action type.
+  static ParserAction ShiftAction() { return SHIFT; }
+
+  // The LEFT_ARC action converts the label to an odd number greater or equal
+  // to 1.
+  static ParserAction LeftArcAction(int label) { return 1 + (label << 1); }
+
+  // The RIGHT_ARC action converts the label to an even number greater or equal
+  // to 2.
+  static ParserAction RightArcAction(int label) {
+    return 1 + ((label << 1) | 1);
+  }
+
+  // Extracts the action type from a given parser action.
+  static ParserActionType ActionType(ParserAction action) {
+    return static_cast<ParserActionType>(action < 1 ? action
+                                                    : 1 + (~action & 1));
+  }
+
+  // Extracts the label from a given parser action. If the action is SHIFT,
+  // returns -1.
+  static int Label(ParserAction action) {
+    return action < 1 ? -1 : (action - 1) >> 1;
+  }
+
+  // Returns the number of action types.
+  int NumActionTypes() const override { return 3; }
+
+  // Returns the number of possible actions.
+  int NumActions(int num_labels) const override { return 1 + 2 * num_labels; }
+
+  // The method returns the default action for a given state.
+  ParserAction GetDefaultAction(const ParserState &state) const override {
+    // If there are further tokens available in the input then Shift.
+    if (!state.EndOfInput()) return ShiftAction();
+
+    // Do a "reduce".
+    return RightArcAction(2);
+  }
+
+  // Returns the next gold action for a given state according to the
+  // underlying annotated sentence.
+  ParserAction GetNextGoldAction(const ParserState &state) const override {
+    // If the stack contains less than 2 tokens, the only valid parser action is
+    // shift.
+    if (state.StackSize() < 2) {
+      DCHECK(!state.EndOfInput());
+      return ShiftAction();
+    }
+
+    // If the second token on the stack is the head of the first one,
+    // return a right arc action.
+    if (state.GoldHead(state.Stack(0)) == state.Stack(1) &&
+        DoneChildrenRightOf(state, state.Stack(0))) {
+      const int gold_label = state.GoldLabel(state.Stack(0));
+      return RightArcAction(gold_label);
+    }
+
+    // If the first token on the stack is the head of the second one,
+    // return a left arc action.
+    if (state.GoldHead(state.Stack(1)) == state.Top()) {
+      const int gold_label = state.GoldLabel(state.Stack(1));
+      return LeftArcAction(gold_label);
+    }
+
+    // Otherwise, shift.
+    return ShiftAction();
+  }
+
+  // Determines if a token has any children to the right in the sentence.
+  // Arc standard is a bottom-up parsing method and has to finish all sub-trees
+  // first.
+  static bool DoneChildrenRightOf(const ParserState &state, int head) {
+    int index = state.Next();
+    int num_tokens = state.sentence().token_size();
+    while (index < num_tokens) {
+      // Check if the token at index is the child of head.
+      int actual_head = state.GoldHead(index);
+      if (actual_head == head) return false;
+
+      // If the head of the token at index is to the right of it there cannot be
+      // any children in-between, so we can skip forward to the head.  Note this
+      // is only true for projective trees.
+      if (actual_head > index) {
+        index = actual_head;
+      } else {
+        ++index;
+      }
+    }
+    return true;
+  }
+
+  // Checks if the action is allowed in a given parser state.
+  bool IsAllowedAction(ParserAction action,
+                       const ParserState &state) const override {
+    switch (ActionType(action)) {
+      case SHIFT:
+        return IsAllowedShift(state);
+      case LEFT_ARC:
+        return IsAllowedLeftArc(state);
+      case RIGHT_ARC:
+        return IsAllowedRightArc(state);
+    }
+
+    return false;
+  }
+
+  // Returns true if a shift is allowed in the given parser state.
+  bool IsAllowedShift(const ParserState &state) const {
+    // We can shift if there are more input tokens.
+    return !state.EndOfInput();
+  }
+
+  // Returns true if a left-arc is allowed in the given parser state.
+  bool IsAllowedLeftArc(const ParserState &state) const {
+    // Left-arc requires two or more tokens on the stack but the first token
+    // is the root an we do not want and left arc to the root.
+    return state.StackSize() > 2;
+  }
+
+  // Returns true if a right-arc is allowed in the given parser state.
+  bool IsAllowedRightArc(const ParserState &state) const {
+    // Right arc requires three or more tokens on the stack.
+    return state.StackSize() > 1;
+  }
+
+  // Performs the specified action on a given parser state, without adding the
+  // action to the state's history.
+  void PerformActionWithoutHistory(ParserAction action,
+                                   ParserState *state) const override {
+    switch (ActionType(action)) {
+      case SHIFT:
+        PerformShift(state);
+        break;
+      case LEFT_ARC:
+        PerformLeftArc(state, Label(action));
+        break;
+      case RIGHT_ARC:
+        PerformRightArc(state, Label(action));
+        break;
+    }
+  }
+
+  // Makes a shift by pushing the next input token on the stack and moving to
+  // the next position.
+  void PerformShift(ParserState *state) const {
+    DCHECK(IsAllowedShift(*state));
+    state->Push(state->Next());
+    state->Advance();
+  }
+
+  // Makes a left-arc between the two top tokens on stack and pops the second
+  // token on stack.
+  void PerformLeftArc(ParserState *state, int label) const {
+    DCHECK(IsAllowedLeftArc(*state));
+    int s0 = state->Pop();
+    state->AddArc(state->Pop(), s0, label);
+    state->Push(s0);
+  }
+
+  // Makes a right-arc between the two top tokens on stack and pops the stack.
+  void PerformRightArc(ParserState *state, int label) const {
+    DCHECK(IsAllowedRightArc(*state));
+    int s0 = state->Pop();
+    int s1 = state->Pop();
+    state->AddArc(s0, s1, label);
+    state->Push(s1);
+  }
+
+  // We are in a deterministic state when we either reached the end of the input
+  // or reduced everything from the stack.
+  bool IsDeterministicState(const ParserState &state) const override {
+    return state.StackSize() < 2 && !state.EndOfInput();
+  }
+
+  // We are in a final state when we reached the end of the input and the stack
+  // is empty.
+  bool IsFinalState(const ParserState &state) const override {
+    return state.EndOfInput() && state.StackSize() < 2;
+  }
+
+  // Returns a string representation of a parser action.
+  string ActionAsString(ParserAction action,
+                        const ParserState &state) const override {
+    switch (ActionType(action)) {
+      case SHIFT:
+        return "SHIFT";
+      case LEFT_ARC:
+        return "LEFT_ARC(" + state.LabelAsString(Label(action)) + ")";
+      case RIGHT_ARC:
+        return "RIGHT_ARC(" + state.LabelAsString(Label(action)) + ")";
+    }
+    return "UNKNOWN";
+  }
+
+  // Returns a new transition state to be used to enhance the parser state.
+  ParserTransitionState *NewTransitionState(bool training_mode) const override {
+    return new ArcStandardTransitionState();
+  }
+};
+
+REGISTER_TRANSITION_SYSTEM("arc-standard", ArcStandardTransitionSystem);
+
+}  // namespace syntaxnet

+ 117 - 0
syntaxnet/syntaxnet/arc_standard_transitions_test.cc

@@ -0,0 +1,117 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <gmock/gmock.h>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/populate_test_inputs.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+
+class ArcStandardTransitionTest : public ::testing::Test {
+ public:
+  ArcStandardTransitionTest()
+      : transition_system_(ParserTransitionSystem::Create("arc-standard")) {}
+
+ protected:
+  // Creates a label map and a tag map for testing based on the given
+  // document and initializes the transition system appropriately.
+  void SetUpForDocument(const Sentence &document) {
+    input_label_map_ = context_.GetInput("label-map", "text", "");
+    transition_system_->Setup(&context_);
+    PopulateTestInputs::Defaults(document).Populate(&context_);
+    label_map_.Load(TaskContext::InputFile(*input_label_map_),
+                    0 /* minimum frequency */,
+                    -1 /* maximum number of terms */);
+    transition_system_->Init(&context_);
+  }
+
+  // Creates a cloned state from a sentence in order to test that cloning
+  // works correctly for the new parser states.
+  ParserState *NewClonedState(Sentence *sentence) {
+    ParserState state(sentence, transition_system_->NewTransitionState(
+                                    true /* training mode */),
+                      &label_map_);
+    return state.Clone();
+  }
+
+  // Performs gold transitions and check that the labels and heads recorded
+  // in the parser state match gold heads and labels.
+  void GoldParse(Sentence *sentence) {
+    ParserState *state = NewClonedState(sentence);
+    LOG(INFO) << "Initial parser state: " << state->ToString();
+    while (!transition_system_->IsFinalState(*state)) {
+      ParserAction action = transition_system_->GetNextGoldAction(*state);
+      EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
+      LOG(INFO) << "Performing action: "
+                << transition_system_->ActionAsString(action, *state);
+      transition_system_->PerformActionWithoutHistory(action, state);
+      LOG(INFO) << "Parser state: " << state->ToString();
+    }
+    for (int i = 0; i < sentence->token_size(); ++i) {
+      EXPECT_EQ(state->GoldLabel(i), state->Label(i));
+      EXPECT_EQ(state->GoldHead(i), state->Head(i));
+    }
+    delete state;
+  }
+
+  // Always takes the default action, and verifies that this leads to
+  // a final state through a sequence of allowed actions.
+  void DefaultParse(Sentence *sentence) {
+    ParserState *state = NewClonedState(sentence);
+    LOG(INFO) << "Initial parser state: " << state->ToString();
+    while (!transition_system_->IsFinalState(*state)) {
+      ParserAction action = transition_system_->GetDefaultAction(*state);
+      EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
+      LOG(INFO) << "Performing action: "
+                << transition_system_->ActionAsString(action, *state);
+      transition_system_->PerformActionWithoutHistory(action, state);
+      LOG(INFO) << "Parser state: " << state->ToString();
+    }
+    delete state;
+  }
+
+  TaskContext context_;
+  TaskInput *input_label_map_ = nullptr;
+  TermFrequencyMap label_map_;
+  std::unique_ptr<ParserTransitionSystem> transition_system_;
+};
+
+TEST_F(ArcStandardTransitionTest, SingleSentenceDocumentTest) {
+  string document_text;
+  Sentence document;
+  TF_CHECK_OK(ReadFileToString(
+      tensorflow::Env::Default(),
+      "syntaxnet/testdata/document",
+      &document_text));
+  LOG(INFO) << "see doc\n:" << document_text;
+  CHECK(TextFormat::ParseFromString(document_text, &document));
+  SetUpForDocument(document);
+  GoldParse(&document);
+  DefaultParse(&document);
+}
+
+}  // namespace syntaxnet

+ 53 - 0
syntaxnet/syntaxnet/base.h

@@ -0,0 +1,53 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_BASE_H_
+#define $TARGETDIR_BASE_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include <unordered_map>
+#include <unordered_set>
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/default/integral_types.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/protobuf.h"
+
+
+
+using tensorflow::int32;
+using tensorflow::int64;
+using tensorflow::uint64;
+using tensorflow::uint32;
+using tensorflow::uint32;
+using tensorflow::protobuf::TextFormat;
+using tensorflow::mutex_lock;
+using tensorflow::mutex;
+using std::map;
+using std::pair;
+using std::vector;
+using std::unordered_map;
+using std::unordered_set;
+typedef signed int char32;
+
+using tensorflow::StringPiece;
+using std::string;
+
+  // namespace syntaxnet
+
+#endif  // $TARGETDIR_BASE_H_

+ 893 - 0
syntaxnet/syntaxnet/beam_reader_ops.cc

@@ -0,0 +1,893 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <algorithm>
+#include <deque>
+#include <map>
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "syntaxnet/base.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/sentence_batch.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/shared_store.h"
+#include "syntaxnet/sparse.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/platform/env.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_STRING;
+using tensorflow::DataType;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::TTypes;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::errors::FailedPrecondition;
+using tensorflow::errors::InvalidArgument;
+
+namespace syntaxnet {
+
+// Wraps ParserState so that the history of transitions (actions
+// performed and the beam slot they were performed in) are recorded.
+struct ParserStateWithHistory {
+ public:
+  // New state with an empty history.
+  explicit ParserStateWithHistory(const ParserState &s) : state(s.Clone()) {}
+
+  // New state obtained by cloning the given state and applying the given
+  // action. The given beam slot and action are appended to the history.
+  ParserStateWithHistory(const ParserStateWithHistory &next,
+                         const ParserTransitionSystem &transitions, int32 slot,
+                         int32 action, float score)
+      : state(next.state->Clone()),
+        slot_history(next.slot_history),
+        action_history(next.action_history),
+        score_history(next.score_history) {
+    transitions.PerformAction(action, state.get());
+    slot_history.push_back(slot);
+    action_history.push_back(action);
+    score_history.push_back(score);
+  }
+
+  std::unique_ptr<ParserState> state;
+  std::vector<int32> slot_history;
+  std::vector<int32> action_history;
+  std::vector<float> score_history;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(ParserStateWithHistory);
+};
+
+struct BatchStateOptions {
+  // Maximum number of parser states in a beam.
+  int max_beam_size;
+
+  // Number of parallel sentences to decode.
+  int batch_size;
+
+  // Argument prefix for context parameters.
+  string arg_prefix;
+
+  // Corpus name to read from from context inputs.
+  string corpus_name;
+
+  // Whether we allow weights in SparseFeatures protos.
+  bool allow_feature_weights;
+
+  // Whether beams should be considered alive until all states are final, or
+  // until the gold path falls off.
+  bool continue_until_all_final;
+
+  // Whether to skip to a new sentence after each training step.
+  bool always_start_new_sentences;
+
+  // Parameter for deciding which tokens to score.
+  string scoring_type;
+};
+
+// Encapsulates the environment needed to parse with a beam, keeping a
+// record of path histories.
+class BeamState {
+ public:
+  // The agenda is keyed by a tuple that is the score followed by an
+  // int that is -1 if the path coincides with the gold path and 0
+  // otherwise. The lexicographic ordering of the keys therefore
+  // ensures that for all paths sharing the same score, the gold path
+  // will always be at the bottom. This situation can occur at the
+  // onset of training when all weights are zero and therefore all
+  // paths have an identically zero score.
+  typedef std::pair<double, int> KeyType;
+  typedef std::multimap<KeyType, std::unique_ptr<ParserStateWithHistory>>
+      AgendaType;
+  typedef std::pair<const KeyType, std::unique_ptr<ParserStateWithHistory>>
+      AgendaItem;
+  typedef Eigen::Tensor<float, 2, Eigen::RowMajor, Eigen::DenseIndex>
+      ScoreMatrixType;
+
+  // The beam can be
+  //   - ALIVE: parsing is still active, features are being output for at least
+  //     some slots in the beam.
+  //   - DYING: features should be output for this beam only one more time, then
+  //     the beam will be DEAD. This state is reached when the gold path falls
+  //     out of the beam and features have to be output one last time.
+  //   - DEAD: parsing is not active, features are not being output and the no
+  //     actions are taken on the states.
+  enum State { ALIVE = 0, DYING = 1, DEAD = 2 };
+
+  explicit BeamState(const BatchStateOptions &options) : options_(options) {}
+
+  void Reset() {
+    if (options_.always_start_new_sentences ||
+        gold_ == nullptr || transition_system_->IsFinalState(*gold_)) {
+      AdvanceSentence();
+    }
+    slots_.clear();
+    if (gold_ == nullptr) {
+      state_ = DEAD;  // EOF has been reached.
+    } else {
+      gold_->set_is_gold(true);
+      slots_.emplace(KeyType(0.0, -1), std::unique_ptr<ParserStateWithHistory>(
+          new ParserStateWithHistory(*gold_)));
+      state_ = ALIVE;
+    }
+  }
+
+  void UpdateAllFinal() {
+    all_final_ = true;
+    for (const AgendaItem &item : slots_) {
+      if (!transition_system_->IsFinalState(*item.second->state)) {
+        all_final_ = false;
+        break;
+      }
+    }
+    if (all_final_) {
+      state_ = DEAD;
+    }
+  }
+
+  // This method updates the beam. For all elements of the beam, all
+  // allowed transitions are scored and insterted into a new beam. The
+  // beam size is capped by discarding the lowest scoring slots at any
+  // given time. There is one exception to this process: the gold path
+  // is forced to remain in the beam at all times, even if it scores
+  // low. This is to ensure that the gold path can be used for
+  // training at the moment it would otherwise fall off (and be absent
+  // from) the beam.
+  void Advance(const ScoreMatrixType &scores) {
+    // If the beam was in the state of DYING, it is now DEAD.
+    if (state_ == DYING) state_ = DEAD;
+
+    // When to stop advancing depends on the 'continue_until_all_final' arg.
+    if (!IsAlive() || gold_ == nullptr) return;
+
+    AdvanceGold();
+
+    const int score_rows = scores.dimension(0);
+    const int num_actions = scores.dimension(1);
+
+    // Advance beam.
+    AgendaType previous_slots;
+    previous_slots.swap(slots_);
+
+    CHECK_EQ(state_, ALIVE);
+
+    int slot = 0;
+    for (AgendaItem &item : previous_slots) {
+      {
+        ParserState *current = item.second->state.get();
+        VLOG(2) << "Slot: " << slot;
+        VLOG(2) << "Parser state: " << current->ToString();
+        VLOG(2) << "Parser state cumulative score: " << item.first.first << " "
+                << (item.first.second < 0 ? "golden" : "");
+      }
+      if (!transition_system_->IsFinalState(*item.second->state)) {
+        // Not a final state.
+        for (int action = 0; action < num_actions; ++action) {
+          // Is action allowed?
+          if (!transition_system_->IsAllowedAction(action,
+                                                   *item.second->state)) {
+            continue;
+          }
+          CHECK_LT(slot, score_rows);
+          MaybeInsertWithNewAction(item, slot, scores(slot, action), action);
+          PruneBeam();
+        }
+      } else {
+        // Final state: no need to advance.
+        MaybeInsert(&item);
+        PruneBeam();
+      }
+      ++slot;
+    }
+    UpdateAllFinal();
+  }
+
+  void PopulateFeatureOutputs(
+      std::vector<std::vector<std::vector<SparseFeatures>>> *features) {
+    for (const AgendaItem &item : slots_) {
+      VLOG(2) << "State: " << item.second->state->ToString();
+      std::vector<std::vector<SparseFeatures>> f =
+          features_->ExtractSparseFeatures(*workspace_, *item.second->state);
+      for (size_t i = 0; i < f.size(); ++i) (*features)[i].push_back(f[i]);
+    }
+  }
+
+  int BeamSize() const { return slots_.size(); }
+
+  bool IsAlive() const { return state_ == ALIVE; }
+
+  bool IsDead() const { return state_ == DEAD; }
+
+  bool AllFinal() const { return all_final_; }
+
+  // The current contents of the beam.
+  AgendaType slots_;
+
+  // Which batch this refers to.
+  int beam_id_ = 0;
+
+  // Sentence batch reader.
+  SentenceBatch *sentence_batch_ = nullptr;
+
+  // Label map.
+  const TermFrequencyMap *label_map_ = nullptr;
+
+  // Transition system.
+  const ParserTransitionSystem *transition_system_ = nullptr;
+
+  // Feature extractor.
+  const ParserEmbeddingFeatureExtractor *features_ = nullptr;
+
+  // Feature workspace set.
+  WorkspaceSet *workspace_ = nullptr;
+
+  // Internal workspace registry for use in feature extraction.
+  WorkspaceRegistry *workspace_registry_ = nullptr;
+
+  // ParserState used to get gold actions.
+  std::unique_ptr<ParserState> gold_;
+
+ private:
+  // Creates a new ParserState if there's another sentence to be read.
+  void AdvanceSentence() {
+    gold_.reset();
+    if (sentence_batch_->AdvanceSentence(beam_id_)) {
+      gold_.reset(new ParserState(sentence_batch_->sentence(beam_id_),
+                                  transition_system_->NewTransitionState(true),
+                                  label_map_));
+      workspace_->Reset(*workspace_registry_);
+      features_->Preprocess(workspace_, gold_.get());
+    }
+  }
+
+  void AdvanceGold() {
+    gold_action_ = -1;
+    if (!transition_system_->IsFinalState(*gold_)) {
+      gold_action_ = transition_system_->GetNextGoldAction(*gold_);
+      if (transition_system_->IsAllowedAction(gold_action_, *gold_)) {
+        // In cases where the gold annotation is incompatible with the
+        // transition system, the action returned as gold might be not allowed.
+        transition_system_->PerformAction(gold_action_, gold_.get());
+      }
+    }
+  }
+
+  // Removes the first non-gold beam element if the beam is larger than
+  // the maximum beam size. If the gold element was at the bottom of the
+  // beam, sets the beam state to DYING, otherwise leaves the state alone.
+  void PruneBeam() {
+    if (static_cast<int>(slots_.size()) > options_.max_beam_size) {
+      auto bottom = slots_.begin();
+      if (!options_.continue_until_all_final &&
+          bottom->second->state->is_gold()) {
+        state_ = DYING;
+        ++bottom;
+      }
+      slots_.erase(bottom);
+    }
+  }
+
+  // Inserts an item in the beam if
+  //   - the item is gold,
+  //   - the beam is not full, or
+  //   - the item's new score is greater than the lowest score in the beam after
+  //     the score has been incremented by given delta_score.
+  // Inserted items have slot, delta_score and action appended to their history.
+  void MaybeInsertWithNewAction(const AgendaItem &item, const int slot,
+                                const double delta_score, const int action) {
+    const double score = item.first.first + delta_score;
+    const bool is_gold =
+        item.second->state->is_gold() && action == gold_action_;
+    if (is_gold || static_cast<int>(slots_.size()) < options_.max_beam_size ||
+        score > slots_.begin()->first.first) {
+      const KeyType key{score, -static_cast<int>(is_gold)};
+      slots_.emplace(key, std::unique_ptr<ParserStateWithHistory>(
+                              new ParserStateWithHistory(
+                                  *item.second, *transition_system_, slot,
+                                  action, delta_score)))
+          ->second->state->set_is_gold(is_gold);
+    }
+  }
+
+  // Inserts an item in the beam if
+  //   - the item is gold,
+  //   - the beam is not full, or
+  //   - the item's new score is greater than the lowest score in the beam.
+  // The history of inserted items is left untouched.
+  void MaybeInsert(AgendaItem *item) {
+    const bool is_gold = item->second->state->is_gold();
+    const double score = item->first.first;
+    if (is_gold || static_cast<int>(slots_.size()) < options_.max_beam_size ||
+        score > slots_.begin()->first.first) {
+      slots_.emplace(item->first, std::move(item->second));
+    }
+  }
+
+  // Limits the number of slots on the beam.
+  const BatchStateOptions &options_;
+
+  int gold_action_ = -1;
+  State state_ = ALIVE;
+  bool all_final_ = false;
+  TF_DISALLOW_COPY_AND_ASSIGN(BeamState);
+};
+
+// Encapsulates the state of a batch of beams. It is an object of this
+// type that will persist through repeated Op evaluations as the
+// multiple steps are computed in sequence.
+class BatchState {
+ public:
+  explicit BatchState(const BatchStateOptions &options)
+      : options_(options), features_(options.arg_prefix) {}
+
+  ~BatchState() { SharedStore::Release(label_map_); }
+
+  void Init(TaskContext *task_context) {
+    // Create sentence batch.
+    sentence_batch_.reset(
+        new SentenceBatch(BatchSize(), options_.corpus_name));
+    sentence_batch_->Init(task_context);
+
+    // Create transition system.
+    transition_system_.reset(ParserTransitionSystem::Create(task_context->Get(
+        tensorflow::strings::StrCat(options_.arg_prefix, "_transition_system"),
+        "arc-standard")));
+    transition_system_->Setup(task_context);
+    transition_system_->Init(task_context);
+
+    // Create label map.
+    string label_map_path =
+        TaskContext::InputFile(*task_context->GetInput("label-map"));
+    label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
+        label_map_path, 0, 0);
+
+    // Setup features.
+    features_.Setup(task_context);
+    features_.Init(task_context);
+    features_.RequestWorkspaces(&workspace_registry_);
+
+    // Create workspaces.
+    workspaces_.resize(BatchSize());
+
+    // Create beams.
+    beams_.clear();
+    for (int beam_id = 0; beam_id < BatchSize(); ++beam_id) {
+      beams_.emplace_back(options_);
+      beams_[beam_id].beam_id_ = beam_id;
+      beams_[beam_id].sentence_batch_ = sentence_batch_.get();
+      beams_[beam_id].transition_system_ = transition_system_.get();
+      beams_[beam_id].label_map_ = label_map_;
+      beams_[beam_id].features_ = &features_;
+      beams_[beam_id].workspace_ = &workspaces_[beam_id];
+      beams_[beam_id].workspace_registry_ = &workspace_registry_;
+    }
+  }
+
+  void ResetBeams() {
+    for (BeamState &beam : beams_) {
+      beam.Reset();
+    }
+
+    // Rewind if no states remain in the batch (we need to rewind the corpus).
+    if (sentence_batch_->size() == 0) {
+      ++epoch_;
+      VLOG(2) << "Starting epoch " << epoch_;
+      sentence_batch_->Rewind();
+    }
+  }
+
+  // Resets the offset vectors required for a single run because we're
+  // starting a new matrix of scores.
+  void ResetOffsets() {
+    beam_offsets_.clear();
+    step_offsets_ = {0};
+    UpdateOffsets();
+  }
+
+  void AdvanceBeam(const int beam_id,
+                   const TTypes<float>::ConstMatrix &scores) {
+    const int offset = beam_offsets_.back()[beam_id];
+    Eigen::array<Eigen::DenseIndex, 2> offsets = {offset, 0};
+    Eigen::array<Eigen::DenseIndex, 2> extents = {
+        beam_offsets_.back()[beam_id + 1] - offset, NumActions()};
+    BeamState::ScoreMatrixType beam_scores = scores.slice(offsets, extents);
+    beams_[beam_id].Advance(beam_scores);
+  }
+
+  void UpdateOffsets() {
+    beam_offsets_.emplace_back(BatchSize() + 1, 0);
+    std::vector<int> &offsets = beam_offsets_.back();
+    for (int beam_id = 0; beam_id < BatchSize(); ++beam_id) {
+      // If the beam is ALIVE or DYING (but not DEAD), we want to
+      // output the activations.
+      const BeamState &beam = beams_[beam_id];
+      const int beam_size = beam.IsDead() ? 0 : beam.BeamSize();
+      offsets[beam_id + 1] = offsets[beam_id] + beam_size;
+    }
+    const int output_size = offsets.back();
+    step_offsets_.push_back(step_offsets_.back() + output_size);
+  }
+
+  tensorflow::Status PopulateFeatureOutputs(OpKernelContext *context) {
+    const int feature_size = FeatureSize();
+    std::vector<std::vector<std::vector<SparseFeatures>>> features(
+        feature_size);
+    for (int beam_id = 0; beam_id < BatchSize(); ++beam_id) {
+      if (!beams_[beam_id].IsDead()) {
+        beams_[beam_id].PopulateFeatureOutputs(&features);
+      }
+    }
+    CHECK_EQ(features.size(), feature_size);
+    Tensor *output;
+    const int total_slots = beam_offsets_.back().back();
+    for (int i = 0; i < feature_size; ++i) {
+      std::vector<std::vector<SparseFeatures>> &f = features[i];
+      CHECK_EQ(total_slots, f.size());
+      if (total_slots == 0) {
+        TF_RETURN_IF_ERROR(
+            context->allocate_output(i, TensorShape({0, 0}), &output));
+      } else {
+        const int size = f[0].size();
+        TF_RETURN_IF_ERROR(context->allocate_output(
+            i, TensorShape({total_slots, size}), &output));
+        for (int j = 0; j < total_slots; ++j) {
+          CHECK_EQ(size, f[j].size());
+          for (int k = 0; k < size; ++k) {
+            if (!options_.allow_feature_weights && f[j][k].weight_size() > 0) {
+              return FailedPrecondition(
+                  "Feature weights are not allowed when allow_feature_weights "
+                  "is set to false.");
+            }
+            output->matrix<string>()(j, k) = f[j][k].SerializeAsString();
+          }
+        }
+      }
+    }
+    return tensorflow::Status::OK();
+  }
+
+  // Returns the offset (i.e. row number) of a particular beam at a
+  // particular step in the final concatenated score matrix.
+  int GetOffset(const int step, const int beam_id) const {
+    return step_offsets_[step] + beam_offsets_[step][beam_id];
+  }
+
+  int FeatureSize() const { return features_.embedding_dims().size(); }
+
+  int NumActions() const {
+    return transition_system_->NumActions(label_map_->Size());
+  }
+
+  int BatchSize() const { return options_.batch_size; }
+
+  const BeamState &Beam(const int i) const { return beams_[i]; }
+
+  int Epoch() const { return epoch_; }
+
+  const string &ScoringType() const { return options_.scoring_type; }
+
+ private:
+  const BatchStateOptions options_;
+
+  // How many times the document source has been rewound.
+  int epoch_ = 0;
+
+  // Batch of sentences, and the corresponding parser states.
+  std::unique_ptr<SentenceBatch> sentence_batch_;
+
+  // Transition system.
+  std::unique_ptr<ParserTransitionSystem> transition_system_;
+
+  // Label map for transition system..
+  const TermFrequencyMap *label_map_;
+
+  // Typed feature extractor for embeddings.
+  ParserEmbeddingFeatureExtractor features_;
+
+  // Batch: WorkspaceSet objects.
+  std::vector<WorkspaceSet> workspaces_;
+
+  // Internal workspace registry for use in feature extraction.
+  WorkspaceRegistry workspace_registry_;
+
+  std::deque<BeamState> beams_;
+  std::vector<std::vector<int>> beam_offsets_;
+
+  // Keeps track of the slot offset of each step.
+  std::vector<int> step_offsets_;
+  TF_DISALLOW_COPY_AND_ASSIGN(BatchState);
+};
+
+// Creates a BeamState and hooks it up with a parser. This Op needs to
+// remain alive for the duration of the parse.
+class BeamParseReader : public OpKernel {
+ public:
+  explicit BeamParseReader(OpKernelConstruction *context) : OpKernel(context) {
+    string file_path;
+    int feature_size;
+    BatchStateOptions options;
+    OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
+    OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("beam_size", &options.max_beam_size));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("batch_size", &options.batch_size));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("arg_prefix", &options.arg_prefix));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("corpus_name", &options.corpus_name));
+    OP_REQUIRES_OK(context, context->GetAttr("allow_feature_weights",
+                                             &options.allow_feature_weights));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("continue_until_all_final",
+                                    &options.continue_until_all_final));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("always_start_new_sentences",
+                                    &options.always_start_new_sentences));
+
+    // Reads task context from file.
+    string data;
+    OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
+                                             file_path, &data));
+    TaskContext task_context;
+    OP_REQUIRES(context,
+                TextFormat::ParseFromString(data, task_context.mutable_spec()),
+                InvalidArgument("Could not parse task context at ", file_path));
+    OP_REQUIRES(
+        context, options.batch_size > 0,
+        InvalidArgument("Batch size ", options.batch_size, " too small."));
+    options.scoring_type = task_context.Get(
+        tensorflow::strings::StrCat(options.arg_prefix, "_scoring"), "");
+
+    // Create batch state.
+    batch_state_.reset(new BatchState(options));
+    batch_state_->Init(&task_context);
+
+    // Check number of feature groups matches the task context.
+    const int required_size = batch_state_->FeatureSize();
+    OP_REQUIRES(
+        context, feature_size == required_size,
+        InvalidArgument("Task context requires feature_size=", required_size));
+
+    // Set expected signature.
+    std::vector<DataType> output_types(feature_size, DT_STRING);
+    output_types.push_back(DT_INT64);
+    output_types.push_back(DT_INT32);
+    OP_REQUIRES_OK(context, context->MatchSignature({}, output_types));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    mutex_lock lock(mu_);
+
+    // Write features.
+    batch_state_->ResetBeams();
+    batch_state_->ResetOffsets();
+    batch_state_->PopulateFeatureOutputs(context);
+
+    // Forward the beam state vector.
+    Tensor *output;
+    const int feature_size = batch_state_->FeatureSize();
+    OP_REQUIRES_OK(context, context->allocate_output(feature_size,
+                                                     TensorShape({}), &output));
+    output->scalar<int64>()() = reinterpret_cast<int64>(batch_state_.get());
+
+    // Output number of epochs.
+    OP_REQUIRES_OK(context, context->allocate_output(feature_size + 1,
+                                                     TensorShape({}), &output));
+    output->scalar<int32>()() = batch_state_->Epoch();
+  }
+
+ private:
+  // mutex to synchronize access to Compute.
+  mutex mu_;
+
+  // The object whose handle will be passed among the Ops.
+  std::unique_ptr<BatchState> batch_state_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(BeamParseReader);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BeamParseReader").Device(DEVICE_CPU),
+                        BeamParseReader);
+
+// Updates the beam based on incoming scores and outputs new feature vectors
+// based on the updated beam.
+class BeamParser : public OpKernel {
+ public:
+  explicit BeamParser(OpKernelConstruction *context) : OpKernel(context) {
+    int feature_size;
+    OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size));
+
+    // Set expected signature.
+    std::vector<DataType> output_types(feature_size, DT_STRING);
+    output_types.push_back(DT_INT64);
+    output_types.push_back(DT_BOOL);
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_INT64, DT_FLOAT}, output_types));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    BatchState *batch_state =
+        reinterpret_cast<BatchState *>(context->input(0).scalar<int64>()());
+
+    const TTypes<float>::ConstMatrix scores = context->input(1).matrix<float>();
+    VLOG(2) << "Scores: " << scores;
+    CHECK_EQ(scores.dimension(1), batch_state->NumActions());
+
+    // In AdvanceBeam we use beam_offsets_[beam_id] to determine the slice of
+    // scores that should be used for advancing, but beam_offsets_[beam_id] only
+    // exists for beams that have a sentence loaded.
+    const int batch_size = batch_state->BatchSize();
+    for (int beam_id = 0; beam_id < batch_size; ++beam_id) {
+      batch_state->AdvanceBeam(beam_id, scores);
+    }
+    batch_state->UpdateOffsets();
+
+    // Forward the beam state unmodified.
+    Tensor *output;
+    const int feature_size = batch_state->FeatureSize();
+    OP_REQUIRES_OK(context, context->allocate_output(feature_size,
+                                                     TensorShape({}), &output));
+    output->scalar<int64>()() = context->input(0).scalar<int64>()();
+
+    // Output the new features of all the slots in all the beams.
+    OP_REQUIRES_OK(context, batch_state->PopulateFeatureOutputs(context));
+
+    // Output whether the beams are alive.
+    OP_REQUIRES_OK(
+        context, context->allocate_output(feature_size + 1,
+                                          TensorShape({batch_size}), &output));
+    for (int beam_id = 0; beam_id < batch_size; ++beam_id) {
+      output->vec<bool>()(beam_id) = batch_state->Beam(beam_id).IsAlive();
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(BeamParser);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BeamParser").Device(DEVICE_CPU), BeamParser);
+
+// Extracts the paths for the elements of the current beams and returns
+// indices into a scoring matrix that is assumed to have been
+// constructed along with the beam search.
+class BeamParserOutput : public OpKernel {
+ public:
+  explicit BeamParserOutput(OpKernelConstruction *context) : OpKernel(context) {
+    // Set expected signature.
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature(
+                       {DT_INT64}, {DT_INT32, DT_INT32, DT_INT32, DT_FLOAT}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    BatchState *batch_state =
+        reinterpret_cast<BatchState *>(context->input(0).scalar<int64>()());
+
+    const int num_actions = batch_state->NumActions();
+    const int batch_size = batch_state->BatchSize();
+
+    // Vectors for output.
+    //
+    // Each step of each batch:path gets its index computed and a
+    // unique path id assigned.
+    std::vector<int32> indices;
+    std::vector<int32> path_ids;
+
+    // Each unique path gets a batch id and a slot (in the beam)
+    // id. These are in effect the row and column of the final
+    // 'logits' matrix going to CrossEntropy.
+    std::vector<int32> beam_ids;
+    std::vector<int32> slot_ids;
+
+    // To compute the cross entropy we also need the slot id of the
+    // gold path, one per batch.
+    std::vector<int32> gold_slot(batch_size, -1);
+
+    // For good measure we also output the path scores as computed by
+    // the beam decoder, so it can be compared in tests with the path
+    // scores computed via the indices in TF. This has the same length
+    // as beam_ids and slot_ids.
+    std::vector<float> path_scores;
+
+    // The scores tensor has, conceptually, four dimensions: 1. number
+    // of steps, 2. batch size, 3. number of paths on the beam at that
+    // step, and 4. the number of actions scored. However this is not
+    // a true tensor since the size of the beam at each step may not
+    // be equal among all steps and among all batches. Only the batch
+    // size and number of actions is fixed.
+    int path_id = 0;
+    for (int beam_id = 0; beam_id < batch_size; ++beam_id) {
+      // This occurs at the end of the corpus, when there aren't enough
+      // sentences to fill the batch.
+      if (batch_state->Beam(beam_id).gold_ == nullptr) continue;
+
+      // Populate the vectors that will index into the concatenated
+      // scores tensor.
+      int slot = 0;
+      for (const auto &item : batch_state->Beam(beam_id).slots_) {
+        beam_ids.push_back(beam_id);
+        slot_ids.push_back(slot);
+        path_scores.push_back(item.first.first);
+        VLOG(2) << "PATH SCORE @ beam_id:" << beam_id << " slot:" << slot
+                << " : " << item.first.first << " " << item.first.second;
+        VLOG(2) << "SLOT HISTORY: "
+                << utils::Join(item.second->slot_history, " ");
+        VLOG(2) << "SCORE HISTORY: "
+                << utils::Join(item.second->score_history, " ");
+        VLOG(2) << "ACTION HISTORY: "
+                << utils::Join(item.second->action_history, " ");
+
+        // Record where the gold path ended up.
+        if (item.second->state->is_gold()) {
+          CHECK_EQ(gold_slot[beam_id], -1);
+          gold_slot[beam_id] = slot;
+        }
+
+        for (size_t step = 0; step < item.second->slot_history.size(); ++step) {
+          const int step_beam_offset = batch_state->GetOffset(step, beam_id);
+          const int slot_index = item.second->slot_history[step];
+          const int action_index = item.second->action_history[step];
+          indices.push_back(num_actions * (step_beam_offset + slot_index) +
+                            action_index);
+          path_ids.push_back(path_id);
+        }
+        ++slot;
+        ++path_id;
+      }
+
+      // One and only path must be the golden one.
+      CHECK_GE(gold_slot[beam_id], 0);
+    }
+
+    const int num_ix_elements = indices.size();
+    Tensor *output;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                0, TensorShape({2, num_ix_elements}), &output));
+    auto indices_and_path_ids = output->matrix<int32>();
+    for (size_t i = 0; i < indices.size(); ++i) {
+      indices_and_path_ids(0, i) = indices[i];
+      indices_and_path_ids(1, i) = path_ids[i];
+    }
+
+    const int num_path_elements = beam_ids.size();
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       1, TensorShape({2, num_path_elements}), &output));
+    auto beam_and_slot_ids = output->matrix<int32>();
+    for (size_t i = 0; i < beam_ids.size(); ++i) {
+      beam_and_slot_ids(0, i) = beam_ids[i];
+      beam_and_slot_ids(1, i) = slot_ids[i];
+    }
+
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                2, TensorShape({batch_size}), &output));
+    std::copy(gold_slot.begin(), gold_slot.end(), output->vec<int32>().data());
+
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                3, TensorShape({num_path_elements}), &output));
+    std::copy(path_scores.begin(), path_scores.end(),
+              output->vec<float>().data());
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(BeamParserOutput);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BeamParserOutput").Device(DEVICE_CPU),
+                        BeamParserOutput);
+
+// Computes eval metrics for the best path in the input beams.
+class BeamEvalOutput : public OpKernel {
+ public:
+  explicit BeamEvalOutput(OpKernelConstruction *context) : OpKernel(context) {
+    // Set expected signature.
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_INT64}, {DT_INT32, DT_STRING}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    int num_tokens = 0;
+    int num_correct = 0;
+    int all_final = 0;
+    BatchState *batch_state =
+        reinterpret_cast<BatchState *>(context->input(0).scalar<int64>()());
+    const int batch_size = batch_state->BatchSize();
+    vector<Sentence> documents;
+    for (int beam_id = 0; beam_id < batch_size; ++beam_id) {
+      if (batch_state->Beam(beam_id).gold_ != nullptr &&
+          batch_state->Beam(beam_id).AllFinal()) {
+        ++all_final;
+        const auto &item = *batch_state->Beam(beam_id).slots_.rbegin();
+        ComputeTokenAccuracy(*item.second->state, batch_state->ScoringType(),
+                             &num_tokens, &num_correct);
+        documents.push_back(item.second->state->sentence());
+        item.second->state->AddParseToDocument(&documents.back());
+      }
+    }
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({2}), &output));
+    auto eval_metrics = output->vec<int32>();
+    eval_metrics(0) = num_tokens;
+    eval_metrics(1) = num_correct;
+
+    const int output_size = documents.size();
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                1, TensorShape({output_size}), &output));
+    for (int i = 0; i < output_size; ++i) {
+      output->vec<string>()(i) = documents[i].SerializeAsString();
+    }
+  }
+
+ private:
+  // Tallies the # of correct and incorrect tokens for a given ParserState.
+  void ComputeTokenAccuracy(const ParserState &state,
+                            const string &scoring_type,
+                            int *num_tokens, int *num_correct) {
+    for (int i = 0; i < state.sentence().token_size(); ++i) {
+      const Token &token = state.GetToken(i);
+      if (utils::PunctuationUtil::ScoreToken(token.word(), token.tag(),
+                                             scoring_type)) {
+        ++*num_tokens;
+        if (state.IsTokenCorrect(i)) ++*num_correct;
+      }
+    }
+  }
+
+  TF_DISALLOW_COPY_AND_ASSIGN(BeamEvalOutput);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BeamEvalOutput").Device(DEVICE_CPU),
+                        BeamEvalOutput);
+
+}  // namespace syntaxnet

+ 230 - 0
syntaxnet/syntaxnet/beam_reader_ops_test.py

@@ -0,0 +1,230 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for beam_reader_ops."""
+
+
+import os.path
+import time
+
+import tensorflow as tf
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+
+from syntaxnet import structured_graph_builder
+from syntaxnet.ops import gen_parser_ops
+
+FLAGS = tf.app.flags.FLAGS
+if not hasattr(FLAGS, 'test_srcdir'):
+  FLAGS.test_srcdir = ''
+if not hasattr(FLAGS, 'test_tmpdir'):
+  FLAGS.test_tmpdir = tf.test.get_temp_dir()
+
+
+class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    # Creates a task context with the correct testing paths.
+    initial_task_context = os.path.join(
+        FLAGS.test_srcdir,
+        'syntaxnet/'
+        'testdata/context.pbtxt')
+    self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
+    with open(initial_task_context, 'r') as fin:
+      with open(self._task_context, 'w') as fout:
+        fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
+                   .replace('OUTPATH', FLAGS.test_tmpdir))
+
+    # Creates necessary term maps.
+    with self.test_session() as sess:
+      gen_parser_ops.lexicon_builder(task_context=self._task_context,
+                                     corpus_name='training-corpus').run()
+      self._num_features, self._num_feature_ids, _, self._num_actions = (
+          sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
+                                               arg_prefix='brain_parser')))
+
+  def MakeGraph(self,
+                max_steps=10,
+                beam_size=2,
+                batch_size=1,
+                **kwargs):
+    """Constructs a structured learning graph."""
+    assert max_steps > 0, 'Empty network not supported.'
+
+    logging.info('MakeGraph + %s', kwargs)
+
+    with self.test_session(graph=tf.Graph()) as sess:
+      feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
+          gen_parser_ops.feature_size(task_context=self._task_context))
+    embedding_dims = [8, 8, 8]
+    hidden_layer_sizes = []
+    learning_rate = 0.01
+    builder = structured_graph_builder.StructuredGraphBuilder(
+        num_actions,
+        feature_sizes,
+        domain_sizes,
+        embedding_dims,
+        hidden_layer_sizes,
+        seed=1,
+        max_steps=max_steps,
+        beam_size=beam_size,
+        gate_gradients=True,
+        use_locking=True,
+        use_averaging=False,
+        check_parameters=False,
+        **kwargs)
+    builder.AddTraining(self._task_context,
+                        batch_size,
+                        learning_rate=learning_rate,
+                        decay_steps=1000,
+                        momentum=0.9,
+                        corpus_name='training-corpus')
+    builder.AddEvaluation(self._task_context,
+                          batch_size,
+                          evaluation_max_steps=25,
+                          corpus_name=None)
+    builder.training['inits'] = tf.group(*builder.inits.values(), name='inits')
+    return builder
+
+  def Train(self, **kwargs):
+    with self.test_session(graph=tf.Graph()) as sess:
+      max_steps = 3
+      batch_size = 3
+      beam_size = 3
+      builder = (
+          self.MakeGraph(
+              max_steps=max_steps, beam_size=beam_size,
+              batch_size=batch_size, **kwargs))
+      logging.info('params: %s', builder.params.keys())
+      logging.info('variables: %s', builder.variables.keys())
+
+      t = builder.training
+      sess.run(t['inits'])
+      costs = []
+      gold_slots = []
+      alive_steps_vector = []
+      every_n = 5
+      walltime = time.time()
+      for step in range(10):
+        if step > 0 and step % every_n == 0:
+          new_walltime = time.time()
+          logging.info(
+              'Step: %d <cost>: %f <gold_slot>: %f <alive_steps>: %f <iter '
+              'time>: %f ms',
+              step, sum(costs[-every_n:]) / float(every_n),
+              sum(gold_slots[-every_n:]) / float(every_n),
+              sum(alive_steps_vector[-every_n:]) / float(every_n),
+              1000 * (new_walltime - walltime) / float(every_n))
+          walltime = new_walltime
+
+        cost, gold_slot, alive_steps, _ = sess.run(
+            [t['cost'], t['gold_slot'], t['alive_steps'], t['train_op']])
+        costs.append(cost)
+        gold_slots.append(gold_slot.mean())
+        alive_steps_vector.append(alive_steps.mean())
+
+      if builder._only_train:
+        trainable_param_names = [
+            k for k in builder.params if k in builder._only_train]
+      else:
+        trainable_param_names = builder.params.keys()
+      if builder._use_averaging:
+        for v in trainable_param_names:
+          avg = builder.variables['%s_avg_var' % v].eval()
+          tf.assign(builder.params[v], avg).eval()
+
+      # Reset for pseudo eval.
+      costs = []
+      gold_slots = []
+      alive_stepss = []
+      for step in range(10):
+        cost, gold_slot, alive_steps = sess.run(
+            [t['cost'], t['gold_slot'], t['alive_steps']])
+        costs.append(cost)
+        gold_slots.append(gold_slot.mean())
+        alive_stepss.append(alive_steps.mean())
+
+      logging.info(
+          'Pseudo eval: <cost>: %f <gold_slot>: %f <alive_steps>: %f',
+          sum(costs[-every_n:]) / float(every_n),
+          sum(gold_slots[-every_n:]) / float(every_n),
+          sum(alive_stepss[-every_n:]) / float(every_n))
+
+  def PathScores(self, iterations, beam_size, max_steps, batch_size):
+    with self.test_session(graph=tf.Graph()) as sess:
+      t = self.MakeGraph(beam_size=beam_size, max_steps=max_steps,
+                         batch_size=batch_size).training
+      sess.run(t['inits'])
+      all_path_scores = []
+      beam_path_scores = []
+      for i in range(iterations):
+        logging.info('run %d', i)
+        tensors = (
+            sess.run(
+                [t['alive_steps'], t['concat_scores'],
+                 t['all_path_scores'], t['beam_path_scores'],
+                 t['indices'], t['path_ids']]))
+
+        logging.info('alive for %s, all_path_scores and beam_path_scores, '
+                     'indices and path_ids:'
+                     '\n%s\n%s\n%s\n%s',
+                     tensors[0], tensors[2], tensors[3], tensors[4], tensors[5])
+        logging.info('diff:\n%s', tensors[2] - tensors[3])
+
+        all_path_scores.append(tensors[2])
+        beam_path_scores.append(tensors[3])
+      return all_path_scores, beam_path_scores
+
+  def testParseUntilNotAlive(self):
+    """Ensures that the 'alive' condition works in the Cond ops."""
+    with self.test_session(graph=tf.Graph()) as sess:
+      t = self.MakeGraph(batch_size=3, beam_size=2, max_steps=5).training
+      sess.run(t['inits'])
+      for i in range(5):
+        logging.info('run %d', i)
+        tf_alive = t['alive'].eval()
+        self.assertFalse(any(tf_alive))
+
+  def testParseMomentum(self):
+    """Ensures that Momentum training can be done using the gradients."""
+    self.Train()
+    self.Train(model_cost='perceptron_loss')
+    self.Train(model_cost='perceptron_loss',
+               only_train='softmax_weight,softmax_bias', softmax_init=0)
+    self.Train(only_train='softmax_weight,softmax_bias', softmax_init=0)
+
+  def testPathScoresAgree(self):
+    """Ensures that path scores computed in the beam are same in the net."""
+    all_path_scores, beam_path_scores = self.PathScores(
+        iterations=1, beam_size=130, max_steps=5, batch_size=1)
+    self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
+
+  def testBatchPathScoresAgree(self):
+    """Ensures that path scores computed in the beam are same in the net."""
+    all_path_scores, beam_path_scores = self.PathScores(
+        iterations=1, beam_size=130, max_steps=5, batch_size=22)
+    self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
+
+  def testBatchOneStepPathScoresAgree(self):
+    """Ensures that path scores computed in the beam are same in the net."""
+    all_path_scores, beam_path_scores = self.PathScores(
+        iterations=1, beam_size=130, max_steps=1, batch_size=22)
+    self.assertArrayNear(all_path_scores[0], beam_path_scores[0], 1e-6)
+
+
+if __name__ == '__main__':
+  googletest.main()

+ 93 - 0
syntaxnet/syntaxnet/conll2tree.py

@@ -0,0 +1,93 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A program to generate ASCII trees from conll files."""
+
+import collections
+
+import asciitree
+import tensorflow as tf
+
+import syntaxnet.load_parser_ops
+
+from tensorflow.python.platform import logging
+from syntaxnet import sentence_pb2
+from syntaxnet.ops import gen_parser_ops
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('task_context',
+                    'syntaxnet/models/parsey_mcparseface/context.pbtxt',
+                    'Path to a task context with inputs and parameters for '
+                    'feature extractors.')
+flags.DEFINE_string('corpus_name', 'stdin-conll',
+                    'Path to a task context with inputs and parameters for '
+                    'feature extractors.')
+
+
+def to_dict(sentence):
+  """Builds a dictionary representing the parse tree of a sentence.
+
+  Args:
+    sentence: Sentence protocol buffer to represent.
+  Returns:
+    Dictionary mapping tokens to children.
+  """
+  token_str = ['%s %s %s' % (token.word, token.tag, token.label)
+               for token in sentence.token]
+  children = [[] for token in sentence.token]
+  root = -1
+  for i in range(0, len(sentence.token)):
+    token = sentence.token[i]
+    if token.head == -1:
+      root = i
+    else:
+      children[token.head].append(i)
+
+  def _get_dict(i):
+    d = collections.OrderedDict()
+    for c in children[i]:
+      d[token_str[c]] = _get_dict(c)
+    return d
+
+  tree = collections.OrderedDict()
+  tree[token_str[root]] = _get_dict(root)
+  return tree
+
+
+def main(unused_argv):
+  logging.set_verbosity(logging.INFO)
+  with tf.Session() as sess:
+    src = gen_parser_ops.document_source(batch_size=32,
+                                         corpus_name=FLAGS.corpus_name,
+                                         task_context=FLAGS.task_context)
+    sentence = sentence_pb2.Sentence()
+    while True:
+      documents, finished = sess.run(src)
+      logging.info('Read %d documents', len(documents))
+      for d in documents:
+        sentence.ParseFromString(d)
+        tr = asciitree.LeftAligned()
+        d = to_dict(sentence)
+        print 'Input: %s' % sentence.text
+        print 'Parse:'
+        print tr(d)
+
+      if finished:
+        break
+
+
+if __name__ == '__main__':
+  tf.app.run()

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 156 - 0
syntaxnet/syntaxnet/context.pbtxt


+ 56 - 0
syntaxnet/syntaxnet/demo.sh

@@ -0,0 +1,56 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# A script that runs a tokenizer, a part-of-speech tagger and a dependency
+# parser on an English text file, with one sentence per line.
+#
+# Example usage:
+#  echo "Parsey McParseface is my favorite parser!" | syntaxnet/demo.sh
+
+# To run on a conll formatted file, add the --conll command line argument.
+#
+
+PARSER_EVAL=bazel-bin/syntaxnet/parser_eval
+MODEL_DIR=syntaxnet/models/parsey_mcparseface
+[[ "$1" == "--conll" ]] && INPUT_FORMAT=stdin-conll || INPUT_FORMAT=stdin
+
+$PARSER_EVAL \
+  --input=$INPUT_FORMAT \
+  --output=stdout-conll \
+  --hidden_layer_sizes=64 \
+  --arg_prefix=brain_tagger \
+  --graph_builder=structured \
+  --task_context=$MODEL_DIR/context.pbtxt \
+  --model_path=$MODEL_DIR/tagger-params \
+  --slim_model \
+  --batch_size=1024 \
+  --alsologtostderr \
+   | \
+  $PARSER_EVAL \
+  --input=stdin-conll \
+  --output=stdout-conll \
+  --hidden_layer_sizes=512,512 \
+  --arg_prefix=brain_parser \
+  --graph_builder=structured \
+  --task_context=$MODEL_DIR/context.pbtxt \
+  --model_path=$MODEL_DIR/parser-params \
+  --slim_model \
+  --batch_size=1024 \
+  --alsologtostderr \
+  | \
+  bazel-bin/syntaxnet/conll2tree \
+  --task_context=$MODEL_DIR/context.pbtxt \
+  --alsologtostderr

+ 57 - 0
syntaxnet/syntaxnet/dictionary.proto

@@ -0,0 +1,57 @@
+// Protocol buffers for serializing string<=>index dictionaries.
+
+syntax = "proto2";
+
+package syntaxnet;
+
+// Serializable representation of a string=>string pair.
+message StringToStringPair {
+  // String representing the key.
+  required string key = 1;
+
+  // String representing the value.
+  required string value = 2;
+}
+
+// Serializable representation of a string=>string mapping.
+message StringToStringMap {
+  // Key=>value pairs.
+  repeated StringToStringPair pair = 1;
+}
+
+// Affix table entry, for serialization of the affix tables.
+message AffixTableEntry {
+  // Nested message for serializing a single affix.
+  message AffixEntry {
+    // The affix as a string.
+    required string form = 1;
+
+    // The length of the affix (this is non-trivial to compute due to UTF-8).
+    required int32 length = 2;
+
+    // The ID of the affix that is one character shorter, or -1 if none exists.
+    required int32 shorter_id = 3;
+  }
+
+  // The type of affix table, as a string.
+  required string type = 1;
+
+  // The maximum affix length.
+  required int32 max_length = 2;
+
+  // The list of affixes, in order of affix ID.
+  repeated AffixEntry affix = 3;
+}
+
+// A light-weight proto to store vectors in binary format.
+message TokenEmbedding {
+  required bytes token = 1;  // can be word or phrase, or URL, etc.
+
+  // If available, raw count of this token in the training corpus.
+  optional int64 count = 3;
+
+  message Vector {
+    repeated float values = 1 [packed = true];
+  }
+  optional Vector vector = 2;
+};

+ 335 - 0
syntaxnet/syntaxnet/document_filters.cc

@@ -0,0 +1,335 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Various utilities for handling documents.
+
+#include <stddef.h>
+#include <algorithm>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/base.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/utils.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::errors::InvalidArgument;
+
+namespace syntaxnet {
+
+namespace {
+
+void GetTaskContext(OpKernelConstruction *context, TaskContext *task_context) {
+  string file_path, data;
+  OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
+  OP_REQUIRES_OK(
+      context, ReadFileToString(tensorflow::Env::Default(), file_path, &data));
+  OP_REQUIRES(context,
+              TextFormat::ParseFromString(data, task_context->mutable_spec()),
+              InvalidArgument("Could not parse task context at ", file_path));
+}
+
+// Outputs the given batch of sentences as a tensor and deletes them.
+void OutputDocuments(OpKernelContext *context,
+                     vector<Sentence *> *document_batch) {
+  const int64 size = document_batch->size();
+  Tensor *output;
+  OP_REQUIRES_OK(context,
+                 context->allocate_output(0, TensorShape({size}), &output));
+  for (int64 i = 0; i < size; ++i) {
+    output->vec<string>()(i) = (*document_batch)[i]->SerializeAsString();
+  }
+  utils::STLDeleteElements(document_batch);
+}
+
+}  // namespace
+
+class DocumentSource : public OpKernel {
+ public:
+  explicit DocumentSource(OpKernelConstruction *context) : OpKernel(context) {
+    GetTaskContext(context, &task_context_);
+    string corpus_name;
+    OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
+    OP_REQUIRES_OK(context, context->GetAttr("batch_size", &batch_size_));
+    OP_REQUIRES(context, batch_size_ > 0,
+                InvalidArgument("invalid batch_size provided"));
+    corpus_.reset(new TextReader(*task_context_.GetInput(corpus_name)));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    mutex_lock lock(mu_);
+    Sentence *document;
+    vector<Sentence *> document_batch;
+    while ((document = corpus_->Read()) != NULL) {
+      document_batch.push_back(document);
+      if (static_cast<int>(document_batch.size()) == batch_size_) {
+        OutputDocuments(context, &document_batch);
+        OutputLast(context, false);
+        return;
+      }
+    }
+    OutputDocuments(context, &document_batch);
+    OutputLast(context, true);
+  }
+
+ private:
+  void OutputLast(OpKernelContext *context, bool last) {
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(1, TensorShape({}), &output));
+    output->scalar<bool>()() = last;
+  }
+
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // mutex to synchronize access to Compute.
+  mutex mu_;
+
+  std::unique_ptr<TextReader> corpus_;
+  string documents_path_;
+  int batch_size_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DocumentSource").Device(DEVICE_CPU),
+                        DocumentSource);
+
+class DocumentSink : public OpKernel {
+ public:
+  explicit DocumentSink(OpKernelConstruction *context) : OpKernel(context) {
+    GetTaskContext(context, &task_context_);
+    string corpus_name;
+    OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
+    writer_.reset(new TextWriter(*task_context_.GetInput(corpus_name)));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    mutex_lock lock(mu_);
+    auto documents = context->input(0).vec<string>();
+    for (int i = 0; i < documents.size(); ++i) {
+      Sentence document;
+      OP_REQUIRES(context, document.ParseFromString(documents(i)),
+                  InvalidArgument("failed to parse sentence"));
+      writer_->Write(document);
+    }
+  }
+
+ private:
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // mutex to synchronize access to Compute.
+  mutex mu_;
+
+  string documents_path_;
+  std::unique_ptr<TextWriter> writer_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("DocumentSink").Device(DEVICE_CPU),
+                        DocumentSink);
+
+// Sentence filter for filtering out documents where the parse trees are not
+// well-formed, i.e. they contain cycles.
+class WellFormedFilter : public OpKernel {
+ public:
+  explicit WellFormedFilter(OpKernelConstruction *context) : OpKernel(context) {
+    GetTaskContext(context, &task_context_);
+    OP_REQUIRES_OK(context, context->GetAttr("keep_malformed_documents",
+                                             &keep_malformed_));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    auto documents = context->input(0).vec<string>();
+    vector<Sentence *> output_documents;
+    for (int i = 0; i < documents.size(); ++i) {
+      Sentence *document = new Sentence;
+      OP_REQUIRES(context, document->ParseFromString(documents(i)),
+                  InvalidArgument("failed to parse sentence"));
+      if (ShouldKeep(*document)) {
+        output_documents.push_back(document);
+      } else {
+        delete document;
+      }
+    }
+    OutputDocuments(context, &output_documents);
+  }
+
+ private:
+  bool ShouldKeep(const Sentence &doc)  {
+    vector<int> visited(doc.token_size(), -1);
+    for (int i = 0; i < doc.token_size(); ++i) {
+      // Already visited node.
+      if (visited[i] != -1) continue;
+      int t = i;
+      while (t != -1) {
+        if (visited[t] == -1) {
+          // If it is not visited yet, mark it.
+          visited[t] = i;
+        } else if (visited[t] < i) {
+          // If the index number is smaller than index and not -1, the token has
+          // already been visited.
+          break;
+        } else {
+          // Loop detected.
+          LOG(ERROR) << "Loop detected in document " << doc.DebugString();
+          return keep_malformed_;
+        }
+        t = doc.token(t).head();
+      }
+    }
+    return true;
+  }
+
+ private:
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  bool keep_malformed_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("WellFormedFilter").Device(DEVICE_CPU),
+                        WellFormedFilter);
+
+// Sentence filter that modifies dependency trees to make them projective. This
+// could be made more efficient by looping over sentences instead of the entire
+// document. Assumes that the document is well-formed in the sense of having
+// no looping dependencies.
+//
+// Task arguments:
+//   bool discard_non_projective (false) : If true, discards documents with
+//     non-projective trees instead of projectivizing them.
+class ProjectivizeFilter : public OpKernel {
+ public:
+  explicit ProjectivizeFilter(OpKernelConstruction *context)
+      : OpKernel(context) {
+    GetTaskContext(context, &task_context_);
+    OP_REQUIRES_OK(context, context->GetAttr("discard_non_projective",
+                                             &discard_non_projective_));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    auto documents = context->input(0).vec<string>();
+    vector<Sentence *> output_documents;
+    for (int i = 0; i < documents.size(); ++i) {
+      Sentence *document = new Sentence;
+      OP_REQUIRES(context, document->ParseFromString(documents(i)),
+                  InvalidArgument("failed to parse sentence"));
+      if (Process(document)) {
+        output_documents.push_back(document);
+      } else {
+        delete document;
+      }
+    }
+    OutputDocuments(context, &output_documents);
+  }
+
+  bool Process(Sentence *doc) {
+    const int num_tokens = doc->token_size();
+
+    // Left and right boundaries for arcs. The left and right ends of an arc are
+    // bounded by the arcs that pass over it. If an arc exceeds these bounds it
+    // will cross an arc passing over it, making it a non-projective arc.
+    vector<int> left(num_tokens);
+    vector<int> right(num_tokens);
+
+    // Lift the shortest non-projective arc until the document is projective.
+    while (true) {
+      // Initialize boundaries to the whole document for all arcs.
+      for (int i = 0; i < num_tokens; ++i) {
+        left[i] = -1;
+        right[i] = num_tokens - 1;
+      }
+
+      // Find left and right bounds for each token.
+      for (int i = 0; i < num_tokens; ++i) {
+        int head_index = doc->token(i).head();
+
+        // Find left and right end of arc.
+        int l = std::min(i, head_index);
+        int r = std::max(i, head_index);
+
+        // Bound all tokens under the arc.
+        for (int j = l + 1; j < r; ++j) {
+          if (left[j] < l) left[j] = l;
+          if (right[j] > r) right[j] = r;
+        }
+      }
+
+      // Find deepest non-projective arc.
+      int deepest_arc = -1;
+      int max_depth = -1;
+
+      // The non-projective arcs are those that exceed their bounds.
+      for (int i = 0; i < num_tokens; ++i) {
+        int head_index = doc->token(i).head();
+        if (head_index == -1) continue;  // any crossing arc must be deeper
+
+        int l = std::min(i, head_index);
+        int r = std::max(i, head_index);
+
+        int left_bound = std::max(left[l], left[r]);
+        int right_bound = std::min(right[l], right[r]);
+
+        if (l < left_bound || r > right_bound) {
+          // Found non-projective arc.
+          if (discard_non_projective_) return false;
+
+          // Pick the deepest as the best candidate for lifting.
+          int depth = 0;
+          int j = i;
+          while (j != -1) {
+            ++depth;
+            j = doc->token(j).head();
+          }
+          if (depth > max_depth) {
+            deepest_arc = i;
+            max_depth = depth;
+          }
+        }
+      }
+
+      // If there are no more non-projective arcs we are done.
+      if (deepest_arc == -1) return true;
+
+      // Lift non-projective arc.
+      int lifted_head = doc->token(doc->token(deepest_arc).head()).head();
+      doc->mutable_token(deepest_arc)->set_head(lifted_head);
+    }
+  }
+
+ private:
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // Whether or not to throw away non-projective documents.
+  bool discard_non_projective_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("ProjectivizeFilter").Device(DEVICE_CPU),
+                        ProjectivizeFilter);
+
+}  // namespace syntaxnet

+ 23 - 0
syntaxnet/syntaxnet/document_format.cc

@@ -0,0 +1,23 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/document_format.h"
+
+namespace syntaxnet {
+
+// Component registry for document formatters.
+REGISTER_CLASS_REGISTRY("document format", DocumentFormat);
+
+}  // namespace syntaxnet

+ 63 - 0
syntaxnet/syntaxnet/document_format.h

@@ -0,0 +1,63 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// An interface for document formats.
+
+#ifndef $TARGETDIR_DOCUMENT_FORMAT_H__
+#define $TARGETDIR_DOCUMENT_FORMAT_H__
+
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/registry.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+
+namespace syntaxnet {
+
+// A document format component converts a key/value pair from a record to one or
+// more documents. The record format is used for selecting the document format
+// component. A document format component can be registered with the
+// REGISTER_DOCUMENT_FORMAT macro.
+class DocumentFormat : public RegisterableClass<DocumentFormat> {
+ public:
+  DocumentFormat() {}
+  virtual ~DocumentFormat() {}
+
+  // Reads a record from the given input buffer with format specific logic.
+  // Returns false if no record could be read because we reached end of file.
+  virtual bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+                          string *record) = 0;
+
+  // Converts a key/value pair to one or more documents.
+  virtual void ConvertFromString(const string &key, const string &value,
+                                 vector<Sentence *> *documents) = 0;
+
+  // Converts a document to a key/value pair.
+  virtual void ConvertToString(const Sentence &document,
+                               string *key, string *value) = 0;
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(DocumentFormat);
+};
+
+#define REGISTER_DOCUMENT_FORMAT(type, component) \
+  REGISTER_CLASS_COMPONENT(DocumentFormat, type, component)
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_DOCUMENT_FORMAT_H__

+ 80 - 0
syntaxnet/syntaxnet/embedding_feature_extractor.cc

@@ -0,0 +1,80 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/embedding_feature_extractor.h"
+
+#include <vector>
+
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/parser_features.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+void GenericEmbeddingFeatureExtractor::Setup(TaskContext *context) {
+  // Don't use version to determine how to get feature FML.
+  const string features = context->Get(
+      tensorflow::strings::StrCat(ArgPrefix(), "_", "features"), "");
+  const string embedding_names =
+      context->Get(GetParamName("embedding_names"), "");
+  const string embedding_dims =
+      context->Get(GetParamName("embedding_dims"), "");
+  LOG(INFO) << "Features: " << features;
+  LOG(INFO) << "Embedding names: " << embedding_names;
+  LOG(INFO) << "Embedding dims: " << embedding_dims;
+  embedding_fml_ = utils::Split(features, ';');
+  add_strings_ = context->Get(GetParamName("add_varlen_strings"), false);
+  embedding_names_ = utils::Split(embedding_names, ';');
+  for (const string &dim : utils::Split(embedding_dims, ';')) {
+    embedding_dims_.push_back(utils::ParseUsing<int>(dim, utils::ParseInt32));
+  }
+}
+
+void GenericEmbeddingFeatureExtractor::Init(TaskContext *context) {
+}
+
+vector<vector<SparseFeatures>> GenericEmbeddingFeatureExtractor::ConvertExample(
+    const vector<FeatureVector> &feature_vectors) const {
+  // Extract the features.
+  vector<vector<SparseFeatures>> sparse_features(feature_vectors.size());
+  for (size_t i = 0; i < feature_vectors.size(); ++i) {
+    // Convert the nlp_parser::FeatureVector to dist belief format.
+    sparse_features[i] =
+        vector<SparseFeatures>(generic_feature_extractor(i).feature_types());
+
+    for (int j = 0; j < feature_vectors[i].size(); ++j) {
+      const FeatureType &feature_type = *feature_vectors[i].type(j);
+      const FeatureValue value = feature_vectors[i].value(j);
+      const bool is_continuous = feature_type.name().find("continuous") == 0;
+      const int64 id = is_continuous ? FloatFeatureValue(value).id : value;
+      const int base = feature_type.base();
+      if (id >= 0) {
+        sparse_features[i][base].add_id(id);
+        if (is_continuous) {
+          sparse_features[i][base].add_weight(FloatFeatureValue(value).weight);
+        }
+        if (add_strings_) {
+          sparse_features[i][base].add_description(tensorflow::strings::StrCat(
+              feature_type.name(), "=", feature_type.GetFeatureValueName(id)));
+        }
+      }
+    }
+  }
+
+  return sparse_features;
+}
+
+}  // namespace syntaxnet

+ 222 - 0
syntaxnet/syntaxnet/embedding_feature_extractor.h

@@ -0,0 +1,222 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_
+#define $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_
+
+#include <functional>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/parser_features.h"
+#include "syntaxnet/sentence_features.h"
+#include "syntaxnet/sparse.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+// An EmbeddingFeatureExtractor manages the extraction of features for
+// embedding-based models. It wraps a sequence of underlying classes of feature
+// extractors, along with associated predicate maps. Each class of feature
+// extractors is associated with a name, e.g., "words", "labels", "tags".
+//
+// The class is split between a generic abstract version,
+// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
+// signature of the ExtractFeatures method) and a typed version.
+//
+// The predicate maps must be initialized before use: they can be loaded using
+// Read() or updated via UpdateMapsForExample.
+class GenericEmbeddingFeatureExtractor {
+ public:
+  virtual ~GenericEmbeddingFeatureExtractor() {}
+
+  // Get the prefix string to put in front of all arguments, so they don't
+  // conflict with other embedding models.
+  virtual const string ArgPrefix() const = 0;
+
+  // Sets up predicate maps and embedding space names that are common for all
+  // embedding based feature extractors.
+  virtual void Setup(TaskContext *context);
+  virtual void Init(TaskContext *context);
+
+  // Requests workspace for the underlying feature extractors. This is
+  // implemented in the typed class.
+  virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
+
+  // Number of predicates for the embedding at a given index (vocabulary size.)
+  int EmbeddingSize(int index) const {
+    return generic_feature_extractor(index).GetDomainSize();
+  }
+
+  // Returns number of embedding spaces.
+  int NumEmbeddings() const { return embedding_dims_.size(); }
+
+  // Returns the number of features in the embedding space.
+  const int FeatureSize(int idx) const {
+    return generic_feature_extractor(idx).feature_types();
+  }
+
+  // Returns the dimensionality of the embedding space.
+  int EmbeddingDims(int index) const { return embedding_dims_[index]; }
+
+  // Accessor for embedding dims (dimensions of the embedding spaces).
+  const vector<int> &embedding_dims() const { return embedding_dims_; }
+
+  const vector<string> &embedding_fml() const { return embedding_fml_; }
+
+  // Get parameter name by concatenating the prefix and the original name.
+  string GetParamName(const string &param_name) const {
+    return tensorflow::strings::StrCat(ArgPrefix(), "_", param_name);
+  }
+
+ protected:
+  // Provides the generic class with access to the templated extractors. This is
+  // used to get the type information out of the feature extractor without
+  // knowing the specific calling arguments of the extractor itself.
+  virtual const GenericFeatureExtractor &generic_feature_extractor(
+      int idx) const = 0;
+
+  // Converts a vector of extracted features into
+  // dist_belief::SparseFeatures. Each feature in each feature vector becomes a
+  // single SparseFeatures. The predicates are mapped through map_fn which
+  // should point to either mutable_map_fn or const_map_fn depending on whether
+  // or not the predicate maps should be updated.
+  vector<vector<SparseFeatures>> ConvertExample(
+      const vector<FeatureVector> &feature_vectors) const;
+
+ private:
+  // Embedding space names for parameter sharing.
+  vector<string> embedding_names_;
+
+  // FML strings for each feature extractor.
+  vector<string> embedding_fml_;
+
+  // Size of each of the embedding spaces (maximum predicate id).
+  vector<int> embedding_sizes_;
+
+  // Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
+  vector<int> embedding_dims_;
+
+  // Whether or not to add string descriptions to converted examples.
+  bool add_strings_;
+};
+
+// Templated, object-specific implementation of the
+// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
+// ARGS...> class that has the appropriate FeatureTraits() to ensure that
+// locator type features work.
+//
+// Note: for backwards compatibility purposes, this always reads the FML spec
+// from "<prefix>_features".
+template <class EXTRACTOR, class OBJ, class... ARGS>
+class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
+ public:
+  // Sets up all predicate maps, feature extractors, and flags.
+  void Setup(TaskContext *context) override {
+    GenericEmbeddingFeatureExtractor::Setup(context);
+    feature_extractors_.resize(embedding_fml().size());
+    for (int i = 0; i < embedding_fml().size(); ++i) {
+      feature_extractors_[i].Parse(embedding_fml()[i]);
+      feature_extractors_[i].Setup(context);
+    }
+  }
+
+  // Initializes resources needed by the feature extractors.
+  void Init(TaskContext *context) override {
+    GenericEmbeddingFeatureExtractor::Init(context);
+    for (auto &feature_extractor : feature_extractors_) {
+      feature_extractor.Init(context);
+    }
+  }
+
+  // Requests workspaces from the registry. Must be called after Init(), and
+  // before Preprocess().
+  void RequestWorkspaces(WorkspaceRegistry *registry) override {
+    for (auto &feature_extractor : feature_extractors_) {
+      feature_extractor.RequestWorkspaces(registry);
+    }
+  }
+
+  // Must be called on the object one state for each sentence, before any
+  // feature extraction (e.g., UpdateMapsForExample, ExtractSparseFeatures).
+  void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
+    for (auto &feature_extractor : feature_extractors_) {
+      feature_extractor.Preprocess(workspaces, obj);
+    }
+  }
+
+  // Returns a ragged array of SparseFeatures, for 1) each feature extractor
+  // class e, and 2) each feature f extracted by e. Underlying predicate maps
+  // will not be updated and so unrecognized predicates may occur. In such a
+  // case the SparseFeatures object associated with a given extractor class and
+  // feature will be empty.
+  vector<vector<SparseFeatures>> ExtractSparseFeatures(
+      const WorkspaceSet &workspaces, const OBJ &obj, ARGS... args) const {
+    vector<FeatureVector> features(feature_extractors_.size());
+    ExtractFeatures(workspaces, obj, args..., &features);
+    return ConvertExample(features);
+  }
+
+  // Extracts features using the extractors. Note that features must already
+  // be initialized to the correct number of feature extractors. No predicate
+  // mapping is applied.
+  void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
+                       ARGS... args,
+                       vector<FeatureVector> *features) const {
+    DCHECK(features != nullptr);
+    DCHECK_EQ(features->size(), feature_extractors_.size());
+    for (int i = 0; i < feature_extractors_.size(); ++i) {
+      (*features)[i].clear();
+      feature_extractors_[i].ExtractFeatures(workspaces, obj, args...,
+                                             &(*features)[i]);
+    }
+  }
+
+ protected:
+  // Provides generic access to the feature extractors.
+  const GenericFeatureExtractor &generic_feature_extractor(
+      int idx) const override {
+    DCHECK_LT(idx, feature_extractors_.size());
+    DCHECK_GE(idx, 0);
+    return feature_extractors_[idx];
+  }
+
+ private:
+  // Templated feature extractor class.
+  vector<EXTRACTOR> feature_extractors_;
+};
+
+class ParserEmbeddingFeatureExtractor
+    : public EmbeddingFeatureExtractor<ParserFeatureExtractor, ParserState> {
+ public:
+  explicit ParserEmbeddingFeatureExtractor(const string &arg_prefix)
+      : arg_prefix_(arg_prefix) {}
+
+ private:
+  const string ArgPrefix() const override { return arg_prefix_; }
+
+  // Prefix for context parameters.
+  string arg_prefix_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_EMBEDDING_FEATURE_EXTRACTOR_H_

+ 122 - 0
syntaxnet/syntaxnet/feature_extractor.cc

@@ -0,0 +1,122 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/feature_extractor.h"
+
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/fml_parser.h"
+
+namespace syntaxnet {
+
+constexpr FeatureValue GenericFeatureFunction::kNone;
+
+GenericFeatureExtractor::GenericFeatureExtractor() {}
+
+GenericFeatureExtractor::~GenericFeatureExtractor() {}
+
+void GenericFeatureExtractor::Parse(const string &source) {
+  // Parse feature specification into descriptor.
+  FMLParser parser;
+  parser.Parse(source, mutable_descriptor());
+
+  // Initialize feature extractor from descriptor.
+  InitializeFeatureFunctions();
+}
+
+void GenericFeatureExtractor::InitializeFeatureTypes() {
+  // Register all feature types.
+  GetFeatureTypes(&feature_types_);
+  for (size_t i = 0; i < feature_types_.size(); ++i) {
+    FeatureType *ft = feature_types_[i];
+    ft->set_base(i);
+
+    // Check for feature space overflow.
+    double domain_size = ft->GetDomainSize();
+    if (domain_size < 0) {
+      LOG(FATAL) << "Illegal domain size for feature " << ft->name()
+                 << domain_size;
+    }
+  }
+
+  vector<string> types_names;
+  GetFeatureTypeNames(&types_names);
+  CHECK_EQ(feature_types_.size(), types_names.size());
+}
+
+void GenericFeatureExtractor::GetFeatureTypeNames(
+    vector<string> *type_names) const {
+  for (size_t i = 0; i < feature_types_.size(); ++i) {
+    FeatureType *ft = feature_types_[i];
+    type_names->push_back(ft->name());
+  }
+}
+
+FeatureValue GenericFeatureExtractor::GetDomainSize() const {
+  // Domain size of the set of features is equal to:
+  //   [largest domain size of any feature types] * [number of feature types]
+  FeatureValue max_feature_type_dsize = 0;
+  for (size_t i = 0; i < feature_types_.size(); ++i) {
+    FeatureType *ft = feature_types_[i];
+    const FeatureValue feature_type_dsize = ft->GetDomainSize();
+    if (feature_type_dsize > max_feature_type_dsize) {
+      max_feature_type_dsize = feature_type_dsize;
+    }
+  }
+
+  return max_feature_type_dsize;
+}
+
+string GenericFeatureFunction::GetParameter(const string &name) const {
+  // Find named parameter in feature descriptor.
+  for (int i = 0; i < descriptor_->parameter_size(); ++i) {
+    if (name == descriptor_->parameter(i).name()) {
+      return descriptor_->parameter(i).value();
+    }
+  }
+  return "";
+}
+
+GenericFeatureFunction::GenericFeatureFunction() {}
+
+GenericFeatureFunction::~GenericFeatureFunction() {
+  delete feature_type_;
+}
+
+int GenericFeatureFunction::GetIntParameter(const string &name,
+                                            int default_value) const {
+  string value = GetParameter(name);
+  return utils::ParseUsing<int>(value, default_value,
+                                tensorflow::strings::safe_strto32);
+}
+
+void GenericFeatureFunction::GetFeatureTypes(
+    vector<FeatureType *> *types) const {
+  if (feature_type_ != nullptr) types->push_back(feature_type_);
+}
+
+FeatureType *GenericFeatureFunction::GetFeatureType() const {
+  // If a single feature type has been registered return it.
+  if (feature_type_ != nullptr) return feature_type_;
+
+  // Get feature types for function.
+  vector<FeatureType *> types;
+  GetFeatureTypes(&types);
+
+  // If there is exactly one feature type return this, else return null.
+  if (types.size() == 1) return types[0];
+  return nullptr;
+}
+
+}  // namespace syntaxnet

+ 624 - 0
syntaxnet/syntaxnet/feature_extractor.h

@@ -0,0 +1,624 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Generic feature extractor for extracting features from objects. The feature
+// extractor can be used for extracting features from any object. The feature
+// extractor and feature function classes are template classes that have to
+// be instantiated for extracting feature from a specific object type.
+//
+// A feature extractor consists of a hierarchy of feature functions. Each
+// feature function extracts one or more feature type and value pairs from the
+// object.
+//
+// The feature extractor has a modular design where new feature functions can be
+// registered as components. The feature extractor is initialized from a
+// descriptor represented by a protocol buffer. The feature extractor can also
+// be initialized from a text-based source specification of the feature
+// extractor. Feature specification parsers can be added as components. By
+// default the feature extractor can be read from an ASCII protocol buffer or in
+// a simple feature modeling language (fml).
+
+// A feature function is invoked with a focus. Nested feature function can be
+// invoked with another focus determined by the parent feature function.
+
+#ifndef $TARGETDIR_FEATURE_EXTRACTOR_H_
+#define $TARGETDIR_FEATURE_EXTRACTOR_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/feature_extractor.pb.h"
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/proto_io.h"
+#include "syntaxnet/registry.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/utils.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace syntaxnet {
+
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// Output feature model in FML format.
+void ToFMLFunction(const FeatureFunctionDescriptor &function, string *output);
+void ToFML(const FeatureFunctionDescriptor &function, string *output);
+
+// A feature vector contains feature type and value pairs.
+class FeatureVector {
+ public:
+  FeatureVector() {}
+
+  // Adds feature type and value pair to feature vector.
+  void add(FeatureType *type, FeatureValue value) {
+    features_.emplace_back(type, value);
+  }
+
+  // Removes all elements from the feature vector.
+  void clear() { features_.clear(); }
+
+  // Returns the number of elements in the feature vector.
+  int size() const { return features_.size(); }
+
+  // Reserves space in the underlying feature vector.
+  void reserve(int n) { features_.reserve(n); }
+
+  // Returns feature type for an element in the feature vector.
+  FeatureType *type(int index) const { return features_[index].type; }
+
+  // Returns feature value for an element in the feature vector.
+  FeatureValue value(int index) const { return features_[index].value; }
+
+ private:
+  // Structure for holding feature type and value pairs.
+  struct Element {
+    Element() : type(NULL), value(-1) {}
+    Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
+
+    FeatureType *type;
+    FeatureValue value;
+  };
+
+  // Array for storing feature vector elements.
+  vector<Element> features_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
+};
+
+// The generic feature extractor is the type-independent part of a feature
+// extractor. This holds the descriptor for the feature extractor and the
+// collection of feature types used in the feature extractor.  The feature
+// types are not available until FeatureExtractor<>::Init() has been called.
+class GenericFeatureExtractor {
+ public:
+  GenericFeatureExtractor();
+  virtual ~GenericFeatureExtractor();
+
+  // Initializes the feature extractor from a source representation of the
+  // feature extractor. The first line is used for determining the feature
+  // specification language. If the first line starts with #! followed by a name
+  // then this name is used for instantiating a feature specification parser
+  // with that name. If the language cannot be detected this way it falls back
+  // to using the default language supplied.
+  void Parse(const string &source);
+
+  // Returns the feature extractor descriptor.
+  const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
+  FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
+
+  // Returns the number of feature types in the feature extractor.  Invalid
+  // before Init() has been called.
+  int feature_types() const { return feature_types_.size(); }
+
+  // Returns all feature types names used by the extractor. The names are
+  // added to the types_names array.  Invalid before Init() has been called.
+  void GetFeatureTypeNames(vector<string> *type_names) const;
+
+  // Returns a feature type used in the extractor.  Invalid before Init() has
+  // been called.
+  const FeatureType *feature_type(int index) const {
+    return feature_types_[index];
+  }
+
+  // Returns the feature domain size of this feature extractor.
+  // NOTE: The way that domain size is calculated is, for some, unintuitive. It
+  // is the largest domain size of any feature type.
+  FeatureValue GetDomainSize() const;
+
+ protected:
+  // Initializes the feature types used by the extractor.  Called from
+  // FeatureExtractor<>::Init().
+  void InitializeFeatureTypes();
+
+ private:
+  // Initializes the top-level feature functions.
+  virtual void InitializeFeatureFunctions() = 0;
+
+  // Returns all feature types used by the extractor. The feature types are
+  // added to the result array.
+  virtual void GetFeatureTypes(vector<FeatureType *> *types) const = 0;
+
+  // Descriptor for the feature extractor. This is a protocol buffer that
+  // contains all the information about the feature extractor. The feature
+  // functions are initialized from the information in the descriptor.
+  FeatureExtractorDescriptor descriptor_;
+
+  // All feature types used by the feature extractor. The collection of all the
+  // feature types describes the feature space of the feature set produced by
+  // the feature extractor.  Not owned.
+  vector<FeatureType *> feature_types_;
+};
+
+// The generic feature function is the type-independent part of a feature
+// function. Each feature function is associated with the descriptor that it is
+// instantiated from.  The feature types associated with this feature function
+// will be established by the time FeatureExtractor<>::Init() completes.
+class GenericFeatureFunction {
+ public:
+  // A feature value that represents the absence of a value.
+  static constexpr FeatureValue kNone = -1;
+
+  GenericFeatureFunction();
+  virtual ~GenericFeatureFunction();
+
+  // Sets up the feature function. NB: FeatureTypes of nested functions are not
+  // guaranteed to be available until Init().
+  virtual void Setup(TaskContext *context) {}
+
+  // Initializes the feature function. NB: The FeatureType of this function must
+  // be established when this method completes.
+  virtual void Init(TaskContext *context) {}
+
+  // Requests workspaces from a registry to obtain indices into a WorkspaceSet
+  // for any Workspace objects used by this feature function. NB: This will be
+  // called after Init(), so it can depend on resources and arguments.
+  virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
+
+  // Appends the feature types produced by the feature function to types.  The
+  // default implementation appends feature_type(), if non-null.  Invalid
+  // before Init() has been called.
+  virtual void GetFeatureTypes(vector<FeatureType *> *types) const;
+
+  // Returns the feature type for feature produced by this feature function. If
+  // the feature function produces features of different types this returns
+  // null.  Invalid before Init() has been called.
+  virtual FeatureType *GetFeatureType() const;
+
+  // Returns the name of the registry used for creating the feature function.
+  // This can be used for checking if two feature functions are of the same
+  // kind.
+  virtual const char *RegistryName() const = 0;
+
+  // Returns the value of a named parameter in the feature functions descriptor.
+  // If the named parameter is not found the global parameters are searched.
+  string GetParameter(const string &name) const;
+  int GetIntParameter(const string &name, int default_value) const;
+
+  // Returns the FML function description for the feature function, i.e. the
+  // name and parameters without the nested features.
+  string FunctionName() const {
+    string output;
+    ToFMLFunction(*descriptor_, &output);
+    return output;
+  }
+
+  // Returns the prefix for nested feature functions. This is the prefix of this
+  // feature function concatenated with the feature function name.
+  string SubPrefix() const {
+    return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
+  }
+
+  // Returns/sets the feature extractor this function belongs to.
+  GenericFeatureExtractor *extractor() const { return extractor_; }
+  void set_extractor(GenericFeatureExtractor *extractor) {
+    extractor_ = extractor;
+  }
+
+  // Returns/sets the feature function descriptor.
+  FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
+  void set_descriptor(FeatureFunctionDescriptor *descriptor) {
+    descriptor_ = descriptor;
+  }
+
+  // Returns a descriptive name for the feature function. The name is taken from
+  // the descriptor for the feature function. If the name is empty or the
+  // feature function is a variable the name is the FML representation of the
+  // feature, including the prefix.
+  string name() const {
+    string output;
+    if (descriptor_->name().empty()) {
+      if (!prefix_.empty()) {
+        output.append(prefix_);
+        output.append(".");
+      }
+      ToFML(*descriptor_, &output);
+    } else {
+      output = descriptor_->name();
+    }
+    tensorflow::StringPiece stripped(output);
+    utils::RemoveWhitespaceContext(&stripped);
+    return stripped.ToString();
+  }
+
+  // Returns the argument from the feature function descriptor. It defaults to
+  // 0 if the argument has not been specified.
+  int argument() const {
+    return descriptor_->has_argument() ? descriptor_->argument() : 0;
+  }
+
+  // Returns/sets/clears function name prefix.
+  const string &prefix() const { return prefix_; }
+  void set_prefix(const string &prefix) { prefix_ = prefix; }
+
+ protected:
+  // Returns the feature type for single-type feature functions.
+  FeatureType *feature_type() const { return feature_type_; }
+
+  // Sets the feature type for single-type feature functions.  This takes
+  // ownership of feature_type.  Can only be called once.
+  void set_feature_type(FeatureType *feature_type) {
+    CHECK(feature_type_ == nullptr);
+    feature_type_ = feature_type;
+  }
+
+ private:
+  // Feature extractor this feature function belongs to.  Not owned.
+  GenericFeatureExtractor *extractor_ = nullptr;
+
+  // Descriptor for feature function.  Not owned.
+  FeatureFunctionDescriptor *descriptor_ = nullptr;
+
+  // Feature type for features produced by this feature function. If the
+  // feature function produces features of multiple feature types this is null
+  // and the feature function must return it's feature types in
+  // GetFeatureTypes().  Owned.
+  FeatureType *feature_type_ = nullptr;
+
+  // Prefix used for sub-feature types of this function.
+  string prefix_;
+};
+
+// Feature function that can extract features from an object.  Templated on
+// two type arguments:
+//
+// OBJ:  The "object" from which features are extracted; e.g., a sentence.  This
+//       should be a plain type, rather than a reference or pointer.
+//
+// ARGS: A set of 0 or more types that are used to "index" into some part of the
+//       object that should be extracted, e.g. an int token index for a sentence
+//       object.  This should not be a reference type.
+template<class OBJ, class ...ARGS>
+class FeatureFunction
+    : public GenericFeatureFunction,
+      public RegisterableClass< FeatureFunction<OBJ, ARGS...> > {
+ public:
+  using Self = FeatureFunction<OBJ, ARGS...>;
+
+  // Preprocesses the object.  This will be called prior to calling Evaluate()
+  // or Compute() on that object.
+  virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {}
+
+  // Appends features computed from the object and focus to the result.  The
+  // default implementation delegates to Compute(), adding a single value if
+  // available.  Multi-valued feature functions must override this method.
+  virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+                        ARGS... args, FeatureVector *result) const {
+    FeatureValue value = Compute(workspaces, object, args..., result);
+    if (value != kNone) result->add(feature_type(), value);
+  }
+
+  // Returns a feature value computed from the object and focus, or kNone if no
+  // value is computed.  Single-valued feature functions only need to override
+  // this method.
+  virtual FeatureValue Compute(const WorkspaceSet &workspaces,
+                               const OBJ &object,
+                               ARGS... args,
+                               const FeatureVector *fv) const {
+    return kNone;
+  }
+
+  // Instantiates a new feature function in a feature extractor from a feature
+  // descriptor.
+  static Self *Instantiate(GenericFeatureExtractor *extractor,
+                           FeatureFunctionDescriptor *fd,
+                           const string &prefix) {
+    Self *f = Self::Create(fd->type());
+    f->set_extractor(extractor);
+    f->set_descriptor(fd);
+    f->set_prefix(prefix);
+    return f;
+  }
+
+  // Returns the name of the registry for the feature function.
+  const char *RegistryName() const override {
+    return Self::registry()->name;
+  }
+
+ private:
+  // Special feature function class for resolving variable references. The type
+  // of the feature function is used for resolving the variable reference. When
+  // evaluated it will either get the feature value(s) from the variable portion
+  // of the feature vector, if present, or otherwise it will call the referenced
+  // feature extractor function directly to extract the feature(s).
+  class Reference;
+};
+
+// Base class for features with nested feature functions. The nested functions
+// are of type NES, which may be different from the type of the parent function.
+// NB: NestedFeatureFunction will ensure that all initialization of nested
+// functions takes place during Setup() and Init() -- after the nested features
+// are initialized, the parent feature is initialized via SetupNested() and
+// InitNested(). Alternatively, a derived classes that overrides Setup() and
+// Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
+//
+// Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
+// Compute, since the nested functions may be of a different type.
+template<class NES, class OBJ, class ...ARGS>
+class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
+ public:
+  using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
+
+  // Clean up nested functions.
+  ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
+
+  // By default, just appends the nested feature types.
+  void GetFeatureTypes(vector<FeatureType *> *types) const override {
+    CHECK(!this->nested().empty())
+        << "Nested features require nested features to be defined.";
+    for (auto *function : nested_) function->GetFeatureTypes(types);
+  }
+
+  // Sets up the nested features.
+  void Setup(TaskContext *context) override {
+    CreateNested(this->extractor(), this->descriptor(), &nested_,
+                 this->SubPrefix());
+    for (auto *function : nested_) function->Setup(context);
+    SetupNested(context);
+  }
+
+  // Sets up this NestedFeatureFunction specifically.
+  virtual void SetupNested(TaskContext *context) {}
+
+  // Initializes the nested features.
+  void Init(TaskContext *context) override {
+    for (auto *function : nested_) function->Init(context);
+    InitNested(context);
+  }
+
+  // Initializes this NestedFeatureFunction specifically.
+  virtual void InitNested(TaskContext *context) {}
+
+  // Gets all the workspaces needed for the nested functions.
+  void RequestWorkspaces(WorkspaceRegistry *registry) override {
+    for (auto *function : nested_) function->RequestWorkspaces(registry);
+  }
+
+  // Returns the list of nested feature functions.
+  const vector<NES *> &nested() const { return nested_; }
+
+  // Instantiates nested feature functions for a feature function. Creates and
+  // initializes one feature function for each sub-descriptor in the feature
+  // descriptor.
+  static void CreateNested(GenericFeatureExtractor *extractor,
+                           FeatureFunctionDescriptor *fd,
+                           vector<NES *> *functions,
+                           const string &prefix) {
+    for (int i = 0; i < fd->feature_size(); ++i) {
+      FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
+      NES *f = NES::Instantiate(extractor, sub, prefix);
+      functions->push_back(f);
+    }
+  }
+
+ protected:
+  // The nested feature functions, if any, in order of declaration in the
+  // feature descriptor.  Owned.
+  vector<NES *> nested_;
+};
+
+// Base class for a nested feature function that takes nested features with the
+// same signature as these features, i.e. a meta feature. For this class, we can
+// provide preprocessing of the nested features.
+template<class OBJ, class ...ARGS>
+class MetaFeatureFunction : public NestedFeatureFunction<
+  FeatureFunction<OBJ, ARGS...>, OBJ, ARGS...> {
+ public:
+  // Preprocesses using the nested features.
+  void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
+    for (auto *function : this->nested_) {
+      function->Preprocess(workspaces, object);
+    }
+  }
+};
+
+// Template for a special type of locator: The locator of type
+// FeatureFunction<OBJ, ARGS...> calls nested functions of type
+// FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
+// responsible for translating by providing the following:
+//
+// // Gets the new additional focus.
+// IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
+//
+// This is useful to e.g. add a token focus to a parser state based on some
+// desired property of that state.
+template<class DER, class OBJ, class IDX, class ...ARGS>
+class FeatureAddFocusLocator : public NestedFeatureFunction<
+  FeatureFunction<OBJ, IDX, ARGS...>, OBJ, ARGS...> {
+ public:
+  void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
+    for (auto *function : this->nested_) {
+      function->Preprocess(workspaces, object);
+    }
+  }
+
+  void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+                ARGS... args, FeatureVector *result) const override {
+    IDX focus = static_cast<const DER *>(this)->GetFocus(
+        workspaces, object, args...);
+    for (auto *function : this->nested()) {
+      function->Evaluate(workspaces, object, focus, args..., result);
+    }
+  }
+
+  // Returns the first nested feature's computed value.
+  FeatureValue Compute(const WorkspaceSet &workspaces,
+                       const OBJ &object,
+                       ARGS... args,
+                       const FeatureVector *result) const override {
+    IDX focus = static_cast<const DER *>(this)->GetFocus(
+        workspaces, object, args...);
+    return this->nested()[0]->Compute(
+        workspaces, object, focus, args..., result);
+  }
+};
+
+// CRTP feature locator class. This is a meta feature that modifies ARGS and
+// then calls the nested feature functions with the modified ARGS. Note that in
+// order for this template to work correctly, all of ARGS must be types for
+// which the reference operator & can be interpreted as a pointer to the
+// argument. The derived class DER must implement the UpdateFocus method which
+// takes pointers to the ARGS arguments:
+//
+// // Updates the current arguments.
+// void UpdateArgs(const OBJ &object, ARGS *...args) const;
+template<class DER, class OBJ, class ...ARGS>
+class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
+ public:
+  // Feature locators have an additional check that there is no intrinsic type.
+  void GetFeatureTypes(vector<FeatureType *> *types) const override {
+    CHECK(this->feature_type() == nullptr)
+        << "FeatureLocators should not have an intrinsic type.";
+    MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
+  }
+
+  // Evaluates the locator.
+  void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
+                ARGS... args, FeatureVector *result) const override {
+    static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+    for (auto *function : this->nested()) {
+      function->Evaluate(workspaces, object, args..., result);
+    }
+  }
+
+  // Returns the first nested feature's computed value.
+  FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
+                       ARGS... args,
+                       const FeatureVector *result) const override {
+    static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
+    return this->nested()[0]->Compute(workspaces, object, args..., result);
+  }
+};
+
+// Feature extractor for extracting features from objects of a certain class.
+// Template type parameters are as defined for FeatureFunction.
+template<class OBJ, class ...ARGS>
+class FeatureExtractor : public GenericFeatureExtractor {
+ public:
+  // Feature function type for top-level functions in the feature extractor.
+  typedef FeatureFunction<OBJ, ARGS...> Function;
+  typedef FeatureExtractor<OBJ, ARGS...> Self;
+
+  // Feature locator type for the feature extractor.
+  template<class DER>
+  using Locator = FeatureLocator<DER, OBJ, ARGS...>;
+
+  // Initializes feature extractor.
+  FeatureExtractor() {}
+
+  ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
+
+  // Sets up the feature extractor. Note that only top-level functions exist
+  // until Setup() is called. This does not take ownership over the context,
+  // which must outlive this.
+  void Setup(TaskContext *context) {
+    for (Function *function : functions_) function->Setup(context);
+  }
+
+  // Initializes the feature extractor.  Must be called after Setup().  This
+  // does not take ownership over the context, which must outlive this.
+  void Init(TaskContext *context) {
+    for (Function *function : functions_) function->Init(context);
+    this->InitializeFeatureTypes();
+  }
+
+  // Requests workspaces from the registry. Must be called after Init(), and
+  // before Preprocess(). Does not take ownership over registry. This should be
+  // the same registry used to initialize the WorkspaceSet used in Preprocess()
+  // and ExtractFeatures(). NB: This is a different ordering from that used in
+  // SentenceFeatureRepresentation style feature computation.
+  void RequestWorkspaces(WorkspaceRegistry *registry) {
+    for (auto *function : functions_) function->RequestWorkspaces(registry);
+  }
+
+  // Preprocesses the object using feature functions for the phase.  Must be
+  // called before any calls to ExtractFeatures() on that object and phase.
+  void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {
+    for (Function *function : functions_) {
+      function->Preprocess(workspaces, object);
+    }
+  }
+
+  // Extracts features from an object with a focus. This invokes all the
+  // top-level feature functions in the feature extractor. Only feature
+  // functions belonging to the specified phase are invoked.
+  void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
+                       ARGS... args, FeatureVector *result) const {
+    result->reserve(this->feature_types());
+
+    // Extract features.
+    for (int i = 0; i < functions_.size(); ++i) {
+      functions_[i]->Evaluate(workspaces, object, args..., result);
+    }
+  }
+
+ private:
+  // Creates and initializes all feature functions in the feature extractor.
+  void InitializeFeatureFunctions() override {
+    // Create all top-level feature functions.
+    for (int i = 0; i < descriptor().feature_size(); ++i) {
+      FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i);
+      Function *function = Function::Instantiate(this, fd, "");
+      functions_.push_back(function);
+    }
+  }
+
+  // Collect all feature types used in the feature extractor.
+  void GetFeatureTypes(vector<FeatureType *> *types) const override {
+    for (int i = 0; i < functions_.size(); ++i) {
+      functions_[i]->GetFeatureTypes(types);
+    }
+  }
+
+  // Top-level feature functions (and variables) in the feature extractor.
+  // Owned.
+  vector<Function *> functions_;
+};
+
+#define REGISTER_FEATURE_FUNCTION(base, name, component) \
+  REGISTER_CLASS_COMPONENT(base, name, component)
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_FEATURE_EXTRACTOR_H_

+ 34 - 0
syntaxnet/syntaxnet/feature_extractor.proto

@@ -0,0 +1,34 @@
+// Protocol buffers for feature extractor.
+
+syntax = "proto2";
+
+package syntaxnet;
+
+message Parameter {
+  optional string name = 1;
+  optional string value = 2;
+}
+
+// Descriptor for feature function.
+message FeatureFunctionDescriptor {
+  // Feature function type.
+  required string type = 1;
+
+  // Feature function name.
+  optional string name = 2;
+
+  // Default argument for feature function.
+  optional int32 argument = 3 [default = 0];
+
+  // Named parameters for feature descriptor.
+  repeated Parameter parameter = 4;
+
+  // Nested sub-feature function descriptors.
+  repeated FeatureFunctionDescriptor feature = 7;
+};
+
+// Descriptor for feature extractor.
+message FeatureExtractorDescriptor {
+  // Top-level feature function for extractor.
+  repeated FeatureFunctionDescriptor feature = 1;
+};

+ 176 - 0
syntaxnet/syntaxnet/feature_types.h

@@ -0,0 +1,176 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Common feature types for parser components.
+
+#ifndef $TARGETDIR_FEATURE_TYPES_H_
+#define $TARGETDIR_FEATURE_TYPES_H_
+
+#include <algorithm>
+#include <map>
+#include <string>
+#include <utility>
+
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+// Use the same type for feature values as is used for predicated.
+typedef int64 Predicate;
+typedef Predicate FeatureValue;
+
+// Each feature value in a feature vector has a feature type. The feature type
+// is used for converting feature type and value pairs to predicate values. The
+// feature type can also return names for feature values and calculate the size
+// of the feature value domain. The FeatureType class is abstract and must be
+// specialized for the concrete feature types.
+class FeatureType {
+ public:
+  // Initializes a feature type.
+  explicit FeatureType(const string &name)
+      : name_(name), base_(0) {}
+
+  virtual ~FeatureType() {}
+
+  // Converts a feature value to a name.
+  virtual string GetFeatureValueName(FeatureValue value) const = 0;
+
+  // Returns the size of the feature values domain.
+  virtual int64 GetDomainSize() const = 0;
+
+  // Returns the feature type name.
+  const string &name() const { return name_; }
+
+  Predicate base() const { return base_; }
+  void set_base(Predicate base) { base_ = base; }
+
+ private:
+  // Feature type name.
+  string name_;
+
+  // "Base" feature value: i.e. a "slot" in a global ordering of features.
+  Predicate base_;
+};
+
+// Templated generic resource based feature type. This feature type delegates
+// look up of feature value names to an unknown resource class, which is not
+// owned. Optionally, this type can also store a mapping of extra values which
+// are not in the resource.
+//
+// Note: this class assumes that Resource->GetFeatureValueName() will return
+// successfully for values ONLY in the range [0, Resource->NumValues()) Any
+// feature value not in the extra value map and not in the above range of
+// Resource will result in a ERROR and return of "<INVALID>".
+template<class Resource>
+class ResourceBasedFeatureType : public FeatureType {
+ public:
+  // Creates a new type with given name, resource object, and a mapping of
+  // special values. The values must be greater or equal to
+  // resource->NumValues() so as to avoid collisions; this is verified with
+  // CHECK at creation.
+  ResourceBasedFeatureType(const string &name, const Resource *resource,
+                           const map<FeatureValue, string> &values)
+      : FeatureType(name), resource_(resource), values_(values) {
+    max_value_ = resource->NumValues() - 1;
+    for (const auto &pair : values) {
+      CHECK_GE(pair.first, resource->NumValues()) << "Invalid extra value: "
+               << pair.first << "," << pair.second;
+      max_value_ = pair.first > max_value_ ? pair.first : max_value_;
+    }
+  }
+
+  // Creates a new type with no special values.
+  ResourceBasedFeatureType(const string &name, const Resource *resource)
+      : ResourceBasedFeatureType(name, resource, {}) {}
+
+  // Returns the feature name for a given feature value. First checks the values
+  // map, then checks the resource to look up the name.
+  string GetFeatureValueName(FeatureValue value) const override {
+    if (values_.find(value) != values_.end()) {
+      return values_.find(value)->second;
+    }
+    if (value >= 0 && value < resource_->NumValues()) {
+      return resource_->GetFeatureValueName(value);
+    } else {
+      LOG(ERROR) << "Invalid feature value " << value << " for " << name();
+      return "<INVALID>";
+    }
+  }
+
+  // Returns the number of possible values for this feature type. This is the
+  // based on the largest value that was observed in the extra values.
+  FeatureValue GetDomainSize() const override { return max_value_ + 1; }
+
+ protected:
+  // Shared resource. Not owned.
+  const Resource *resource_ = nullptr;
+
+  // Maximum possible value this feature could take.
+  FeatureValue max_value_;
+
+  // Mapping for extra feature values not in the resource.
+  map<FeatureValue, string> values_;
+};
+
+// Feature type that is defined using an explicit map from FeatureValue to
+// string values.  This can reduce some of the boilerplate when defining
+// features that generate enum values.  Example usage:
+//
+//   class BeverageSizeFeature : public FeatureFunction<Beverage>
+//     enum FeatureValue { SMALL, MEDIUM, LARGE };  // values for this feature
+//     void Init(TaskContext *context) override {
+//       set_feature_type(new EnumFeatureType("beverage_size",
+//           {{SMALL, "SMALL"}, {MEDIUM, "MEDIUM"}, {LARGE, "LARGE"}});
+//     }
+//     [...]
+//   };
+class EnumFeatureType : public FeatureType {
+ public:
+  EnumFeatureType(const string &name,
+                  const map<FeatureValue, string> &value_names)
+      : FeatureType(name), value_names_(value_names) {
+    for (const auto &pair : value_names) {
+      CHECK_GE(pair.first, 0)
+          << "Invalid feature value: " << pair.first << ", " << pair.second;
+      domain_size_ = std::max(domain_size_, pair.first + 1);
+    }
+  }
+
+  // Returns the feature name for a given feature value.
+  string GetFeatureValueName(FeatureValue value) const override {
+    auto it = value_names_.find(value);
+    if (it == value_names_.end()) {
+      LOG(ERROR)
+          << "Invalid feature value " << value << " for " << name();
+      return "<INVALID>";
+    }
+    return it->second;
+  }
+
+  // Returns the number of possible values for this feature type. This is one
+  // greater than the largest value in the value_names map.
+  FeatureValue GetDomainSize() const override { return domain_size_; }
+
+ protected:
+  // Maximum possible value this feature could take.
+  FeatureValue domain_size_ = 0;
+
+  // Names of feature values.
+  map<FeatureValue, string> value_names_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_FEATURE_TYPES_H_

+ 291 - 0
syntaxnet/syntaxnet/fml_parser.cc

@@ -0,0 +1,291 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/fml_parser.h"
+
+#include <ctype.h>
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+void FMLParser::Initialize(const string &source) {
+  // Initialize parser state.
+  source_ = source;
+  current_ = source_.begin();
+  item_start_ = line_start_ = current_;
+  line_number_ = item_line_number_ = 1;
+
+  // Read first input item.
+  NextItem();
+}
+
+void FMLParser::Error(const string &error_message) {
+  LOG(FATAL) << "Error in feature model, line " << item_line_number_
+             << ", position " << (item_start_ - line_start_ + 1)
+             << ": " << error_message
+             << "\n    " << string(line_start_, current_) << " <--HERE";
+}
+
+void FMLParser::Next() {
+  // Move to the next input character. If we are at a line break update line
+  // number and line start position.
+  if (*current_ == '\n') {
+    ++line_number_;
+    ++current_;
+    line_start_ = current_;
+  } else {
+    ++current_;
+  }
+}
+
+void FMLParser::NextItem() {
+  // Skip white space and comments.
+  while (!eos()) {
+    if (*current_ == '#') {
+      // Skip comment.
+      while (!eos() && *current_ != '\n') Next();
+    } else if (isspace(*current_)) {
+      // Skip whitespace.
+      while (!eos() && isspace(*current_)) Next();
+    } else {
+      break;
+    }
+  }
+
+  // Record start position for next item.
+  item_start_ = current_;
+  item_line_number_ = line_number_;
+
+  // Check for end of input.
+  if (eos()) {
+    item_type_ = END;
+    return;
+  }
+
+  // Parse number.
+  if (isdigit(*current_) || *current_ == '+' || *current_ == '-') {
+    string::iterator start = current_;
+    Next();
+    while (isdigit(*current_) || *current_ == '.') Next();
+    item_text_.assign(start, current_);
+    item_type_ = NUMBER;
+    return;
+  }
+
+  // Parse string.
+  if (*current_ == '"') {
+    Next();
+    string::iterator start = current_;
+    while (*current_ != '"') {
+      if (eos()) Error("Unterminated string");
+      Next();
+    }
+    item_text_.assign(start, current_);
+    item_type_ = STRING;
+    Next();
+    return;
+  }
+
+  // Parse identifier name.
+  if (isalpha(*current_) || *current_ == '_' || *current_ == '/') {
+    string::iterator start = current_;
+    while (isalnum(*current_) || *current_ == '_' || *current_ == '-' ||
+           *current_ == '/') Next();
+    item_text_.assign(start, current_);
+    item_type_ = NAME;
+    return;
+  }
+
+  // Single character item.
+  item_type_ = *current_;
+  Next();
+}
+
+void FMLParser::Parse(const string &source,
+                      FeatureExtractorDescriptor *result) {
+  // Initialize parser.
+  Initialize(source);
+
+  while (item_type_ != END) {
+    // Parse either a parameter name or a feature.
+    if (item_type_ != NAME) Error("Feature type name expected");
+    string name = item_text_;
+    NextItem();
+
+    if (item_type_ == '=') {
+      Error("Invalid syntax: feature expected");
+    } else {
+      // Parse feature.
+      FeatureFunctionDescriptor *descriptor = result->add_feature();
+      descriptor->set_type(name);
+      ParseFeature(descriptor);
+    }
+  }
+}
+
+void FMLParser::ParseFeature(FeatureFunctionDescriptor *result) {
+  // Parse argument and parameters.
+  if (item_type_ == '(') {
+    NextItem();
+    ParseParameter(result);
+    while (item_type_ == ',') {
+      NextItem();
+      ParseParameter(result);
+    }
+
+    if (item_type_ != ')') Error(") expected");
+    NextItem();
+  }
+
+  // Parse feature name.
+  if (item_type_ == ':') {
+    NextItem();
+    if (item_type_ != NAME && item_type_ != STRING) {
+      Error("Feature name expected");
+    }
+    string name = item_text_;
+    NextItem();
+
+    // Set feature name.
+    result->set_name(name);
+  }
+
+  // Parse sub-features.
+  if (item_type_ == '.') {
+    // Parse dotted sub-feature.
+    NextItem();
+    if (item_type_ != NAME) Error("Feature type name expected");
+    string type = item_text_;
+    NextItem();
+
+    // Parse sub-feature.
+    FeatureFunctionDescriptor *subfeature = result->add_feature();
+    subfeature->set_type(type);
+    ParseFeature(subfeature);
+  } else if (item_type_ == '{') {
+    // Parse sub-feature block.
+    NextItem();
+    while (item_type_ != '}') {
+      if (item_type_ != NAME) Error("Feature type name expected");
+      string type = item_text_;
+      NextItem();
+
+      // Parse sub-feature.
+      FeatureFunctionDescriptor *subfeature = result->add_feature();
+      subfeature->set_type(type);
+      ParseFeature(subfeature);
+    }
+    NextItem();
+  }
+}
+
+void FMLParser::ParseParameter(FeatureFunctionDescriptor *result) {
+  if (item_type_ == NUMBER) {
+    int argument =
+        utils::ParseUsing<int>(item_text_, tensorflow::strings::safe_strto32);
+    NextItem();
+
+    // Set default argument for feature.
+    result->set_argument(argument);
+  } else if (item_type_ == NAME) {
+     string name = item_text_;
+     NextItem();
+     if (item_type_ != '=') Error("= expected");
+     NextItem();
+     if (item_type_ >= END) Error("Parameter value expected");
+     string value = item_text_;
+     NextItem();
+
+     // Add parameter to feature.
+     Parameter *parameter;
+     parameter = result->add_parameter();
+     parameter->set_name(name);
+     parameter->set_value(value);
+  } else {
+    Error("Syntax error in parameter list");
+  }
+}
+
+void ToFMLFunction(const FeatureFunctionDescriptor &function, string *output) {
+  output->append(function.type());
+  if (function.argument() != 0 || function.parameter_size() > 0) {
+    output->append("(");
+    bool first = true;
+    if (function.argument() != 0) {
+      tensorflow::strings::StrAppend(output, function.argument());
+      first = false;
+    }
+    for (int i = 0; i < function.parameter_size(); ++i) {
+      if (!first) output->append(",");
+      output->append(function.parameter(i).name());
+      output->append("=");
+      output->append("\"");
+      output->append(function.parameter(i).value());
+      output->append("\"");
+      first = false;
+    }
+    output->append(")");
+  }
+}
+
+void ToFML(const FeatureFunctionDescriptor &function, string *output) {
+  ToFMLFunction(function, output);
+  if (function.feature_size() == 1) {
+    output->append(".");
+    ToFML(function.feature(0), output);
+  } else if (function.feature_size() > 1) {
+    output->append(" { ");
+    for (int i = 0; i < function.feature_size(); ++i) {
+      if (i > 0) output->append(" ");
+      ToFML(function.feature(i), output);
+    }
+    output->append(" } ");
+  }
+}
+
+void ToFML(const FeatureExtractorDescriptor &extractor, string *output) {
+  for (int i = 0; i < extractor.feature_size(); ++i) {
+    ToFML(extractor.feature(i), output);
+    output->append("\n");
+  }
+}
+
+string AsFML(const FeatureFunctionDescriptor &function) {
+  string str;
+  ToFML(function, &str);
+  return str;
+}
+
+string AsFML(const FeatureExtractorDescriptor &extractor) {
+  string str;
+  ToFML(extractor, &str);
+  return str;
+}
+
+void StripFML(string *fml_string) {
+  auto it = fml_string->begin();
+  while (it != fml_string->end()) {
+    if (*it == '"') {
+      it = fml_string->erase(it);
+    } else {
+      ++it;
+    }
+  }
+}
+
+}  // namespace syntaxnet

+ 113 - 0
syntaxnet/syntaxnet/fml_parser.h

@@ -0,0 +1,113 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Feature modeling language (fml) parser.
+//
+// BNF grammar for fml:
+//
+// <feature model> ::= { <feature extractor> }
+//
+// <feature extractor> ::= <extractor spec> |
+//                         <extractor spec> '.' <feature extractor> |
+//                         <extractor spec> '{' { <feature extractor> } '}'
+//
+// <extractor spec> ::= <extractor type>
+//                      [ '(' <parameter list> ')' ]
+//                      [ ':' <extractor name> ]
+//
+// <parameter list> = ( <parameter> | <argument> ) { ',' <parameter> }
+//
+// <parameter> ::= <parameter name> '=' <parameter value>
+//
+// <extractor type> ::= NAME
+// <extractor name> ::= NAME | STRING
+// <argument> ::= NUMBER
+// <parameter name> ::= NAME
+// <parameter value> ::= NUMBER | STRING | NAME
+
+#ifndef $TARGETDIR_FML_PARSER_H_
+#define $TARGETDIR_FML_PARSER_H_
+
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/feature_extractor.pb.h"
+
+namespace syntaxnet {
+
+class FMLParser {
+ public:
+  // Parses fml specification into feature extractor descriptor.
+  void Parse(const string &source, FeatureExtractorDescriptor *result);
+
+ private:
+  // Initializes the parser with the source text.
+  void Initialize(const string &source);
+
+  // Outputs error message and exits.
+  void Error(const string &error_message);
+
+  // Moves to the next input character.
+  void Next();
+
+  // Moves to the next input item.
+  void NextItem();
+
+  // Parses a feature descriptor.
+  void ParseFeature(FeatureFunctionDescriptor *result);
+
+  // Parses a parameter specification.
+  void ParseParameter(FeatureFunctionDescriptor *result);
+
+  // Returns true if end of source input has been reached.
+  bool eos() { return current_ == source_.end(); }
+
+  // Item types.
+  enum ItemTypes {
+    END = 0,
+    NAME = -1,
+    NUMBER = -2,
+    STRING = -3,
+  };
+
+  // Source text.
+  string source_;
+
+  // Current input position.
+  string::iterator current_;
+
+  // Line number for current input position.
+  int line_number_;
+
+  // Start position for current item.
+  string::iterator item_start_;
+
+  // Start position for current line.
+  string::iterator line_start_;
+
+  // Line number for current item.
+  int item_line_number_;
+
+  // Item type for current item. If this is positive it is interpreted as a
+  // character. If it is negative it is interpreted as an item type.
+  int item_type_;
+
+  // Text for current item.
+  string item_text_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_FML_PARSER_H_

+ 569 - 0
syntaxnet/syntaxnet/graph_builder.py

@@ -0,0 +1,569 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Builds parser models."""
+
+import tensorflow as tf
+
+import syntaxnet.load_parser_ops
+
+from tensorflow.python.ops import control_flow_ops as cf
+from tensorflow.python.ops import state_ops
+from tensorflow.python.platform import logging
+
+from syntaxnet.ops import gen_parser_ops
+
+
+def BatchedSparseToDense(sparse_indices, output_size):
+  """Batch compatible sparse to dense conversion.
+
+  This is useful for one-hot coded target labels.
+
+  Args:
+    sparse_indices: [batch_size] tensor containing one index per batch
+    output_size: needed in order to generate the correct dense output
+
+  Returns:
+    A [batch_size, output_size] dense tensor.
+  """
+  eye = tf.diag(tf.fill([output_size], tf.constant(1, tf.float32)))
+  return tf.nn.embedding_lookup(eye, sparse_indices)
+
+
+def EmbeddingLookupFeatures(params, sparse_features, allow_weights):
+  """Computes embeddings for each entry of sparse features sparse_features.
+
+  Args:
+    params: list of 2D tensors containing vector embeddings
+    sparse_features: 1D tensor of strings. Each entry is a string encoding of
+      dist_belief.SparseFeatures, and represents a variable length list of
+      feature ids, and optionally, corresponding weights values.
+    allow_weights: boolean to control whether the weights returned from the
+      SparseFeatures are used to multiply the embeddings.
+
+  Returns:
+    A tensor representing the combined embeddings for the sparse features.
+    For each entry s in sparse_features, the function looks up the embeddings
+    for each id and sums them into a single tensor weighing them by the
+    weight of each id. It returns a tensor with each entry of sparse_features
+    replaced by this combined embedding.
+  """
+  if not isinstance(params, list):
+    params = [params]
+  # Lookup embeddings.
+  sparse_features = tf.convert_to_tensor(sparse_features)
+  indices, ids, weights = gen_parser_ops.unpack_sparse_features(sparse_features)
+  embeddings = tf.nn.embedding_lookup(params, ids)
+
+  if allow_weights:
+    # Multiply by weights, reshaping to allow broadcast.
+    broadcast_weights_shape = tf.concat(0, [tf.shape(weights), [1]])
+    embeddings *= tf.reshape(weights, broadcast_weights_shape)
+
+  # Sum embeddings by index.
+  return tf.unsorted_segment_sum(embeddings, indices, tf.size(sparse_features))
+
+
+class GreedyParser(object):
+  """Builds a Chen & Manning style greedy neural net parser.
+
+  Builds a graph with an optional reader op connected at one end and
+  operations needed to train the network on the other. Supports multiple
+  network instantiations sharing the same parameters and network topology.
+
+  The following named nodes are added to the training and eval networks:
+    epochs: a tensor containing the current epoch number
+    cost: a tensor containing the current training step cost
+    gold_actions: a tensor containing actions from gold decoding
+    feature_endpoints: a list of sparse feature vectors
+    logits: output of the final layer before computing softmax
+  The training network also contains:
+    train_op: an op that executes a single training step
+
+  Typical usage:
+
+  parser = graph_builder.GreedyParser(num_actions, num_features,
+                                      num_feature_ids, embedding_sizes,
+                                      hidden_layer_sizes)
+  parser.AddTraining(task_context, batch_size=5)
+  with tf.Session('local') as sess:
+    # This works because the session uses the same default graph as the
+    # GraphBuilder did.
+    sess.run(parser.inits.values())
+    while True:
+      tf_epoch, _ = sess.run([parser.training['epoch'],
+                              parser.training['train_op']])
+      if tf_epoch[0] > 0:
+        break
+  """
+
+  def __init__(self,
+               num_actions,
+               num_features,
+               num_feature_ids,
+               embedding_sizes,
+               hidden_layer_sizes,
+               seed=None,
+               gate_gradients=False,
+               use_locking=False,
+               embedding_init=1.0,
+               relu_init=1e-4,
+               bias_init=0.2,
+               softmax_init=1e-4,
+               averaging_decay=0.9999,
+               use_averaging=True,
+               check_parameters=True,
+               check_every=1,
+               allow_feature_weights=False,
+               only_train='',
+               arg_prefix=None,
+               **unused_kwargs):
+    """Initialize the graph builder with parameters defining the network.
+
+    Args:
+      num_actions: int size of the set of parser actions
+      num_features: int list of dimensions of the feature vectors
+      num_feature_ids: int list of same length as num_features corresponding to
+        the sizes of the input feature spaces
+      embedding_sizes: int list of same length as num_features of the desired
+        embedding layer sizes
+      hidden_layer_sizes: int list of desired relu layer sizes; may be empty
+      seed: optional random initializer seed to enable reproducibility
+      gate_gradients: if True, gradient updates are computed synchronously,
+        ensuring consistency and reproducibility
+      use_locking: if True, use locking to avoid read-write contention when
+        updating Variables
+      embedding_init: sets the std dev of normal initializer of embeddings to
+        embedding_init / embedding_size ** .5
+      relu_init: sets the std dev of normal initializer of relu weights
+        to relu_init
+      bias_init: sets constant initializer of relu bias to bias_init
+      softmax_init: sets the std dev of normal initializer of softmax init
+        to softmax_init
+      averaging_decay: decay for exponential moving average when computing
+        averaged parameters, set to 1 to do vanilla averaging
+      use_averaging: whether to use moving averages of parameters during evals
+      check_parameters: whether to check for NaN/Inf parameters during
+        training
+      check_every: checks numerics every check_every steps.
+      allow_feature_weights: whether feature weights are allowed.
+      only_train: the comma separated set of parameter names to train. If empty,
+        all model parameters will be trained.
+      arg_prefix: prefix for context parameters.
+    """
+    self._num_actions = num_actions
+    self._num_features = num_features
+    self._num_feature_ids = num_feature_ids
+    self._embedding_sizes = embedding_sizes
+    self._hidden_layer_sizes = hidden_layer_sizes
+    self._seed = seed
+    self._gate_gradients = gate_gradients
+    self._use_locking = use_locking
+    self._use_averaging = use_averaging
+    self._check_parameters = check_parameters
+    self._check_every = check_every
+    self._allow_feature_weights = allow_feature_weights
+    self._only_train = set(only_train.split(',')) if only_train else None
+    self._feature_size = len(embedding_sizes)
+    self._embedding_init = embedding_init
+    self._relu_init = relu_init
+    self._softmax_init = softmax_init
+    self._arg_prefix = arg_prefix
+    # Parameters of the network with respect to which training is done.
+    self.params = {}
+    # Other variables, with respect to which no training is done, but which we
+    # nonetheless need to save in order to capture the state of the graph.
+    self.variables = {}
+    # Operations to initialize any nodes that require initialization.
+    self.inits = {}
+    # Training- and eval-related nodes.
+    self.training = {}
+    self.evaluation = {}
+    self.saver = None
+    # Nodes to compute moving averages of parameters, called every train step.
+    self._averaging = {}
+    self._averaging_decay = averaging_decay
+    # Pretrained embeddings that can be used instead of constant initializers.
+    self._pretrained_embeddings = {}
+    # After the following 'with' statement, we'll be able to re-enter the
+    # 'params' scope by re-using the self._param_scope member variable. See for
+    # instance _AddParam.
+    with tf.name_scope('params') as self._param_scope:
+      self._relu_bias_init = tf.constant_initializer(bias_init)
+
+  @property
+  def embedding_size(self):
+    size = 0
+    for i in range(self._feature_size):
+      size += self._num_features[i] * self._embedding_sizes[i]
+    return size
+
+  def _AddParam(self,
+                shape,
+                dtype,
+                name,
+                initializer=None,
+                return_average=False):
+    """Add a model parameter w.r.t. we expect to compute gradients.
+
+    _AddParam creates both regular parameters (usually for training) and
+    averaged nodes (usually for inference). It returns one or the other based
+    on the 'return_average' arg.
+
+    Args:
+      shape: int list, tensor shape of the parameter to create
+      dtype: tf.DataType, data type of the parameter
+      name: string, name of the parameter in the TF graph
+      initializer: optional initializer for the paramter
+      return_average: if False, return parameter otherwise return moving average
+
+    Returns:
+      parameter or averaged parameter
+    """
+    if name not in self.params:
+      step = tf.cast(self.GetStep(), tf.float32)
+      # Put all parameters and their initializing ops in their own scope
+      # irrespective of the current scope (training or eval).
+      with tf.name_scope(self._param_scope):
+        self.params[name] = tf.get_variable(name, shape, dtype, initializer)
+        param = self.params[name]
+        if initializer is not None:
+          self.inits[name] = state_ops.init_variable(param, initializer)
+        if self._averaging_decay == 1:
+          logging.info('Using vanilla averaging of parameters.')
+          ema = tf.train.ExponentialMovingAverage(decay=(step / (step + 1.0)),
+                                                  num_updates=None)
+        else:
+          ema = tf.train.ExponentialMovingAverage(decay=self._averaging_decay,
+                                                  num_updates=step)
+        self._averaging[name + '_avg_update'] = ema.apply([param])
+        self.variables[name + '_avg_var'] = ema.average(param)
+        self.inits[name + '_avg_init'] = state_ops.init_variable(
+            ema.average(param), tf.zeros_initializer)
+    return (self.variables[name + '_avg_var'] if return_average else
+            self.params[name])
+
+  def GetStep(self):
+    def OnesInitializer(shape, dtype=tf.float32):
+      return tf.ones(shape, dtype)
+    return self._AddVariable([], tf.int32, 'step', OnesInitializer)
+
+  def _AddVariable(self, shape, dtype, name, initializer=None):
+    if name in self.variables:
+      return self.variables[name]
+    self.variables[name] = tf.get_variable(name, shape, dtype, initializer)
+    if initializer is not None:
+      self.inits[name] = state_ops.init_variable(self.variables[name],
+                                                 initializer)
+    return self.variables[name]
+
+  def _ReluWeightInitializer(self):
+    with tf.name_scope(self._param_scope):
+      return tf.random_normal_initializer(stddev=self._relu_init,
+                                          seed=self._seed)
+
+  def _EmbeddingMatrixInitializer(self, index, embedding_size):
+    if index in self._pretrained_embeddings:
+      return self._pretrained_embeddings[index]
+    else:
+      return tf.random_normal_initializer(
+          stddev=self._embedding_init / embedding_size**.5,
+          seed=self._seed)
+
+  def _AddEmbedding(self,
+                    features,
+                    num_features,
+                    num_ids,
+                    embedding_size,
+                    index,
+                    return_average=False):
+    """Adds an embedding matrix and passes the `features` vector through it."""
+    embedding_matrix = self._AddParam(
+        [num_ids, embedding_size],
+        tf.float32,
+        'embedding_matrix_%d' % index,
+        self._EmbeddingMatrixInitializer(index, embedding_size),
+        return_average=return_average)
+    embedding = EmbeddingLookupFeatures(embedding_matrix,
+                                        tf.reshape(features,
+                                                   [-1],
+                                                   name='feature_%d' % index),
+                                        self._allow_feature_weights)
+    return tf.reshape(embedding, [-1, num_features * embedding_size])
+
+  def _BuildNetwork(self, feature_endpoints, return_average=False):
+    """Builds a feed-forward part of the net given features as input.
+
+    The network topology is already defined in the constructor, so multiple
+    calls to BuildForward build multiple networks whose parameters are all
+    shared. It is the source of the input features and the use of the output
+    that distinguishes each network.
+
+    Args:
+      feature_endpoints: tensors with input features to the network
+      return_average: whether to use moving averages as model parameters
+
+    Returns:
+      logits: output of the final layer before computing softmax
+    """
+    assert len(feature_endpoints) == self._feature_size
+
+    # Create embedding layer.
+    embeddings = []
+    for i in range(self._feature_size):
+      embeddings.append(self._AddEmbedding(feature_endpoints[i],
+                                           self._num_features[i],
+                                           self._num_feature_ids[i],
+                                           self._embedding_sizes[i],
+                                           i,
+                                           return_average=return_average))
+
+    last_layer = tf.concat(1, embeddings)
+    last_layer_size = self.embedding_size
+
+    # Create ReLU layers.
+    for i, hidden_layer_size in enumerate(self._hidden_layer_sizes):
+      weights = self._AddParam(
+          [last_layer_size, hidden_layer_size],
+          tf.float32,
+          'weights_%d' % i,
+          self._ReluWeightInitializer(),
+          return_average=return_average)
+      bias = self._AddParam([hidden_layer_size],
+                            tf.float32,
+                            'bias_%d' % i,
+                            self._relu_bias_init,
+                            return_average=return_average)
+      last_layer = tf.nn.relu_layer(last_layer,
+                                    weights,
+                                    bias,
+                                    name='layer_%d' % i)
+      last_layer_size = hidden_layer_size
+
+    # Create softmax layer.
+    softmax_weight = self._AddParam(
+        [last_layer_size, self._num_actions],
+        tf.float32,
+        'softmax_weight',
+        tf.random_normal_initializer(stddev=self._softmax_init,
+                                     seed=self._seed),
+        return_average=return_average)
+    softmax_bias = self._AddParam(
+        [self._num_actions],
+        tf.float32,
+        'softmax_bias',
+        tf.zeros_initializer,
+        return_average=return_average)
+    logits = tf.nn.xw_plus_b(last_layer,
+                             softmax_weight,
+                             softmax_bias,
+                             name='logits')
+    return {'logits': logits}
+
+  def _AddGoldReader(self, task_context, batch_size, corpus_name):
+    features, epochs, gold_actions = (
+        gen_parser_ops.gold_parse_reader(task_context,
+                                         self._feature_size,
+                                         batch_size,
+                                         corpus_name=corpus_name,
+                                         arg_prefix=self._arg_prefix))
+    return {'gold_actions': tf.identity(gold_actions,
+                                        name='gold_actions'),
+            'epochs': tf.identity(epochs,
+                                  name='epochs'),
+            'feature_endpoints': features}
+
+  def _AddDecodedReader(self, task_context, batch_size, transition_scores,
+                        corpus_name):
+    features, epochs, eval_metrics, documents = (
+        gen_parser_ops.decoded_parse_reader(transition_scores,
+                                            task_context,
+                                            self._feature_size,
+                                            batch_size,
+                                            corpus_name=corpus_name,
+                                            arg_prefix=self._arg_prefix))
+    return {'eval_metrics': eval_metrics,
+            'epochs': tf.identity(epochs,
+                                  name='epochs'),
+            'feature_endpoints': features,
+            'documents': documents}
+
+  def _AddCostFunction(self, batch_size, gold_actions, logits):
+    """Cross entropy plus L2 loss on weights and biases of the hidden layers."""
+    dense_golden = BatchedSparseToDense(gold_actions, self._num_actions)
+    cross_entropy = tf.div(
+        tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(
+            logits, dense_golden)), batch_size)
+    regularized_params = [tf.nn.l2_loss(p)
+                          for k, p in self.params.items()
+                          if k.startswith('weights') or k.startswith('bias')]
+    l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
+    return {'cost': tf.add(cross_entropy, l2_loss, name='cost')}
+
+  def AddEvaluation(self,
+                    task_context,
+                    batch_size,
+                    evaluation_max_steps=300,
+                    corpus_name='documents'):
+    """Builds the forward network only without the training operation.
+
+    Args:
+      task_context: file path from which to read the task context.
+      batch_size: batch size to request from reader op.
+      evaluation_max_steps: max number of parsing actions during evaluation,
+          only used in beam parsing.
+      corpus_name: name of the task input to read parses from.
+
+    Returns:
+      Dictionary of named eval nodes.
+    """
+    def _AssignTransitionScores():
+      return tf.assign(nodes['transition_scores'],
+                       nodes['logits'], validate_shape=False)
+    def _Pass():
+      return tf.constant(-1.0)
+    unused_evaluation_max_steps = evaluation_max_steps
+    with tf.name_scope('evaluation'):
+      nodes = self.evaluation
+      nodes['transition_scores'] = self._AddVariable(
+          [batch_size, self._num_actions], tf.float32, 'transition_scores',
+          tf.constant_initializer(-1.0))
+      nodes.update(self._AddDecodedReader(task_context, batch_size, nodes[
+          'transition_scores'], corpus_name))
+      nodes.update(self._BuildNetwork(nodes['feature_endpoints'],
+                                      return_average=self._use_averaging))
+      nodes['eval_metrics'] = cf.with_dependencies(
+          [tf.cond(tf.greater(tf.size(nodes['logits']), 0),
+                   _AssignTransitionScores, _Pass)],
+          nodes['eval_metrics'], name='eval_metrics')
+    return nodes
+
+  def _IncrementCounter(self, counter):
+    return state_ops.assign_add(counter, 1, use_locking=True)
+
+  def _AddLearningRate(self, initial_learning_rate, decay_steps):
+    """Returns a learning rate that decays by 0.96 every decay_steps.
+
+    Args:
+      initial_learning_rate: initial value of the learning rate
+      decay_steps: decay by 0.96 every this many steps
+
+    Returns:
+      learning rate variable.
+    """
+    step = self.GetStep()
+    return cf.with_dependencies(
+        [self._IncrementCounter(step)],
+        tf.train.exponential_decay(initial_learning_rate,
+                                   step,
+                                   decay_steps,
+                                   0.96,
+                                   staircase=True))
+
+  def AddPretrainedEmbeddings(self, index, embeddings_path, task_context):
+    """Embeddings at the given index will be set to pretrained values."""
+
+    def _Initializer(shape, dtype=tf.float32):
+      unused_dtype = dtype
+      t = gen_parser_ops.word_embedding_initializer(
+          vectors=embeddings_path,
+          task_context=task_context,
+          embedding_init=self._embedding_init)
+
+      t.set_shape(shape)
+      return t
+
+    self._pretrained_embeddings[index] = _Initializer
+
+  def AddTraining(self,
+                  task_context,
+                  batch_size,
+                  learning_rate=0.1,
+                  decay_steps=4000,
+                  momentum=0.9,
+                  corpus_name='documents'):
+    """Builds a trainer to minimize the cross entropy cost function.
+
+    Args:
+      task_context: file path from which to read the task context
+      batch_size: batch size to request from reader op
+      learning_rate: initial value of the learning rate
+      decay_steps: decay learning rate by 0.96 every this many steps
+      momentum: momentum parameter used when training with momentum
+      corpus_name: name of the task input to read parses from
+
+    Returns:
+      Dictionary of named training nodes.
+    """
+    with tf.name_scope('training'):
+      nodes = self.training
+      nodes.update(self._AddGoldReader(task_context, batch_size, corpus_name))
+      nodes.update(self._BuildNetwork(nodes['feature_endpoints'],
+                                      return_average=False))
+      nodes.update(self._AddCostFunction(batch_size, nodes['gold_actions'],
+                                         nodes['logits']))
+      # Add the optimizer
+      if self._only_train:
+        trainable_params = [v
+                            for k, v in self.params.iteritems()
+                            if k in self._only_train]
+      else:
+        trainable_params = self.params.values()
+      lr = self._AddLearningRate(learning_rate, decay_steps)
+      optimizer = tf.train.MomentumOptimizer(lr,
+                                             momentum,
+                                             use_locking=self._use_locking)
+      train_op = optimizer.minimize(nodes['cost'], var_list=trainable_params)
+      for param in trainable_params:
+        slot = optimizer.get_slot(param, 'momentum')
+        self.inits[slot.name] = state_ops.init_variable(slot,
+                                                        tf.zeros_initializer)
+        self.variables[slot.name] = slot
+      numerical_checks = [
+          tf.check_numerics(param,
+                            message='Parameter is not finite.')
+          for param in trainable_params
+          if param.dtype.base_dtype in [tf.float32, tf.float64]
+      ]
+      check_op = tf.group(*numerical_checks)
+      avg_update_op = tf.group(*self._averaging.values())
+      train_ops = [train_op]
+      if self._check_parameters:
+        train_ops.append(check_op)
+      if self._use_averaging:
+        train_ops.append(avg_update_op)
+      nodes['train_op'] = tf.group(*train_ops, name='train_op')
+    return nodes
+
+  def AddSaver(self, slim_model=False):
+    """Adds ops to save and restore model parameters.
+
+    Args:
+      slim_model: whether only averaged variables are saved.
+
+    Returns:
+      the saver object.
+    """
+    # We have to put the save op in the root scope otherwise running
+    # "save/restore_all" won't find the "save/Const" node it expects.
+    with tf.name_scope(None):
+      variables_to_save = self.params.copy()
+      variables_to_save.update(self.variables)
+      if slim_model:
+        for key in variables_to_save.keys():
+          if not key.endswith('avg_var'):
+            del variables_to_save[key]
+      self.saver = tf.train.Saver(variables_to_save)
+    return self.saver

+ 325 - 0
syntaxnet/syntaxnet/graph_builder_test.py

@@ -0,0 +1,325 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for graph_builder."""
+
+
+# disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
+import os.path
+
+import tensorflow as tf
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+from syntaxnet import graph_builder
+from syntaxnet import sparse_pb2
+from syntaxnet.ops import gen_parser_ops
+
+FLAGS = tf.app.flags.FLAGS
+if not hasattr(FLAGS, 'test_srcdir'):
+  FLAGS.test_srcdir = ''
+if not hasattr(FLAGS, 'test_tmpdir'):
+  FLAGS.test_tmpdir = tf.test.get_temp_dir()
+
+
+class GraphBuilderTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    # Creates a task context with the correct testing paths.
+    initial_task_context = os.path.join(
+        FLAGS.test_srcdir,
+        'syntaxnet/'
+        'testdata/context.pbtxt')
+    self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
+    with open(initial_task_context, 'r') as fin:
+      with open(self._task_context, 'w') as fout:
+        fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
+                   .replace('OUTPATH', FLAGS.test_tmpdir))
+
+    # Creates necessary term maps.
+    with self.test_session() as sess:
+      gen_parser_ops.lexicon_builder(task_context=self._task_context,
+                                     corpus_name='training-corpus').run()
+      self._num_features, self._num_feature_ids, _, self._num_actions = (
+          sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
+                                               arg_prefix='brain_parser')))
+
+  def MakeBuilder(self, use_averaging=True, **kw_args):
+    # Set the seed and gate_gradients to ensure reproducibility.
+    return graph_builder.GreedyParser(
+        self._num_actions, self._num_features, self._num_feature_ids,
+        embedding_sizes=[8, 8, 8], hidden_layer_sizes=[32, 32], seed=42,
+        gate_gradients=True, use_averaging=use_averaging, **kw_args)
+
+  def FindNode(self, name):
+    for node in tf.get_default_graph().as_graph_def().node:
+      if node.name == name:
+        return node
+    return None
+
+  def NodeFound(self, name):
+    return self.FindNode(name) is not None
+
+  def testScope(self):
+    # Set up the network topology
+    graph = tf.Graph()
+    with graph.as_default():
+      parser = self.MakeBuilder()
+      parser.AddTraining(self._task_context,
+                         batch_size=10,
+                         corpus_name='training-corpus')
+      parser.AddEvaluation(self._task_context,
+                           batch_size=2,
+                           corpus_name='tuning-corpus')
+      parser.AddSaver()
+
+      # Check that the node ids we may rely on are there with the expected
+      # names.
+      self.assertEqual(parser.training['logits'].name, 'training/logits:0')
+      self.assertTrue(self.NodeFound('training/logits'))
+      self.assertTrue(self.NodeFound('training/feature_0'))
+      self.assertTrue(self.NodeFound('training/feature_1'))
+      self.assertTrue(self.NodeFound('training/feature_2'))
+      self.assertFalse(self.NodeFound('training/feature_3'))
+
+      self.assertEqual(parser.evaluation['logits'].name, 'evaluation/logits:0')
+      self.assertTrue(self.NodeFound('evaluation/logits'))
+
+      # The saver node is expected to be in the root scope.
+      self.assertTrue(self.NodeFound('save/restore_all'))
+
+      # Also check that the parameters have the scope we expect.
+      self.assertTrue(self.NodeFound('embedding_matrix_0'))
+      self.assertTrue(self.NodeFound('embedding_matrix_1'))
+      self.assertTrue(self.NodeFound('embedding_matrix_2'))
+      self.assertFalse(self.NodeFound('embedding_matrix_3'))
+
+  def testNestedScope(self):
+    # It's OK to put the whole graph in a scope of its own.
+    graph = tf.Graph()
+    with graph.as_default():
+      with graph.name_scope('top'):
+        parser = self.MakeBuilder()
+        parser.AddTraining(self._task_context,
+                           batch_size=10,
+                           corpus_name='training-corpus')
+        parser.AddSaver()
+
+      self.assertTrue(self.NodeFound('top/training/logits'))
+      self.assertTrue(self.NodeFound('top/training/feature_0'))
+
+      # The saver node is expected to be in the root scope no matter what.
+      self.assertFalse(self.NodeFound('top/save/restore_all'))
+      self.assertTrue(self.NodeFound('save/restore_all'))
+
+  def testUseCustomGraphs(self):
+    batch_size = 10
+
+    # Use separate custom graphs.
+    custom_train_graph = tf.Graph()
+    with custom_train_graph.as_default():
+      train_parser = self.MakeBuilder()
+      train_parser.AddTraining(self._task_context,
+                               batch_size,
+                               corpus_name='training-corpus')
+
+    custom_eval_graph = tf.Graph()
+    with custom_eval_graph.as_default():
+      eval_parser = self.MakeBuilder()
+      eval_parser.AddEvaluation(self._task_context,
+                                batch_size,
+                                corpus_name='tuning-corpus')
+
+    # The following session runs should not fail.
+    with self.test_session(graph=custom_train_graph) as sess:
+      self.assertTrue(self.NodeFound('training/logits'))
+      sess.run(train_parser.inits.values())
+      sess.run(['training/logits:0'])
+
+    with self.test_session(graph=custom_eval_graph) as sess:
+      self.assertFalse(self.NodeFound('training/logits'))
+      self.assertTrue(self.NodeFound('evaluation/logits'))
+      sess.run(eval_parser.inits.values())
+      sess.run(['evaluation/logits:0'])
+
+  def testTrainingAndEvalAreIndependent(self):
+    batch_size = 10
+    graph = tf.Graph()
+    with graph.as_default():
+      parser = self.MakeBuilder(use_averaging=False)
+      parser.AddTraining(self._task_context,
+                         batch_size,
+                         corpus_name='training-corpus')
+      parser.AddEvaluation(self._task_context,
+                           batch_size,
+                           corpus_name='tuning-corpus')
+    with self.test_session(graph=graph) as sess:
+      sess.run(parser.inits.values())
+      # Before any training updates are performed, both training and eval nets
+      # should return the same computations.
+      eval_logits, = sess.run([parser.evaluation['logits']])
+      training_logits, = sess.run([parser.training['logits']])
+      self.assertNear(abs((eval_logits - training_logits).sum()), 0, 1e-6)
+
+      # After training, activations should differ.
+      for _ in range(5):
+        eval_logits = parser.evaluation['logits'].eval()
+      for _ in range(5):
+        training_logits, _ = sess.run([parser.training['logits'],
+                                       parser.training['train_op']])
+      self.assertGreater(abs((eval_logits - training_logits).sum()), 0, 1e-3)
+
+  def testReproducibility(self):
+    batch_size = 10
+
+    def ComputeACost(graph):
+      with graph.as_default():
+        parser = self.MakeBuilder(use_averaging=False)
+        parser.AddTraining(self._task_context,
+                           batch_size,
+                           corpus_name='training-corpus')
+        parser.AddEvaluation(self._task_context,
+                             batch_size,
+                             corpus_name='tuning-corpus')
+      with self.test_session(graph=graph) as sess:
+        sess.run(parser.inits.values())
+        for _ in range(5):
+          cost, _ = sess.run([parser.training['cost'],
+                              parser.training['train_op']])
+      return cost
+
+    cost1 = ComputeACost(tf.Graph())
+    cost2 = ComputeACost(tf.Graph())
+    self.assertNear(cost1, cost2, 1e-8)
+
+  def testAddTrainingAndEvalOrderIndependent(self):
+    batch_size = 10
+
+    graph1 = tf.Graph()
+    with graph1.as_default():
+      parser = self.MakeBuilder(use_averaging=False)
+      parser.AddTraining(self._task_context,
+                         batch_size,
+                         corpus_name='training-corpus')
+      parser.AddEvaluation(self._task_context,
+                           batch_size,
+                           corpus_name='tuning-corpus')
+    with self.test_session(graph=graph1) as sess:
+      sess.run(parser.inits.values())
+      metrics1 = None
+      for _ in range(500):
+        cost1, _ = sess.run([parser.training['cost'],
+                             parser.training['train_op']])
+        em1 = parser.evaluation['eval_metrics'].eval()
+        metrics1 = metrics1 + em1 if metrics1 is not None else em1
+
+    # Reverse the order in which Training and Eval stacks are added.
+    graph2 = tf.Graph()
+    with graph2.as_default():
+      parser = self.MakeBuilder(use_averaging=False)
+      parser.AddEvaluation(self._task_context,
+                           batch_size,
+                           corpus_name='tuning-corpus')
+      parser.AddTraining(self._task_context,
+                         batch_size,
+                         corpus_name='training-corpus')
+    with self.test_session(graph=graph2) as sess:
+      sess.run(parser.inits.values())
+      metrics2 = None
+      for _ in range(500):
+        cost2, _ = sess.run([parser.training['cost'],
+                             parser.training['train_op']])
+        em2 = parser.evaluation['eval_metrics'].eval()
+        metrics2 = metrics2 + em2 if metrics2 is not None else em2
+
+    self.assertNear(cost1, cost2, 1e-8)
+    self.assertEqual(abs(metrics1 - metrics2).sum(), 0)
+
+  def testEvalMetrics(self):
+    batch_size = 10
+    graph = tf.Graph()
+    with graph.as_default():
+      parser = self.MakeBuilder()
+      parser.AddEvaluation(self._task_context,
+                           batch_size,
+                           corpus_name='tuning-corpus')
+    with self.test_session(graph=graph) as sess:
+      sess.run(parser.inits.values())
+      tokens = 0
+      correct_heads = 0
+      for _ in range(100):
+        eval_metrics = sess.run(parser.evaluation['eval_metrics'])
+        tokens += eval_metrics[0]
+        correct_heads += eval_metrics[1]
+      self.assertGreater(tokens, 0)
+      self.assertGreaterEqual(tokens, correct_heads)
+      self.assertGreaterEqual(correct_heads, 0)
+
+  def MakeSparseFeatures(self, ids, weights):
+    f = sparse_pb2.SparseFeatures()
+    for i, w in zip(ids, weights):
+      f.id.append(i)
+      f.weight.append(w)
+    return f.SerializeToString()
+
+  def testEmbeddingOp(self):
+    graph = tf.Graph()
+    with self.test_session(graph=graph):
+      params = tf.constant([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]],
+                           tf.float32)
+
+      var = variables.Variable([self.MakeSparseFeatures([1, 2], [1.0, 1.0]),
+                                self.MakeSparseFeatures([], [])])
+      var.initializer.run()
+      embeddings = graph_builder.EmbeddingLookupFeatures(params, var,
+                                                         True).eval()
+      self.assertAllClose([[8.0, 10.0], [0.0, 0.0]], embeddings)
+
+      var = variables.Variable([self.MakeSparseFeatures([], []),
+                                self.MakeSparseFeatures([0, 2],
+                                                        [0.5, 2.0])])
+      var.initializer.run()
+      embeddings = graph_builder.EmbeddingLookupFeatures(params, var,
+                                                         True).eval()
+      self.assertAllClose([[0.0, 0.0], [10.5, 13.0]], embeddings)
+
+  def testOnlyTrainSomeParameters(self):
+    batch_size = 10
+    graph = tf.Graph()
+    with graph.as_default():
+      parser = self.MakeBuilder(use_averaging=False, only_train='softmax_bias')
+      parser.AddTraining(self._task_context,
+                         batch_size,
+                         corpus_name='training-corpus')
+    with self.test_session(graph=graph) as sess:
+      sess.run(parser.inits.values())
+      # Before training, save the state of two of the parameters.
+      bias0, weight0 = sess.run([parser.params['softmax_bias'],
+                                 parser.params['softmax_weight']])
+
+      for _ in range(5):
+        bias, weight, _ = sess.run([parser.params['softmax_bias'],
+                                    parser.params['softmax_weight'],
+                                    parser.training['train_op']])
+
+      # After training, only one of the parameters should have changed.
+      self.assertAllEqual(weight, weight0)
+      self.assertGreater(abs(bias - bias0).sum(), 0, 1e-5)
+
+
+if __name__ == '__main__':
+  googletest.main()

+ 82 - 0
syntaxnet/syntaxnet/kbest_syntax.proto

@@ -0,0 +1,82 @@
+// K-best part-of-speech and dependency annotations for tokens.
+
+syntax = "proto2";
+
+import "syntaxnet/sentence.proto";
+
+package syntaxnet;
+
+// A list of alternative (k-best) syntax analyses, grouped by sentences.
+message KBestSyntaxAnalyses {
+  extend Sentence {
+    optional KBestSyntaxAnalyses extension = 60366242;
+  }
+
+  // Alternative analyses for each sentence. Sentences are listed in the
+  // order visited by a SentenceIterator.
+  repeated KBestSyntaxAnalysesForSentence sentence = 1;
+
+  // Alternative analyses for each token.
+  repeated KBestSyntaxAnalysesForToken token = 2;
+}
+
+// A list of alternative (k-best) analyses for a sentence spanning from a start
+// token index to an end token index. The alternative analyses are ordered by
+// decreasing model score from best to worst. The first analysis is the 1-best
+// analysis, which is typically also stored in the document tokens.
+message KBestSyntaxAnalysesForSentence {
+  // First token of sentence.
+  optional int32 start = 1 [default = -1];
+
+  // Last token of sentence.
+  optional int32 end = 2 [default = -1];
+
+  // K-best analyses for the tokens in this sentence. All of the analyses in
+  // the list have the same "type"; e.g., k-best taggings,
+  // k-best {tagging+parse}s, etc.
+  // Note also that the type of analysis stored in this list can change
+  // depending on where we are in the document processing pipeline; e.g.,
+  // may initially be taggings, and then switch to parses.  The first
+  // token_analysis would be the 1-best analysis, which is typically also stored
+  // in the document.  Note: some post-processors will update the document's
+  // syntax trees, but will leave these unchanged.
+  repeated AlternativeTokenAnalysis token_analysis = 3;
+}
+
+// A list of scored alternative (k-best) analyses for a particular token. These
+// are all distinct from each other and ordered by decreasing model score. The
+// first is the 1-best analysis, which may or may not match the document tokens
+// depending on how the k-best analyses are selected.
+message KBestSyntaxAnalysesForToken {
+  // All token analyses in this repeated field refer to the same token.
+  // Each alternative analysis will contain a single entry for repeated fields
+  // such as head, tag, category and label.
+  repeated AlternativeTokenAnalysis token_analysis = 3;
+}
+
+// An alternative analysis of tokens in the document. The repeated fields
+// are indexed relative to the beginning of a sentence. Fields not
+// represented in the alternative analysis are assumed to be unchanged.
+// Currently only alternatives for tags, categories and (labeled) dependency
+// heads are supported.
+// Each repeated field should either have length=0 or length=number of tokens.
+message AlternativeTokenAnalysis {
+  // Head of this token in the dependency tree: the id of the token which has
+  // an arc going to this one. If it is the root token of a sentence, then it
+  // is set to -1.
+  repeated int32 head = 1;
+
+  // Part-of-speech tag for token.
+  repeated string tag = 2;
+
+  // Coarse-grained word category for token.
+  repeated string category = 3;
+
+  // Label for dependency relation between this token and its head.
+  repeated string label = 4;
+
+  // The score of this analysis, where bigger values typically indicate better
+  // quality, but there are no guarantees and there is also no pre-defined
+  // range.
+  optional double score = 5;
+}

+ 248 - 0
syntaxnet/syntaxnet/lexicon_builder.cc

@@ -0,0 +1,248 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stddef.h>
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/affix.h"
+#include "syntaxnet/dictionary.pb.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/sentence_batch.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+// A task that collects term statistics over a corpus and saves a set of
+// term maps; these saved mappings are used to map strings to ints in both the
+// chunker trainer and the chunker processors.
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_INT32;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::errors::InvalidArgument;
+
+namespace syntaxnet {
+
+// A workflow task that creates term maps (e.g., word, tag, etc.).
+//
+// Non-flag task parameters:
+// int lexicon_max_prefix_length (3):
+//   The maximum prefix length for lexicon words.
+// int lexicon_max_suffix_length (3):
+//   The maximum suffix length for lexicon words.
+class LexiconBuilder : public OpKernel {
+ public:
+  explicit LexiconBuilder(OpKernelConstruction *context) : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name_));
+    OP_REQUIRES_OK(context, context->GetAttr("lexicon_max_prefix_length",
+                                             &max_prefix_length_));
+    OP_REQUIRES_OK(context, context->GetAttr("lexicon_max_suffix_length",
+                                             &max_suffix_length_));
+
+    string file_path, data;
+    OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
+    OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
+                                             file_path, &data));
+    OP_REQUIRES(context,
+                TextFormat::ParseFromString(data, task_context_.mutable_spec()),
+                InvalidArgument("Could not parse task context at ", file_path));
+  }
+
+  // Counts term frequencies.
+  void Compute(OpKernelContext *context) override {
+    // Term frequency maps to be populated by the corpus.
+    TermFrequencyMap words;
+    TermFrequencyMap lcwords;
+    TermFrequencyMap tags;
+    TermFrequencyMap categories;
+    TermFrequencyMap labels;
+
+    // Affix tables to be populated by the corpus.
+    AffixTable prefixes(AffixTable::PREFIX, max_prefix_length_);
+    AffixTable suffixes(AffixTable::SUFFIX, max_suffix_length_);
+
+    // Tag-to-category mapping.
+    TagToCategoryMap tag_to_category;
+
+    // Make a pass over the corpus.
+    int64 num_tokens = 0;
+    int64 num_documents = 0;
+    Sentence *document;
+    TextReader corpus(*task_context_.GetInput(corpus_name_));
+    while ((document = corpus.Read()) != NULL) {
+      // Gather token information.
+      for (int t = 0; t < document->token_size(); ++t) {
+        // Get token and lowercased word.
+        const Token &token = document->token(t);
+        string word = token.word();
+        utils::NormalizeDigits(&word);
+        string lcword = tensorflow::str_util::Lowercase(word);
+
+        // Make sure the token does not contain a newline.
+        CHECK(lcword.find('\n') == string::npos);
+
+        // Increment frequencies (only for terms that exist).
+        if (!word.empty() && !HasSpaces(word)) words.Increment(word);
+        if (!lcword.empty() && !HasSpaces(lcword)) lcwords.Increment(lcword);
+        if (!token.tag().empty()) tags.Increment(token.tag());
+        if (!token.category().empty()) categories.Increment(token.category());
+        if (!token.label().empty()) labels.Increment(token.label());
+
+        // Add prefixes/suffixes for the current word.
+        prefixes.AddAffixesForWord(word.c_str(), word.size());
+        suffixes.AddAffixesForWord(word.c_str(), word.size());
+
+        // Add mapping from tag to category.
+        tag_to_category.SetCategory(token.tag(), token.category());
+
+        // Update the number of processed tokens.
+        ++num_tokens;
+      }
+
+      delete document;
+      ++num_documents;
+    }
+    LOG(INFO) << "Term maps collected over " << num_tokens << " tokens from "
+              << num_documents << " documents";
+
+    // Write mappings to disk.
+    words.Save(TaskContext::InputFile(*task_context_.GetInput("word-map")));
+    lcwords.Save(TaskContext::InputFile(*task_context_.GetInput("lcword-map")));
+    tags.Save(TaskContext::InputFile(*task_context_.GetInput("tag-map")));
+    categories.Save(
+        TaskContext::InputFile(*task_context_.GetInput("category-map")));
+    labels.Save(TaskContext::InputFile(*task_context_.GetInput("label-map")));
+
+    // Write affixes to disk.
+    WriteAffixTable(prefixes, TaskContext::InputFile(
+                                  *task_context_.GetInput("prefix-table")));
+    WriteAffixTable(suffixes, TaskContext::InputFile(
+                                  *task_context_.GetInput("suffix-table")));
+
+    // Write tag-to-category mapping to disk.
+    tag_to_category.Save(
+        TaskContext::InputFile(*task_context_.GetInput("tag-to-category")));
+  }
+
+ private:
+  // Returns true if the word contains spaces.
+  static bool HasSpaces(const string &word) {
+    for (char c : word) {
+      if (c == ' ') return true;
+    }
+    return false;
+  }
+
+  // Writes an affix table to a task output.
+  static void WriteAffixTable(const AffixTable &affixes,
+                              const string &output_file) {
+    ProtoRecordWriter writer(output_file);
+    affixes.Write(&writer);
+  }
+
+  // Name of the context input to compute lexicons.
+  string corpus_name_;
+
+  // Max length for prefix table.
+  int max_prefix_length_;
+
+  // Max length for suffix table.
+  int max_suffix_length_;
+
+  // Task context used to configure this op.
+  TaskContext task_context_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("LexiconBuilder").Device(DEVICE_CPU),
+                        LexiconBuilder);
+
+class FeatureSize : public OpKernel {
+ public:
+  explicit FeatureSize(OpKernelConstruction *context) : OpKernel(context) {
+    string task_context_path;
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("task_context", &task_context_path));
+    OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
+    OP_REQUIRES_OK(context, context->MatchSignature(
+                                {}, {DT_INT32, DT_INT32, DT_INT32, DT_INT32}));
+    string data;
+    OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
+                                             task_context_path, &data));
+    OP_REQUIRES(
+        context,
+        TextFormat::ParseFromString(data, task_context_.mutable_spec()),
+        InvalidArgument("Could not parse task context at ", task_context_path));
+    string label_map_path =
+        TaskContext::InputFile(*task_context_.GetInput("label-map"));
+    label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
+        label_map_path, 0, 0);
+  }
+
+  ~FeatureSize() override { SharedStore::Release(label_map_); }
+
+  void Compute(OpKernelContext *context) override {
+    // Computes feature sizes.
+    ParserEmbeddingFeatureExtractor features(arg_prefix_);
+    features.Setup(&task_context_);
+    features.Init(&task_context_);
+    const int num_embeddings = features.NumEmbeddings();
+    Tensor *feature_sizes = nullptr;
+    Tensor *domain_sizes = nullptr;
+    Tensor *embedding_dims = nullptr;
+    Tensor *num_actions = nullptr;
+    TF_CHECK_OK(context->allocate_output(0, TensorShape({num_embeddings}),
+                                         &feature_sizes));
+    TF_CHECK_OK(context->allocate_output(1, TensorShape({num_embeddings}),
+                                         &domain_sizes));
+    TF_CHECK_OK(context->allocate_output(2, TensorShape({num_embeddings}),
+                                         &embedding_dims));
+    TF_CHECK_OK(context->allocate_output(3, TensorShape({}), &num_actions));
+    for (int i = 0; i < num_embeddings; ++i) {
+      feature_sizes->vec<int32>()(i) = features.FeatureSize(i);
+      domain_sizes->vec<int32>()(i) = features.EmbeddingSize(i);
+      embedding_dims->vec<int32>()(i) = features.EmbeddingDims(i);
+    }
+
+    // Computes number of actions in the transition system.
+    std::unique_ptr<ParserTransitionSystem> transition_system(
+        ParserTransitionSystem::Create(task_context_.Get(
+            features.GetParamName("transition_system"), "arc-standard")));
+    transition_system->Setup(&task_context_);
+    transition_system->Init(&task_context_);
+    num_actions->scalar<int32>()() =
+        transition_system->NumActions(label_map_->Size());
+  }
+
+ private:
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // Dependency label map used in transition system.
+  const TermFrequencyMap *label_map_;
+
+  // Prefix for context parameters.
+  string arg_prefix_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("FeatureSize").Device(DEVICE_CPU), FeatureSize);
+
+}  // namespace syntaxnet

+ 174 - 0
syntaxnet/syntaxnet/lexicon_builder_test.py

@@ -0,0 +1,174 @@
+# coding=utf-8
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for lexicon_builder."""
+
+
+# disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
+import os.path
+
+import tensorflow as tf
+
+import syntaxnet.load_parser_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+
+from syntaxnet import sentence_pb2
+from syntaxnet import task_spec_pb2
+from syntaxnet.ops import gen_parser_ops
+
+FLAGS = tf.app.flags.FLAGS
+
+CONLL_DOC1 = u'''1 बात _ n NN _ _ _ _ _
+2 गलत _ adj JJ _ _ _ _ _
+3 हो _ v VM _ _ _ _ _
+4 तो _ avy CC _ _ _ _ _
+5 गुस्सा _ n NN _ _ _ _ _
+6 सेलेब्रिटिज _ n NN _ _ _ _ _
+7 को _ psp PSP _ _ _ _ _
+8 भी _ avy RP _ _ _ _ _
+9 आना _ v VM _ _ _ _ _
+10 लाजमी _ adj JJ _ _ _ _ _
+11 है _ v VM _ _ _ _ _
+12 । _ punc SYM _ _ _ _ _'''
+
+CONLL_DOC2 = u'''1 लेकिन _ avy CC _ _ _ _ _
+2 अभिनेत्री _ n NN _ _ _ _ _
+3 के _ psp PSP _ _ _ _ _
+4 इस _ pn DEM _ _ _ _ _
+5 कदम _ n NN _ _ _ _ _
+6 से _ psp PSP _ _ _ _ _
+7 वहां _ pn PRP _ _ _ _ _
+8 रंग _ n NN _ _ _ _ _
+9 में _ psp PSP _ _ _ _ _
+10 भंग _ adj JJ _ _ _ _ _
+11 पड़ _ v VM _ _ _ _ _
+12 गया _ v VAUX _ _ _ _ _
+13 । _ punc SYM _ _ _ _ _'''
+
+TAGS = ['NN', 'JJ', 'VM', 'CC', 'PSP', 'RP', 'JJ', 'SYM', 'DEM', 'PRP', 'VAUX']
+
+CATEGORIES = ['n', 'adj', 'v', 'avy', 'n', 'psp', 'punc', 'pn']
+
+TOKENIZED_DOCS = u'''बात गलत हो तो गुस्सा सेलेब्रिटिज को भी आना लाजमी है ।
+लेकिन अभिनेत्री के इस कदम से वहां रंग में भंग पड़ गया ।
+'''
+
+COMMENTS = u'# Line with fake comments.'
+
+
+class LexiconBuilderTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    if not hasattr(FLAGS, 'test_srcdir'):
+      FLAGS.test_srcdir = ''
+    if not hasattr(FLAGS, 'test_tmpdir'):
+      FLAGS.test_tmpdir = tf.test.get_temp_dir()
+    self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
+    self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
+
+  def AddInput(self, name, file_pattern, record_format, context):
+    inp = context.input.add()
+    inp.name = name
+    inp.record_format.append(record_format)
+    inp.part.add().file_pattern = file_pattern
+
+  def WriteContext(self, corpus_format):
+    context = task_spec_pb2.TaskSpec()
+    self.AddInput('documents', self.corpus_file, corpus_format, context)
+    for name in ('word-map', 'lcword-map', 'tag-map',
+                 'category-map', 'label-map', 'prefix-table',
+                 'suffix-table', 'tag-to-category'):
+      self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context)
+    logging.info('Writing context to: %s', self.context_file)
+    with open(self.context_file, 'w') as f:
+      f.write(str(context))
+
+  def ReadNextDocument(self, sess, doc_source):
+    doc_str, last = sess.run(doc_source)
+    if doc_str:
+      doc = sentence_pb2.Sentence()
+      doc.ParseFromString(doc_str[0])
+    else:
+      doc = None
+    return doc, last
+
+  def ValidateDocuments(self):
+    doc_source = gen_parser_ops.document_source(self.context_file, batch_size=1)
+    with self.test_session() as sess:
+      logging.info('Reading document1')
+      doc, last = self.ReadNextDocument(sess, doc_source)
+      self.assertEqual(len(doc.token), 12)
+      self.assertEqual(u'लाजमी', doc.token[9].word)
+      self.assertFalse(last)
+      logging.info('Reading document2')
+      doc, last = self.ReadNextDocument(sess, doc_source)
+      self.assertEqual(len(doc.token), 13)
+      self.assertEqual(u'भंग', doc.token[9].word)
+      self.assertFalse(last)
+      logging.info('Hitting end of the dataset')
+      doc, last = self.ReadNextDocument(sess, doc_source)
+      self.assertTrue(doc is None)
+      self.assertTrue(last)
+
+  def ValidateTagToCategoryMap(self):
+    with file(os.path.join(FLAGS.test_tmpdir, 'tag-to-category'), 'r') as f:
+      entries = [line.strip().split('\t') for line in f.readlines()]
+    for tag, category in entries:
+      self.assertIn(tag, TAGS)
+      self.assertIn(category, CATEGORIES)
+
+  def BuildLexicon(self):
+    with self.test_session():
+      gen_parser_ops.lexicon_builder(task_context=self.context_file).run()
+
+  def testCoNLLFormat(self):
+    self.WriteContext('conll-sentence')
+    logging.info('Writing conll file to: %s', self.corpus_file)
+    with open(self.corpus_file, 'w') as f:
+      f.write((CONLL_DOC1 + u'\n\n' + CONLL_DOC2 + u'\n')
+              .replace(' ', '\t').encode('utf-8'))
+    self.ValidateDocuments()
+    self.BuildLexicon()
+    self.ValidateTagToCategoryMap()
+
+  def testCoNLLFormatExtraNewlinesAndComments(self):
+    self.WriteContext('conll-sentence')
+    with open(self.corpus_file, 'w') as f:
+      f.write((u'\n\n\n' + CONLL_DOC1 + u'\n\n\n' + COMMENTS +
+               u'\n\n' + CONLL_DOC2).replace(' ', '\t').encode('utf-8'))
+    self.ValidateDocuments()
+    self.BuildLexicon()
+    self.ValidateTagToCategoryMap()
+
+  def testTokenizedTextFormat(self):
+    self.WriteContext('tokenized-text')
+    with open(self.corpus_file, 'w') as f:
+      f.write(TOKENIZED_DOCS.encode('utf-8'))
+    self.ValidateDocuments()
+    self.BuildLexicon()
+
+  def testTokenizedTextFormatExtraNewlines(self):
+    self.WriteContext('tokenized-text')
+    with open(self.corpus_file, 'w') as f:
+      f.write((u'\n\n\n' + TOKENIZED_DOCS + u'\n\n\n').encode('utf-8'))
+    self.ValidateDocuments()
+    self.BuildLexicon()
+
+if __name__ == '__main__':
+  googletest.main()

+ 23 - 0
syntaxnet/syntaxnet/load_parser_ops.py

@@ -0,0 +1,23 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Loads parser_ops shared library."""
+
+import os.path
+import tensorflow as tf
+
+tf.load_op_library(
+    os.path.join(tf.resource_loader.get_data_files_path(),
+                 'parser_ops.so'))

+ 189 - 0
syntaxnet/syntaxnet/models/parsey_mcparseface/context.pbtxt

@@ -0,0 +1,189 @@
+Parameter {
+  name: "brain_parser_embedding_dims"
+  value: "32;32;64"
+}
+Parameter {
+  name: "brain_parser_embedding_names"
+  value: "labels;tags;words"
+}
+Parameter {
+  name: 'brain_parser_scoring'
+  value: 'default'
+}
+Parameter {
+  name: "brain_parser_features"
+  value:
+  'stack.child(1).label '
+  'stack.child(1).sibling(-1).label '
+  'stack.child(-1).label '
+  'stack.child(-1).sibling(1).label '
+  'stack.child(2).label '
+  'stack.child(-2).label '
+  'stack(1).child(1).label '
+  'stack(1).child(1).sibling(-1).label '
+  'stack(1).child(-1).label '
+  'stack(1).child(-1).sibling(1).label '
+  'stack(1).child(2).label '
+  'stack(1).child(-2).label; '
+  'input.token.tag '
+  'input(1).token.tag '
+  'input(2).token.tag '
+  'input(3).token.tag '
+  'stack.token.tag '
+  'stack.child(1).token.tag '
+  'stack.child(1).sibling(-1).token.tag '
+  'stack.child(-1).token.tag '
+  'stack.child(-1).sibling(1).token.tag '
+  'stack.child(2).token.tag '
+  'stack.child(-2).token.tag '
+  'stack(1).token.tag '
+  'stack(1).child(1).token.tag '
+  'stack(1).child(1).sibling(-1).token.tag '
+  'stack(1).child(-1).token.tag '
+  'stack(1).child(-1).sibling(1).token.tag '
+  'stack(1).child(2).token.tag '
+  'stack(1).child(-2).token.tag '
+  'stack(2).token.tag '
+  'stack(3).token.tag; '
+  'input.token.word '
+  'input(1).token.word '
+  'input(2).token.word '
+  'input(3).token.word '
+  'stack.token.word '
+  'stack.child(1).token.word '
+  'stack.child(1).sibling(-1).token.word '
+  'stack.child(-1).token.word '
+  'stack.child(-1).sibling(1).token.word '
+  'stack.child(2).token.word '
+  'stack.child(-2).token.word '
+  'stack(1).token.word '
+  'stack(1).child(1).token.word '
+  'stack(1).child(1).sibling(-1).token.word '
+  'stack(1).child(-1).token.word '
+  'stack(1).child(-1).sibling(1).token.word '
+  'stack(1).child(2).token.word '
+  'stack(1).child(-2).token.word '
+  'stack(2).token.word '
+  'stack(3).token.word '
+}
+Parameter {
+  name: "brain_parser_transition_system"
+  value: "arc-standard"
+}
+
+Parameter {
+  name: "brain_tagger_embedding_dims"
+  value: "8;16;16;16;16;64"
+}
+Parameter {
+  name: "brain_tagger_embedding_names"
+  value: "other;prefix2;prefix3;suffix2;suffix3;words"
+}
+Parameter {
+  name: "brain_tagger_features"
+  value:
+  'input.digit '
+  'input.hyphen; '
+  'input.prefix(length="2") '
+  'input(1).prefix(length="2") '
+  'input(2).prefix(length="2") '
+  'input(3).prefix(length="2") '
+  'input(-1).prefix(length="2") '
+  'input(-2).prefix(length="2") '
+  'input(-3).prefix(length="2") '
+  'input(-4).prefix(length="2"); '
+  'input.prefix(length="3") '
+  'input(1).prefix(length="3") '
+  'input(2).prefix(length="3") '
+  'input(3).prefix(length="3") '
+  'input(-1).prefix(length="3") '
+  'input(-2).prefix(length="3") '
+  'input(-3).prefix(length="3") '
+  'input(-4).prefix(length="3"); '
+  'input.suffix(length="2") '
+  'input(1).suffix(length="2") '
+  'input(2).suffix(length="2") '
+  'input(3).suffix(length="2") '
+  'input(-1).suffix(length="2") '
+  'input(-2).suffix(length="2") '
+  'input(-3).suffix(length="2") '
+  'input(-4).suffix(length="2"); '
+  'input.suffix(length="3") '
+  'input(1).suffix(length="3") '
+  'input(2).suffix(length="3") '
+  'input(3).suffix(length="3") '
+  'input(-1).suffix(length="3") '
+  'input(-2).suffix(length="3") '
+  'input(-3).suffix(length="3") '
+  'input(-4).suffix(length="3"); '
+  'input.token.word '
+  'input(1).token.word '
+  'input(2).token.word '
+  'input(3).token.word '
+  'input(-1).token.word '
+  'input(-2).token.word '
+  'input(-3).token.word '
+  'input(-4).token.word '
+}
+Parameter {
+  name: "brain_tagger_transition_system"
+  value: "tagger"
+}
+
+input {
+  name: "tag-map"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/tag-map"
+  }
+}
+input {
+  name: "tag-to-category"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/fine-to-universal.map"
+  }
+}
+input {
+  name: "word-map"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/word-map"
+  }
+}
+input {
+  name: "label-map"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/label-map"
+  }
+}
+input {
+  name: "prefix-table"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/prefix-table"
+  }
+}
+input {
+  name: "suffix-table"
+  Part {
+    file_pattern: "syntaxnet/models/parsey_mcparseface/suffix-table"
+  }
+}
+input {
+  name: 'stdin'
+  record_format: 'english-text'
+  Part {
+    file_pattern: '-'
+  }
+}
+input {
+  name: 'stdin-conll'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: '-'
+  }
+}
+input {
+  name: 'stdout-conll'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: '-'
+  }
+}

+ 52 - 0
syntaxnet/syntaxnet/models/parsey_mcparseface/fine-to-universal.map

@@ -0,0 +1,52 @@
+#	.
+$	.
+''	.
+-LRB-	.
+-RRB-	.
+,	.
+.	.
+:	.
+ADD	X
+AFX	PRT
+CC	CONJ
+CD	NUM
+DT	DET
+EX	DET
+FW	X
+GW	X
+HYPH	.
+IN	ADP
+JJ	ADJ
+JJR	ADJ
+JJS	ADJ
+LS	X
+MD	VERB
+NFP	.
+NN	NOUN
+NNP	NOUN
+NNPS	NOUN
+NNS	NOUN
+PDT	DET
+POS	PRT
+PRP	PRON
+PRP$	PRON
+RB	ADV
+RBR	ADV
+RBS	ADV
+RP	PRT
+SYM	X
+TO	PRT
+UH	X
+VB	VERB
+VBD	VERB
+VBG	VERB
+VBN	VERB
+VBP	VERB
+VBZ	VERB
+WDT	DET
+WP	PRON
+WP$	PRON
+WRB	ADV
+``	.
+X	X
+XX	X

+ 47 - 0
syntaxnet/syntaxnet/models/parsey_mcparseface/label-map

@@ -0,0 +1,47 @@
+46
+punct 243160
+prep 194627
+pobj 186958
+det 170592
+nsubj 144821
+nn 144800
+amod 117242
+ROOT 90592
+dobj 88551
+aux 76523
+advmod 72893
+conj 59384
+cc 57532
+num 36350
+poss 35117
+dep 34986
+ccomp 29470
+cop 25991
+mark 25141
+xcomp 25111
+rcmod 16234
+auxpass 15740
+advcl 14996
+possessive 14866
+nsubjpass 14133
+pcomp 12488
+appos 11112
+partmod 11106
+neg 11090
+number 10658
+prt 7123
+quantmod 6653
+tmod 5418
+infmod 5134
+npadvmod 3213
+parataxis 3012
+mwe 2793
+expl 2712
+iobj 1642
+acomp 1632
+discourse 1381
+csubj 1225
+predet 1160
+preconj 749
+goeswith 146
+csubjpass 41

BIN
syntaxnet/syntaxnet/models/parsey_mcparseface/parser-params


BIN
syntaxnet/syntaxnet/models/parsey_mcparseface/prefix-table


BIN
syntaxnet/syntaxnet/models/parsey_mcparseface/suffix-table


+ 50 - 0
syntaxnet/syntaxnet/models/parsey_mcparseface/tag-map

@@ -0,0 +1,50 @@
+49
+NN 285194
+IN 228165
+DT 179147
+NNP 175147
+JJ 125667
+NNS 115732
+, 97481
+. 85938
+RB 78513
+VB 63952
+CC 57554
+VBD 56635
+CD 55674
+PRP 55244
+VBZ 48126
+VBN 44458
+VBG 34524
+VBP 33669
+TO 28772
+MD 22364
+PRP$ 20706
+HYPH 18526
+POS 14905
+`` 12193
+'' 12154
+WDT 10267
+: 8713
+$ 7993
+WP 7336
+RP 7335
+WRB 6634
+JJR 6295
+NNPS 5917
+-RRB- 3904
+-LRB- 3840
+JJS 3596
+RBR 3186
+EX 2733
+UH 1521
+RBS 1467
+PDT 1271
+FW 928
+NFP 844
+SYM 652
+ADD 476
+LS 392
+WP$ 332
+GW 184
+AFX 42

BIN
syntaxnet/syntaxnet/models/parsey_mcparseface/tagger-params


Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 64037 - 0
syntaxnet/syntaxnet/models/parsey_mcparseface/word-map


+ 274 - 0
syntaxnet/syntaxnet/ops/parser_ops.cc

@@ -0,0 +1,274 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/op.h"
+
+namespace syntaxnet {
+
+// -----------------------------------------------------------------------------
+
+REGISTER_OP("GoldParseReader")
+    .Output("features: feature_size * string")
+    .Output("num_epochs: int32")
+    .Output("gold_actions: int32")
+    .Attr("task_context: string")
+    .Attr("feature_size: int")
+    .Attr("batch_size: int")
+    .Attr("corpus_name: string='documents'")
+    .Attr("arg_prefix: string='brain_parser'")
+    .SetIsStateful()
+    .Doc(R"doc(
+Reads sentences, parses them, and returns (gold action, feature) pairs.
+
+features: features firing at the current parser state, encoded as
+          dist_belief.SparseFeatures protocol buffers.
+num_epochs: number of times this reader went over the training corpus.
+gold_actions: action to perform at the current parser state.
+task_context: file path at which to read the task context.
+feature_size: number of feature outputs emitted by this reader.
+batch_size: number of sentences to parse at a time.
+corpus_name: name of task input in the task context to read parses from.
+arg_prefix: prefix for context parameters.
+)doc");
+
+REGISTER_OP("DecodedParseReader")
+    .Input("transition_scores: float")
+    .Output("features: feature_size * string")
+    .Output("num_epochs: int32")
+    .Output("eval_metrics: int32")
+    .Output("documents: string")
+    .Attr("task_context: string")
+    .Attr("feature_size: int")
+    .Attr("batch_size: int")
+    .Attr("corpus_name: string='documents'")
+    .Attr("arg_prefix: string='brain_parser'")
+    .SetIsStateful()
+    .Doc(R"doc(
+Reads sentences and parses them taking parsing transitions based on the
+input transition scores.
+
+transition_scores: scores for every transition from the current parser state.
+features: features firing at the current parser state encoded as
+          dist_belief.SparseFeatures protocol buffers.
+num_epochs: number of times this reader went over the training corpus.
+eval_metrics: token counts used to compute evaluation metrics.
+task_context: file path at which to read the task context.
+feature_size: number of feature outputs emitted by this reader.
+batch_size: number of sentences to parse at a time.
+corpus_name: name of task input in the task context to read parses from.
+arg_prefix: prefix for context parameters.
+)doc");
+
+REGISTER_OP("BeamParseReader")
+    .Output("features: feature_size * string")
+    .Output("beam_state: int64")
+    .Output("num_epochs: int32")
+    .Attr("task_context: string")
+    .Attr("feature_size: int")
+    .Attr("beam_size: int")
+    .Attr("batch_size: int=1")
+    .Attr("corpus_name: string='documents'")
+    .Attr("allow_feature_weights: bool=true")
+    .Attr("arg_prefix: string='brain_parser'")
+    .Attr("continue_until_all_final: bool=false")
+    .Attr("always_start_new_sentences: bool=false")
+    .SetIsStateful()
+    .Doc(R"doc(
+Reads sentences and creates a beam parser.
+
+features: features firing at the initial parser state encoded as
+          dist_belief.SparseFeatures protocol buffers.
+beam_state: beam state handle.
+task_context: file path at which to read the task context.
+feature_size: number of feature outputs emitted by this reader.
+beam_size: limit on the beam size.
+corpus_name: name of task input in the task context to read parses from.
+allow_feature_weights: whether the op is expected to output weighted features.
+                       If false, it will check that no weights are specified.
+arg_prefix: prefix for context parameters.
+continue_until_all_final: whether to continue parsing after the gold path falls
+                          off the beam.
+always_start_new_sentences: whether to skip to the beginning of a new sentence
+                            after each training step.
+)doc");
+
+REGISTER_OP("BeamParser")
+    .Input("beam_state: int64")
+    .Input("transition_scores: float")
+    .Output("features: feature_size * string")
+    .Output("next_beam_state: int64")
+    .Output("alive: bool")
+    .Attr("feature_size: int")
+    .SetIsStateful()
+    .Doc(R"doc(
+Updates the beam parser based on scores in the input transition scores.
+
+beam_state: beam state.
+transition_scores: scores for every transition from the current parser state.
+features: features firing at the current parser state encoded as
+          dist_belief.SparseFeatures protocol buffers.
+next_beam_state: beam state handle.
+alive: whether the gold state is still in the beam.
+feature_size: number of feature outputs emitted by this reader.
+)doc");
+
+REGISTER_OP("BeamParserOutput")
+    .Input("beam_state: int64")
+    .Output("indices_and_paths: int32")
+    .Output("batches_and_slots: int32")
+    .Output("gold_slot: int32")
+    .Output("path_scores: float")
+    .SetIsStateful()
+    .Doc(R"doc(
+Converts the current state of the beam parser into a set of indices into
+the scoring matrices that lead there.
+
+beam_state: beam state handle.
+indices_and_paths: matrix whose first row is a vector to look up beam paths and
+                   decisions with, and whose second row are the corresponding
+                   path ids.
+batches_and_slots: matrix whose first row is a vector identifying the batch to
+                   which the paths correspond, and whose second row are the
+                   slots.
+gold_slot: location in final beam of the gold path [batch_size].
+path_scores: cumulative sum of scores along each path in each beam. Within each
+             beam, scores are sorted from low to high.
+)doc");
+
+REGISTER_OP("BeamEvalOutput")
+    .Input("beam_state: int64")
+    .Output("eval_metrics: int32")
+    .Output("documents: string")
+    .SetIsStateful()
+    .Doc(R"doc(
+Computes eval metrics for the best paths in the input beams.
+
+beam_state: beam state handle.
+eval_metrics: token counts used to compute evaluation metrics.
+documents: parsed documents.
+)doc");
+
+REGISTER_OP("LexiconBuilder")
+    .Attr("task_context: string")
+    .Attr("corpus_name: string='documents'")
+    .Attr("lexicon_max_prefix_length: int = 3")
+    .Attr("lexicon_max_suffix_length: int = 3")
+    .Doc(R"doc(
+An op that collects term statistics over a corpus and saves a set of term maps.
+
+task_context: file path at which to read the task context.
+corpus_name: name of the context input to compute lexicons.
+lexicon_max_prefix_length: maximum prefix length for lexicon words.
+lexicon_max_suffix_length: maximum suffix length for lexicon words.
+)doc");
+
+REGISTER_OP("FeatureSize")
+    .Attr("task_context: string")
+    .Output("feature_sizes: int32")
+    .Output("domain_sizes: int32")
+    .Output("embedding_dims: int32")
+    .Output("num_actions: int32")
+    .Attr("arg_prefix: string='brain_parser'")
+    .Doc(R"doc(
+An op that returns the number and domain sizes of parser features.
+
+task_context: file path at which to read the task context.
+feature_sizes: number of feature locators in each group of parser features.
+domain_sizes: domain size for each feature group of parser features.
+embedding_dims: embedding dimension for each feature group of parser features.
+num_actions: number of actions a parser can perform.
+arg_prefix: prefix for context parameters.
+)doc");
+
+REGISTER_OP("UnpackSparseFeatures")
+    .Input("sf: string")
+    .Output("indices: int32")
+    .Output("ids: int64")
+    .Output("weights: float")
+    .Doc(R"doc(
+Converts a vector of strings with SparseFeatures to tensors.
+
+Note that indices, ids and weights are vectors of the same size and have
+one-to-one correspondence between their elements. ids and weights are each
+obtained by sequentially concatenating sf[i].id and sf[i].weight, for i in
+1...size(sf). Note that if sf[i].weight is not set, the default value for the
+weight is assumed to be 1.0. Also for any j, if ids[j] and weights[j] were
+extracted from sf[i], then index[j] is set to i.
+
+sf: vector of string, where each element is the string encoding of
+    SpareFeatures proto.
+indices: vector of indices inside sf
+ids: vector of id extracted from the SparseFeatures proto.
+weights: vector of weight extracted from the SparseFeatures proto.
+)doc");
+
+REGISTER_OP("WordEmbeddingInitializer")
+    .Output("word_embeddings: float")
+    .Attr("vectors: string")
+    .Attr("task_context: string")
+    .Attr("embedding_init: float = 1.0")
+    .Doc(R"doc(
+Reads word embeddings from an sstable of dist_belief.TokenEmbedding protos for
+every word specified in a text vocabulary file.
+
+word_embeddings: a tensor containing word embeddings from the specified sstable.
+vectors: path to recordio of word embedding vectors.
+task_context: file path at which to read the task context.
+)doc");
+
+REGISTER_OP("DocumentSource")
+    .Output("documents: string")
+    .Output("last: bool")
+    .Attr("task_context: string")
+    .Attr("corpus_name: string='documents'")
+    .Attr("batch_size: int")
+    .SetIsStateful()
+    .Doc(R"doc(
+Reads documents from documents_path and outputs them.
+
+documents: a vector of documents as serialized protos.
+last: whether this is the last batch of documents from this document path.
+batch_size: how many documents to read at once.
+)doc");
+
+REGISTER_OP("DocumentSink")
+    .Input("documents: string")
+    .Attr("task_context: string")
+    .Attr("corpus_name: string='documents'")
+    .Doc(R"doc(
+Write documents to documents_path.
+
+documents: documents to write.
+)doc");
+
+REGISTER_OP("WellFormedFilter")
+    .Input("documents: string")
+    .Output("filtered: string")
+    .Attr("task_context: string")
+    .Attr("corpus_name: string='documents'")
+    .Attr("keep_malformed_documents: bool = False")
+    .Doc(R"doc(
+)doc");
+
+REGISTER_OP("ProjectivizeFilter")
+    .Input("documents: string")
+    .Output("filtered: string")
+    .Attr("task_context: string")
+    .Attr("corpus_name: string='documents'")
+    .Attr("discard_non_projective: bool = False")
+    .Doc(R"doc(
+)doc");
+
+}  // namespace syntaxnet

+ 149 - 0
syntaxnet/syntaxnet/parser_eval.py

@@ -0,0 +1,149 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A program to annotate a conll file with a tensorflow neural net parser."""
+
+
+import os
+import os.path
+import time
+
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+from syntaxnet import sentence_pb2
+from syntaxnet import graph_builder
+from syntaxnet import structured_graph_builder
+from syntaxnet.ops import gen_parser_ops
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+
+flags.DEFINE_string('task_context', '',
+                    'Path to a task context with inputs and parameters for '
+                    'feature extractors.')
+flags.DEFINE_string('model_path', '', 'Path to model parameters.')
+flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
+flags.DEFINE_string('graph_builder', 'greedy',
+                    'Which graph builder to use, either greedy or structured.')
+flags.DEFINE_string('input', 'stdin',
+                    'Name of the context input to read data from.')
+flags.DEFINE_string('output', 'stdout',
+                    'Name of the context input to write data to.')
+flags.DEFINE_string('hidden_layer_sizes', '200,200',
+                    'Comma separated list of hidden layer sizes.')
+flags.DEFINE_integer('batch_size', 32,
+                     'Number of sentences to process in parallel.')
+flags.DEFINE_integer('beam_size', 8, 'Number of slots for beam parsing.')
+flags.DEFINE_integer('max_steps', 1000, 'Max number of steps to take.')
+flags.DEFINE_bool('slim_model', False,
+                  'Whether to expect only averaged variables.')
+
+
+def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
+  """Builds and evaluates a network.
+
+  Args:
+    sess: tensorflow session to use
+    num_actions: number of possible golden actions
+    feature_sizes: size of each feature vector
+    domain_sizes: number of possible feature ids in each feature vector
+    embedding_dims: embedding dimension for each feature group
+  """
+  t = time.time()
+  hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
+  logging.info('Building training network with parameters: feature_sizes: %s '
+               'domain_sizes: %s', feature_sizes, domain_sizes)
+  if FLAGS.graph_builder == 'greedy':
+    parser = graph_builder.GreedyParser(num_actions,
+                                        feature_sizes,
+                                        domain_sizes,
+                                        embedding_dims,
+                                        hidden_layer_sizes,
+                                        gate_gradients=True,
+                                        arg_prefix=FLAGS.arg_prefix)
+  else:
+    parser = structured_graph_builder.StructuredGraphBuilder(
+        num_actions,
+        feature_sizes,
+        domain_sizes,
+        embedding_dims,
+        hidden_layer_sizes,
+        gate_gradients=True,
+        arg_prefix=FLAGS.arg_prefix,
+        beam_size=FLAGS.beam_size,
+        max_steps=FLAGS.max_steps)
+  task_context = FLAGS.task_context
+  parser.AddEvaluation(task_context,
+                       FLAGS.batch_size,
+                       corpus_name=FLAGS.input,
+                       evaluation_max_steps=FLAGS.max_steps)
+
+  parser.AddSaver(FLAGS.slim_model)
+  sess.run(parser.inits.values())
+  parser.saver.restore(sess, FLAGS.model_path)
+
+  sink_documents = tf.placeholder(tf.string)
+  sink = gen_parser_ops.document_sink(sink_documents,
+                                      task_context=FLAGS.task_context,
+                                      corpus_name=FLAGS.output)
+  t = time.time()
+  num_epochs = None
+  num_tokens = 0
+  num_correct = 0
+  num_documents = 0
+  while True:
+    tf_eval_epochs, tf_eval_metrics, tf_documents = sess.run([
+        parser.evaluation['epochs'],
+        parser.evaluation['eval_metrics'],
+        parser.evaluation['documents'],
+    ])
+
+    if len(tf_documents):
+      logging.info('Processed %d documents', len(tf_documents))
+      num_documents += len(tf_documents)
+      sess.run(sink, feed_dict={sink_documents: tf_documents})
+
+    num_tokens += tf_eval_metrics[0]
+    num_correct += tf_eval_metrics[1]
+    if num_epochs is None:
+      num_epochs = tf_eval_epochs
+    elif num_epochs < tf_eval_epochs:
+      break
+
+  logging.info('Total processed documents: %d', num_documents)
+  if num_tokens > 0:
+    eval_metric = 100.0 * num_correct / num_tokens
+    logging.info('num correct tokens: %d', num_correct)
+    logging.info('total tokens: %d', num_tokens)
+    logging.info('Seconds elapsed in evaluation: %.2f, '
+                 'eval metric: %.2f%%', time.time() - t, eval_metric)
+
+
+def main(unused_argv):
+  logging.set_verbosity(logging.INFO)
+  with tf.Session() as sess:
+    feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
+        gen_parser_ops.feature_size(task_context=FLAGS.task_context,
+                                    arg_prefix=FLAGS.arg_prefix))
+
+  with tf.Session() as sess:
+    Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 213 - 0
syntaxnet/syntaxnet/parser_features.cc

@@ -0,0 +1,213 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/parser_features.h"
+
+#include <string>
+
+#include "syntaxnet/registry.h"
+#include "syntaxnet/sentence_features.h"
+#include "syntaxnet/workspace.h"
+
+namespace syntaxnet {
+
+// Registry for the parser feature functions.
+REGISTER_CLASS_REGISTRY("parser feature function", ParserFeatureFunction);
+
+// Registry for the parser state + token index feature functions.
+REGISTER_CLASS_REGISTRY("parser+index feature function",
+                        ParserIndexFeatureFunction);
+
+RootFeatureType::RootFeatureType(const string &name,
+                                 const FeatureType &wrapped_type,
+                                 int root_value)
+    : FeatureType(name), wrapped_type_(wrapped_type), root_value_(root_value) {}
+
+string RootFeatureType::GetFeatureValueName(FeatureValue value) const {
+  if (value == root_value_) return "<ROOT>";
+  return wrapped_type_.GetFeatureValueName(value);
+}
+
+FeatureValue RootFeatureType::GetDomainSize() const {
+  return wrapped_type_.GetDomainSize() + 1;
+}
+
+// Parser feature locator for accessing the remaining input tokens in the parser
+// state. It takes the offset relative to the current input token as argument.
+// Negative values represent tokens to the left, positive values to the right
+// and 0 (the default argument value) represents the current input token.
+class InputParserLocator : public ParserLocator<InputParserLocator> {
+ public:
+  // Gets the new focus.
+  int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
+    const int offset = argument();
+    return state.Input(offset);
+  }
+};
+
+REGISTER_PARSER_FEATURE_FUNCTION("input", InputParserLocator);
+
+// Parser feature locator for accessing the stack in the parser state. The
+// argument represents the position on the stack, 0 being the top of the stack.
+class StackParserLocator : public ParserLocator<StackParserLocator> {
+ public:
+  // Gets the new focus.
+  int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
+    const int position = argument();
+    return state.Stack(position);
+  }
+};
+
+REGISTER_PARSER_FEATURE_FUNCTION("stack", StackParserLocator);
+
+// Parser feature locator for locating the head of the focus token. The argument
+// specifies the number of times the head function is applied. Please note that
+// this operates on partially built dependency trees.
+class HeadFeatureLocator : public ParserIndexLocator<HeadFeatureLocator> {
+ public:
+  // Updates the current focus to a new location. If the initial focus is
+  // outside the range of the sentence, returns -2.
+  void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
+                  int *focus) const {
+    if (*focus < -1 || *focus >= state.sentence().token_size()) {
+      *focus = -2;
+      return;
+    }
+    const int levels = argument();
+    *focus = state.Parent(*focus, levels);
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("head", HeadFeatureLocator);
+
+// Parser feature locator for locating children of the focus token. The argument
+// specifies the number of times the leftmost (when the argument is < 0) or
+// rightmost (when the argument > 0) child function is applied. Please note that
+// this operates on partially built dependency trees.
+class ChildFeatureLocator : public ParserIndexLocator<ChildFeatureLocator> {
+ public:
+  // Updates the current focus to a new location. If the initial focus is
+  // outside the range of the sentence, returns -2.
+  void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
+                  int *focus) const {
+    if (*focus < -1 || *focus >= state.sentence().token_size()) {
+      *focus = -2;
+      return;
+    }
+    const int levels = argument();
+    if (levels < 0) {
+      *focus = state.LeftmostChild(*focus, -levels);
+    } else {
+      *focus = state.RightmostChild(*focus, levels);
+    }
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("child", ChildFeatureLocator);
+
+// Parser feature locator for locating siblings of the focus token. The argument
+// specifies the sibling position relative to the focus token: a negative value
+// triggers a search to the left, while a positive value one to the right.
+// Please note that this operates on partially built dependency trees.
+class SiblingFeatureLocator
+    : public ParserIndexLocator<SiblingFeatureLocator> {
+ public:
+  // Updates the current focus to a new location. If the initial focus is
+  // outside the range of the sentence, returns -2.
+  void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
+                  int *focus) const {
+    if (*focus < -1 || *focus >= state.sentence().token_size()) {
+      *focus = -2;
+      return;
+    }
+    const int position = argument();
+    if (position < 0) {
+      *focus = state.LeftSibling(*focus, -position);
+    } else {
+      *focus = state.RightSibling(*focus, position);
+    }
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("sibling", SiblingFeatureLocator);
+
+// Feature function for computing the label from focus token. Note that this
+// does not use the precomputed values, since we get the labels from the stack;
+// the reason it utilizes sentence_features::Label is to obtain the label map.
+class LabelFeatureFunction : public BasicParserSentenceFeatureFunction<Label> {
+ public:
+  // Computes the label of the relation between the focus token and its parent.
+  // Valid focus values range from -1 to sentence->size() - 1, inclusively.
+  FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
+                       int focus, const FeatureVector *result) const override {
+    if (focus == -1) return RootValue();
+    if (focus < -1 || focus >= state.sentence().token_size()) {
+      return feature_.NumValues();
+    }
+    const int label = state.Label(focus);
+    return label == -1 ? RootValue() : label;
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("label", LabelFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<Word> WordFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("word", WordFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<Tag> TagFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("tag", TagFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<Digit> DigitFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("digit", DigitFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<Hyphen> HyphenFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("hyphen", HyphenFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<PrefixFeature> PrefixFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("prefix", PrefixFeatureFunction);
+
+typedef BasicParserSentenceFeatureFunction<SuffixFeature> SuffixFeatureFunction;
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("suffix", SuffixFeatureFunction);
+
+// Parser feature function that can use nested sentence feature functions for
+// feature extraction.
+class ParserTokenFeatureFunction : public NestedFeatureFunction<
+  FeatureFunction<Sentence, int>, ParserState, int> {
+ public:
+  void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
+    for (auto *function : nested_) {
+      function->Preprocess(workspaces, state->mutable_sentence());
+    }
+  }
+
+  void Evaluate(const WorkspaceSet &workspaces, const ParserState &state,
+                int focus, FeatureVector *result) const override {
+    for (auto *function : nested_) {
+      function->Evaluate(workspaces, state.sentence(), focus, result);
+    }
+  }
+
+  // Returns the first nested feature's computed value.
+  FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
+                       int focus, const FeatureVector *result) const override {
+    if (nested_.empty()) return -1;
+    return nested_[0]->Compute(workspaces, state.sentence(), focus, result);
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("token",
+                                     ParserTokenFeatureFunction);
+
+}  // namespace syntaxnet

+ 150 - 0
syntaxnet/syntaxnet/parser_features.h

@@ -0,0 +1,150 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Sentence-based features for the transition parser.
+
+#ifndef $TARGETDIR_PARSER_FEATURES_H_
+#define $TARGETDIR_PARSER_FEATURES_H_
+
+#include <string>
+
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/workspace.h"
+
+namespace syntaxnet {
+
+// A union used to represent discrete and continuous feature values.
+union FloatFeatureValue {
+ public:
+  explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
+  FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
+  FeatureValue discrete_value;
+  struct {
+    uint32 id;
+    float weight;
+  };
+};
+
+typedef FeatureFunction<ParserState> ParserFeatureFunction;
+
+// Feature function for the transition parser based on a parser state object and
+// a token index. This typically extracts information from a given token.
+typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction;
+
+// Utilities to register the two types of parser features.
+#define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \
+  REGISTER_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
+
+#define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
+  REGISTER_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, component)
+
+// Alias for locator type that takes a parser state, and produces a focus
+// integer that can be used on nested ParserIndexFeature objects.
+template<class DER>
+using ParserLocator = FeatureAddFocusLocator<DER, ParserState, int>;
+
+// Alias for Locator type features that take (ParserState, int) signatures and
+// call other ParserIndexFeatures.
+template<class DER>
+using ParserIndexLocator = FeatureLocator<DER, ParserState, int>;
+
+// Feature extractor for the transition parser based on a parser state object.
+typedef FeatureExtractor<ParserState> ParserFeatureExtractor;
+
+// A simple wrapper FeatureType that adds a special "<ROOT>" type.
+class RootFeatureType : public FeatureType {
+ public:
+  // Creates a RootFeatureType that wraps a given type and adds the special
+  // "<ROOT>" value in root_value.
+  RootFeatureType(const string &name, const FeatureType &wrapped_type,
+                  int root_value);
+
+  // Returns the feature value name, but with the special "<ROOT>" value.
+  string GetFeatureValueName(FeatureValue value) const override;
+
+  // Returns the original number of features plus one for the "<ROOT>" value.
+  FeatureValue GetDomainSize() const override;
+
+ private:
+  // A wrapped type that handles everything else besides "<ROOT>".
+  const FeatureType &wrapped_type_;
+
+  // The reserved root value.
+  int root_value_;
+};
+
+// Simple feature function that wraps a Sentence based feature
+// function. It adds a "<ROOT>" feature value that is triggered whenever the
+// focus is the special root token. This class is sub-classed based on the
+// extracted arguments of the nested function.
+template<class F>
+class ParserSentenceFeatureFunction : public ParserIndexFeatureFunction {
+ public:
+  // Instantiates and sets up the nested feature.
+  void Setup(TaskContext *context) override {
+    this->feature_.set_descriptor(this->descriptor());
+    this->feature_.set_prefix(this->prefix());
+    this->feature_.set_extractor(this->extractor());
+    feature_.Setup(context);
+  }
+
+  // Initializes the nested feature and sets feature type.
+  void Init(TaskContext *context) override {
+    feature_.Init(context);
+    num_base_values_ = feature_.GetFeatureType()->GetDomainSize();
+    set_feature_type(new RootFeatureType(
+        name(), *feature_.GetFeatureType(), RootValue()));
+  }
+
+  // Passes workspace requests and preprocessing to the nested feature.
+  void RequestWorkspaces(WorkspaceRegistry *registry) override {
+    feature_.RequestWorkspaces(registry);
+  }
+
+  void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
+    feature_.Preprocess(workspaces, state->mutable_sentence());
+  }
+
+ protected:
+  // Returns the special value to represent a root token.
+  FeatureValue RootValue() const { return num_base_values_; }
+
+  // Store the number of base values from the wrapped function so compute the
+  // root value.
+  int num_base_values_;
+
+  // The wrapped feature.
+  F feature_;
+};
+
+// Specialization of ParserSentenceFeatureFunction that calls the nested feature
+// with (Sentence, int) arguments based on the current integer focus.
+template<class F>
+class BasicParserSentenceFeatureFunction :
+      public ParserSentenceFeatureFunction<F> {
+ public:
+  FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
+                       int focus, const FeatureVector *result) const override {
+    if (focus == -1) return this->RootValue();
+    return this->feature_.Compute(workspaces, state.sentence(), focus, result);
+  }
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_PARSER_FEATURES_H_

+ 144 - 0
syntaxnet/syntaxnet/parser_features_test.cc

@@ -0,0 +1,144 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/parser_features.h"
+
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/populate_test_inputs.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+
+// Feature extractor for the transition parser based on a parser state object.
+typedef FeatureExtractor<ParserState, int> ParserIndexFeatureExtractor;
+
+// Test fixture for parser features.
+class ParserFeatureFunctionTest : public ::testing::Test {
+ protected:
+  // Sets up a parser state.
+  void SetUp() override {
+    // Prepare a document.
+    const char *kTaggedDocument =
+        "text: 'I saw a man with a telescope.' "
+        "token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
+        " break_level: NO_BREAK } "
+        "token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
+        " break_level: SPACE_BREAK } "
+        "token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
+        " break_level: SPACE_BREAK } "
+        "token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
+        " break_level: SPACE_BREAK } "
+        "token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
+        " break_level: SPACE_BREAK } "
+        "token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
+        " break_level: SPACE_BREAK } "
+        "token { word: 'telescope' start: 19 end: 27 tag: 'NN' category: 'NOUN'"
+        " break_level: SPACE_BREAK } "
+        "token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
+        " break_level: NO_BREAK }";
+    CHECK(TextFormat::ParseFromString(kTaggedDocument, &sentence_));
+    creators_ = PopulateTestInputs::Defaults(sentence_);
+
+    // Prepare a label map. By adding labels in lexicographic order we make sure
+    // the term indices stay the same after sorting (which happens when the
+    // label map is saved to disk).
+    label_map_.Increment("NULL");
+    label_map_.Increment("ROOT");
+    label_map_.Increment("det");
+    label_map_.Increment("dobj");
+    label_map_.Increment("nsubj");
+    label_map_.Increment("p");
+    label_map_.Increment("pobj");
+    label_map_.Increment("prep");
+    creators_.Add("label-map", "text", "", [this](const string &filename) {
+      label_map_.Save(filename);
+    });
+
+    // Prepare a parser state.
+    state_.reset(new ParserState(&sentence_, nullptr /* no transition state */,
+                                 &label_map_));
+  }
+
+  // Prepares a feature for computations.
+  string ExtractFeature(const string &feature_name) {
+    context_.mutable_spec()->mutable_input()->Clear();
+    context_.mutable_spec()->mutable_output()->Clear();
+    feature_extractor_.reset(new ParserFeatureExtractor());
+    feature_extractor_->Parse(feature_name);
+    feature_extractor_->Setup(&context_);
+    creators_.Populate(&context_);
+    feature_extractor_->Init(&context_);
+    feature_extractor_->RequestWorkspaces(&registry_);
+    workspaces_.Reset(registry_);
+    feature_extractor_->Preprocess(&workspaces_, state_.get());
+    FeatureVector result;
+    feature_extractor_->ExtractFeatures(workspaces_, *state_, &result);
+    return result.type(0)->GetFeatureValueName(result.value(0));
+  }
+
+  std::unique_ptr<ParserState> state_;
+  Sentence sentence_;
+  WorkspaceSet workspaces_;
+  TermFrequencyMap label_map_;
+
+  PopulateTestInputs::CreatorMap creators_;
+  TaskContext context_;
+  WorkspaceRegistry registry_;
+  std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
+};
+
+TEST_F(ParserFeatureFunctionTest, TagFeatureFunction) {
+  state_->Push(-1);
+  state_->Push(0);
+  EXPECT_EQ("PRP", ExtractFeature("input.tag"));
+  EXPECT_EQ("VBD", ExtractFeature("input(1).tag"));
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature("input(10).tag"));
+  EXPECT_EQ("PRP", ExtractFeature("stack(0).tag"));
+  EXPECT_EQ("<ROOT>", ExtractFeature("stack(1).tag"));
+}
+
+TEST_F(ParserFeatureFunctionTest, LabelFeatureFunction) {
+  // Construct a partial dependency tree.
+  state_->AddArc(0, 1, 4);
+  state_->AddArc(1, -1, 1);
+  state_->AddArc(2, 3, 2);
+  state_->AddArc(3, 1, 3);
+  state_->AddArc(5, 6, 2);
+  state_->AddArc(6, 4, 6);
+  state_->AddArc(7, 1, 5);
+
+  // Test the feature function.
+  EXPECT_EQ(label_map_.GetTerm(4), ExtractFeature("input.label"));
+  EXPECT_EQ("ROOT", ExtractFeature("input(1).label"));
+  EXPECT_EQ(label_map_.GetTerm(2), ExtractFeature("input(2).label"));
+
+  // Push artifical root token onto the stack. This triggers the wrapped <ROOT>
+  // value, rather than indicating a token with the label "ROOT" (which may or
+  // may not be the artificial root token.)
+  state_->Push(-1);
+  EXPECT_EQ("<ROOT>", ExtractFeature("stack.label"));
+}
+
+}  // namespace syntaxnet

+ 248 - 0
syntaxnet/syntaxnet/parser_state.cc

@@ -0,0 +1,248 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/parser_state.h"
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/kbest_syntax.pb.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+
+namespace syntaxnet {
+
+const char ParserState::kRootLabel[] = "ROOT";
+
+ParserState::ParserState(Sentence *sentence,
+                         ParserTransitionState *transition_state,
+                         const TermFrequencyMap *label_map)
+    : sentence_(sentence),
+      num_tokens_(sentence->token_size()),
+      transition_state_(transition_state),
+      label_map_(label_map),
+      root_label_(kDefaultRootLabel),
+      next_(0) {
+  // Initialize the stack. Some transition systems could also push the
+  // artificial root on the stack, so we make room for that as well.
+  stack_.reserve(num_tokens_ + 1);
+
+  // Allocate space for head indices and labels. Initialize the head for all
+  // tokens to be the artificial root node, i.e. token -1.
+  head_.resize(num_tokens_, -1);
+  label_.resize(num_tokens_, RootLabel());
+
+  // Transition system-specific preprocessing.
+  if (transition_state_ != nullptr) transition_state_->Init(this);
+}
+
+ParserState::~ParserState() { delete transition_state_; }
+
+ParserState *ParserState::Clone() const {
+  ParserState *new_state = new ParserState();
+  new_state->sentence_ = sentence_;
+  new_state->num_tokens_ = num_tokens_;
+  new_state->alternative_ = alternative_;
+  new_state->transition_state_ =
+      (transition_state_ == nullptr ? nullptr : transition_state_->Clone());
+  new_state->label_map_ = label_map_;
+  new_state->root_label_ = root_label_;
+  new_state->next_ = next_;
+  new_state->stack_.assign(stack_.begin(), stack_.end());
+  new_state->head_.assign(head_.begin(), head_.end());
+  new_state->label_.assign(label_.begin(), label_.end());
+  new_state->score_ = score_;
+  new_state->is_gold_ = is_gold_;
+  return new_state;
+}
+
+int ParserState::RootLabel() const { return root_label_; }
+
+int ParserState::Next() const {
+  DCHECK_GE(next_, -1);
+  DCHECK_LE(next_, num_tokens_);
+  return next_;
+}
+
+int ParserState::Input(int offset) const {
+  int index = next_ + offset;
+  return index >= -1 && index < num_tokens_ ? index : -2;
+}
+
+void ParserState::Advance() {
+  DCHECK_LT(next_, num_tokens_);
+  ++next_;
+}
+
+bool ParserState::EndOfInput() const { return next_ == num_tokens_; }
+
+void ParserState::Push(int index) {
+  DCHECK_LE(stack_.size(), num_tokens_);
+  stack_.push_back(index);
+}
+
+int ParserState::Pop() {
+  DCHECK(!StackEmpty());
+  const int result = stack_.back();
+  stack_.pop_back();
+  return result;
+}
+
+int ParserState::Top() const {
+  DCHECK(!StackEmpty());
+  return stack_.back();
+}
+
+int ParserState::Stack(int position) const {
+  if (position < 0) return -2;
+  const int index = stack_.size() - 1 - position;
+  return (index < 0) ? -2 : stack_[index];
+}
+
+int ParserState::StackSize() const { return stack_.size(); }
+
+bool ParserState::StackEmpty() const { return stack_.empty(); }
+
+int ParserState::Head(int index) const {
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  return index == -1 ? -1 : head_[index];
+}
+
+int ParserState::Label(int index) const {
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  return index == -1 ? RootLabel() : label_[index];
+}
+
+int ParserState::Parent(int index, int n) const {
+  // Find the n-th parent by applying the head function n times.
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  while (n-- > 0) index = Head(index);
+  return index;
+}
+
+int ParserState::LeftmostChild(int index, int n) const {
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  while (n-- > 0) {
+    // Find the leftmost child by scanning from start until a child is
+    // encountered.
+    int i;
+    for (i = -1; i < index; ++i) {
+      if (Head(i) == index) break;
+    }
+    if (i == index) return -2;
+    index = i;
+  }
+  return index;
+}
+
+int ParserState::RightmostChild(int index, int n) const {
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  while (n-- > 0) {
+    // Find the rightmost child by scanning backward from end until a child
+    // is encountered.
+    int i;
+    for (i = num_tokens_ - 1; i > index; --i) {
+      if (Head(i) == index) break;
+    }
+    if (i == index) return -2;
+    index = i;
+  }
+  return index;
+}
+
+int ParserState::LeftSibling(int index, int n) const {
+  // Find the n-th left sibling by scanning left until the n-th child of the
+  // parent is encountered.
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  if (index == -1 && n > 0) return -2;
+  int i = index;
+  while (n > 0) {
+    --i;
+    if (i == -1) return -2;
+    if (Head(i) == Head(index)) --n;
+  }
+  return i;
+}
+
+int ParserState::RightSibling(int index, int n) const {
+  // Find the n-th right sibling by scanning right until the n-th child of the
+  // parent is encountered.
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  if (index == -1 && n > 0) return -2;
+  int i = index;
+  while (n > 0) {
+    ++i;
+    if (i == num_tokens_) return -2;
+    if (Head(i) == Head(index)) --n;
+  }
+  return i;
+}
+
+void ParserState::AddArc(int index, int head, int label) {
+  DCHECK_GE(index, 0);
+  DCHECK_LT(index, num_tokens_);
+  head_[index] = head;
+  label_[index] = label;
+}
+
+int ParserState::GoldHead(int index) const {
+  // A valid ParserState index is transformed to a valid Sentence index,
+  // then the gold head is extracted.
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  if (index == -1) return -1;
+  const int offset = 0;
+  const int gold_head = GetToken(index).head();
+  return gold_head == -1 ? -1 : gold_head - offset;
+}
+
+int ParserState::GoldLabel(int index) const {
+  // A valid ParserState index is transformed to a valid Sentence index,
+  // then the gold label is extracted.
+  DCHECK_GE(index, -1);
+  DCHECK_LT(index, num_tokens_);
+  if (index == -1) return RootLabel();
+  string gold_label;
+  gold_label = GetToken(index).label();
+  return label_map_->LookupIndex(gold_label, RootLabel() /* unknown */);
+}
+
+void ParserState::AddParseToDocument(Sentence *sentence,
+                                     bool rewrite_root_labels) const {
+  transition_state_->AddParseToDocument(*this, rewrite_root_labels, sentence);
+}
+
+bool ParserState::IsTokenCorrect(int index) const {
+  return transition_state_->IsTokenCorrect(*this, index);
+}
+
+string ParserState::LabelAsString(int label) const {
+  if (label == root_label_) return "ROOT";
+  if (label >= 0 && label < label_map_->Size()) {
+    return label_map_->GetTerm(label);
+  }
+  return "";
+}
+
+string ParserState::ToString() const {
+  return transition_state_->ToString(*this);
+}
+
+}  // namespace syntaxnet

+ 233 - 0
syntaxnet/syntaxnet/parser_state.h

@@ -0,0 +1,233 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Parser state for the transition-based dependency parser.
+
+#ifndef $TARGETDIR_PARSER_STATE_H_
+#define $TARGETDIR_PARSER_STATE_H_
+
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/kbest_syntax.pb.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/sentence.pb.h"
+
+namespace syntaxnet {
+
+class TermFrequencyMap;
+
+// A ParserState object represents the state of the parser during the parsing of
+// a sentence. The state consists of a pointer to the next input token and a
+// stack of partially processed tokens. The parser state can be changed by
+// applying a sequence of transitions. Some transitions also add relations
+// to the dependency tree of the sentence. The parser state also records the
+// (partial) parse tree for the sentence by recording the head of each token and
+// the label of this relation. The state is used for both training and parsing.
+class ParserState {
+ public:
+  // String representation of the root label.
+  static const char kRootLabel[];
+
+  // Default value for the root label in case it's not in the label map.
+  static const int kDefaultRootLabel = -1;
+
+  // Initializes the parser state for a sentence, using an additional transition
+  // state for preprocessing and/or additional information specific to the
+  // transition system. The transition state is allowed to be null, in which
+  // case no additional work is performed. The parser state takes ownership of
+  // the transition state. A label map is used for transforming between integer
+  // and string representations of the labels.
+  ParserState(Sentence *sentence,
+              ParserTransitionState *transition_state,
+              const TermFrequencyMap *label_map);
+
+  // Deletes the parser state.
+  ~ParserState();
+
+  // Clones the parser state.
+  ParserState *Clone() const;
+
+  // Returns the root label.
+  int RootLabel() const;
+
+  // Returns the index of the next input token.
+  int Next() const;
+
+  // Returns the number of tokens in the sentence.
+  int NumTokens() const { return num_tokens_; }
+
+  // Returns the token index relative to the next input token. If no such token
+  // exists, returns -2.
+  int Input(int offset) const;
+
+  // Advances to the next input token.
+  void Advance();
+
+  // Returns true if all tokens have been processed.
+  bool EndOfInput() const;
+
+  // Pushes an element to the stack.
+  void Push(int index);
+
+  // Pops the top element from stack and returns it.
+  int Pop();
+
+  // Returns the element from the top of the stack.
+  int Top() const;
+
+  // Returns the element at a certain position in the stack. Stack(0) is the top
+  // stack element. If no such position exists, returns -2.
+  int Stack(int position) const;
+
+  // Returns the number of elements on the stack.
+  int StackSize() const;
+
+  // Returns true if the stack is empty.
+  bool StackEmpty() const;
+
+  // Returns the head index for a given token.
+  int Head(int index) const;
+
+  // Returns the label of the relation to head for a given token.
+  int Label(int index) const;
+
+  // Returns the parent of a given token 'n' levels up in the tree.
+  int Parent(int index, int n) const;
+
+  // Returns the leftmost child of a given token 'n' levels down in the tree. If
+  // no such child exists, returns -2.
+  int LeftmostChild(int index, int n) const;
+
+  // Returns the rightmost child of a given token 'n' levels down in the tree.
+  // If no such child exists, returns -2.
+  int RightmostChild(int index, int n) const;
+
+  // Returns the n-th left sibling of a given token. If no such sibling exists,
+  // returns -2.
+  int LeftSibling(int index, int n) const;
+
+  // Returns the n-th right sibling of a given token. If no such sibling exists,
+  // returns -2.
+  int RightSibling(int index, int n) const;
+
+  // Adds an arc to the partial dependency tree of the state.
+  void AddArc(int index, int head, int label);
+
+  // Returns the gold head index for a given token, based on the underlying
+  // annotated sentence.
+  int GoldHead(int index) const;
+
+  // Returns the gold label for a given token, based on the underlying annotated
+  // sentence.
+  int GoldLabel(int index) const;
+
+  // Get a reference to the underlying token at index. Returns an empty default
+  // Token if accessing the root.
+  const Token &GetToken(int index) const {
+    if (index == -1) return kRootToken;
+    return sentence().token(index);
+  }
+
+  // Annotates a document with the dependency relations built during parsing for
+  // one of its sentences. If rewrite_root_labels is true, then all tokens with
+  // no heads will be assigned the default root label "ROOT".
+  void AddParseToDocument(Sentence *document, bool rewrite_root_labels) const;
+
+  // As above, but uses the default of rewrite_root_labels = true.
+  void AddParseToDocument(Sentence *document) const {
+    AddParseToDocument(document, true);
+  }
+
+  // Whether a parsed token should be considered correct for evaluation.
+  bool IsTokenCorrect(int index) const;
+
+  // Returns the string representation of a dependency label, or an empty string
+  // if the label is invalid.
+  string LabelAsString(int label) const;
+
+  // Returns a string representation of the parser state.
+  string ToString() const;
+
+  // Returns the underlying sentence instance.
+  const Sentence &sentence() const { return *sentence_; }
+  Sentence *mutable_sentence() const { return sentence_; }
+
+  // Returns the transition system-specific state.
+  const ParserTransitionState *transition_state() const {
+    return transition_state_;
+  }
+  ParserTransitionState *mutable_transition_state() {
+    return transition_state_;
+  }
+
+  // Gets/sets the flag which says that the state was obtained though gold
+  // transitions only.
+  bool is_gold() const { return is_gold_; }
+  void set_is_gold(bool is_gold) { is_gold_ = is_gold; }
+
+ private:
+  // Empty constructor used for the cloning operation.
+  ParserState() {}
+
+  // Default value for the root token.
+  const Token kRootToken;
+
+  // Sentence to parse. Not owned.
+  Sentence *sentence_ = nullptr;
+
+  // Number of tokens in the sentence to parse.
+  int num_tokens_;
+
+  // Which alternative token analysis is used for tag/category/head/label
+  // information. -1 means use default.
+  int alternative_ = -1;
+
+  // Transition system-specific state. Owned.
+  ParserTransitionState *transition_state_ = nullptr;
+
+  // Label map used for conversions between integer and string representations
+  // of the dependency labels. Not owned.
+  const TermFrequencyMap *label_map_ = nullptr;
+
+  // Root label.
+  int root_label_;
+
+  // Index of the next input token.
+  int next_;
+
+  // Parse stack of partially processed tokens.
+  vector<int> stack_;
+
+  // List of head positions for the (partial) dependency tree.
+  vector<int> head_;
+
+  // List of dependency relation labels describing the (partial) dependency
+  // tree.
+  vector<int> label_;
+
+  // Score of the parser state.
+  double score_ = 0.0;
+
+  // True if this is the gold standard sequence (used for structured learning).
+  bool is_gold_ = false;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(ParserState);
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_PARSER_STATE_H_

+ 303 - 0
syntaxnet/syntaxnet/parser_trainer.py

@@ -0,0 +1,303 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A program to train a tensorflow neural net parser from a a conll file."""
+
+
+
+import os
+import os.path
+import time
+
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import logging
+
+from google.protobuf import text_format
+
+from syntaxnet import graph_builder
+from syntaxnet import structured_graph_builder
+from syntaxnet.ops import gen_parser_ops
+from syntaxnet import task_spec_pb2
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('tf_master', '',
+                    'TensorFlow execution engine to connect to.')
+flags.DEFINE_string('output_path', '', 'Top level for output.')
+flags.DEFINE_string('task_context', '',
+                    'Path to a task context with resource locations and '
+                    'parameters.')
+flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
+flags.DEFINE_string('params', '0', 'Unique identifier of parameter grid point.')
+flags.DEFINE_string('training_corpus', 'training-corpus',
+                    'Name of the context input to read training data from.')
+flags.DEFINE_string('tuning_corpus', 'tuning-corpus',
+                    'Name of the context input to read tuning data from.')
+flags.DEFINE_string('word_embeddings', None,
+                    'Recordio containing pretrained word embeddings, will be '
+                    'loaded as the first embedding matrix.')
+flags.DEFINE_bool('compute_lexicon', False, '')
+flags.DEFINE_bool('projectivize_training_set', False, '')
+flags.DEFINE_string('hidden_layer_sizes', '200,200',
+                    'Comma separated list of hidden layer sizes.')
+flags.DEFINE_string('graph_builder', 'greedy',
+                    'Graph builder to use, either "greedy" or "structured".')
+flags.DEFINE_integer('batch_size', 32,
+                     'Number of sentences to process in parallel.')
+flags.DEFINE_integer('beam_size', 10, 'Number of slots for beam parsing.')
+flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to train for.')
+flags.DEFINE_integer('max_steps', 50,
+                     'Max number of parser steps during a training step.')
+flags.DEFINE_integer('report_every', 100,
+                     'Report cost and training accuracy every this many steps.')
+flags.DEFINE_integer('checkpoint_every', 5000,
+                     'Measure tuning UAS and checkpoint every this many steps.')
+flags.DEFINE_bool('slim_model', False,
+                  'Whether to remove non-averaged variables, for compactness.')
+flags.DEFINE_float('learning_rate', 0.1, 'Initial learning rate parameter.')
+flags.DEFINE_integer('decay_steps', 4000,
+                     'Decay learning rate by 0.96 every this many steps.')
+flags.DEFINE_float('momentum', 0.9,
+                   'Momentum parameter for momentum optimizer.')
+flags.DEFINE_string('seed', '0', 'Initialization seed for TF variables.')
+flags.DEFINE_string('pretrained_params', None,
+                    'Path to model from which to load params.')
+flags.DEFINE_string('pretrained_params_names', None,
+                    'List of names of tensors to load from pretrained model.')
+flags.DEFINE_float('averaging_decay', 0.9999,
+                   'Decay for exponential moving average when computing'
+                   'averaged parameters, set to 1 to do vanilla averaging.')
+
+
+def StageName():
+  return os.path.join(FLAGS.arg_prefix, FLAGS.graph_builder)
+
+
+def OutputPath(path):
+  return os.path.join(FLAGS.output_path, StageName(), FLAGS.params, path)
+
+
+def RewriteContext():
+  context = task_spec_pb2.TaskSpec()
+  with gfile.FastGFile(FLAGS.task_context) as fin:
+    text_format.Merge(fin.read(), context)
+  for resource in context.input:
+    if resource.creator == StageName():
+      del resource.part[:]
+      part = resource.part.add()
+      part.file_pattern = os.path.join(OutputPath(resource.name))
+  with gfile.FastGFile(OutputPath('context'), 'w') as fout:
+    fout.write(str(context))
+
+
+def WriteStatus(num_steps, eval_metric, best_eval_metric):
+  status = os.path.join(os.getenv('GOOGLE_STATUS_DIR') or '/tmp', 'STATUS')
+  message = ('Parameters: %s | Steps: %d | Tuning score: %.2f%% | '
+             'Best tuning score: %.2f%%' % (FLAGS.params, num_steps,
+                                            eval_metric, best_eval_metric))
+  with gfile.FastGFile(status, 'w') as fout:
+    fout.write(message)
+  with gfile.FastGFile(OutputPath('status'), 'a') as fout:
+    fout.write(message + '\n')
+
+
+def Eval(sess, parser, num_steps, best_eval_metric):
+  """Evaluates a network and checkpoints it to disk.
+
+  Args:
+    sess: tensorflow session to use
+    parser: graph builder containing all ops references
+    num_steps: number of training steps taken, for logging
+    best_eval_metric: current best eval metric, to decide whether this model is
+        the best so far
+
+  Returns:
+    new best eval metric
+  """
+  logging.info('Evaluating training network.')
+  t = time.time()
+  num_epochs = None
+  num_tokens = 0
+  num_correct = 0
+  while True:
+    tf_eval_epochs, tf_eval_metrics = sess.run([
+        parser.evaluation['epochs'], parser.evaluation['eval_metrics']
+    ])
+    num_tokens += tf_eval_metrics[0]
+    num_correct += tf_eval_metrics[1]
+    if num_epochs is None:
+      num_epochs = tf_eval_epochs
+    elif num_epochs < tf_eval_epochs:
+      break
+  eval_metric = 0 if num_tokens == 0 else (100.0 * num_correct / num_tokens)
+  logging.info('Seconds elapsed in evaluation: %.2f, '
+               'eval metric: %.2f%%', time.time() - t, eval_metric)
+  WriteStatus(num_steps, eval_metric, max(eval_metric, best_eval_metric))
+
+  # Save parameters.
+  if FLAGS.output_path:
+    logging.info('Writing out trained parameters.')
+    parser.saver.save(sess, OutputPath('latest-model'))
+    if eval_metric > best_eval_metric:
+      parser.saver.save(sess, OutputPath('model'))
+
+  return max(eval_metric, best_eval_metric)
+
+
+def Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
+  """Builds and trains the network.
+
+  Args:
+    sess: tensorflow session to use.
+    num_actions: number of possible golden actions.
+    feature_sizes: size of each feature vector.
+    domain_sizes: number of possible feature ids in each feature vector.
+    embedding_dims: embedding dimension to use for each feature group.
+  """
+  t = time.time()
+  hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
+  logging.info('Building training network with parameters: feature_sizes: %s '
+               'domain_sizes: %s', feature_sizes, domain_sizes)
+
+  if FLAGS.graph_builder == 'greedy':
+    parser = graph_builder.GreedyParser(num_actions,
+                                        feature_sizes,
+                                        domain_sizes,
+                                        embedding_dims,
+                                        hidden_layer_sizes,
+                                        seed=int(FLAGS.seed),
+                                        gate_gradients=True,
+                                        averaging_decay=FLAGS.averaging_decay,
+                                        arg_prefix=FLAGS.arg_prefix)
+  else:
+    parser = structured_graph_builder.StructuredGraphBuilder(
+        num_actions,
+        feature_sizes,
+        domain_sizes,
+        embedding_dims,
+        hidden_layer_sizes,
+        seed=int(FLAGS.seed),
+        gate_gradients=True,
+        averaging_decay=FLAGS.averaging_decay,
+        arg_prefix=FLAGS.arg_prefix,
+        beam_size=FLAGS.beam_size,
+        max_steps=FLAGS.max_steps)
+
+  task_context = OutputPath('context')
+  if FLAGS.word_embeddings is not None:
+    parser.AddPretrainedEmbeddings(0, FLAGS.word_embeddings, task_context)
+
+  corpus_name = ('projectivized-training-corpus' if
+                 FLAGS.projectivize_training_set else FLAGS.training_corpus)
+  parser.AddTraining(task_context,
+                     FLAGS.batch_size,
+                     learning_rate=FLAGS.learning_rate,
+                     momentum=FLAGS.momentum,
+                     decay_steps=FLAGS.decay_steps,
+                     corpus_name=corpus_name)
+  parser.AddEvaluation(task_context,
+                       FLAGS.batch_size,
+                       corpus_name=FLAGS.tuning_corpus)
+  parser.AddSaver(FLAGS.slim_model)
+
+  # Save graph.
+  if FLAGS.output_path:
+    with gfile.FastGFile(OutputPath('graph'), 'w') as f:
+      f.write(sess.graph_def.SerializeToString())
+
+  logging.info('Initializing...')
+  num_epochs = 0
+  cost_sum = 0.0
+  num_steps = 0
+  best_eval_metric = 0.0
+  sess.run(parser.inits.values())
+
+  if FLAGS.pretrained_params is not None:
+    logging.info('Loading pretrained params from %s', FLAGS.pretrained_params)
+    feed_dict = {'save/Const:0': FLAGS.pretrained_params}
+    targets = []
+    for node in sess.graph_def.node:
+      if (node.name.startswith('save/Assign') and
+          node.input[0] in FLAGS.pretrained_params_names.split(',')):
+        logging.info('Loading %s with op %s', node.input[0], node.name)
+        targets.append(node.name)
+    sess.run(targets, feed_dict=feed_dict)
+
+  logging.info('Training...')
+  while num_epochs < FLAGS.num_epochs:
+    tf_epochs, tf_cost, _ = sess.run([parser.training[
+        'epochs'], parser.training['cost'], parser.training['train_op']])
+    num_epochs = tf_epochs
+    num_steps += 1
+    cost_sum += tf_cost
+    if num_steps % FLAGS.report_every == 0:
+      logging.info('Epochs: %d, num steps: %d, '
+                   'seconds elapsed: %.2f, avg cost: %.2f, ', num_epochs,
+                   num_steps, time.time() - t, cost_sum / FLAGS.report_every)
+      cost_sum = 0.0
+    if num_steps % FLAGS.checkpoint_every == 0:
+      best_eval_metric = Eval(sess, parser, num_steps, best_eval_metric)
+
+
+def main(unused_argv):
+  logging.set_verbosity(logging.INFO)
+  if not gfile.IsDirectory(OutputPath('')):
+    gfile.MakeDirs(OutputPath(''))
+
+  # Rewrite context.
+  RewriteContext()
+
+  # Creates necessary term maps.
+  if FLAGS.compute_lexicon:
+    logging.info('Computing lexicon...')
+    with tf.Session(FLAGS.tf_master) as sess:
+      gen_parser_ops.lexicon_builder(task_context=OutputPath('context'),
+                                     corpus_name=FLAGS.training_corpus).run()
+  with tf.Session(FLAGS.tf_master) as sess:
+    feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
+        gen_parser_ops.feature_size(task_context=OutputPath('context'),
+                                    arg_prefix=FLAGS.arg_prefix))
+
+  # Well formed and projectivize.
+  if FLAGS.projectivize_training_set:
+    logging.info('Preprocessing...')
+    with tf.Session(FLAGS.tf_master) as sess:
+      source, last = gen_parser_ops.document_source(
+          task_context=OutputPath('context'),
+          batch_size=FLAGS.batch_size,
+          corpus_name=FLAGS.training_corpus)
+      sink = gen_parser_ops.document_sink(
+          task_context=OutputPath('context'),
+          corpus_name='projectivized-training-corpus',
+          documents=gen_parser_ops.projectivize_filter(
+              gen_parser_ops.well_formed_filter(source,
+                                                task_context=OutputPath(
+                                                    'context')),
+              task_context=OutputPath('context')))
+      while True:
+        tf_last, _ = sess.run([last, sink])
+        if tf_last:
+          break
+
+  logging.info('Training...')
+  with tf.Session(FLAGS.tf_master) as sess:
+    Train(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 115 - 0
syntaxnet/syntaxnet/parser_trainer_test.sh

@@ -0,0 +1,115 @@
+#!/bin/bash
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This test trains a parser on a small dataset, then runs it in greedy mode and
+# in structured mode with beam 1, and checks that the result is identical.
+
+
+
+
+set -eux
+
+BINDIR=$TEST_SRCDIR/syntaxnet
+CONTEXT=$BINDIR/testdata/context.pbtxt
+TMP_DIR=/tmp/syntaxnet-output
+
+mkdir -p $TMP_DIR
+sed "s=SRCDIR=$TEST_SRCDIR=" "$CONTEXT" | \
+  sed "s=OUTPATH=$TMP_DIR=" > $TMP_DIR/context
+
+PARAMS=128-0.08-3600-0.9-0
+
+"$BINDIR/parser_trainer" \
+  --arg_prefix=brain_parser \
+  --batch_size=32 \
+  --compute_lexicon \
+  --decay_steps=3600 \
+  --graph_builder=greedy \
+  --hidden_layer_sizes=128 \
+  --learning_rate=0.08 \
+  --momentum=0.9 \
+  --output_path=$TMP_DIR \
+  --task_context=$TMP_DIR/context \
+  --training_corpus=training-corpus \
+  --tuning_corpus=tuning-corpus \
+  --params=$PARAMS \
+  --num_epochs=12 \
+  --report_every=100 \
+  --checkpoint_every=1000 \
+  --logtostderr
+
+"$BINDIR/parser_eval" \
+  --task_context=$TMP_DIR/brain_parser/greedy/$PARAMS/context \
+  --hidden_layer_sizes=128 \
+  --input=tuning-corpus \
+  --output=stdout \
+  --arg_prefix=brain_parser \
+  --graph_builder=greedy \
+  --model_path=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
+  --logtostderr \
+  > $TMP_DIR/greedy-out
+
+"$BINDIR/parser_eval" \
+  --task_context=$TMP_DIR/context \
+  --hidden_layer_sizes=128 \
+  --beam_size=1 \
+  --input=tuning-corpus \
+  --output=stdout \
+  --arg_prefix=brain_parser \
+  --graph_builder=structured \
+  --model_path=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
+  --logtostderr \
+  > $TMP_DIR/struct-beam1-out
+
+diff $TMP_DIR/greedy-out $TMP_DIR/struct-beam1-out
+
+STRUCT_PARAMS=128-0.001-3600-0.9-0
+
+"$BINDIR/parser_trainer" \
+  --arg_prefix=brain_parser \
+  --batch_size=8 \
+  --compute_lexicon \
+  --decay_steps=3600 \
+  --graph_builder=structured \
+  --hidden_layer_sizes=128 \
+  --learning_rate=0.001 \
+  --momentum=0.9 \
+  --pretrained_params=$TMP_DIR/brain_parser/greedy/$PARAMS/model \
+  --pretrained_params_names=\
+embedding_matrix_0,embedding_matrix_1,embedding_matrix_2,bias_0,weights_0 \
+  --output_path=$TMP_DIR \
+  --task_context=$TMP_DIR/context \
+  --training_corpus=training-corpus \
+  --tuning_corpus=tuning-corpus \
+  --params=$STRUCT_PARAMS \
+  --num_epochs=20 \
+  --report_every=25 \
+  --checkpoint_every=200 \
+  --logtostderr
+
+"$BINDIR/parser_eval" \
+  --task_context=$TMP_DIR/context \
+  --hidden_layer_sizes=128 \
+  --beam_size=8 \
+  --input=tuning-corpus \
+  --output=stdout \
+  --arg_prefix=brain_parser \
+  --graph_builder=structured \
+  --model_path=$TMP_DIR/brain_parser/structured/$STRUCT_PARAMS/model \
+  --logtostderr \
+  > $TMP_DIR/struct-beam8-out
+
+echo "PASS"

+ 30 - 0
syntaxnet/syntaxnet/parser_transitions.cc

@@ -0,0 +1,30 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/parser_transitions.h"
+
+#include "syntaxnet/parser_state.h"
+
+namespace syntaxnet {
+
+// Transition system registry.
+REGISTER_CLASS_REGISTRY("transition system", ParserTransitionSystem);
+
+void ParserTransitionSystem::PerformAction(ParserAction action,
+                                           ParserState *state) const {
+  PerformActionWithoutHistory(action, state);
+}
+
+}  // namespace syntaxnet

+ 208 - 0
syntaxnet/syntaxnet/parser_transitions.h

@@ -0,0 +1,208 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Transition system for the transition-based dependency parser.
+
+#ifndef $TARGETDIR_PARSER_TRANSITIONS_H_
+#define $TARGETDIR_PARSER_TRANSITIONS_H_
+
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/registry.h"
+
+namespace tensorflow {
+namespace io {
+class RecordReader;
+class RecordWriter;
+}
+}
+
+namespace syntaxnet {
+
+class Sentence;
+class ParserState;
+class TaskContext;
+
+// Parser actions for the transition system are encoded as integers.
+typedef int ParserAction;
+
+// Label type for the parser action.
+enum class LabelType {
+  NO_LABEL = 0,
+  LEFT_LABEL = 1,
+  RIGHT_LABEL = 2,
+};
+
+// Transition system-specific state. Transition systems can subclass this to
+// preprocess the parser state and/or to keep additional information during
+// parsing.
+class ParserTransitionState {
+ public:
+  virtual ~ParserTransitionState() {}
+
+  // Clones the transition state.
+  virtual ParserTransitionState *Clone() const = 0;
+
+  // Initializes a parser state for the transition system.
+  virtual void Init(ParserState *state) = 0;
+
+  virtual void AddParseToDocument(const ParserState &state,
+                                  bool rewrite_root_labels,
+                                  Sentence *sentence) const {}
+
+  // Whether a parsed token should be considered correct for evaluation.
+  virtual bool IsTokenCorrect(const ParserState &state, int index) const = 0;
+
+  // Returns a human readable string representation of this state.
+  virtual string ToString(const ParserState &state) const = 0;
+};
+
+// A transition system is used for handling the parser state transitions. During
+// training the transition system is used for extracting a canonical sequence of
+// transitions for an annotated sentence. During parsing the transition system
+// is used for applying the predicted transitions to the parse state and
+// therefore build the parse tree for the sentence. Transition systems can be
+// implemented by subclassing this abstract class and registered using the
+// REGISTER_TRANSITION_SYSTEM macro.
+class ParserTransitionSystem
+    : public RegisterableClass<ParserTransitionSystem> {
+ public:
+  // Construction and cleanup.
+  ParserTransitionSystem() {}
+  virtual ~ParserTransitionSystem() {}
+
+  // Sets up the transition system. If inputs are needed, this is the place to
+  // specify them.
+  virtual void Setup(TaskContext *context) {}
+
+  // Initializes the transition system.
+  virtual void Init(TaskContext *context) {}
+
+  // Reads the transition system from disk.
+  virtual void Read(tensorflow::io::RecordReader *reader) {}
+
+  // Writes the transition system to disk.
+  virtual void Write(tensorflow::io::RecordWriter *writer) const {}
+
+  // Returns the number of action types.
+  virtual int NumActionTypes() const = 0;
+
+  // Returns the number of actions.
+  virtual int NumActions(int num_labels) const = 0;
+
+  // Internally creates the set of outcomes (when transition systems support a
+  // variable number of actions).
+  virtual void CreateOutcomeSet(int num_labels) {}
+
+  // Returns the default action for a given state.
+  virtual ParserAction GetDefaultAction(const ParserState &state) const = 0;
+
+  // Returns the next gold action for the parser during training using the
+  // dependency relations found in the underlying annotated sentence.
+  virtual ParserAction GetNextGoldAction(const ParserState &state) const = 0;
+
+  // Returns all next gold actions for the parser during training using the
+  // dependency relations found in the underlying annotated sentence.
+  virtual void GetAllNextGoldActions(const ParserState &state,
+                                     vector<ParserAction> *actions) const {
+    ParserAction action = GetNextGoldAction(state);
+    *actions = {action};
+  }
+
+  // Internally counts all next gold actions from the current parser state.
+  virtual void CountAllNextGoldActions(const ParserState &state) {}
+
+  // Returns the number of atomic actions within the specified ParserAction.
+  virtual int ActionLength(ParserAction action) const { return 1; }
+
+  // Returns true if the action is allowed in the given parser state.
+  virtual bool IsAllowedAction(ParserAction action,
+                               const ParserState &state) const = 0;
+
+  // Performs the specified action on a given parser state. The action is not
+  // saved in the state's history.
+  virtual void PerformActionWithoutHistory(ParserAction action,
+                                           ParserState *state) const = 0;
+
+  // Performs the specified action on a given parser state. The action is saved
+  // in the state's history.
+  void PerformAction(ParserAction action, ParserState *state) const;
+
+  // Returns true if a given state is deterministic.
+  virtual bool IsDeterministicState(const ParserState &state) const = 0;
+
+  // Returns true if no more actions can be applied to a given parser state.
+  virtual bool IsFinalState(const ParserState &state) const = 0;
+
+  // Returns a string representation of a parser action.
+  virtual string ActionAsString(ParserAction action,
+                                const ParserState &state) const = 0;
+
+  // Returns a new transition state that can be used to put additional
+  // information in a parser state. By specifying if we are in training_mode
+  // (true) or not (false), we can construct a different transition state
+  // depending on whether we are training a model or parsing new documents. A
+  // null return value means we don't need to add anything to the parser state.
+  virtual ParserTransitionState *NewTransitionState(bool training_mode) const {
+    return nullptr;
+  }
+
+  // Whether to back off to the best allowable transition rather than the
+  // default action when the highest scoring action is not allowed.  Some
+  // transition systems do not degrade gracefully to the default action and so
+  // should return true for this function.
+  virtual bool BackOffToBestAllowableTransition() const { return false; }
+
+  // Whether the system returns multiple gold transitions from a single
+  // configuration.
+  virtual bool ReturnsMultipleGoldTransitions() const { return false; }
+
+  // Whether the system allows non-projective trees.
+  virtual bool AllowsNonProjective() const { return false; }
+
+  // Action meta data: get pointers to token indices based on meta-info about
+  // (state, action) pairs. NOTE: the following interface is somewhat
+  // experimental and may be subject to change. Use with caution and ask
+  // djweiss@ for details.
+
+  // Whether or not the system supports computing meta-data about actions.
+  virtual bool SupportsActionMetaData() const { return false; }
+
+  // Get the index of the child that would be created by this action. -1 for
+  // no child created.
+  virtual int ChildIndex(const ParserState &state,
+                         const ParserAction &action) const {
+    return -1;
+  }
+
+  // Get the index of the parent that would gain a new child by this action. -1
+  // for no parent modified.
+  virtual int ParentIndex(const ParserState &state,
+                          const ParserAction &action) const {
+    return -1;
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(ParserTransitionSystem);
+};
+
+#define REGISTER_TRANSITION_SYSTEM(type, component) \
+  REGISTER_CLASS_COMPONENT(ParserTransitionSystem, type, component)
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_PARSER_TRANSITIONS_H_

+ 151 - 0
syntaxnet/syntaxnet/populate_test_inputs.cc

@@ -0,0 +1,151 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/populate_test_inputs.h"
+
+#include <map>
+#include <utility>
+
+#include "gtest/gtest.h"
+#include "syntaxnet/utils.h"
+#include "syntaxnet/dictionary.pb.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+
+void PopulateTestInputs::CreatorMap::Add(
+    const string &name, const string &file_format, const string &record_format,
+    PopulateTestInputs::CreateFile makefile) {
+  (*this)[name] = [name, file_format, record_format,
+                   makefile](TaskInput *input) {
+    makefile(AddPart(input, file_format, record_format));
+  };
+}
+
+bool PopulateTestInputs::CreatorMap::Populate(TaskContext *context) const {
+  return PopulateTestInputs::Populate(*this, context);
+}
+
+PopulateTestInputs::CreatorMap PopulateTestInputs::Defaults(
+    const Sentence &document) {
+  CreatorMap creators;
+  creators["category-map"] =
+      CreateTFMapFromDocumentTokens(document, TokenCategory);
+  creators["label-map"] = CreateTFMapFromDocumentTokens(document, TokenLabel);
+  creators["tag-map"] = CreateTFMapFromDocumentTokens(document, TokenTag);
+  creators["tag-to-category"] = CreateTagToCategoryFromTokens(document);
+  creators["word-map"] = CreateTFMapFromDocumentTokens(document, TokenWord);
+  return creators;
+}
+
+bool PopulateTestInputs::Populate(
+    const std::unordered_map<string, Create> &creator_map,
+    TaskContext *context) {
+  TaskSpec *spec = context->mutable_spec();
+  bool found_all_inputs = true;
+
+  // Fail if a mandatory input is not found.
+  auto name_not_found = [&found_all_inputs](TaskInput *input) {
+    found_all_inputs = false;
+  };
+
+  for (TaskInput &input : *spec->mutable_input()) {
+    auto it = creator_map.find(input.name());
+    (it == creator_map.end() ? name_not_found : it->second)(&input);
+
+    // Check for compatibility with declared supported formats.
+    for (const auto &part : input.part()) {
+      if (!TaskContext::Supports(input, part.file_format(),
+                                 part.record_format())) {
+        LOG(FATAL) << "Input " << input.name()
+                   << " does not support file of type " << part.file_format()
+                   << "/" << part.record_format();
+      }
+    }
+  }
+  return found_all_inputs;
+}
+
+PopulateTestInputs::Create PopulateTestInputs::CreateTFMapFromDocumentTokens(
+    const Sentence &document,
+    std::function<vector<string>(const Token &)> token2str) {
+  return [document, token2str](TaskInput *input) {
+    TermFrequencyMap map;
+
+    // Build and write the dummy term frequency map.
+    for (const Token &token : document.token()) {
+      vector<string> strings_for_token = token2str(token);
+      for (const string &s : strings_for_token) map.Increment(s);
+    }
+    string file_name = AddPart(input, "text", "");
+    map.Save(file_name);
+  };
+}
+
+PopulateTestInputs::Create PopulateTestInputs::CreateTagToCategoryFromTokens(
+    const Sentence &document) {
+  return [document](TaskInput *input) {
+    TagToCategoryMap tag_to_category;
+    for (auto &token : document.token()) {
+      if (token.has_tag()) {
+        tag_to_category.SetCategory(token.tag(), token.category());
+      }
+    }
+    const string file_name = AddPart(input, "text", "");
+    tag_to_category.Save(file_name);
+  };
+}
+
+vector<string> PopulateTestInputs::TokenCategory(const Token &token) {
+  if (token.has_category()) return {token.category()};
+  return {};
+}
+
+vector<string> PopulateTestInputs::TokenLabel(const Token &token) {
+  if (token.has_label()) return {token.label()};
+  return {};
+}
+
+vector<string> PopulateTestInputs::TokenTag(const Token &token) {
+  if (token.has_tag()) return {token.tag()};
+  return {};
+}
+
+vector<string> PopulateTestInputs::TokenWord(const Token &token) {
+  if (token.has_word()) return {token.word()};
+  return {};
+}
+
+string PopulateTestInputs::AddPart(TaskInput *input, const string &file_format,
+                                   const string &record_format) {
+  string file_name =
+      tensorflow::strings::StrCat(
+          tensorflow::testing::TmpDir(), input->name());
+  auto *part = CHECK_NOTNULL(input)->add_part();
+  part->set_file_pattern(file_name);
+  part->set_file_format(file_format);
+  part->set_record_format(record_format);
+  return file_name;
+}
+
+}  // namespace syntaxnet

+ 153 - 0
syntaxnet/syntaxnet/populate_test_inputs.h

@@ -0,0 +1,153 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A utility for populating a set of inputs of a task.  This knows how to create
+// tag-map, category-map, label-map and has hooks to
+// populate other kinds of inputs.  The expected set of operations are:
+//
+// Sentence document_for_init = ...;
+// TaskContext context;
+// context->SetParameter("my_parameter", "true");
+// MyDocumentProcessor processor;
+// processor.Setup(&context);
+// PopulateTestInputs::Defaults(document_for_init).Populate(&context);
+// processor.Init(&context);
+//
+// This will check the inputs requested by the processor's Setup(TaskContext *)
+// function, and files corresponding to them.  For example, if the processor
+// asked for the a "tag-map" input, it will create a TermFrequencyMap, populate
+// it with the POS tags found in the Sentence document_for_init, save it to disk
+// and update the TaskContext with the location of the file.  By convention, the
+// location is the name of the input. Conceptually, the logic is very simple:
+//
+// for (TaskInput &input : context->mutable_spec()->mutable_input()) {
+//   creators[input.name()](&input);
+//   // check for missing inputs, incompatible formats, etc...
+// }
+//
+// The Populate() routine will also check compatability between requested and
+// supplied formats. The Default mapping knows how to populate the following
+// inputs:
+//
+//  - category-map: TermFrequencyMap containing POS categories.
+//
+//  - label-map: TermFrequencyMap containing parser labels.
+//
+//  - tag-map: TermFrequencyMap containing POS tags.
+//
+//  - tag-to-category: StringToStringMap mapping POS tags to categories.
+//
+//  - word-map: TermFrequencyMap containing words.
+//
+// Clients can add creation routines by defining a std::function:
+//
+// auto creators = PopulateTestInputs::Defaults(document_for_init);
+// creators["my-input"] = [](TaskInput *input) { ...; }
+//
+// See also creators.Add() for more convenience functions.
+
+#ifndef $TARGETDIR_POPULATE_TEST_INPUTS_H_
+#define $TARGETDIR_POPULATE_TEST_INPUTS_H_
+
+#include <functional>
+#include <string>
+#include <unordered_map>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+class Sentence;
+class TaskContext;
+class TaskInput;
+class TaskOutput;
+class Token;
+
+class PopulateTestInputs {
+ public:
+  // When called, Create() should populate an input by creating a file and
+  // adding one or more parts to the TaskInput.
+  typedef std::function<void(TaskInput *)> Create;
+
+  // When called, CreateFile() should create a file resource at the given
+  // path. These are typically less inconvient to write.
+  typedef std::function<void(const string &)> CreateFile;
+
+  // A set of creators, one for each input in a TaskContext.
+  class CreatorMap : public std::unordered_map<string, Create> {
+   public:
+    // A simplified way to add a single-file creator.  The name of the file
+    // location will be file::JoinPath(FLAGS_test_tmpdir, name).
+    void Add(const string &name, const string &file_format,
+             const string &record_format, CreateFile makefile);
+
+    // Convenience method to populate the inputs in context.  Returns true if it
+    // was possible to populate each input, and false otherwise.  If a mandatory
+    // input does not have a creator, then we LOG(FATAL).
+    bool Populate(TaskContext *context) const;
+  };
+
+  // Default creator set.  This knows how to generate from a given Document
+  //  - category-map
+  //  - label-map
+  //  - tag-map
+  //  - tag-to-category
+  //  - word-map
+  //
+  //  Note: the default creators capture the document input by value: this means
+  //  that subsequent modifications to the document will NOT be
+  //  reflected in the inputs. However, the following is perfectly valid:
+  //
+  //  CreatorMap creators;
+  //  {
+  //    Sentence document;
+  //    creators = PopulateTestInputs::Defaults(document);
+  //  }
+  //  creators.Populate(context);
+  static CreatorMap Defaults(const Sentence &document);
+
+  // Populates the TaskContext object from a map of creator functions. Note that
+  // this static version is compatible with any hash map of the correct type.
+  static bool Populate(const std::unordered_map<string, Create> &creator_map,
+                       TaskContext *context);
+
+  // Helper function for creating a term frequency map from a document.  This
+  // iterates over all the tokens in the document, calls token2str on each
+  // token, and adds each returned string to the term frequency map.  The map is
+  // then saved to FLAGS_test_tmpdir/name.
+  static Create CreateTFMapFromDocumentTokens(
+      const Sentence &document,
+      std::function<vector<string>(const Token &)> token2str);
+
+  // Creates a StringToStringMap protocol buffer input that maps tags to
+  // categories. Uses whatever mapping is present in the document.
+  static Create CreateTagToCategoryFromTokens(const Sentence &document);
+
+  // Default implementations for "token2str" above.
+  static vector<string> TokenCategory(const Token &token);
+  static vector<string> TokenLabel(const Token &token);
+  static vector<string> TokenTag(const Token &token);
+  static vector<string> TokenWord(const Token &token);
+
+  // Utility function. Sets the TaskInput->part() fields for a new input part.
+  // Returns the file name.
+  static string AddPart(TaskInput *input, const string &file_format,
+                        const string &record_format);
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_POPULATE_TEST_INPUTS_H_

+ 242 - 0
syntaxnet/syntaxnet/proto_io.h

@@ -0,0 +1,242 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_PROTO_IO_H_
+#define $TARGETDIR_PROTO_IO_H_
+
+#include <iostream>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/document_format.h"
+#include "syntaxnet/feature_extractor.pb.h"
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/registry.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/utils.h"
+#include "syntaxnet/workspace.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/record_reader.h"
+#include "tensorflow/core/lib/io/record_writer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace syntaxnet {
+
+// A convenience wrapper to read protos with a RecordReader.
+class ProtoRecordReader {
+ public:
+  explicit ProtoRecordReader(tensorflow::RandomAccessFile *file)
+      : file_(file), reader_(new tensorflow::io::RecordReader(file_)) {}
+
+  explicit ProtoRecordReader(const string &filename) {
+    TF_CHECK_OK(
+        tensorflow::Env::Default()->NewRandomAccessFile(filename, &file_));
+    reader_.reset(new tensorflow::io::RecordReader(file_));
+  }
+
+  ~ProtoRecordReader() {
+    reader_.reset();
+    delete file_;
+  }
+
+  template <typename T>
+  tensorflow::Status Read(T *proto) {
+    string buffer;
+    tensorflow::Status status = reader_->ReadRecord(&offset_, &buffer);
+    if (status.ok()) {
+      CHECK(proto->ParseFromString(buffer));
+      return tensorflow::Status::OK();
+    } else {
+      return status;
+    }
+  }
+
+ private:
+  tensorflow::RandomAccessFile *file_ = nullptr;
+  uint64 offset_ = 0;
+  std::unique_ptr<tensorflow::io::RecordReader> reader_;
+};
+
+// A convenience wrapper to write protos with a RecordReader.
+class ProtoRecordWriter {
+ public:
+  explicit ProtoRecordWriter(const string &filename) {
+    TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(filename, &file_));
+    writer_.reset(new tensorflow::io::RecordWriter(file_));
+  }
+
+  ~ProtoRecordWriter() {
+    writer_.reset();
+    delete file_;
+  }
+
+  template <typename T>
+  void Write(const T &proto) {
+    TF_CHECK_OK(writer_->WriteRecord(proto.SerializeAsString()));
+  }
+
+ private:
+  tensorflow::WritableFile *file_ = nullptr;
+  std::unique_ptr<tensorflow::io::RecordWriter> writer_;
+};
+
+// A file implementation to read from stdin.
+class StdIn : public tensorflow::RandomAccessFile {
+ public:
+  StdIn() {}
+  ~StdIn() override {}
+
+  // Reads up to n bytes from standard input.  Returns `OUT_OF_RANGE` if fewer
+  // than n bytes were stored in `*result` because of EOF.
+  tensorflow::Status Read(uint64 offset, size_t n,
+                          tensorflow::StringPiece *result,
+                          char *scratch) const override {
+    CHECK_EQ(expected_offset_, offset);
+    if (!eof_) {
+      string line;
+      eof_ = !std::getline(std::cin, line);
+      buffer_.append(line);
+      buffer_.append("\n");
+    }
+    CopyFromBuffer(std::min(buffer_.size(), n), result, scratch);
+    if (eof_) {
+      return tensorflow::errors::OutOfRange("End of file reached");
+    } else {
+      return tensorflow::Status::OK();
+    }
+  }
+
+ private:
+  void CopyFromBuffer(size_t n, tensorflow::StringPiece *result,
+                      char *scratch) const {
+    memcpy(scratch, buffer_.data(), buffer_.size());
+    buffer_ = buffer_.substr(n);
+    result->set(scratch, n);
+    expected_offset_ += n;
+  }
+
+  mutable bool eof_ = false;
+  mutable int64 expected_offset_ = 0;
+  mutable string buffer_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(StdIn);
+};
+
+// Reads sentence protos from a text file.
+class TextReader {
+ public:
+  explicit TextReader(const TaskInput &input) {
+    CHECK_EQ(input.record_format_size(), 1)
+        << "TextReader only supports inputs with one record format: "
+        << input.DebugString();
+    CHECK_EQ(input.part_size(), 1)
+        << "TextReader only supports inputs with one part: "
+        << input.DebugString();
+    filename_ = TaskContext::InputFile(input);
+    format_.reset(DocumentFormat::Create(input.record_format(0)));
+    Reset();
+  }
+
+  Sentence *Read() {
+    // Skips emtpy sentences, e.g., blank lines at the beginning of a file or
+    // commented out blocks.
+    vector<Sentence *> sentences;
+    string key, value;
+    while (sentences.empty() && format_->ReadRecord(buffer_.get(), &value)) {
+      key = tensorflow::strings::StrCat(filename_, ":", sentence_count_);
+      format_->ConvertFromString(key, value, &sentences);
+      CHECK_LE(sentences.size(), 1);
+    }
+    if (sentences.empty()) {
+      // End of file reached.
+      return nullptr;
+    } else {
+      ++sentence_count_;
+      return sentences[0];
+    }
+  }
+
+  void Reset() {
+    sentence_count_ = 0;
+    tensorflow::RandomAccessFile *file;
+    if (filename_ == "-") {
+      static const int kInputBufferSize = 8 * 1024; /* bytes */
+      file = new StdIn();
+      buffer_.reset(new tensorflow::io::InputBuffer(file, kInputBufferSize));
+    } else {
+      static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
+      TF_CHECK_OK(
+          tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file));
+      buffer_.reset(new tensorflow::io::InputBuffer(file, kInputBufferSize));
+    }
+  }
+
+ private:
+  string filename_;
+  int sentence_count_ = 0;
+  std::unique_ptr<tensorflow::io::InputBuffer> buffer_;
+  std::unique_ptr<DocumentFormat> format_;
+};
+
+// Writes sentence protos to a text conll file.
+class TextWriter {
+ public:
+  explicit TextWriter(const TaskInput &input) {
+    CHECK_EQ(input.record_format_size(), 1)
+        << "TextWriter only supports files with one record format: "
+        << input.DebugString();
+    CHECK_EQ(input.part_size(), 1)
+        << "TextWriter only supports files with one part: "
+        << input.DebugString();
+    filename_ = TaskContext::InputFile(input);
+    format_.reset(DocumentFormat::Create(input.record_format(0)));
+    if (filename_ != "-") {
+      TF_CHECK_OK(
+          tensorflow::Env::Default()->NewWritableFile(filename_, &file_));
+    }
+  }
+
+  ~TextWriter() {
+    if (file_) {
+      file_->Close();
+      delete file_;
+    }
+  }
+
+  void Write(const Sentence &sentence) {
+    string key, value;
+    format_->ConvertToString(sentence, &key, &value);
+    if (file_) {
+      TF_CHECK_OK(file_->Append(value));
+    } else {
+      std::cout << value;
+    }
+  }
+
+ private:
+  string filename_;
+  std::unique_ptr<DocumentFormat> format_;
+  tensorflow::WritableFile *file_ = nullptr;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_PROTO_IO_H_

+ 563 - 0
syntaxnet/syntaxnet/reader_ops.cc

@@ -0,0 +1,563 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <math.h>
+#include <deque>
+#include <unordered_map>
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/base.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/sentence_batch.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/shared_store.h"
+#include "syntaxnet/sparse.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/table.h"
+#include "tensorflow/core/lib/io/table_options.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/env.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_STRING;
+using tensorflow::DataType;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::error::OUT_OF_RANGE;
+using tensorflow::errors::InvalidArgument;
+
+namespace syntaxnet {
+
+class ParsingReader : public OpKernel {
+ public:
+  explicit ParsingReader(OpKernelConstruction *context) : OpKernel(context) {
+    string file_path, corpus_name;
+    OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
+    OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size_));
+    OP_REQUIRES_OK(context, context->GetAttr("batch_size", &max_batch_size_));
+    OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
+    OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
+
+    // Reads task context from file.
+    string data;
+    OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
+                                             file_path, &data));
+    OP_REQUIRES(context,
+                TextFormat::ParseFromString(data, task_context_.mutable_spec()),
+                InvalidArgument("Could not parse task context at ", file_path));
+
+    // Set up the batch reader.
+    sentence_batch_.reset(
+        new SentenceBatch(max_batch_size_, corpus_name));
+    sentence_batch_->Init(&task_context_);
+
+    // Set up the parsing features and transition system.
+    states_.resize(max_batch_size_);
+    workspaces_.resize(max_batch_size_);
+    features_.reset(new ParserEmbeddingFeatureExtractor(arg_prefix_));
+    features_->Setup(&task_context_);
+    transition_system_.reset(ParserTransitionSystem::Create(task_context_.Get(
+        features_->GetParamName("transition_system"), "arc-standard")));
+    transition_system_->Setup(&task_context_);
+    features_->Init(&task_context_);
+    features_->RequestWorkspaces(&workspace_registry_);
+    transition_system_->Init(&task_context_);
+    string label_map_path =
+        TaskContext::InputFile(*task_context_.GetInput("label-map"));
+    label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
+        label_map_path, 0, 0);
+
+    // Checks number of feature groups matches the task context.
+    const int required_size = features_->embedding_dims().size();
+    OP_REQUIRES(
+        context, feature_size_ == required_size,
+        InvalidArgument("Task context requires feature_size=", required_size));
+  }
+
+  ~ParsingReader() override { SharedStore::Release(label_map_); }
+
+  // Creates a new ParserState if there's another sentence to be read.
+  virtual void AdvanceSentence(int index) {
+    states_[index].reset();
+    if (sentence_batch_->AdvanceSentence(index)) {
+      states_[index].reset(new ParserState(
+          sentence_batch_->sentence(index),
+          transition_system_->NewTransitionState(true), label_map_));
+      workspaces_[index].Reset(workspace_registry_);
+      features_->Preprocess(&workspaces_[index], states_[index].get());
+    }
+  }
+
+  void Compute(OpKernelContext *context) override {
+    mutex_lock lock(mu_);
+
+    // Advances states to the next positions.
+    PerformActions(context);
+
+    // Advances any final states to the next sentences.
+    for (int i = 0; i < max_batch_size_; ++i) {
+      if (state(i) == nullptr) continue;
+
+      // Switches to the next sentence if we're at a final state.
+      while (transition_system_->IsFinalState(*state(i))) {
+        VLOG(2) << "Advancing sentence " << i;
+        AdvanceSentence(i);
+        if (state(i) == nullptr) break;  // EOF has been reached
+      }
+    }
+
+    // Rewinds if no states remain in the batch (we need to re-wind the corpus).
+    if (sentence_batch_->size() == 0) {
+      ++num_epochs_;
+      LOG(INFO) << "Starting epoch " << num_epochs_;
+      sentence_batch_->Rewind();
+      for (int i = 0; i < max_batch_size_; ++i) AdvanceSentence(i);
+    }
+
+    // Create the outputs for each feature space.
+    vector<Tensor *> feature_outputs(features_->NumEmbeddings());
+    for (size_t i = 0; i < feature_outputs.size(); ++i) {
+      OP_REQUIRES_OK(context, context->allocate_output(
+                                  i, TensorShape({sentence_batch_->size(),
+                                                  features_->FeatureSize(i)}),
+                                  &feature_outputs[i]));
+    }
+
+    // Populate feature outputs.
+    for (int i = 0, index = 0; i < max_batch_size_; ++i) {
+      if (states_[i] == nullptr) continue;
+
+      // Extract features from the current parser state, and fill up the
+      // available batch slots.
+      std::vector<std::vector<SparseFeatures>> features =
+          features_->ExtractSparseFeatures(workspaces_[i], *states_[i]);
+
+      for (size_t feature_space = 0; feature_space < features.size();
+           ++feature_space) {
+        int feature_size = features[feature_space].size();
+        CHECK(feature_size == features_->FeatureSize(feature_space));
+        auto features_output = feature_outputs[feature_space]->matrix<string>();
+        for (int k = 0; k < feature_size; ++k) {
+          features_output(index, k) =
+              features[feature_space][k].SerializeAsString();
+        }
+      }
+      ++index;
+    }
+
+    // Return the number of epochs.
+    Tensor *epoch_output;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                feature_size_, TensorShape({}), &epoch_output));
+    auto num_epochs = epoch_output->scalar<int32>();
+    num_epochs() = num_epochs_;
+
+    // Create outputs specific to this reader.
+    AddAdditionalOutputs(context);
+  }
+
+ protected:
+  // Peforms any relevant actions on the parser states, typically either
+  // the gold action or a predicted action from decoding.
+  virtual void PerformActions(OpKernelContext *context) = 0;
+
+  // Adds outputs specific to this reader starting at additional_output_index().
+  virtual void AddAdditionalOutputs(OpKernelContext *context) const = 0;
+
+  // Returns the output type specification of the this base class.
+  std::vector<DataType> default_outputs() const {
+    std::vector<DataType> output_types(feature_size_, DT_STRING);
+    output_types.push_back(DT_INT32);
+    return output_types;
+  }
+
+  // Accessors.
+  int max_batch_size() const { return max_batch_size_; }
+  int batch_size() const { return sentence_batch_->size(); }
+  int additional_output_index() const { return feature_size_ + 1; }
+  ParserState *state(int i) const { return states_[i].get(); }
+  const ParserTransitionSystem &transition_system() const {
+    return *transition_system_.get();
+  }
+
+  // Parser task context.
+  const TaskContext &task_context() const { return task_context_; }
+
+  const string &arg_prefix() const { return arg_prefix_; }
+
+ private:
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // Prefix for context parameters.
+  string arg_prefix_;
+
+  // mutex to synchronize access to Compute.
+  mutex mu_;
+
+  // How many times the document source has been rewinded.
+  int num_epochs_ = 0;
+
+  // How many sentences this op can be processing at any given time.
+  int max_batch_size_ = 1;
+
+  // Number of feature groups in the brain parser features.
+  int feature_size_ = -1;
+
+  // Batch of sentences, and the corresponding parser states.
+  std::unique_ptr<SentenceBatch> sentence_batch_;
+
+  // Batch: ParserState objects.
+  std::vector<std::unique_ptr<ParserState>> states_;
+
+  // Batch: WorkspaceSet objects.
+  std::vector<WorkspaceSet> workspaces_;
+
+  // Dependency label map used in transition system.
+  const TermFrequencyMap *label_map_;
+
+  // Transition system.
+  std::unique_ptr<ParserTransitionSystem> transition_system_;
+
+  // Typed feature extractor for embeddings.
+  std::unique_ptr<ParserEmbeddingFeatureExtractor> features_;
+
+  // Internal workspace registry for use in feature extraction.
+  WorkspaceRegistry workspace_registry_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(ParsingReader);
+};
+
+class GoldParseReader : public ParsingReader {
+ public:
+  explicit GoldParseReader(OpKernelConstruction *context)
+      : ParsingReader(context) {
+    // Sets up number and type of inputs and outputs.
+    std::vector<DataType> output_types = default_outputs();
+    output_types.push_back(DT_INT32);
+    OP_REQUIRES_OK(context, context->MatchSignature({}, output_types));
+  }
+
+ private:
+  // Always performs the next gold action for each state.
+  void PerformActions(OpKernelContext *context) override {
+    for (int i = 0; i < max_batch_size(); ++i) {
+      if (state(i) != nullptr) {
+        transition_system().PerformAction(
+            transition_system().GetNextGoldAction(*state(i)), state(i));
+      }
+    }
+  }
+
+  // Adds the list of gold actions for each state as an additional output.
+  void AddAdditionalOutputs(OpKernelContext *context) const override {
+    Tensor *actions_output;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                additional_output_index(),
+                                TensorShape({batch_size()}), &actions_output));
+
+    // Add all gold actions for non-null states as an additional output.
+    auto gold_actions = actions_output->vec<int32>();
+    for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
+      if (state(i) != nullptr) {
+        const int gold_action =
+            transition_system().GetNextGoldAction(*state(i));
+        gold_actions(batch_index++) = gold_action;
+      }
+    }
+  }
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GoldParseReader);
+};
+
+REGISTER_KERNEL_BUILDER(Name("GoldParseReader").Device(DEVICE_CPU),
+                        GoldParseReader);
+
+// DecodedParseReader parses sentences using transition scores computed
+// by a TensorFlow network. This op additionally computes a token correctness
+// evaluation metric which can be used to select hyperparameter settings and
+// training stopping point.
+//
+// The notion of correct token is determined by the transition system, e.g.
+// a tagger will return POS tag accuracy, while an arc-standard parser will
+// return UAS.
+//
+// Which tokens should be scored is controlled by the '<arg_prefix>_scoring'
+// task parameter.  Possible values are
+//   - 'default': skips tokens with only punctuation in the tag name.
+//   - 'conllx': skips tokens with only punctuation in the surface form.
+//   - 'ignore_parens': same as conllx, but skipping parentheses as well.
+//   - '': scores all tokens.
+class DecodedParseReader : public ParsingReader {
+ public:
+  explicit DecodedParseReader(OpKernelConstruction *context)
+      : ParsingReader(context) {
+    // Sets up number and type of inputs and outputs.
+    std::vector<DataType> output_types = default_outputs();
+    output_types.push_back(DT_INT32);
+    output_types.push_back(DT_STRING);
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_FLOAT}, output_types));
+
+    // Gets scoring parameters.
+    scoring_type_ = task_context().Get(
+        tensorflow::strings::StrCat(arg_prefix(), "_scoring"), "");
+  }
+
+ private:
+  void AdvanceSentence(int index) override {
+    ParsingReader::AdvanceSentence(index);
+    if (state(index)) {
+      docids_.push_front(state(index)->sentence().docid());
+    }
+  }
+
+  // Tallies the # of correct and incorrect tokens for a given ParserState.
+  void ComputeTokenAccuracy(const ParserState &state) {
+    for (int i = 0; i < state.sentence().token_size(); ++i) {
+      const Token &token = state.GetToken(i);
+      if (utils::PunctuationUtil::ScoreToken(token.word(), token.tag(),
+                                             scoring_type_)) {
+        ++num_tokens_;
+        if (state.IsTokenCorrect(i)) ++num_correct_;
+      }
+    }
+  }
+
+  // Performs the allowed action with the highest score on the given state.
+  // Also records the accuracy whenver a terminal action is taken.
+  void PerformActions(OpKernelContext *context) override {
+    auto scores_matrix = context->input(0).matrix<float>();
+    num_tokens_ = 0;
+    num_correct_ = 0;
+    for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
+      ParserState *state = this->state(i);
+      if (state != nullptr) {
+        int best_action = 0;
+        float best_score = -INFINITY;
+        for (int action = 0; action < scores_matrix.dimension(1); ++action) {
+          float score = scores_matrix(batch_index, action);
+          if (score > best_score &&
+              transition_system().IsAllowedAction(action, *state)) {
+            best_action = action;
+            best_score = score;
+          }
+        }
+        transition_system().PerformAction(best_action, state);
+
+        // Update the # of scored correct tokens if this is the last state
+        // in the sentence and save the annotated document.
+        if (transition_system().IsFinalState(*state)) {
+          ComputeTokenAccuracy(*state);
+          sentence_map_[state->sentence().docid()] = state->sentence();
+          state->AddParseToDocument(&sentence_map_[state->sentence().docid()]);
+        }
+        ++batch_index;
+      }
+    }
+  }
+
+  // Adds the evaluation metrics and annotated documents as additional outputs,
+  // if there were any terminal states.
+  void AddAdditionalOutputs(OpKernelContext *context) const override {
+    Tensor *counts_output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(additional_output_index(),
+                                            TensorShape({2}), &counts_output));
+    auto eval_metrics = counts_output->vec<int32>();
+    eval_metrics(0) = num_tokens_;
+    eval_metrics(1) = num_correct_;
+
+    // Output annotated documents for each state. To preserve order, repeatedly
+    // pull from the back of the docids queue as long as the sentences have been
+    // completely processed. If the next document has not been completely
+    // processed yet, then the docid will not be found in 'sentence_map_'.
+    vector<Sentence> sentences;
+    while (!docids_.empty() &&
+           sentence_map_.find(docids_.back()) != sentence_map_.end()) {
+      sentences.emplace_back(sentence_map_[docids_.back()]);
+      sentence_map_.erase(docids_.back());
+      docids_.pop_back();
+    }
+    Tensor *annotated_output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       additional_output_index() + 1,
+                       TensorShape({static_cast<int64>(sentences.size())}),
+                       &annotated_output));
+
+    auto document_output = annotated_output->vec<string>();
+    for (size_t i = 0; i < sentences.size(); ++i) {
+      document_output(i) = sentences[i].SerializeAsString();
+    }
+  }
+
+  // State for eval metric computation.
+  int num_tokens_ = 0;
+  int num_correct_ = 0;
+
+  // Parameter for deciding which tokens to score.
+  string scoring_type_;
+
+  mutable std::deque<string> docids_;
+  mutable map<string, Sentence> sentence_map_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader);
+};
+
+REGISTER_KERNEL_BUILDER(Name("DecodedParseReader").Device(DEVICE_CPU),
+                        DecodedParseReader);
+
+class WordEmbeddingInitializer : public OpKernel {
+ public:
+  explicit WordEmbeddingInitializer(OpKernelConstruction *context)
+      : OpKernel(context) {
+    string file_path, data;
+    OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
+    OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
+                                             file_path, &data));
+    OP_REQUIRES(context,
+                TextFormat::ParseFromString(data, task_context_.mutable_spec()),
+                InvalidArgument("Could not parse task context at ", file_path));
+    OP_REQUIRES_OK(context, context->GetAttr("vectors", &vectors_path_));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("embedding_init", &embedding_init_));
+
+    // Sets up number and type of inputs and outputs.
+    OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_FLOAT}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    // Loads words from vocabulary with mapping to ids.
+    string path = TaskContext::InputFile(*task_context_.GetInput("word-map"));
+    const TermFrequencyMap *word_map =
+        SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
+    unordered_map<string, int64> vocab;
+    for (int i = 0; i < word_map->Size(); ++i) {
+      vocab[word_map->GetTerm(i)] = i;
+    }
+
+    // Creates a reader pointing to a local copy of the vectors recordio.
+    string tmp_vectors_path;
+    OP_REQUIRES_OK(context, CopyToTmpPath(vectors_path_, &tmp_vectors_path));
+    ProtoRecordReader reader(tmp_vectors_path);
+
+    // Loads the embedding vectors into a matrix.
+    Tensor *embedding_matrix = nullptr;
+    TokenEmbedding embedding;
+    while (reader.Read(&embedding) == tensorflow::Status::OK()) {
+      if (embedding_matrix == nullptr) {
+        const int embedding_size = embedding.vector().values_size();
+        OP_REQUIRES_OK(
+            context, context->allocate_output(
+                         0, TensorShape({word_map->Size() + 3, embedding_size}),
+                         &embedding_matrix));
+        embedding_matrix->matrix<float>()
+            .setRandom<Eigen::internal::NormalRandomGenerator<float>>();
+        embedding_matrix->matrix<float>() =
+            embedding_matrix->matrix<float>() * static_cast<float>(
+                embedding_init_ / sqrt(embedding_size));
+      }
+      if (vocab.find(embedding.token()) != vocab.end()) {
+        SetNormalizedRow(embedding.vector(), vocab[embedding.token()],
+                         embedding_matrix);
+      }
+    }
+  }
+
+ private:
+  // Sets embedding_matrix[row] to a normalized version of the given vector.
+  void SetNormalizedRow(const TokenEmbedding::Vector &vector, const int row,
+                        Tensor *embedding_matrix) {
+    float norm = 0.0f;
+    for (int col = 0; col < vector.values_size(); ++col) {
+      float val = vector.values(col);
+      norm += val * val;
+    }
+    norm = sqrt(norm);
+    for (int col = 0; col < vector.values_size(); ++col) {
+      embedding_matrix->matrix<float>()(row, col) = vector.values(col) / norm;
+    }
+  }
+
+  // Copies the file at source_path to a temporary file and sets tmp_path to the
+  // temporary file's location. This is helpful since reading from non local
+  // files with a record reader can be very slow.
+  static tensorflow::Status CopyToTmpPath(const string &source_path,
+                                          string *tmp_path) {
+    // Opens source file.
+    tensorflow::RandomAccessFile *source_file;
+    TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewRandomAccessFile(
+        source_path, &source_file));
+    std::unique_ptr<tensorflow::RandomAccessFile> source_file_deleter(
+        source_file);
+
+    // Creates destination file.
+    tensorflow::WritableFile *target_file;
+    *tmp_path = tensorflow::strings::Printf(
+        "/tmp/%d.%lld", getpid(), tensorflow::Env::Default()->NowMicros());
+    TF_RETURN_IF_ERROR(
+        tensorflow::Env::Default()->NewWritableFile(*tmp_path, &target_file));
+    std::unique_ptr<tensorflow::WritableFile> target_file_deleter(target_file);
+
+    // Performs copy.
+    tensorflow::Status s;
+    const size_t kBytesToRead = 10 << 20;  // 10MB at a time.
+    string scratch;
+    scratch.resize(kBytesToRead);
+    for (uint64 offset = 0; s.ok(); offset += kBytesToRead) {
+      tensorflow::StringPiece data;
+      s.Update(source_file->Read(offset, kBytesToRead, &data, &scratch[0]));
+      target_file->Append(data);
+    }
+    if (s.code() == OUT_OF_RANGE) {
+      return tensorflow::Status::OK();
+    } else {
+      return s;
+    }
+  }
+
+  // Task context used to configure this op.
+  TaskContext task_context_;
+
+  // Embedding vectors that are not found in the input sstable are initialized
+  // randomly from a normal distribution with zero mean and
+  //   std dev = embedding_init_ / sqrt(embedding_size).
+  float embedding_init_ = 1.f;
+
+  // Path to recordio with word embedding vectors.
+  string vectors_path_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("WordEmbeddingInitializer").Device(DEVICE_CPU),
+                        WordEmbeddingInitializer);
+
+}  // namespace syntaxnet

+ 198 - 0
syntaxnet/syntaxnet/reader_ops_test.py

@@ -0,0 +1,198 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for reader_ops."""
+
+
+import os.path
+
+import numpy as np
+import tensorflow as tf
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import control_flow_ops as cf
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+
+from syntaxnet import dictionary_pb2
+from syntaxnet import graph_builder
+from syntaxnet import sparse_pb2
+from syntaxnet.ops import gen_parser_ops
+
+
+FLAGS = tf.app.flags.FLAGS
+if not hasattr(FLAGS, 'test_srcdir'):
+  FLAGS.test_srcdir = ''
+if not hasattr(FLAGS, 'test_tmpdir'):
+  FLAGS.test_tmpdir = tf.test.get_temp_dir()
+
+
+class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    # Creates a task context with the correct testing paths.
+    initial_task_context = os.path.join(
+        FLAGS.test_srcdir,
+        'syntaxnet/'
+        'testdata/context.pbtxt')
+    self._task_context = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
+    with open(initial_task_context, 'r') as fin:
+      with open(self._task_context, 'w') as fout:
+        fout.write(fin.read().replace('SRCDIR', FLAGS.test_srcdir)
+                   .replace('OUTPATH', FLAGS.test_tmpdir))
+
+    # Creates necessary term maps.
+    with self.test_session() as sess:
+      gen_parser_ops.lexicon_builder(task_context=self._task_context,
+                                     corpus_name='training-corpus').run()
+      self._num_features, self._num_feature_ids, _, self._num_actions = (
+          sess.run(gen_parser_ops.feature_size(task_context=self._task_context,
+                                               arg_prefix='brain_parser')))
+
+  def GetMaxId(self, sparse_features):
+    max_id = 0
+    for x in sparse_features:
+      for y in x:
+        f = sparse_pb2.SparseFeatures()
+        f.ParseFromString(y)
+        for i in f.id:
+          max_id = max(i, max_id)
+    return max_id
+
+  def testParsingReaderOp(self):
+    # Runs the reader over the test input for two epochs.
+    num_steps_a = 0
+    num_actions = 0
+    num_word_ids = 0
+    num_tag_ids = 0
+    num_label_ids = 0
+    batch_size = 10
+    with self.test_session() as sess:
+      (words, tags, labels), epochs, gold_actions = (
+          gen_parser_ops.gold_parse_reader(self._task_context,
+                                           3,
+                                           batch_size,
+                                           corpus_name='training-corpus'))
+      while True:
+        tf_gold_actions, tf_epochs, tf_words, tf_tags, tf_labels = (
+            sess.run([gold_actions, epochs, words, tags, labels]))
+        num_steps_a += 1
+        num_actions = max(num_actions, max(tf_gold_actions) + 1)
+        num_word_ids = max(num_word_ids, self.GetMaxId(tf_words) + 1)
+        num_tag_ids = max(num_tag_ids, self.GetMaxId(tf_tags) + 1)
+        num_label_ids = max(num_label_ids, self.GetMaxId(tf_labels) + 1)
+        self.assertIn(tf_epochs, [0, 1, 2])
+        if tf_epochs > 1:
+          break
+
+    # Runs the reader again, this time with a lot of added graph nodes.
+    num_steps_b = 0
+    with self.test_session() as sess:
+      num_features = [6, 6, 4]
+      num_feature_ids = [num_word_ids, num_tag_ids, num_label_ids]
+      embedding_sizes = [8, 8, 8]
+      hidden_layer_sizes = [32, 32]
+      # Here we aim to test the iteration of the reader op in a complex network,
+      # not the GraphBuilder.
+      parser = graph_builder.GreedyParser(
+          num_actions, num_features, num_feature_ids, embedding_sizes,
+          hidden_layer_sizes)
+      parser.AddTraining(self._task_context,
+                         batch_size,
+                         corpus_name='training-corpus')
+      sess.run(parser.inits.values())
+      while True:
+        tf_epochs, tf_cost, _ = sess.run(
+            [parser.training['epochs'], parser.training['cost'],
+             parser.training['train_op']])
+        num_steps_b += 1
+        self.assertGreaterEqual(tf_cost, 0)
+        self.assertIn(tf_epochs, [0, 1, 2])
+        if tf_epochs > 1:
+          break
+
+    # Assert that the two runs made the exact same number of steps.
+    logging.info('Number of steps in the two runs: %d, %d',
+                 num_steps_a, num_steps_b)
+    self.assertEqual(num_steps_a, num_steps_b)
+
+  def testParsingReaderOpWhileLoop(self):
+    feature_size = 3
+    batch_size = 5
+
+    def ParserEndpoints():
+      return gen_parser_ops.gold_parse_reader(self._task_context,
+                                              feature_size,
+                                              batch_size,
+                                              corpus_name='training-corpus')
+
+    with self.test_session() as sess:
+      # The 'condition' and 'body' functions expect as many arguments as there
+      # are loop variables. 'condition' depends on the 'epoch' loop variable
+      # only, so we disregard the remaining unused function arguments. 'body'
+      # returns a list of updated loop variables.
+      def Condition(epoch, *unused_args):
+        return tf.less(epoch, 2)
+
+      def Body(epoch, num_actions, *feature_args):
+        # By adding one of the outputs of the reader op ('epoch') as a control
+        # dependency to the reader op we force the repeated evaluation of the
+        # reader op.
+        with epoch.graph.control_dependencies([epoch]):
+          features, epoch, gold_actions = ParserEndpoints()
+        num_actions = tf.maximum(num_actions,
+                                 tf.reduce_max(gold_actions, [0], False) + 1)
+        feature_ids = []
+        for i in range(len(feature_args)):
+          feature_ids.append(features[i])
+        return [epoch, num_actions] + feature_ids
+
+      epoch = ParserEndpoints()[-2]
+      num_actions = tf.constant(0)
+      loop_vars = [epoch, num_actions]
+
+      res = sess.run(
+          cf.While(Condition, Body, loop_vars, parallel_iterations=1))
+      logging.info('Result: %s', res)
+      self.assertEqual(res[0], 2)
+
+  def testWordEmbeddingInitializer(self):
+    def _TokenEmbedding(token, embedding):
+      e = dictionary_pb2.TokenEmbedding()
+      e.token = token
+      e.vector.values.extend(embedding)
+      return e.SerializeToString()
+
+    # Provide embeddings for the first three words in the word map.
+    records_path = os.path.join(FLAGS.test_tmpdir, 'sstable-00000-of-00001')
+    writer = tf.python_io.TFRecordWriter(records_path)
+    writer.write(_TokenEmbedding('.', [1, 2]))
+    writer.write(_TokenEmbedding(',', [3, 4]))
+    writer.write(_TokenEmbedding('the', [5, 6]))
+    del writer
+
+    with self.test_session():
+      embeddings = gen_parser_ops.word_embedding_initializer(
+          vectors=records_path,
+          task_context=self._task_context).eval()
+    self.assertAllClose(
+        np.array([[1. / (1 + 4) ** .5, 2. / (1 + 4) ** .5],
+                  [3. / (9 + 16) ** .5, 4. / (9 + 16) ** .5],
+                  [5. / (25 + 36) ** .5, 6. / (25 + 36) ** .5]]),
+        embeddings[:3,])
+
+
+if __name__ == '__main__':
+  googletest.main()

+ 28 - 0
syntaxnet/syntaxnet/registry.cc

@@ -0,0 +1,28 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/registry.h"
+
+namespace syntaxnet {
+
+// Global list of all component registries.
+RegistryMetadata *global_registry_list = NULL;
+
+void RegistryMetadata::Register(RegistryMetadata *registry) {
+  registry->set_link(global_registry_list);
+  global_registry_list = registry;
+}
+
+}  // namespace syntaxnet

+ 243 - 0
syntaxnet/syntaxnet/registry.h

@@ -0,0 +1,243 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Registry for component registration. These classes can be used for creating
+// registries of components conforming to the same interface. This is useful for
+// making a component-based architecture where the specific implementation
+// classes can be selected at runtime. There is support for both class-based and
+// instance based registries.
+//
+// Example:
+//  function.h:
+//
+//   class Function : public RegisterableInstance<Function> {
+//    public:
+//     virtual double Evaluate(double x) = 0;
+//   };
+//
+//   #define REGISTER_FUNCTION(type, component)
+//     REGISTER_INSTANCE_COMPONENT(Function, type, component);
+//
+//  function.cc:
+//
+//   REGISTER_INSTANCE_REGISTRY("function", Function);
+//
+//   class Cos : public Function {
+//    public:
+//     double Evaluate(double x) { return cos(x); }
+//   };
+//
+//   class Exp : public Function {
+//    public:
+//     double Evaluate(double x) { return exp(x); }
+//   };
+//
+//   REGISTER_FUNCTION("cos", Cos);
+//   REGISTER_FUNCTION("exp", Exp);
+//
+//   Function *f = Function::Lookup("cos");
+//   double result = f->Evaluate(arg);
+
+#ifndef $TARGETDIR_REGISTRY_H_
+#define $TARGETDIR_REGISTRY_H_
+
+#include <string.h>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+// Component metadata with information about name, class, and code location.
+class ComponentMetadata {
+ public:
+  ComponentMetadata(const char *name, const char *class_name, const char *file,
+                    int line)
+      : name_(name),
+        class_name_(class_name),
+        file_(file),
+        line_(line),
+        link_(NULL) {}
+
+  // Returns component name.
+  const char *name() const { return name_; }
+
+  // Metadata objects can be linked in a list.
+  ComponentMetadata *link() const { return link_; }
+  void set_link(ComponentMetadata *link) { link_ = link; }
+
+ private:
+  // Component name.
+  const char *name_;
+
+  // Name of class for component.
+  const char *class_name_;
+
+  // Code file and location where the component was registered.
+  const char *file_;
+  int line_;
+
+  // Link to next metadata object in list.
+  ComponentMetadata *link_;
+};
+
+// The master registry contains all registered component registries. A registry
+// is not registered in the master registry until the first component of that
+// type is registered.
+class RegistryMetadata : public ComponentMetadata {
+ public:
+  RegistryMetadata(const char *name, const char *class_name, const char *file,
+                   int line, ComponentMetadata **components)
+      : ComponentMetadata(name, class_name, file, line),
+        components_(components) {}
+
+  // Registers a component registry in the master registry.
+  static void Register(RegistryMetadata *registry);
+
+ private:
+  // Location of list of components in registry.
+  ComponentMetadata **components_;
+};
+
+// Registry for components. An object can be registered with a type name in the
+// registry. The named instances in the registry can be returned using the
+// Lookup() method. The components in the registry are put into a linked list
+// of components. It is important that the component registry can be statically
+// initialized in order not to depend on initialization order.
+template <class T>
+struct ComponentRegistry {
+  typedef ComponentRegistry<T> Self;
+
+  // Component registration class.
+  class Registrar : public ComponentMetadata {
+   public:
+    // Registers new component by linking itself into the component list of
+    // the registry.
+    Registrar(Self *registry, const char *type, const char *class_name,
+              const char *file, int line, T *object)
+        : ComponentMetadata(type, class_name, file, line), object_(object) {
+      // Register registry in master registry if this is the first registered
+      // component of this type.
+      if (registry->components == NULL) {
+        RegistryMetadata::Register(new RegistryMetadata(
+            registry->name, registry->class_name, registry->file,
+            registry->line,
+            reinterpret_cast<ComponentMetadata **>(&registry->components)));
+      }
+
+      // Register component in registry.
+      set_link(registry->components);
+      registry->components = this;
+    }
+
+    // Returns component type.
+    const char *type() const { return name(); }
+
+    // Returns component object.
+    T *object() const { return object_; }
+
+    // Returns the next component in the component list.
+    Registrar *next() const { return static_cast<Registrar *>(link()); }
+
+   private:
+    // Component object.
+    T *object_;
+  };
+
+  // Finds registrar for named component in registry.
+  const Registrar *GetComponent(const char *type) const {
+    Registrar *r = components;
+    while (r != NULL && strcmp(type, r->type()) != 0) r = r->next();
+    if (r == NULL) {
+      LOG(FATAL) << "Unknown " << name << " component: '" << type << "'.";
+    }
+    return r;
+  }
+
+  // Finds a named component in the registry.
+  T *Lookup(const char *type) const { return GetComponent(type)->object(); }
+  T *Lookup(const string &type) const { return Lookup(type.c_str()); }
+
+  // Textual description of the kind of components in the registry.
+  const char *name;
+
+  // Base class name of component type.
+  const char *class_name;
+
+  // File and line where the registry is defined.
+  const char *file;
+  int line;
+
+  // Linked list of registered components.
+  Registrar *components;
+};
+
+// Base class for registerable class-based components.
+template <class T>
+class RegisterableClass {
+ public:
+  // Factory function type.
+  typedef T *(Factory)();
+
+  // Registry type.
+  typedef ComponentRegistry<Factory> Registry;
+
+  // Creates a new component instance.
+  static T *Create(const string &type) { return registry()->Lookup(type)(); }
+
+  // Returns registry for class.
+  static Registry *registry() { return &registry_; }
+
+ private:
+  // Registry for class.
+  static Registry registry_;
+};
+
+// Base class for registerable instance-based components.
+template <class T>
+class RegisterableInstance {
+ public:
+  // Registry type.
+  typedef ComponentRegistry<T> Registry;
+
+ private:
+  // Registry for class.
+  static Registry registry_;
+};
+
+#define REGISTER_CLASS_COMPONENT(base, type, component)             \
+  static base *__##component##__factory() { return new component; } \
+  static base::Registry::Registrar __##component##__##registrar(    \
+      base::registry(), type, #component, __FILE__, __LINE__,       \
+      __##component##__factory)
+
+#define REGISTER_CLASS_REGISTRY(type, classname)                  \
+  template <>                                                     \
+  classname::Registry RegisterableClass<classname>::registry_ = { \
+      type, #classname, __FILE__, __LINE__, NULL}
+
+#define REGISTER_INSTANCE_COMPONENT(base, type, component)       \
+  static base::Registry::Registrar __##component##__##registrar( \
+      base::registry(), type, #component, __FILE__, __LINE__, new component)
+
+#define REGISTER_INSTANCE_REGISTRY(type, classname)                  \
+  template <>                                                        \
+  classname::Registry RegisterableInstance<classname>::registry_ = { \
+      type, #classname, __FILE__, __LINE__, NULL}
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_REGISTRY_H_

+ 61 - 0
syntaxnet/syntaxnet/sentence.proto

@@ -0,0 +1,61 @@
+// Protocol buffer specification for document analysis.
+
+syntax = "proto2";
+
+package syntaxnet;
+
+// A Sentence contains the raw text contents of a sentence, as well as an
+// analysis.
+message Sentence {
+  // Identifier for document.
+  optional string docid = 1;
+
+  // Raw text contents of the sentence.
+  optional string text = 2;
+
+  // Tokenization of the sentence.
+  repeated Token token = 3;
+
+  extensions 1000 to max;
+}
+
+// A document token marks a span of bytes in the document text as a token
+// or word.
+message Token {
+  // Token word form.
+  required string word = 1;
+
+  // Start position of token in text.
+  required int32 start = 2;
+
+  // End position of token in text. Gives index of last byte, not one past
+  // the last byte. If token came from lexer, excludes any trailing HTML tags.
+  required int32 end = 3;
+
+  // Head of this token in the dependency tree: the id of the token which has an
+  // arc going to this one. If it is the root token of a sentence, then it is
+  // set to -1.
+  optional int32 head = 4 [default = -1];
+
+  // Part-of-speech tag for token.
+  optional string tag = 5;
+
+  // Coarse-grained word category for token.
+  optional string category = 6;
+
+  // Label for dependency relation between this token and its head.
+  optional string label = 7;
+
+  // Break level for tokens that indicates how it was separated from the
+  // previous token in the text.
+  enum BreakLevel {
+    NO_BREAK = 0;         // No separation between tokens.
+    SPACE_BREAK = 1;      // Tokens separated by space.
+    LINE_BREAK = 2;       // Tokens separated by line break.
+    SENTENCE_BREAK = 3;   // Tokens separated by sentence break.
+  }
+
+  optional BreakLevel break_level = 8 [default = SPACE_BREAK];
+
+  extensions 1000 to max;
+}

+ 45 - 0
syntaxnet/syntaxnet/sentence_batch.cc

@@ -0,0 +1,45 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/sentence_batch.h"
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/task_context.h"
+
+namespace syntaxnet {
+
+void SentenceBatch::Init(TaskContext *context) {
+  reader_.reset(new TextReader(*context->GetInput(input_name_)));
+  size_ = 0;
+}
+
+bool SentenceBatch::AdvanceSentence(int index) {
+  if (sentences_[index] == nullptr) ++size_;
+  sentences_[index].reset();
+  std::unique_ptr<Sentence> sentence(reader_->Read());
+  if (sentence == nullptr) {
+    --size_;
+    return false;
+  }
+
+  // Preprocess the new sentence for the parser state.
+  sentences_[index] = std::move(sentence);
+  return true;
+}
+
+}  // namespace syntaxnet

+ 78 - 0
syntaxnet/syntaxnet/sentence_batch.h

@@ -0,0 +1,78 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_SENTENCE_BATCH_H_
+#define $TARGETDIR_SENTENCE_BATCH_H_
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/embedding_feature_extractor.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/sparse.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+
+namespace syntaxnet {
+
+// Helper class to manage generating batches of preprocessed ParserState objects
+// by reading in multiple sentences in parallel.
+class SentenceBatch {
+ public:
+  SentenceBatch(int batch_size, string input_name)
+      : batch_size_(batch_size),
+        input_name_(input_name),
+        sentences_(batch_size) {}
+
+  // Initializes all resources and opens the corpus file.
+  void Init(TaskContext *context);
+
+  // Advances the index'th sentence in the batch to the next sentence. This will
+  // create and preprocess a new ParserState for that element. Returns false if
+  // EOF is reached (if EOF, also sets the state to be nullptr.)
+  bool AdvanceSentence(int index);
+
+  // Rewinds the corpus reader.
+  void Rewind() { reader_->Reset(); }
+
+  int size() const { return size_; }
+
+  Sentence *sentence(int index) { return sentences_[index].get(); }
+
+ private:
+  // Running tally of non-nullptr states in the batch.
+  int size_;
+
+  // Maximum number of states in the batch.
+  int batch_size_;
+
+  // Input to read from the TaskContext.
+  string input_name_;
+
+  // Reader for the corpus.
+  std::unique_ptr<TextReader> reader_;
+
+  // Batch: Sentence objects.
+  std::vector<std::unique_ptr<Sentence>> sentences_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_SENTENCE_BATCH_H_

+ 192 - 0
syntaxnet/syntaxnet/sentence_features.cc

@@ -0,0 +1,192 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/sentence_features.h"
+
+#include "syntaxnet/registry.h"
+#include "util/utf8/unicodetext.h"
+
+namespace syntaxnet {
+
+TermFrequencyMapFeature::~TermFrequencyMapFeature() {
+  if (term_map_ != nullptr) {
+    SharedStore::Release(term_map_);
+    term_map_ = nullptr;
+  }
+}
+
+void TermFrequencyMapFeature::Setup(TaskContext *context) {
+  TokenLookupFeature::Setup(context);
+  context->GetInput(input_name_, "text", "");
+}
+
+void TermFrequencyMapFeature::Init(TaskContext *context) {
+  min_freq_ = GetIntParameter("min-freq", 0);
+  max_num_terms_ = GetIntParameter("max-num-terms", 0);
+  file_name_ = context->InputFile(*context->GetInput(input_name_));
+  term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
+      file_name_, min_freq_, max_num_terms_);
+  TokenLookupFeature::Init(context);
+}
+
+string TermFrequencyMapFeature::GetFeatureValueName(FeatureValue value) const {
+  if (value == UnknownValue()) return "<UNKNOWN>";
+  if (value >= 0 && value < (NumValues() - 1)) {
+    return term_map_->GetTerm(value);
+  }
+  LOG(ERROR) << "Invalid feature value: " << value;
+  return "<INVALID>";
+}
+
+string TermFrequencyMapFeature::WorkspaceName() const {
+  return SharedStoreUtils::CreateDefaultName("term-frequency-map", input_name_,
+                                             min_freq_, max_num_terms_);
+}
+
+string Hyphen::GetFeatureValueName(FeatureValue value) const {
+  switch (value) {
+    case NO_HYPHEN:
+      return "NO_HYPHEN";
+    case HAS_HYPHEN:
+      return "HAS_HYPHEN";
+  }
+  return "<INVALID>";
+}
+
+FeatureValue Hyphen::ComputeValue(const Token &token) const {
+  const string &word = token.word();
+  return (word.find('-') < word.length() ? HAS_HYPHEN : NO_HYPHEN);
+}
+
+string Digit::GetFeatureValueName(FeatureValue value) const {
+  switch (value) {
+    case NO_DIGIT:
+      return "NO_DIGIT";
+    case SOME_DIGIT:
+      return "SOME_DIGIT";
+    case ALL_DIGIT:
+      return "ALL_DIGIT";
+  }
+  return "<INVALID>";
+}
+
+FeatureValue Digit::ComputeValue(const Token &token) const {
+  const string &word = token.word();
+  bool has_digit = isdigit(word[0]);
+  bool all_digit = has_digit;
+  for (size_t i = 1; i < word.length(); ++i) {
+    bool char_is_digit = isdigit(word[i]);
+    all_digit = all_digit && char_is_digit;
+    has_digit = has_digit || char_is_digit;
+    if (!all_digit && has_digit) return SOME_DIGIT;
+  }
+  if (!all_digit) return NO_DIGIT;
+  return ALL_DIGIT;
+}
+
+AffixTableFeature::AffixTableFeature(AffixTable::Type type)
+    : type_(type) {
+  if (type == AffixTable::PREFIX) {
+    input_name_ = "prefix-table";
+  } else {
+    input_name_ = "suffix-table";
+  }
+}
+
+AffixTableFeature::~AffixTableFeature() {
+  SharedStore::Release(affix_table_);
+  affix_table_ = nullptr;
+}
+
+string AffixTableFeature::WorkspaceName() const {
+  return SharedStoreUtils::CreateDefaultName(
+      "affix-table", input_name_, type_, affix_length_);
+}
+
+// Utility function to create a new affix table without changing constructors,
+// to be called by the SharedStore.
+static AffixTable *CreateAffixTable(const string &filename,
+                                    AffixTable::Type type) {
+  AffixTable *affix_table = new AffixTable(type, 1);
+  tensorflow::RandomAccessFile *file;
+  TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
+  ProtoRecordReader reader(file);
+  affix_table->Read(&reader);
+  return affix_table;
+}
+
+void AffixTableFeature::Setup(TaskContext *context) {
+  context->GetInput(input_name_, "recordio", "affix-table");
+  affix_length_ = GetIntParameter("length", 0);
+  CHECK_GE(affix_length_, 0)
+      << "Length must be specified for affix preprocessor.";
+  TokenLookupFeature::Setup(context);
+}
+
+void AffixTableFeature::Init(TaskContext *context) {
+  string filename = context->InputFile(*context->GetInput(input_name_));
+
+  // Get the shared AffixTable object.
+  std::function<AffixTable *()> closure =
+      std::bind(CreateAffixTable, filename, type_);
+  affix_table_ = SharedStore::ClosureGetOrDie(filename, &closure);
+  CHECK_GE(affix_table_->max_length(), affix_length_)
+      << "Affixes of length " << affix_length_ << " needed, but the affix "
+      <<"table only provides affixes of length <= "
+      << affix_table_->max_length() << ".";
+  TokenLookupFeature::Init(context);
+}
+
+FeatureValue AffixTableFeature::ComputeValue(const Token &token) const {
+  const string &word = token.word();
+  UnicodeText text;
+  text.PointToUTF8(word.c_str(), word.size());
+  if (affix_length_ > text.size()) return UnknownValue();
+  UnicodeText::const_iterator start, end;
+  if (type_ == AffixTable::PREFIX) {
+    start = end = text.begin();
+    for (int i = 0; i < affix_length_; ++i) ++end;
+  } else {
+    start = end = text.end();
+    for (int i = 0; i < affix_length_; ++i) --start;
+  }
+  string affix(start.utf8_data(), end.utf8_data() - start.utf8_data());
+  int affix_id = affix_table_->AffixId(affix);
+  return affix_id == -1 ? UnknownValue() : affix_id;
+}
+
+string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
+  if (value == UnknownValue()) return "<UNKNOWN>";
+  if (value >= 0 && value < UnknownValue()) {
+    return affix_table_->AffixForm(value);
+  }
+  LOG(ERROR) << "Invalid feature value: " << value;
+  return "<INVALID>";
+}
+
+// Registry for the Sentence + token index feature functions.
+REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
+
+// Register the features defined in the header.
+REGISTER_SENTENCE_IDX_FEATURE("word", Word);
+REGISTER_SENTENCE_IDX_FEATURE("lcword", LowercaseWord);
+REGISTER_SENTENCE_IDX_FEATURE("tag", Tag);
+REGISTER_SENTENCE_IDX_FEATURE("offset", Offset);
+REGISTER_SENTENCE_IDX_FEATURE("hyphen", Hyphen);
+REGISTER_SENTENCE_IDX_FEATURE("digit", Digit);
+REGISTER_SENTENCE_IDX_FEATURE("prefix", PrefixFeature);
+REGISTER_SENTENCE_IDX_FEATURE("suffix", SuffixFeature);
+
+}  // namespace syntaxnet

+ 317 - 0
syntaxnet/syntaxnet/sentence_features.h

@@ -0,0 +1,317 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Features that operate on Sentence objects. Most features are defined
+// in this header so they may be re-used via composition into other more
+// advanced feature classes.
+
+#ifndef $TARGETDIR_SENTENCE_FEATURES_H_
+#define $TARGETDIR_SENTENCE_FEATURES_H_
+
+#include "syntaxnet/affix.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/feature_types.h"
+#include "syntaxnet/shared_store.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/workspace.h"
+
+namespace syntaxnet {
+
+// Feature function for any component that processes Sentences, whose
+// focus is a token index into the sentence.
+typedef FeatureFunction<Sentence, int> SentenceFeature;
+
+// Alias for Locator type features that take (Sentence, int) signatures
+// and call other (Sentence, int) features.
+template <class DER>
+using Locator = FeatureLocator<DER, Sentence, int>;
+
+class TokenLookupFeature : public SentenceFeature {
+ public:
+  void Init(TaskContext *context) override {
+    set_feature_type(new ResourceBasedFeatureType<TokenLookupFeature>(
+        name(), this, {{NumValues(), "<OUTSIDE>"}}));
+  }
+
+  // Given a position in a sentence and workspaces, looks up the corresponding
+  // feature value. The index is relative to the start of the sentence.
+  virtual FeatureValue ComputeValue(const Token &token) const = 0;
+
+  // Number of unique values.
+  virtual int64 NumValues() const = 0;
+
+  // Convert the numeric value of the feature to a human readable string.
+  virtual string GetFeatureValueName(FeatureValue value) const = 0;
+
+  // Name of the shared workspace.
+  virtual string WorkspaceName() const = 0;
+
+  // Runs ComputeValue for each token in the sentence.
+  void Preprocess(WorkspaceSet *workspaces,
+                  Sentence *sentence) const override {
+    if (workspaces->Has<VectorIntWorkspace>(workspace_)) return;
+    VectorIntWorkspace *workspace = new VectorIntWorkspace(
+        sentence->token_size());
+    for (int i = 0; i < sentence->token_size(); ++i) {
+      const int value = ComputeValue(sentence->token(i));
+      workspace->set_element(i, value);
+    }
+    workspaces->Set<VectorIntWorkspace>(workspace_, workspace);
+  }
+
+  // Requests a vector of int's to store in the workspace registry.
+  void RequestWorkspaces(WorkspaceRegistry *registry) override {
+    workspace_ = registry->Request<VectorIntWorkspace>(WorkspaceName());
+  }
+
+  // Returns the precomputed value, or NumValues() for features outside
+  // the sentence.
+  FeatureValue Compute(const WorkspaceSet &workspaces,
+                       const Sentence &sentence, int focus,
+                       const FeatureVector *result) const override {
+    if (focus < 0 || focus >= sentence.token_size()) return NumValues();
+    return workspaces.Get<VectorIntWorkspace>(workspace_).element(focus);
+  }
+
+ private:
+  int workspace_;
+};
+
+// Lookup feature that uses a TermFrequencyMap to store a string->int mapping.
+class TermFrequencyMapFeature : public TokenLookupFeature {
+ public:
+  explicit TermFrequencyMapFeature(const string &input_name)
+      : input_name_(input_name), min_freq_(0), max_num_terms_(0) {}
+  ~TermFrequencyMapFeature() override;
+
+  // Requests the input map as a resource.
+  void Setup(TaskContext *context) override;
+
+  // Loads the input map into memory (using SharedStore to avoid redundancy.)
+  void Init(TaskContext *context) override;
+
+  // Number of unique values.
+  virtual int64 NumValues() const { return term_map_->Size() + 1; }
+
+  // Special value for strings not in the map.
+  FeatureValue UnknownValue() const { return term_map_->Size(); }
+
+  // Uses the TermFrequencyMap to lookup the string associated with a value.
+  string GetFeatureValueName(FeatureValue value) const override;
+
+  // Name of the shared workspace.
+  string WorkspaceName() const override;
+
+ protected:
+  const TermFrequencyMap &term_map() const { return *term_map_; }
+
+ private:
+  // Shortcut pointer to shared map. Not owned.
+  const TermFrequencyMap *term_map_ = nullptr;
+
+  // Name of the input for the term map.
+  string input_name_;
+
+  // Filename of the underlying resource.
+  string file_name_;
+
+  // Minimum frequency for term map.
+  int min_freq_;
+
+  // Maximum number of terms for term map.
+  int max_num_terms_;
+};
+
+class Word : public TermFrequencyMapFeature {
+ public:
+  Word() : TermFrequencyMapFeature("word-map") {}
+
+  FeatureValue ComputeValue(const Token &token) const override {
+    string form = token.word();
+    return term_map().LookupIndex(form, UnknownValue());
+  }
+};
+
+class LowercaseWord : public TermFrequencyMapFeature {
+ public:
+  LowercaseWord() : TermFrequencyMapFeature("lc-word-map") {}
+
+  FeatureValue ComputeValue(const Token &token) const override {
+    const string lcword = utils::Lowercase(token.word());
+    return term_map().LookupIndex(lcword, UnknownValue());
+  }
+};
+
+class Tag : public TermFrequencyMapFeature {
+ public:
+  Tag() : TermFrequencyMapFeature("tag-map") {}
+
+  FeatureValue ComputeValue(const Token &token) const override {
+    return term_map().LookupIndex(token.tag(), UnknownValue());
+  }
+};
+
+class Label : public TermFrequencyMapFeature {
+ public:
+  Label() : TermFrequencyMapFeature("label-map") {}
+
+  FeatureValue ComputeValue(const Token &token) const override {
+    return term_map().LookupIndex(token.label(), UnknownValue());
+  }
+};
+
+class LexicalCategoryFeature : public TokenLookupFeature {
+ public:
+  LexicalCategoryFeature(const string &name, int cardinality)
+      : name_(name), cardinality_(cardinality) {}
+  ~LexicalCategoryFeature() override {}
+
+  FeatureValue NumValues() const override { return cardinality_; }
+
+  // Returns the identifier for the workspace for this preprocessor.
+  string WorkspaceName() const override {
+    return tensorflow::strings::StrCat(name_, ":", cardinality_);
+  }
+
+ private:
+  // Name of the category type.
+  const string name_;
+
+  // Number of values.
+  const int cardinality_;
+};
+
+// Preprocessor that computes whether a word has a hyphen or not.
+class Hyphen : public LexicalCategoryFeature {
+ public:
+  // Enumeration of values.
+  enum Category {
+    NO_HYPHEN = 0,
+    HAS_HYPHEN = 1,
+    CARDINALITY = 2,
+  };
+
+  // Default constructor.
+  Hyphen() : LexicalCategoryFeature("hyphen", CARDINALITY) {}
+
+  // Returns a string representation of the enum value.
+  string GetFeatureValueName(FeatureValue value) const override;
+
+  // Returns the category value for the token.
+  FeatureValue ComputeValue(const Token &token) const override;
+};
+
+// Preprocessor that computes whether a word has a hyphen or not.
+class Digit : public LexicalCategoryFeature {
+ public:
+  // Enumeration of values.
+  enum Category {
+    NO_DIGIT = 0,
+    SOME_DIGIT = 1,
+    ALL_DIGIT = 2,
+    CARDINALITY = 3,
+  };
+
+  // Default constructor.
+  Digit() : LexicalCategoryFeature("digit", CARDINALITY) {}
+
+  // Returns a string representation of the enum value.
+  string GetFeatureValueName(FeatureValue value) const override;
+
+  // Returns the category value for the token.
+  FeatureValue ComputeValue(const Token &token) const override;
+};
+
+// TokenLookupPreprocessor object to compute prefixes and suffixes of words. The
+// AffixTable is stored in the SharedStore. This is very similar to the
+// implementation of TermFrequencyMapPreprocessor, but using an AffixTable to
+// perform the lookups. There are only two specializations, for prefixes and
+// suffixes.
+class AffixTableFeature : public TokenLookupFeature {
+ public:
+  // Explicit constructor to set the type of the table. This determines the
+  // requested input.
+  explicit AffixTableFeature(AffixTable::Type type);
+  ~AffixTableFeature() override;
+
+  // Requests inputs for the affix table.
+  void Setup(TaskContext *context) override;
+
+  // Loads the affix table from the SharedStore.
+  void Init(TaskContext *context) override;
+
+  // The workspace name is specific to which affix length we are computing.
+  string WorkspaceName() const override;
+
+  // Returns the total number of affixes in the table, regardless of specified
+  // length.
+  FeatureValue NumValues() const override { return affix_table_->size() + 1; }
+
+  // Special value for strings not in the map.
+  FeatureValue UnknownValue() const { return affix_table_->size(); }
+
+  // Looks up the affix for a given word.
+  FeatureValue ComputeValue(const Token &token) const override;
+
+  // Returns the string associated with a value.
+  string GetFeatureValueName(FeatureValue value) const override;
+
+ private:
+  // Size parameter for the affix table.
+  int affix_length_;
+
+  // Name of the input for the table.
+  string input_name_;
+
+  // The type of the affix table.
+  const AffixTable::Type type_;
+
+  // Affix table used for indexing. This comes from the shared store, and is not
+  // owned directly.
+  const AffixTable *affix_table_ = nullptr;
+};
+
+// Specific instantiation for computing prefixes. This requires the input
+// "prefix-table".
+class PrefixFeature : public AffixTableFeature {
+ public:
+  PrefixFeature() : AffixTableFeature(AffixTable::PREFIX) {}
+};
+
+// Specific instantiation for computing suffixes. Requires the input
+// "suffix-table."
+class SuffixFeature : public AffixTableFeature {
+ public:
+  SuffixFeature() : AffixTableFeature(AffixTable::SUFFIX) {}
+};
+
+// Offset locator. Simple locator: just changes the focus by some offset.
+class Offset : public Locator<Offset> {
+ public:
+  void UpdateArgs(const WorkspaceSet &workspaces,
+                  const Sentence &sentence, int *focus) const {
+    *focus += argument();
+  }
+};
+
+typedef FeatureExtractor<Sentence, int> SentenceExtractor;
+
+// Utility to register the sentence_instance::Feature functions.
+#define REGISTER_SENTENCE_IDX_FEATURE(name, type) \
+  REGISTER_FEATURE_FUNCTION(SentenceFeature, name, type)
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_SENTENCE_FEATURES_H_

+ 155 - 0
syntaxnet/syntaxnet/sentence_features_test.cc

@@ -0,0 +1,155 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/sentence_features.h"
+
+#include <string>
+#include <vector>
+
+#include <gmock/gmock.h>
+#include "syntaxnet/utils.h"
+#include "syntaxnet/feature_extractor.h"
+#include "syntaxnet/populate_test_inputs.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/workspace.h"
+
+using testing::UnorderedElementsAreArray;
+
+namespace syntaxnet {
+
+// A basic fixture for testing Features. Takes a string of a
+// Sentence protobuf that is used as the test data in the constructor.
+class SentenceFeaturesTest : public ::testing::Test {
+ protected:
+  explicit SentenceFeaturesTest(const string &prototxt)
+      : sentence_(ParseASCII(prototxt)),
+        creators_(PopulateTestInputs::Defaults(sentence_)) {}
+
+  static Sentence ParseASCII(const string &prototxt) {
+    Sentence document;
+    CHECK(TextFormat::ParseFromString(prototxt, &document));
+    return document;
+  }
+
+  // Prepares a new feature for extracting from the attached sentence,
+  // regenerating the TaskContext and all resources. Will automatically add
+  // anything in info_ field into the LexiFuse repository.
+  virtual void PrepareFeature(const string &fml) {
+    context_.mutable_spec()->mutable_input()->Clear();
+    context_.mutable_spec()->mutable_output()->Clear();
+    extractor_.reset(new SentenceExtractor());
+    extractor_->Parse(fml);
+    extractor_->Setup(&context_);
+    creators_.Populate(&context_);
+    extractor_->Init(&context_);
+    extractor_->RequestWorkspaces(&registry_);
+    workspaces_.Reset(registry_);
+    extractor_->Preprocess(&workspaces_, &sentence_);
+  }
+
+  // Returns the string representation of the prepared feature extracted at the
+  // given index.
+  virtual string ExtractFeature(int index) {
+    FeatureVector result;
+    extractor_->ExtractFeatures(workspaces_, sentence_, index,
+                                &result);
+    return result.type(0)->GetFeatureValueName(result.value(0));
+  }
+
+  // Extracts a vector of string representations from evaluating the prepared
+  // set feature (returning multiple values) at the given index.
+  virtual vector<string> ExtractMultiFeature(int index) {
+    vector<string> values;
+    FeatureVector result;
+    extractor_->ExtractFeatures(workspaces_, sentence_, index,
+                                &result);
+    for (int i = 0; i < result.size(); ++i) {
+      values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
+    }
+    return values;
+  }
+
+  Sentence sentence_;
+  WorkspaceSet workspaces_;
+
+  PopulateTestInputs::CreatorMap creators_;
+  TaskContext context_;
+  WorkspaceRegistry registry_;
+  std::unique_ptr<SentenceExtractor> extractor_;
+};
+
+// Test fixture for simple common features that operate on just a sentence.
+class CommonSentenceFeaturesTest : public SentenceFeaturesTest {
+ protected:
+  CommonSentenceFeaturesTest()
+      : SentenceFeaturesTest(
+            "text: 'I saw a man with a telescope.' "
+            "token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
+            " head: 1 label: 'nsubj' break_level: NO_BREAK } "
+            "token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
+            " label: 'ROOT' break_level: SPACE_BREAK } "
+            "token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
+            " head: 3 label: 'det' break_level: SPACE_BREAK } "
+            "token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
+            " head: 1 label: 'dobj' break_level: SPACE_BREAK } "
+            "token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
+            " head: 1 label: 'prep' break_level: SPACE_BREAK } "
+            "token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
+            " head: 6 label: 'det' break_level: SPACE_BREAK } "
+            "token { word: 'telescope' start: 19 end: 27 tag: 'NN' category: "
+            "'NOUN'"
+            " head: 4 label: 'pobj'  break_level: SPACE_BREAK } "
+            "token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
+            " head: 1 label: 'p' break_level: NO_BREAK }") {}
+};
+
+TEST_F(CommonSentenceFeaturesTest, TagFeature) {
+  PrepareFeature("tag");
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature(-1));
+  EXPECT_EQ("PRP", ExtractFeature(0));
+  EXPECT_EQ("VBD", ExtractFeature(1));
+  EXPECT_EQ("DT", ExtractFeature(2));
+  EXPECT_EQ("NN", ExtractFeature(3));
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature(8));
+}
+
+TEST_F(CommonSentenceFeaturesTest, TagFeaturePassesArgs) {
+  PrepareFeature("tag(min-freq=5)");  // don't load any tags
+  EXPECT_EQ(ExtractFeature(-1), "<OUTSIDE>");
+  EXPECT_EQ(ExtractFeature(0), "<UNKNOWN>");
+  EXPECT_EQ(ExtractFeature(8), "<OUTSIDE>");
+
+  // Only 2 features: <UNKNOWN> and <OUTSIDE>.
+  EXPECT_EQ(2, extractor_->feature_type(0)->GetDomainSize());
+}
+
+TEST_F(CommonSentenceFeaturesTest, OffsetPlusTag) {
+  PrepareFeature("offset(-1).tag(min-freq=2)");
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature(-1));
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature(0));
+  EXPECT_EQ("<UNKNOWN>", ExtractFeature(1));
+  EXPECT_EQ("<UNKNOWN>", ExtractFeature(2));
+  EXPECT_EQ("DT", ExtractFeature(3));  // DT, NN are the only freq tags
+  EXPECT_EQ("NN", ExtractFeature(4));
+  EXPECT_EQ("<UNKNOWN>", ExtractFeature(5));
+  EXPECT_EQ("DT", ExtractFeature(6));
+  EXPECT_EQ("NN", ExtractFeature(7));
+  EXPECT_EQ("<UNKNOWN>", ExtractFeature(8));
+  EXPECT_EQ("<OUTSIDE>", ExtractFeature(9));
+}
+
+}  // namespace syntaxnet

+ 91 - 0
syntaxnet/syntaxnet/shared_store.cc

@@ -0,0 +1,91 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/shared_store.h"
+
+#include <unordered_map>
+
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace syntaxnet {
+
+SharedStore::SharedObjectMap *SharedStore::shared_object_map_ =
+    new SharedObjectMap;
+
+mutex SharedStore::shared_object_map_mutex_(tensorflow::LINKER_INITIALIZED);
+
+SharedStore::SharedObjectMap *SharedStore::shared_object_map() {
+  return shared_object_map_;
+}
+
+bool SharedStore::Release(const void *object) {
+  if (object == nullptr) {
+    return true;
+  }
+  mutex_lock l(shared_object_map_mutex_);
+  for (SharedObjectMap::iterator it = shared_object_map()->begin();
+       it != shared_object_map()->end(); ++it) {
+    if (it->second.object == object) {
+      // Check the invariant that reference counts are positive. A violation
+      // likely implies memory corruption.
+      CHECK_GE(it->second.refcount, 1);
+      it->second.refcount--;
+      if (it->second.refcount == 0) {
+        it->second.delete_callback();
+        shared_object_map()->erase(it);
+      }
+      return true;
+    }
+  }
+  return false;
+}
+
+void SharedStore::Clear() {
+  mutex_lock l(shared_object_map_mutex_);
+  for (SharedObjectMap::iterator it = shared_object_map()->begin();
+       it != shared_object_map()->end(); ++it) {
+    it->second.delete_callback();
+  }
+  shared_object_map()->clear();
+}
+
+string SharedStoreUtils::CreateDefaultName() { return string(); }
+
+string SharedStoreUtils::ToString(const string &input) {
+  return ToString(tensorflow::StringPiece(input));
+}
+
+string SharedStoreUtils::ToString(const char *input) {
+  return ToString(tensorflow::StringPiece(input));
+}
+
+string SharedStoreUtils::ToString(tensorflow::StringPiece input) {
+  return tensorflow::strings::StrCat("\"", utils::CEscape(input.ToString()),
+                                     "\"");
+}
+
+string SharedStoreUtils::ToString(bool input) {
+  return input ? "true" : "false";
+}
+
+string SharedStoreUtils::ToString(float input) {
+  return tensorflow::strings::Printf("%af", input);
+}
+
+string SharedStoreUtils::ToString(double input) {
+  return tensorflow::strings::Printf("%a", input);
+}
+
+}  // namespace syntaxnet

+ 234 - 0
syntaxnet/syntaxnet/shared_store.h

@@ -0,0 +1,234 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Utility for creating read-only objects once and sharing them across threads.
+
+#ifndef $TARGETDIR_SHARED_STORE_H_
+#define $TARGETDIR_SHARED_STORE_H_
+
+#include <functional>
+#include <string>
+#include <typeindex>
+#include <unordered_map>
+#include <utility>
+
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+class SharedStore {
+ public:
+  // Returns an existing object with type T and name 'name' if it exists, else
+  // creates one with "new T(args...)".  Note: Objects will be indexed under
+  // their typeid + name, so names only have to be unique within a given type.
+  template <typename T, typename ...Args>
+  static const T *Get(const string &name,
+                      Args &&...args);  // NOLINT(build/c++11)
+
+  // Like Get(), but creates the object with "closure->Run()". If the closure
+  // returns null, we store a null in the SharedStore, but note that Release()
+  // cannot be used to remove it. This is because Release() finds the object
+  // by associative lookup, and there may be more than one null value, so we
+  // don't know which one to release. If the closure returns a duplicate value
+  // (one that is pointer-equal to an object already in the SharedStore),
+  // we disregard it and store null instead -- otherwise associative lookup
+  // would again fail (and the reference counts would be wrong).
+  template <typename T>
+  static const T *ClosureGet(const string &name, std::function<T *()> *closure);
+
+  // Like ClosureGet(), but check-fails if ClosureGet() would return null.
+  template <typename T>
+  static const T *ClosureGetOrDie(const string &name,
+                                  std::function<T *()> *closure);
+
+  // Release an object that was acquired by Get(). When its reference count
+  // hits 0, the object will be deleted. Returns true if the object was found.
+  // Does nothing and returns true if the object is null.
+  static bool Release(const void *object);
+
+  // Delete all objects in the shared store.
+  static void Clear();
+
+ private:
+  // A shared object.
+  struct SharedObject {
+    void *object;
+    std::function<void()> delete_callback;
+    int refcount;
+
+    SharedObject(void *o, std::function<void()> d)
+        : object(o), delete_callback(d), refcount(1) {}
+  };
+
+  // A map from keys to shared objects.
+  typedef std::unordered_map<string, SharedObject> SharedObjectMap;
+
+  // Return the shared object map.
+  static SharedObjectMap *shared_object_map();
+
+  // Return the string to use for indexing an object in the shared store.
+  template <typename T>
+  static string GetSharedKey(const string &name);
+
+  // Delete an object of type T.
+  template <typename T>
+  static void DeleteObject(T *object);
+
+  // Add an object to the shared object map. Return the object.
+  template <typename T>
+  static T *StoreObject(const string &key, T *object);
+
+  // Increment the reference count of an object in the map. Return the object.
+  template <typename T>
+  static T *IncrementRefCountOfObject(SharedObjectMap::iterator it);
+
+  // Map from keys to shared objects.
+  static SharedObjectMap *shared_object_map_;
+  static mutex shared_object_map_mutex_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(SharedStore);
+};
+
+template <typename T>
+string SharedStore::GetSharedKey(const string &name) {
+  const std::type_index id = std::type_index(typeid(T));
+  return tensorflow::strings::StrCat(id.name(), "_", name);
+}
+
+template <typename T>
+void SharedStore::DeleteObject(T *object) {
+  delete object;
+}
+
+template <typename T>
+T *SharedStore::StoreObject(const string &key, T *object) {
+  std::function<void()> delete_cb =
+      std::bind(SharedStore::DeleteObject<T>, object);
+  SharedObject so(object, delete_cb);
+  shared_object_map()->insert(std::make_pair(key, so));
+  return object;
+}
+
+template <typename T>
+T *SharedStore::IncrementRefCountOfObject(SharedObjectMap::iterator it) {
+  it->second.refcount++;
+  return static_cast<T *>(it->second.object);
+}
+
+template <typename T, typename ...Args>
+const T *SharedStore::Get(const string &name,
+                          Args &&...args) {  // NOLINT(build/c++11)
+  mutex_lock l(shared_object_map_mutex_);
+  const string key = GetSharedKey<T>(name);
+  SharedObjectMap::iterator it = shared_object_map()->find(key);
+  return (it == shared_object_map()->end()) ?
+      StoreObject<T>(key, new T(std::forward<Args>(args)...)) :
+      IncrementRefCountOfObject<T>(it);
+}
+
+template <typename T>
+const T *SharedStore::ClosureGet(const string &name,
+                                 std::function<T *()> *closure) {
+  mutex_lock l(shared_object_map_mutex_);
+  const string key = GetSharedKey<T>(name);
+  SharedObjectMap::iterator it = shared_object_map()->find(key);
+  if (it == shared_object_map()->end()) {
+    // Creates a new object by calling the closure.
+    T *object = (*closure)();
+    if (object == nullptr) {
+      LOG(ERROR) << "Closure returned a null pointer";
+    } else {
+      for (SharedObjectMap::iterator it = shared_object_map()->begin();
+           it != shared_object_map()->end(); ++it) {
+        if (it->second.object == object) {
+          LOG(ERROR)
+              << "Closure returned duplicate pointer: "
+              << "keys " << it->first << " and " << key;
+
+          // Not a memory leak to discard pointer, since we have another copy.
+          object = nullptr;
+          break;
+        }
+      }
+    }
+    return StoreObject<T>(key, object);
+  } else {
+    return IncrementRefCountOfObject<T>(it);
+  }
+}
+
+template <typename T>
+const T *SharedStore::ClosureGetOrDie(const string &name,
+                                      std::function<T *()> *closure) {
+  const T *object = ClosureGet<T>(name, closure);
+  CHECK(object != nullptr);
+  return object;
+}
+
+// A collection of utility functions for working with the shared store.
+class SharedStoreUtils {
+ public:
+  // Returns a shared object registered using a default name that is created
+  // from the constructor args.
+  //
+  // NB: This function does not guarantee a one-to-one relationship between
+  // sets of constructor args and names.  See warnings on CreateDefaultName().
+  // It is the caller's responsibility to ensure that the args provided will
+  // result in unique names.
+  template <class T, class... Args>
+  static const T *GetWithDefaultName(Args &&... args) {  // NOLINT(build/c++11)
+    return SharedStore::Get<T>(CreateDefaultName(std::forward<Args>(args)...),
+                               std::forward<Args>(args)...);
+  }
+
+  // Returns a string name representing the args.  Implemented via a pair of
+  // overloaded functions to achieve compile-time recursion.
+  //
+  // WARNING: It is possible for instances of different types to have the same
+  // string representation.  For example,
+  //
+  // CreateDefaultName(1) == CreateDefaultName(1ULL)
+  //
+  template <class First, class... Rest>
+  static string CreateDefaultName(First &&first,
+                                  Rest &&... rest) {  // NOLINT(build/c++11)
+    return tensorflow::strings::StrCat(
+        ToString<First>(std::forward<First>(first)), ",",
+        CreateDefaultName(std::forward<Rest>(rest)...));
+  }
+  static string CreateDefaultName();
+
+ private:
+  // Returns a string representing the input.  The generic implementation uses
+  // StrCat(), and overloads are provided for selected types.
+  template <class T>
+  static string ToString(T input) {
+    return tensorflow::strings::StrCat(input);
+  }
+  static string ToString(const string &input);
+  static string ToString(const char *input);
+  static string ToString(tensorflow::StringPiece input);
+  static string ToString(bool input);
+  static string ToString(float input);
+  static string ToString(double input);
+
+  TF_DISALLOW_COPY_AND_ASSIGN(SharedStoreUtils);
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_SHARED_STORE_H_

+ 242 - 0
syntaxnet/syntaxnet/shared_store_test.cc

@@ -0,0 +1,242 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/shared_store.h"
+
+#include <string>
+
+#include <gmock/gmock.h>
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+
+using ::testing::_;
+
+namespace syntaxnet {
+
+struct NoArgs {
+  NoArgs() {
+    LOG(INFO) << "Calling NoArgs()";
+  }
+};
+
+struct OneArg {
+  string name;
+  explicit OneArg(const string &n) : name(n) {
+    LOG(INFO) << "Calling OneArg(" << name << ")";
+  }
+};
+
+struct TwoArgs {
+  string name;
+  int age;
+  TwoArgs(const string &n, int a) : name(n), age(a) {
+    LOG(INFO) << "Calling TwoArgs(" << name << ", " << age << ")";
+  }
+};
+
+struct Slow {
+  string lengthy;
+  Slow() {
+    LOG(INFO) << "Calling Slow()";
+    lengthy.assign(50 << 20, 'L');  // 50MB of the letter 'L'
+  }
+};
+
+struct CountCalls {
+  CountCalls() {
+    LOG(INFO) << "Calling CountCalls()";
+    ++constructor_calls;
+  }
+
+  ~CountCalls() {
+    LOG(INFO) << "Calling ~CountCalls()";
+    ++destructor_calls;
+  }
+
+  static void Reset() {
+    constructor_calls = 0;
+    destructor_calls = 0;
+  }
+
+  static int constructor_calls;
+  static int destructor_calls;
+};
+
+int CountCalls::constructor_calls = 0;
+int CountCalls::destructor_calls = 0;
+
+class PointerSet {
+ public:
+  PointerSet() { }
+
+  void Add(const void *p) {
+    mutex_lock l(mu_);
+    pointers_.insert(p);
+  }
+
+  int size() {
+    mutex_lock l(mu_);
+    return pointers_.size();
+  }
+
+ private:
+  mutex mu_;
+  unordered_set<const void *> pointers_;
+};
+
+class SharedStoreTest : public testing::Test {
+ protected:
+  ~SharedStoreTest() {
+    // Clear the shared store after each test, otherwise objects created
+    // in one test may interfere with other tests.
+    SharedStore::Clear();
+  }
+};
+
+// Verify that we can call constructors with varying numbers and types of args.
+TEST_F(SharedStoreTest, ConstructorArgs) {
+  SharedStore::Get<NoArgs>("no args");
+  SharedStore::Get<OneArg>("one arg", "Fred");
+  SharedStore::Get<TwoArgs>("two args", "Pebbles", 2);
+}
+
+// Verify that an object with a given key is created only once.
+TEST_F(SharedStoreTest, Shared) {
+  const NoArgs *ob1 = SharedStore::Get<NoArgs>("first");
+  const NoArgs *ob2 = SharedStore::Get<NoArgs>("second");
+  const NoArgs *ob3 = SharedStore::Get<NoArgs>("first");
+  EXPECT_EQ(ob1, ob3);
+  EXPECT_NE(ob1, ob2);
+  EXPECT_NE(ob2, ob3);
+}
+
+// Verify that objects with the same name but different types do not collide.
+TEST_F(SharedStoreTest, DifferentTypes) {
+  const NoArgs *ob1 = SharedStore::Get<NoArgs>("same");
+  const OneArg *ob2 = SharedStore::Get<OneArg>("same", "foo");
+  const TwoArgs *ob3 = SharedStore::Get<TwoArgs>("same", "bar", 5);
+  EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob2));
+  EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob3));
+  EXPECT_NE(static_cast<const void *>(ob2), static_cast<const void *>(ob3));
+}
+
+// Factory method to make a OneArg.
+OneArg *MakeOneArg(const string &n) {
+  return new OneArg(n);
+}
+
+TEST_F(SharedStoreTest, ClosureGet) {
+  std::function<OneArg *()> closure1 = std::bind(MakeOneArg, "Al");
+  std::function<OneArg *()> closure2 = std::bind(MakeOneArg, "Al");
+  const OneArg *ob1 = SharedStore::ClosureGet("first", &closure1);
+  const OneArg *ob2 = SharedStore::ClosureGet("first", &closure2);
+  EXPECT_EQ("Al", ob1->name);
+  EXPECT_EQ(ob1, ob2);
+}
+
+TEST_F(SharedStoreTest, PermanentCallback) {
+  std::function<OneArg *()> closure = std::bind(MakeOneArg, "Al");
+  const OneArg *ob1 = SharedStore::ClosureGet("first", &closure);
+  const OneArg *ob2 = SharedStore::ClosureGet("first", &closure);
+  EXPECT_EQ("Al", ob1->name);
+  EXPECT_EQ(ob1, ob2);
+}
+
+// Factory method to "make" a NoArgs by simply returning an input pointer.
+NoArgs *BogusMakeNoArgs(NoArgs *ob) {
+  return ob;
+}
+
+// Create a CountCalls object, pretend it failed, and return null.
+CountCalls *MakeFailedCountCalls() {
+  CountCalls *ob = new CountCalls;
+  delete ob;
+  return nullptr;
+}
+
+// Verify that ClosureGet() only calls the closure for a given key once,
+// even if the closure fails.
+TEST_F(SharedStoreTest, FailedClosureGet) {
+  CountCalls::Reset();
+  std::function<CountCalls *()> closure1(MakeFailedCountCalls);
+  std::function<CountCalls *()> closure2(MakeFailedCountCalls);
+  const CountCalls *ob1 = SharedStore::ClosureGet("first", &closure1);
+  const CountCalls *ob2 = SharedStore::ClosureGet("first", &closure2);
+  EXPECT_EQ(nullptr, ob1);
+  EXPECT_EQ(nullptr, ob2);
+  EXPECT_EQ(1, CountCalls::constructor_calls);
+}
+
+typedef SharedStoreTest SharedStoreDeathTest;
+
+TEST_F(SharedStoreDeathTest, ClosureGetOrDie) {
+  NoArgs *empty = nullptr;
+  std::function<NoArgs *()> closure = std::bind(BogusMakeNoArgs, empty);
+  EXPECT_DEATH(SharedStore::ClosureGetOrDie("first", &closure), "nullptr");
+}
+
+TEST_F(SharedStoreTest, Release) {
+  const OneArg *ob1 = SharedStore::Get<OneArg>("first", "Fred");
+  const OneArg *ob2 = SharedStore::Get<OneArg>("first", "Fred");
+  EXPECT_EQ(ob1, ob2);
+  EXPECT_TRUE(SharedStore::Release(ob1));      // now refcount = 1
+  EXPECT_TRUE(SharedStore::Release(ob1));      // now object is deleted
+  EXPECT_FALSE(SharedStore::Release(ob1));     // now object is not found
+  EXPECT_TRUE(SharedStore::Release(nullptr));  // release(nullptr) returns true
+}
+
+TEST_F(SharedStoreTest, Clear) {
+  CountCalls::Reset();
+
+  SharedStore::Get<CountCalls>("first");
+  SharedStore::Get<CountCalls>("second");
+  SharedStore::Get<CountCalls>("first");
+
+  // Test that the constructor and destructor are each called exactly once
+  // for each key in the shared store.
+  SharedStore::Clear();
+  EXPECT_EQ(2, CountCalls::constructor_calls);
+  EXPECT_EQ(2, CountCalls::destructor_calls);
+}
+
+void GetSharedObject(PointerSet *ps) {
+  // Gets a shared object whose constructor takes a long time.
+  const Slow *ob = SharedStore::Get<Slow>("first");
+
+  // Collects the pointer we got. Later, we'll check whether SharedStore
+  // mistakenly called the constructor more than once.
+  ps->Add(static_cast<const void *>(ob));
+}
+
+// If multiple parallel threads all access an object with the same key,
+// only one object is created.
+TEST_F(SharedStoreTest, ThreadSafety) {
+  const int kNumThreads = 20;
+  tensorflow::thread::ThreadPool *pool = new tensorflow::thread::ThreadPool(
+      tensorflow::Env::Default(), "ThreadSafetyPool", kNumThreads);
+  PointerSet ps;
+  for (int i = 0; i < kNumThreads; ++i) {
+    std::function<void()> closure = std::bind(GetSharedObject, &ps);
+    pool->Schedule(closure);
+  }
+
+  // Waits for closures to finish, then delete the pool.
+  delete pool;
+
+  // Expects only one object to have been created across all threads.
+  EXPECT_EQ(1, ps.size());
+}
+
+}  // namespace syntaxnet

+ 19 - 0
syntaxnet/syntaxnet/sparse.proto

@@ -0,0 +1,19 @@
+// Protocol for passing around sparse sets of features.
+
+syntax = "proto2";
+
+package syntaxnet;
+
+// A sparse set of features.
+//
+// If using SparseStringToIdTransformer, description is required and id should
+// be omitted; otherwise, id is required and description optional.
+//
+// id, weight, and description fields are all aligned if present (ie, any of
+// these that are non-empty should have the same # items). If weight is omitted,
+// 1.0 is used.
+message SparseFeatures {
+  repeated uint64 id = 1;
+  repeated float weight = 2;
+  repeated string description = 3;
+};

+ 240 - 0
syntaxnet/syntaxnet/structured_graph_builder.py

@@ -0,0 +1,240 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Build structured parser models."""
+
+import tensorflow as tf
+
+from tensorflow.python.ops import control_flow_ops as cf
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import tensor_array_ops
+
+from syntaxnet import graph_builder
+from syntaxnet.ops import gen_parser_ops
+
+tf.NoGradient('BeamParseReader')
+tf.NoGradient('BeamParser')
+tf.NoGradient('BeamParserOutput')
+
+
+def AddCrossEntropy(batch_size, n):
+  """Adds a cross entropy cost function."""
+  cross_entropies = []
+  def _Pass():
+    return tf.constant(0, dtype=tf.float32, shape=[1])
+
+  for beam_id in range(batch_size):
+    beam_gold_slot = tf.reshape(tf.slice(n['gold_slot'], [beam_id], [1]), [1])
+    def _ComputeCrossEntropy():
+      """Adds ops to compute cross entropy of the gold path in a beam."""
+      # Requires a cast so that UnsortedSegmentSum, in the gradient,
+      # is happy with the type of its input 'segment_ids', which
+      # must be int32.
+      idx = tf.cast(
+          tf.reshape(
+              tf.where(tf.equal(n['beam_ids'], beam_id)), [-1]), tf.int32)
+      beam_scores = tf.reshape(tf.gather(n['all_path_scores'], idx), [1, -1])
+      num = tf.shape(idx)
+      return tf.nn.softmax_cross_entropy_with_logits(
+          beam_scores, tf.expand_dims(
+              tf.sparse_to_dense(beam_gold_slot, num, [1.], 0.), 0))
+    # The conditional here is needed to deal with the last few batches of the
+    # corpus which can contain -1 in beam_gold_slot for empty batch slots.
+    cross_entropies.append(cf.cond(
+        beam_gold_slot[0] >= 0, _ComputeCrossEntropy, _Pass))
+  return {'cross_entropy': tf.div(tf.add_n(cross_entropies), batch_size)}
+
+
+class StructuredGraphBuilder(graph_builder.GreedyParser):
+  """Extends the standard GreedyParser with a CRF objective using a beam.
+
+  The constructor takes two additional keyword arguments.
+  beam_size: the maximum size the beam can grow to.
+  max_steps: the maximum number of steps in any particular beam.
+
+  The model supports batch training with the batch_size argument to the
+  AddTraining method.
+  """
+
+  def __init__(self, *args, **kwargs):
+    self._beam_size = kwargs.pop('beam_size', 10)
+    self._max_steps = kwargs.pop('max_steps', 25)
+    super(StructuredGraphBuilder, self).__init__(*args, **kwargs)
+
+  def _AddBeamReader(self,
+                     task_context,
+                     batch_size,
+                     corpus_name,
+                     until_all_final=False,
+                     always_start_new_sentences=False):
+    """Adds an op capable of reading sentences and parsing them with a beam."""
+    features, state, epochs = gen_parser_ops.beam_parse_reader(
+        task_context=task_context,
+        feature_size=self._feature_size,
+        beam_size=self._beam_size,
+        batch_size=batch_size,
+        corpus_name=corpus_name,
+        allow_feature_weights=self._allow_feature_weights,
+        arg_prefix=self._arg_prefix,
+        continue_until_all_final=until_all_final,
+        always_start_new_sentences=always_start_new_sentences)
+    return {'state': state, 'features': features, 'epochs': epochs}
+
+  def _BuildSequence(self,
+                     batch_size,
+                     max_steps,
+                     features,
+                     state,
+                     use_average=False):
+    """Adds a sequence of beam parsing steps."""
+    def Advance(state, step, scores_array, alive, alive_steps, *features):
+      scores = self._BuildNetwork(features,
+                                  return_average=use_average)['logits']
+      scores_array = scores_array.write(step, scores)
+      features, state, alive = (
+          gen_parser_ops.beam_parser(state, scores, self._feature_size))
+      return [state, step + 1, scores_array, alive, alive_steps + tf.cast(
+          alive, tf.int32)] + list(features)
+
+    # args: (state, step, scores_array, alive, alive_steps, *features)
+    def KeepGoing(*args):
+      return tf.logical_and(args[1] < max_steps, tf.reduce_any(args[3]))
+
+    step = tf.constant(0, tf.int32, [])
+    scores_array = tensor_array_ops.TensorArray(dtype=tf.float32,
+                                                size=0,
+                                                dynamic_size=True)
+    alive = tf.constant(True, tf.bool, [batch_size])
+    alive_steps = tf.constant(0, tf.int32, [batch_size])
+    t = tf.while_loop(
+        KeepGoing,
+        Advance,
+        [state, step, scores_array, alive, alive_steps] + list(features),
+        parallel_iterations=100)
+
+    # Link to the final nodes/values of ops that have passed through While:
+    return {'state': t[0],
+            'concat_scores': t[2].concat(),
+            'alive': t[3],
+            'alive_steps': t[4]}
+
+  def AddTraining(self,
+                  task_context,
+                  batch_size,
+                  learning_rate=0.1,
+                  decay_steps=4000,
+                  momentum=None,
+                  corpus_name='documents'):
+    with tf.name_scope('training'):
+      n = self.training
+      n['accumulated_alive_steps'] = self._AddVariable(
+          [batch_size], tf.int32, 'accumulated_alive_steps',
+          tf.zeros_initializer)
+      n.update(self._AddBeamReader(task_context, batch_size, corpus_name))
+      # This adds a required 'step' node too:
+      learning_rate = tf.constant(learning_rate, dtype=tf.float32)
+      n['learning_rate'] = self._AddLearningRate(learning_rate, decay_steps)
+      # Call BuildNetwork *only* to set up the params outside of the main loop.
+      self._BuildNetwork(list(n['features']))
+
+      n.update(self._BuildSequence(batch_size, self._max_steps, n['features'],
+                                   n['state']))
+
+      flat_concat_scores = tf.reshape(n['concat_scores'], [-1])
+      (indices_and_paths, beams_and_slots, n['gold_slot'], n[
+          'beam_path_scores']) = gen_parser_ops.beam_parser_output(n[
+              'state'])
+      n['indices'] = tf.reshape(tf.gather(indices_and_paths, [0]), [-1])
+      n['path_ids'] = tf.reshape(tf.gather(indices_and_paths, [1]), [-1])
+      n['all_path_scores'] = tf.sparse_segment_sum(
+          flat_concat_scores, n['indices'], n['path_ids'])
+      n['beam_ids'] = tf.reshape(tf.gather(beams_and_slots, [0]), [-1])
+      n.update(AddCrossEntropy(batch_size, n))
+
+      if self._only_train:
+        trainable_params = {k: v for k, v in self.params.iteritems()
+                            if k in self._only_train}
+      else:
+        trainable_params = self.params
+      for p in trainable_params:
+        tf.logging.info('trainable_param: %s', p)
+
+      regularized_params = [
+          tf.nn.l2_loss(p) for k, p in trainable_params.iteritems()
+          if k.startswith('weights') or k.startswith('bias')]
+      l2_loss = 1e-4 * tf.add_n(regularized_params) if regularized_params else 0
+
+      n['cost'] = tf.add(n['cross_entropy'], l2_loss, name='cost')
+
+      n['gradients'] = tf.gradients(n['cost'], trainable_params.values())
+
+      with tf.control_dependencies([n['alive_steps']]):
+        update_accumulators = tf.group(
+            tf.assign_add(n['accumulated_alive_steps'], n['alive_steps']))
+
+      def ResetAccumulators():
+        return tf.assign(
+            n['accumulated_alive_steps'], tf.zeros([batch_size], tf.int32))
+      n['reset_accumulators_func'] = ResetAccumulators
+
+      optimizer = tf.train.MomentumOptimizer(n['learning_rate'],
+                                             momentum,
+                                             use_locking=self._use_locking)
+      train_op = optimizer.minimize(n['cost'],
+                                    var_list=trainable_params.values())
+      for param in trainable_params.values():
+        slot = optimizer.get_slot(param, 'momentum')
+        self.inits[slot.name] = state_ops.init_variable(slot,
+                                                        tf.zeros_initializer)
+        self.variables[slot.name] = slot
+
+      def NumericalChecks():
+        return tf.group(*[
+            tf.check_numerics(param, message='Parameter is not finite.')
+            for param in trainable_params.values()
+            if param.dtype.base_dtype in [tf.float32, tf.float64]])
+      check_op = cf.cond(tf.equal(tf.mod(self.GetStep(), self._check_every), 0),
+                         NumericalChecks, tf.no_op)
+      avg_update_op = tf.group(*self._averaging.values())
+      train_ops = [train_op]
+      if self._check_parameters:
+        train_ops.append(check_op)
+      if self._use_averaging:
+        train_ops.append(avg_update_op)
+      with tf.control_dependencies([update_accumulators]):
+        n['train_op'] = tf.group(*train_ops, name='train_op')
+      n['alive_steps'] = tf.identity(n['alive_steps'], name='alive_steps')
+    return n
+
+  def AddEvaluation(self,
+                    task_context,
+                    batch_size,
+                    evaluation_max_steps=300,
+                    corpus_name=None):
+    with tf.name_scope('evaluation'):
+      n = self.evaluation
+      n.update(self._AddBeamReader(task_context,
+                                   batch_size,
+                                   corpus_name,
+                                   until_all_final=True,
+                                   always_start_new_sentences=True))
+      self._BuildNetwork(
+          list(n['features']),
+          return_average=self._use_averaging)
+      n.update(self._BuildSequence(batch_size, evaluation_max_steps, n[
+          'features'], n['state'], use_average=self._use_averaging))
+      n['eval_metrics'], n['documents'] = (
+          gen_parser_ops.beam_eval_output(n['state']))
+    return n

+ 107 - 0
syntaxnet/syntaxnet/syntaxnet.bzl

@@ -0,0 +1,107 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+load("@tf//google/protobuf:protobuf.bzl", "cc_proto_library")
+load("@tf//google/protobuf:protobuf.bzl", "py_proto_library")
+
+def if_cuda(a, b=[]):
+  return select({
+      "@tf//third_party/gpus/cuda:cuda_crosstool_condition": a,
+      "//conditions:default": b,
+  })
+
+def tf_copts():
+  return (["-fno-exceptions", "-DEIGEN_AVOID_STL_ARRAY",] +
+          if_cuda(["-DGOOGLE_CUDA=1"]) +
+          select({"@tf//tensorflow:darwin": [],
+                  "//conditions:default": ["-pthread"]}))
+
+def tf_proto_library(name, srcs=[], has_services=False,
+                     deps=[], visibility=None, testonly=0,
+                     cc_api_version=2, go_api_version=2,
+                     java_api_version=2,
+                     py_api_version=2):
+  native.filegroup(name=name + "_proto_srcs",
+                   srcs=srcs,
+                   testonly=testonly,)
+
+  cc_proto_library(name=name,
+                   srcs=srcs,
+                   deps=deps,
+                   cc_libs = ["@tf//google/protobuf:protobuf"],
+                   protoc="@tf//google/protobuf:protoc",
+                   default_runtime="@tf//google/protobuf:protobuf",
+                   testonly=testonly,
+                   visibility=visibility,)
+
+def tf_proto_library_py(name, srcs=[], deps=[], visibility=None, testonly=0):
+  py_proto_library(name=name,
+                   srcs=srcs,
+                   srcs_version = "PY2AND3",
+                   deps=deps,
+                   default_runtime="@tf//google/protobuf:protobuf_python",
+                   protoc="@tf//google/protobuf:protoc",
+                   visibility=visibility,
+                   testonly=testonly,)
+
+# Given a list of "op_lib_names" (a list of files in the ops directory
+# without their .cc extensions), generate a library for that file.
+def tf_gen_op_libs(op_lib_names):
+  # Make library out of each op so it can also be used to generate wrappers
+  # for various languages.
+  for n in op_lib_names:
+    native.cc_library(name=n + "_op_lib",
+                      copts=tf_copts(),
+                      srcs=["ops/" + n + ".cc"],
+                      deps=(["@tf//tensorflow/core:framework"]),
+                      visibility=["//visibility:public"],
+                      alwayslink=1,
+                      linkstatic=1,)
+
+# Invoke this rule in .../tensorflow/python to build the wrapper library.
+def tf_gen_op_wrapper_py(name, out=None, hidden=[], visibility=None, deps=[],
+                         require_shape_functions=False):
+  # Construct a cc_binary containing the specified ops.
+  tool_name = "gen_" + name + "_py_wrappers_cc"
+  if not deps:
+    deps = ["//tensorflow/core:" + name + "_op_lib"]
+  native.cc_binary(
+      name = tool_name,
+      linkopts = ["-lm"],
+      copts = tf_copts(),
+      linkstatic = 1,   # Faster to link this one-time-use binary dynamically
+      deps = (["@tf//tensorflow/core:framework",
+               "@tf//tensorflow/python:python_op_gen_main"] + deps),
+  )
+
+  # Invoke the previous cc_binary to generate a python file.
+  if not out:
+    out = "ops/gen_" + name + ".py"
+
+  native.genrule(
+      name=name + "_pygenrule",
+      outs=[out],
+      tools=[tool_name],
+      cmd=("$(location " + tool_name + ") " + ",".join(hidden)
+           + " " + ("1" if require_shape_functions else "0") + " > $@"))
+
+  # Make a py_library out of the generated python file.
+  native.py_library(name=name,
+                    srcs=[out],
+                    srcs_version="PY2AND3",
+                    visibility=visibility,
+                    deps=[
+                        "@tf//tensorflow/python:framework_for_generated_wrappers",
+                    ],)

+ 258 - 0
syntaxnet/syntaxnet/tagger_transitions.cc

@@ -0,0 +1,258 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Tagger transition system.
+//
+// This transition system has one type of actions:
+//  - The SHIFT action pushes the next input token to the stack and
+//    advances to the next input token, assigning a part-of-speech tag to the
+//    token that was shifted.
+//
+// The transition system operates with parser actions encoded as integers:
+//  - A SHIFT action is encoded as number starting from 0.
+
+#include <string>
+
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/shared_store.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+class TaggerTransitionState : public ParserTransitionState {
+ public:
+  explicit TaggerTransitionState(const TermFrequencyMap *tag_map,
+                                 const TagToCategoryMap *tag_to_category)
+      : tag_map_(tag_map), tag_to_category_(tag_to_category) {}
+
+  explicit TaggerTransitionState(const TaggerTransitionState *state)
+      : TaggerTransitionState(state->tag_map_, state->tag_to_category_) {
+    tag_ = state->tag_;
+    gold_tag_ = state->gold_tag_;
+  }
+
+  // Clones the transition state by returning a new object.
+  ParserTransitionState *Clone() const override {
+    return new TaggerTransitionState(this);
+  }
+
+  // Reads gold tags for each token.
+  void Init(ParserState *state) {
+    tag_.resize(state->sentence().token_size(), -1);
+    gold_tag_.resize(state->sentence().token_size(), -1);
+    for (int pos = 0; pos < state->sentence().token_size(); ++pos) {
+      int tag = tag_map_->LookupIndex(state->GetToken(pos).tag(), -1);
+      gold_tag_[pos] = tag;
+    }
+  }
+
+  // Returns the tag assigned to a given token.
+  int Tag(int index) const {
+    DCHECK_GE(index, 0);
+    DCHECK_LT(index, tag_.size());
+    return index == -1 ? -1 : tag_[index];
+  }
+
+  // Sets this tag on the token at index.
+  void SetTag(int index, int tag) {
+    DCHECK_GE(index, 0);
+    DCHECK_LT(index, tag_.size());
+    tag_[index] = tag;
+  }
+
+  // Returns the gold tag for a given token.
+  int GoldTag(int index) const {
+    DCHECK_GE(index, -1);
+    DCHECK_LT(index, gold_tag_.size());
+    return index == -1 ? -1 : gold_tag_[index];
+  }
+
+  // Returns the string representation of a POS tag, or an empty string
+  // if the tag is invalid.
+  string TagAsString(int tag) const {
+    if (tag >= 0 && tag < tag_map_->Size()) {
+      return tag_map_->GetTerm(tag);
+    }
+    return "";
+  }
+
+  // Adds transition state specific annotations to the document.
+  void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
+                          Sentence *sentence) const override {
+    for (size_t i = 0; i < tag_.size(); ++i) {
+      Token *token = sentence->mutable_token(i);
+      token->set_tag(TagAsString(Tag(i)));
+      token->set_category(tag_to_category_->GetCategory(token->tag()));
+    }
+  }
+
+  // Whether a parsed token should be considered correct for evaluation.
+  bool IsTokenCorrect(const ParserState &state, int index) const override {
+    return GoldTag(index) == Tag(index);
+  }
+
+  // Returns a human readable string representation of this state.
+  string ToString(const ParserState &state) const override {
+    string str;
+    for (int i = state.StackSize(); i > 0; --i) {
+      const string &word = state.GetToken(state.Stack(i - 1)).word();
+      if (i != state.StackSize() - 1) str.append(" ");
+      tensorflow::strings::StrAppend(
+          &str, word, "[", TagAsString(Tag(state.StackSize() - i)), "]");
+    }
+    for (int i = state.Next(); i < state.NumTokens(); ++i) {
+      tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
+    }
+    return str;
+  }
+
+ private:
+  // Currently assigned POS tags for each token in this sentence.
+  vector<int> tag_;
+
+  // Gold POS tags from the input document.
+  vector<int> gold_tag_;
+
+  // Tag map used for conversions between integer and string representations
+  // part of speech tags. Not owned.
+  const TermFrequencyMap *tag_map_ = nullptr;
+
+  // Tag to category map. Not owned.
+  const TagToCategoryMap *tag_to_category_ = nullptr;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(TaggerTransitionState);
+};
+
+class TaggerTransitionSystem : public ParserTransitionSystem {
+ public:
+  ~TaggerTransitionSystem() override { SharedStore::Release(tag_map_); }
+
+  // Determines tag map location.
+  void Setup(TaskContext *context) override {
+    input_tag_map_ = context->GetInput("tag-map", "text", "");
+    input_tag_to_category_ = context->GetInput("tag-to-category", "text", "");
+  }
+
+  // Reads tag map and tag to category map.
+  void Init(TaskContext *context) {
+    const string tag_map_path = TaskContext::InputFile(*input_tag_map_);
+    tag_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
+        tag_map_path, 0, 0);
+    const string tag_to_category_path =
+        TaskContext::InputFile(*input_tag_to_category_);
+    tag_to_category_ = SharedStoreUtils::GetWithDefaultName<TagToCategoryMap>(
+        tag_to_category_path);
+  }
+
+  // The SHIFT action uses the same value as the corresponding action type.
+  static ParserAction ShiftAction(int tag) { return tag; }
+
+  // Returns the number of action types.
+  int NumActionTypes() const override { return 1; }
+
+  // Returns the number of possible actions.
+  int NumActions(int num_labels) const override { return tag_map_->Size(); }
+
+  // The default action for a given state is assigning the most frequent tag.
+  ParserAction GetDefaultAction(const ParserState &state) const override {
+    return ShiftAction(0);
+  }
+
+  // Returns the next gold action for a given state according to the
+  // underlying annotated sentence.
+  ParserAction GetNextGoldAction(const ParserState &state) const override {
+    if (!state.EndOfInput()) {
+      return ShiftAction(TransitionState(state).GoldTag(state.Next()));
+    }
+    return ShiftAction(0);
+  }
+
+  // Checks if the action is allowed in a given parser state.
+  bool IsAllowedAction(ParserAction action,
+                       const ParserState &state) const override {
+    return !state.EndOfInput();
+  }
+
+  // Makes a shift by pushing the next input token on the stack and moving to
+  // the next position.
+  void PerformActionWithoutHistory(ParserAction action,
+                                   ParserState *state) const override {
+    DCHECK(!state->EndOfInput());
+    if (!state->EndOfInput()) {
+      MutableTransitionState(state)->SetTag(state->Next(), action);
+      state->Push(state->Next());
+      state->Advance();
+    }
+  }
+
+  // We are in a final state when we reached the end of the input and the stack
+  // is empty.
+  bool IsFinalState(const ParserState &state) const override {
+    return state.EndOfInput();
+  }
+
+  // Returns a string representation of a parser action.
+  string ActionAsString(ParserAction action,
+                        const ParserState &state) const override {
+    return tensorflow::strings::StrCat("SHIFT(", tag_map_->GetTerm(action),
+                                       ")");
+  }
+
+  // No state is deterministic in this transition system.
+  bool IsDeterministicState(const ParserState &state) const override {
+    return false;
+  }
+
+  // Returns a new transition state to be used to enhance the parser state.
+  ParserTransitionState *NewTransitionState(bool training_mode) const override {
+    return new TaggerTransitionState(tag_map_, tag_to_category_);
+  }
+
+  // Downcasts the const ParserTransitionState in ParserState to a const
+  // TaggerTransitionState.
+  static const TaggerTransitionState &TransitionState(
+      const ParserState &state) {
+    return *static_cast<const TaggerTransitionState *>(
+        state.transition_state());
+  }
+
+  // Downcasts the ParserTransitionState in ParserState to an
+  // TaggerTransitionState.
+  static TaggerTransitionState *MutableTransitionState(ParserState *state) {
+    return static_cast<TaggerTransitionState *>(
+        state->mutable_transition_state());
+  }
+
+  // Input for the tag map. Not owned.
+  TaskInput *input_tag_map_ = nullptr;
+
+  // Tag map used for conversions between integer and string representations
+  // part of speech tags. Owned through SharedStore.
+  const TermFrequencyMap *tag_map_ = nullptr;
+
+  // Input for the tag to category map. Not owned.
+  TaskInput *input_tag_to_category_ = nullptr;
+
+  // Tag to category map. Owned through SharedStore.
+  const TagToCategoryMap *tag_to_category_ = nullptr;
+};
+
+REGISTER_TRANSITION_SYSTEM("tagger", TaggerTransitionSystem);
+
+}  // namespace syntaxnet

+ 113 - 0
syntaxnet/syntaxnet/tagger_transitions_test.cc

@@ -0,0 +1,113 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/populate_test_inputs.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/task_context.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/term_frequency_map.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+
+class TaggerTransitionTest : public ::testing::Test {
+ public:
+  TaggerTransitionTest()
+      : transition_system_(ParserTransitionSystem::Create("tagger")) {}
+
+ protected:
+  // Creates a label map and a tag map for testing based on the given
+  // document and initializes the transition system appropriately.
+  void SetUpForDocument(const Sentence &document) {
+    input_label_map_ = context_.GetInput("label-map", "text", "");
+    input_label_map_ = context_.GetInput("tag-map", "text", "");
+    transition_system_->Setup(&context_);
+    PopulateTestInputs::Defaults(document).Populate(&context_);
+    label_map_.Load(TaskContext::InputFile(*input_label_map_),
+                    0 /* minimum frequency */,
+                    -1 /* maximum number of terms */);
+    transition_system_->Init(&context_);
+  }
+
+  // Creates a cloned state from a sentence in order to test that cloning
+  // works correctly for the new parser states.
+  ParserState *NewClonedState(Sentence *sentence) {
+    ParserState state(sentence, transition_system_->NewTransitionState(
+                                    true /* training mode */),
+                      &label_map_);
+    return state.Clone();
+  }
+
+  // Performs gold transitions and check that the labels and heads recorded
+  // in the parser state match gold heads and labels.
+  void GoldParse(Sentence *sentence) {
+    ParserState *state = NewClonedState(sentence);
+    LOG(INFO) << "Initial parser state: " << state->ToString();
+    while (!transition_system_->IsFinalState(*state)) {
+      ParserAction action = transition_system_->GetNextGoldAction(*state);
+      EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
+      LOG(INFO) << "Performing action: "
+                << transition_system_->ActionAsString(action, *state);
+      transition_system_->PerformActionWithoutHistory(action, state);
+      LOG(INFO) << "Parser state: " << state->ToString();
+    }
+    delete state;
+  }
+
+  // Always takes the default action, and verifies that this leads to
+  // a final state through a sequence of allowed actions.
+  void DefaultParse(Sentence *sentence) {
+    ParserState *state = NewClonedState(sentence);
+    LOG(INFO) << "Initial parser state: " << state->ToString();
+    while (!transition_system_->IsFinalState(*state)) {
+      ParserAction action = transition_system_->GetDefaultAction(*state);
+      EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state));
+      LOG(INFO) << "Performing action: "
+                << transition_system_->ActionAsString(action, *state);
+      transition_system_->PerformActionWithoutHistory(action, state);
+      LOG(INFO) << "Parser state: " << state->ToString();
+    }
+    delete state;
+  }
+
+  TaskContext context_;
+  TaskInput *input_label_map_ = nullptr;
+  TermFrequencyMap label_map_;
+  std::unique_ptr<ParserTransitionSystem> transition_system_;
+};
+
+TEST_F(TaggerTransitionTest, SingleSentenceDocumentTest) {
+  string document_text;
+  Sentence document;
+  TF_CHECK_OK(ReadFileToString(
+      tensorflow::Env::Default(),
+      "syntaxnet/testdata/document",
+      &document_text));
+  LOG(INFO) << "see doc\n:" << document_text;
+  CHECK(TextFormat::ParseFromString(document_text, &document));
+  SetUpForDocument(document);
+  GoldParse(&document);
+  DefaultParse(&document);
+}
+
+}  // namespace syntaxnet

+ 173 - 0
syntaxnet/syntaxnet/task_context.cc

@@ -0,0 +1,173 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/task_context.h"
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace syntaxnet {
+namespace {
+
+const char *const kShardPrintFormat = "%05d";
+
+}  // namespace
+
+TaskInput *TaskContext::GetInput(const string &name) {
+  // Return existing input if it exists.
+  for (int i = 0; i < spec_.input_size(); ++i) {
+    if (spec_.input(i).name() == name) return spec_.mutable_input(i);
+  }
+
+  // Create new input.
+  TaskInput *input = spec_.add_input();
+  input->set_name(name);
+  return input;
+}
+
+TaskInput *TaskContext::GetInput(const string &name, const string &file_format,
+                                 const string &record_format) {
+  TaskInput *input = GetInput(name);
+  if (!file_format.empty()) {
+    bool found = false;
+    for (int i = 0; i < input->file_format_size(); ++i) {
+      if (input->file_format(i) == file_format) found = true;
+    }
+    if (!found) input->add_file_format(file_format);
+  }
+  if (!record_format.empty()) {
+    bool found = false;
+    for (int i = 0; i < input->record_format_size(); ++i) {
+      if (input->record_format(i) == record_format) found = true;
+    }
+    if (!found) input->add_record_format(record_format);
+  }
+  return input;
+}
+
+void TaskContext::SetParameter(const string &name, const string &value) {
+  // If the parameter already exists update the value.
+  for (int i = 0; i < spec_.parameter_size(); ++i) {
+    if (spec_.parameter(i).name() == name) {
+      spec_.mutable_parameter(i)->set_value(value);
+      return;
+    }
+  }
+
+  // Add new parameter.
+  TaskSpec::Parameter *param = spec_.add_parameter();
+  param->set_name(name);
+  param->set_value(value);
+}
+
+string TaskContext::GetParameter(const string &name) const {
+  // First try to find parameter in task specification.
+  for (int i = 0; i < spec_.parameter_size(); ++i) {
+    if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
+  }
+
+  // Parameter not found, return empty string.
+  return "";
+}
+
+int TaskContext::GetIntParameter(const string &name) const {
+  string value = GetParameter(name);
+  return utils::ParseUsing<int>(value, 0, utils::ParseInt32);
+}
+
+int64 TaskContext::GetInt64Parameter(const string &name) const {
+  string value = GetParameter(name);
+  return utils::ParseUsing<int64>(value, 0ll, utils::ParseInt64);
+}
+
+bool TaskContext::GetBoolParameter(const string &name) const {
+  string value = GetParameter(name);
+  return value == "true";
+}
+
+double TaskContext::GetFloatParameter(const string &name) const {
+  string value = GetParameter(name);
+  return utils::ParseUsing<double>(value, .0, utils::ParseDouble);
+}
+
+string TaskContext::Get(const string &name, const char *defval) const {
+  // First try to find parameter in task specification.
+  for (int i = 0; i < spec_.parameter_size(); ++i) {
+    if (spec_.parameter(i).name() == name) return spec_.parameter(i).value();
+  }
+
+  // Parameter not found, return default value.
+  return defval;
+}
+
+string TaskContext::Get(const string &name, const string &defval) const {
+  return Get(name, defval.c_str());
+}
+
+int TaskContext::Get(const string &name, int defval) const {
+  string value = Get(name, "");
+  return utils::ParseUsing<int>(value, defval, utils::ParseInt32);
+}
+
+int64 TaskContext::Get(const string &name, int64 defval) const {
+  string value = Get(name, "");
+  return utils::ParseUsing<int64>(value, defval, utils::ParseInt64);
+}
+
+double TaskContext::Get(const string &name, double defval) const {
+  string value = Get(name, "");
+  return utils::ParseUsing<double>(value, defval, utils::ParseDouble);
+}
+
+bool TaskContext::Get(const string &name, bool defval) const {
+  string value = Get(name, "");
+  return value.empty() ? defval : value == "true";
+}
+
+string TaskContext::InputFile(const TaskInput &input) {
+  CHECK_EQ(input.part_size(), 1) << input.name();
+  return input.part(0).file_pattern();
+}
+
+bool TaskContext::Supports(const TaskInput &input, const string &file_format,
+                           const string &record_format) {
+  // Check file format.
+  if (input.file_format_size() > 0) {
+    bool found = false;
+    for (int i = 0; i < input.file_format_size(); ++i) {
+      if (input.file_format(i) == file_format) {
+        found = true;
+        break;
+      }
+    }
+    if (!found) return false;
+  }
+
+  // Check record format.
+  if (input.record_format_size() > 0) {
+    bool found = false;
+    for (int i = 0; i < input.record_format_size(); ++i) {
+      if (input.record_format(i) == record_format) {
+        found = true;
+        break;
+      }
+    }
+    if (!found) return false;
+  }
+
+  return true;
+}
+
+}  // namespace syntaxnet

+ 80 - 0
syntaxnet/syntaxnet/task_context.h

@@ -0,0 +1,80 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_TASK_CONTEXT_H_
+#define $TARGETDIR_TASK_CONTEXT_H_
+
+#include <string>
+#include <vector>
+
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+// A task context holds configuration information for a task. It is basically a
+// wrapper around a TaskSpec protocol buffer.
+class TaskContext {
+ public:
+  // Returns the underlying task specification protocol buffer for the context.
+  const TaskSpec &spec() const { return spec_; }
+  TaskSpec *mutable_spec() { return &spec_; }
+
+  // Returns a named input descriptor for the task. A new input  is created if
+  // the task context does not already have an input with that name.
+  TaskInput *GetInput(const string &name);
+  TaskInput *GetInput(const string &name, const string &file_format,
+                      const string &record_format);
+
+  // Sets task parameter.
+  void SetParameter(const string &name, const string &value);
+
+  // Returns task parameter. If the parameter is not in the task configuration
+  // the (default) value of the corresponding command line flag is returned.
+  string GetParameter(const string &name) const;
+  int GetIntParameter(const string &name) const;
+  int64 GetInt64Parameter(const string &name) const;
+  bool GetBoolParameter(const string &name) const;
+  double GetFloatParameter(const string &name) const;
+
+  // Returns task parameter. If the parameter is not in the task configuration
+  // the default value is returned. Parameters retrieved using these methods
+  // don't need to be defined with a DEFINE_*() macro.
+  string Get(const string &name, const string &defval) const;
+  string Get(const string &name, const char *defval) const;
+  int Get(const string &name, int defval) const;
+  int64 Get(const string &name, int64 defval) const;
+  double Get(const string &name, double defval) const;
+  bool Get(const string &name, bool defval) const;
+
+  // Returns input file name for a single-file task input.
+  static string InputFile(const TaskInput &input);
+
+  // Returns true if task input supports the file and record format.
+  static bool Supports(const TaskInput &input, const string &file_format,
+                       const string &record_format);
+
+ private:
+  // Underlying task specification protocol buffer.
+  TaskSpec spec_;
+
+  // Vector of parameters required by this task.  These must be specified in the
+  // task rather than relying on default values.
+  vector<string> required_parameters_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_TASK_CONTEXT_H_

+ 82 - 0
syntaxnet/syntaxnet/task_spec.proto

@@ -0,0 +1,82 @@
+// LINT: ALLOW_GROUPS
+// Protocol buffer specifications for task configuration.
+
+syntax = "proto2";
+
+package syntaxnet;
+
+// Task input descriptor.
+message TaskInput {
+  // Name of input resource.
+  required string name = 1;
+
+  // Name of stage responsible of creating this resource.
+  optional string creator = 2;
+
+  // File format for resource.
+  repeated string file_format = 3;
+
+  // Record format for resource.
+  repeated string record_format = 4;
+
+  // Is this resource multi-file?
+  optional bool multi_file = 5 [default = false];
+
+  // An input can consist of multiple file sets.
+  repeated group Part = 6 {
+    // File pattern for file set.
+    optional string file_pattern = 7;
+
+    // File format for file set.
+    optional string file_format = 8;
+
+    // Record format for file set.
+    optional string record_format = 9;
+  }
+}
+
+// Task output descriptor.
+message TaskOutput {
+  // Name of output resource.
+  required string name = 1;
+
+  // File format for output resource.
+  optional string file_format = 2;
+
+  // Record format for output resource.
+  optional string record_format = 3;
+
+  // Number of shards in output. If it is different from zero this output is
+  // sharded. If the number of shards is set to -1 this means that the output is
+  // sharded, but the number of shard is unknown. The files are then named
+  // 'base-*-of-*'.
+  optional int32 shards = 4 [default = 0];
+
+  // Base file name for output resource. If this is not set by the task
+  // component it is set to a default value by the workflow engine.
+  optional string file_base = 5;
+
+  // Optional extension added to the file name.
+  optional string file_extension = 6;
+}
+
+// A task specification is used for describing executing parameters.
+message TaskSpec {
+  // Name of task.
+  optional string task_name = 1;
+
+  // Workflow task type.
+  optional string task_type = 2;
+
+  // Task parameters.
+  repeated group Parameter = 3 {
+    required string name = 4;
+    optional string value = 5;
+  }
+
+  // Task inputs.
+  repeated TaskInput input = 6;
+
+  // Task outputs.
+  repeated TaskOutput output = 7;
+}

+ 188 - 0
syntaxnet/syntaxnet/term_frequency_map.cc

@@ -0,0 +1,188 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/term_frequency_map.h"
+
+#include <stddef.h>
+#include <algorithm>
+#include <limits>
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace syntaxnet {
+
+int TermFrequencyMap::Increment(const string &term) {
+  CHECK_EQ(term_index_.size(), term_data_.size());
+  const TermIndex::const_iterator it = term_index_.find(term);
+  if (term_index_.find(term) != term_index_.end()) {
+    // Increment the existing term.
+    pair<string, int64> &data = term_data_[it->second];
+    CHECK_EQ(term, data.first);
+    ++(data.second);
+    return it->second;
+  } else {
+    // Add a new term.
+    const int index = term_index_.size();
+    CHECK_LT(index, std::numeric_limits<int32>::max());  // overflow
+    term_index_[term] = index;
+    term_data_.push_back(pair<string, int64>(term, 1));
+    return index;
+  }
+}
+
+void TermFrequencyMap::Clear() {
+  term_index_.clear();
+  term_data_.clear();
+}
+
+void TermFrequencyMap::Load(const string &filename, int min_frequency,
+                            int max_num_terms) {
+  Clear();
+
+  // If max_num_terms is non-positive, replace it with INT_MAX.
+  if (max_num_terms <= 0) max_num_terms = std::numeric_limits<int>::max();
+
+  // Read the first line (total # of terms in the mapping).
+  tensorflow::RandomAccessFile *file;
+  TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
+  static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
+  tensorflow::io::InputBuffer input(file, kInputBufferSize);
+  string line;
+  TF_CHECK_OK(input.ReadLine(&line));
+  int32 total = -1;
+  CHECK(utils::ParseInt32(line.c_str(), &total));
+  CHECK_GE(total, 0);
+
+  // Read the mapping.
+  int64 last_frequency = -1;
+  for (int i = 0; i < total && i < max_num_terms; ++i) {
+    TF_CHECK_OK(input.ReadLine(&line));
+    vector<string> elements = utils::Split(line, ' ');
+    CHECK_EQ(2, elements.size());
+    CHECK(!elements[0].empty());
+    CHECK(!elements[1].empty());
+    int64 frequency = 0;
+    CHECK(utils::ParseInt64(elements[1].c_str(), &frequency));
+    CHECK_GT(frequency, 0);
+    const string &term = elements[0];
+
+    // Check frequency sorting (descending order).
+    if (i > 0) CHECK_GE(last_frequency, frequency);
+    last_frequency = frequency;
+
+    // Ignore low-frequency items.
+    if (frequency < min_frequency) continue;
+
+    // Check uniqueness of the mapped terms.
+    CHECK(term_index_.find(term) == term_index_.end())
+        << "File " << filename << " has duplicate term: " << term;
+
+    // Assign the next available index.
+    const int index = term_index_.size();
+    term_index_[term] = index;
+    term_data_.push_back(pair<string, int64>(term, frequency));
+  }
+  CHECK_EQ(term_index_.size(), term_data_.size());
+  LOG(INFO) << "Loaded " << term_index_.size() << " terms from " << filename
+            << ".";
+}
+
+struct TermFrequencyMap::SortByFrequencyThenTerm {
+  // Return a > b to sort in descending order of frequency; otherwise,
+  // lexicographic sort on term.
+  bool operator()(const pair<string, int64> &a,
+                  const pair<string, int64> &b) const {
+    return (a.second > b.second || (a.second == b.second && a.first < b.first));
+  }
+};
+
+void TermFrequencyMap::Save(const string &filename) const {
+  CHECK_EQ(term_index_.size(), term_data_.size());
+
+  // Copy and sort the term data.
+  vector<pair<string, int64>> sorted_data(term_data_);
+  std::sort(sorted_data.begin(), sorted_data.end(), SortByFrequencyThenTerm());
+
+  // Write the number of terms.
+  tensorflow::WritableFile *file;
+  TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(filename, &file));
+  CHECK_LE(term_index_.size(), std::numeric_limits<int32>::max());  // overflow
+  const int32 num_terms = term_index_.size();
+  const string header = tensorflow::strings::StrCat(num_terms, "\n");
+  TF_CHECK_OK(file->Append(header));
+
+  // Write each term and frequency.
+  for (size_t i = 0; i < sorted_data.size(); ++i) {
+    if (i > 0) CHECK_GE(sorted_data[i - 1].second, sorted_data[i].second);
+    const string line = tensorflow::strings::StrCat(
+        sorted_data[i].first, " ", sorted_data[i].second, "\n");
+    TF_CHECK_OK(file->Append(line));
+  }
+  TF_CHECK_OK(file->Close()) << "for file " << filename;
+  LOG(INFO) << "Saved " << term_index_.size() << " terms to " << filename
+            << ".";
+  delete file;
+}
+
+TagToCategoryMap::TagToCategoryMap(const string &filename) {
+  // Load the mapping.
+  tensorflow::RandomAccessFile *file;
+  TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
+  static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
+  tensorflow::io::InputBuffer input(file, kInputBufferSize);
+  string line;
+  while (input.ReadLine(&line) == tensorflow::Status::OK()) {
+    vector<string> pair = utils::Split(line, '\t');
+    CHECK(line.empty() || pair.size() == 2) << line;
+    tag_to_category_[pair[0]] = pair[1];
+  }
+}
+
+// Returns the category associated with the given tag.
+const string &TagToCategoryMap::GetCategory(const string &tag) const {
+  const auto it = tag_to_category_.find(tag);
+  CHECK(it != tag_to_category_.end()) << "No category found for tag " << tag;
+  return it->second;
+}
+
+void TagToCategoryMap::SetCategory(const string &tag, const string &category) {
+  const auto it = tag_to_category_.find(tag);
+  if (it != tag_to_category_.end()) {
+    CHECK_EQ(category, it->second)
+        << "POS tag cannot be mapped to multiple coarse POS tags. "
+        << "'" << tag << "' is mapped to: '" << category << "' and '"
+        << it->second << "'";
+  } else {
+    tag_to_category_[tag] = category;
+  }
+}
+
+void TagToCategoryMap::Save(const string &filename) const {
+  // Write tag and category on each line.
+  tensorflow::WritableFile *file;
+  TF_CHECK_OK(tensorflow::Env::Default()->NewWritableFile(filename, &file));
+  for (const auto &pair : tag_to_category_) {
+    const string line =
+        tensorflow::strings::StrCat(pair.first, "\t", pair.second, "\n");
+    TF_CHECK_OK(file->Append(line));
+  }
+  TF_CHECK_OK(file->Close()) << "for file " << filename;
+  delete file;
+}
+
+}  // namespace syntaxnet

+ 117 - 0
syntaxnet/syntaxnet/term_frequency_map.h

@@ -0,0 +1,117 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_TERM_FREQUENCY_MAP_H_
+#define $TARGETDIR_TERM_FREQUENCY_MAP_H_
+
+#include <stddef.h>
+#include <memory>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+// A mapping from strings to frequencies with save and load functionality.
+class TermFrequencyMap {
+ public:
+  // Creates an empty frequency map.
+  TermFrequencyMap() {}
+
+  // Creates a term frequency map by calling Load.
+  TermFrequencyMap(const string &file, int min_frequency, int max_num_terms) {
+    Load(file, min_frequency, max_num_terms);
+  }
+
+  // Returns the number of terms with positive frequency.
+  int Size() const { return term_index_.size(); }
+
+  // Returns the index associated with the given term.  If the term does not
+  // exist, the unknown index is returned instead.
+  int LookupIndex(const string &term, int unknown) const {
+    const TermIndex::const_iterator it = term_index_.find(term);
+    return (it != term_index_.end() ? it->second : unknown);
+  }
+
+  // Returns the term associated with the given index.
+  const string &GetTerm(int index) const { return term_data_[index].first; }
+
+  // Increases the frequency of the given term by 1, creating a new entry if
+  // necessary, and returns the index of the term.
+  int Increment(const string &term);
+
+  // Clears all frequencies.
+  void Clear();
+
+  // Loads a frequency mapping from the given file, which must have been created
+  // by an earlier call to Save().  After loading, the term indices are
+  // guaranteed to be ordered in descending order of frequency (breaking ties
+  // arbitrarily).  However, any new terms inserted after loading do not
+  // maintain this sorting invariant.
+  //
+  // Only loads terms with frequency >= min_frequency.  If max_num_terms <= 0,
+  // then all qualifying terms are loaded; otherwise, max_num_terms terms with
+  // maximal frequency are loaded (breaking ties arbitrarily).
+  void Load(const string &filename, int min_frequency, int max_num_terms);
+
+  // Saves a frequency mapping to the given file.
+  void Save(const string &filename) const;
+
+ private:
+  // Hashtable for term-to-index mapping.
+  typedef std::unordered_map<string, int> TermIndex;
+
+  // Sorting functor for term data.
+  struct SortByFrequencyThenTerm;
+
+  // Mapping from terms to indices.
+  TermIndex term_index_;
+
+  // Mapping from indices to term and frequency.
+  vector<pair<string, int64>> term_data_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(TermFrequencyMap);
+};
+
+// A mapping from tags to categories.
+class TagToCategoryMap {
+ public:
+  TagToCategoryMap() {}
+  ~TagToCategoryMap() {}
+
+  // Loads a tag to category map from a text file.
+  explicit TagToCategoryMap(const string &filename);
+
+  // Sets the category for the given tag.
+  void SetCategory(const string &tag, const string &category);
+
+  // Returns the category associated with the given tag.
+  const string &GetCategory(const string &tag) const;
+
+  // Saves a tag to category map to the given file.
+  void Save(const string &filename) const;
+
+ private:
+  map<string, string> tag_to_category_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(TagToCategoryMap);
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_TERM_FREQUENCY_MAP_H_

+ 45 - 0
syntaxnet/syntaxnet/test_main.cc

@@ -0,0 +1,45 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// A program with a main that is suitable for unittests, including those
+// that also define microbenchmarks.  Based on whether the user specified
+// the --benchmark_filter flag which specifies which benchmarks to run,
+// we will either run benchmarks or run the gtest tests in the program.
+
+#include "tensorflow/core/platform/platform.h"
+#include "tensorflow/core/platform/types.h"
+
+#if defined(PLATFORM_GOOGLE) || defined(__ANDROID__)
+
+// main() is supplied by gunit_main
+#else
+#include "gtest/gtest.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test_benchmark.h"
+
+GTEST_API_ int main(int argc, char **argv) {
+  std::cout << "Running main() from test_main.cc\n";
+
+  testing::InitGoogleTest(&argc, argv);
+  for (int i = 1; i < argc; i++) {
+    if (tensorflow::StringPiece(argv[i]).starts_with("--benchmarks=")) {
+      const char *pattern = argv[i] + strlen("--benchmarks=");
+      tensorflow::testing::Benchmark::Run(pattern);
+      return 0;
+    }
+  }
+  return RUN_ALL_TESTS();
+}
+#endif

+ 87 - 0
syntaxnet/syntaxnet/testdata/context.pbtxt

@@ -0,0 +1,87 @@
+Parameter {
+  name: 'brain_parser_embedding_dims'
+  value: '8;8;8'
+}
+Parameter {
+  name: 'brain_parser_features'
+  value: 'input.token.word input(1).token.word input(2).token.word stack.token.word stack(1).token.word stack(2).token.word;input.tag input(1).tag input(2).tag stack.tag stack(1).tag stack(2).tag;stack.child(1).label stack.child(1).sibling(-1).label stack.child(-1).label stack.child(-1).sibling(1).label'
+}
+Parameter {
+  name: 'brain_parser_embedding_names'
+  value: 'words;tags;labels'
+}
+input {
+  name: 'training-corpus'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: 'syntaxnet/testdata/mini-training-set'
+  }
+}
+input {
+  name: 'tuning-corpus'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: 'syntaxnet/testdata/mini-training-set'
+  }
+}
+input {
+  name: 'parsed-tuning-corpus'
+  creator: 'brain_parser/greedy'
+  record_format: 'conll-sentence'
+}
+input {
+  name: 'label-map'
+  file_format: 'text'
+  Part {
+    file_pattern: 'OUTPATH/label-map'
+  }
+}
+input {
+  name: 'word-map'
+  Part {
+    file_pattern: 'OUTPATH/word-map'
+  }
+}
+input {
+  name: 'lcword-map'
+  Part {
+    file_pattern: 'OUTPATH/lcword-map'
+  }
+}
+input {
+  name: 'tag-map'
+  Part {
+    file_pattern: 'OUTPATH/tag-map'
+  }
+}
+input {
+  name: 'category-map'
+  Part {
+    file_pattern: 'OUTPATH/category-map'
+  }
+}
+input {
+  name: 'prefix-table'
+  Part {
+    file_pattern: 'OUTPATH/prefix-table'
+  }
+}
+input {
+  name: 'suffix-table'
+  Part {
+    file_pattern: 'OUTPATH/suffix-table'
+  }
+}
+input {
+  name: 'tag-to-category'
+  Part {
+    file_pattern: 'OUTPATH/tag-to-category'
+  }
+}
+input {
+  name: 'stdout'
+  record_format: 'conll-sentence'
+  Part {
+    file_pattern: '-'
+  }
+}

+ 145 - 0
syntaxnet/syntaxnet/testdata/document

@@ -0,0 +1,145 @@
+text       : "I can not recall any disorder in currency markets since the 1974 guidelines were adopted ."
+token: {
+  word    : "I"
+  start   : 0
+  end     : 0
+  head    : 3
+  tag     : "PRP"
+  category: "PRON"
+  label   : "nsubj"
+  break_level       : SENTENCE_BREAK
+}
+token: {
+  word    : "can"
+  start   : 2
+  end     : 4
+  head    : 3
+  tag     : "MD"
+  category: "VERB"
+  label   : "aux"
+}
+token: {
+  word    : "not"
+  start   : 6
+  end     : 8
+  head    : 3
+  tag     : "RB"
+  category: "ADV"
+  label   : "neg"
+}
+token: {
+  word    : "recall"
+  start   : 10
+  end     : 15
+  tag     : "VB"
+  category: "VERB"
+  label   : "ROOT"
+}
+token: {
+  word    : "any"
+  start   : 17
+  end     : 19
+  head    : 5
+  tag     : "DT"
+  category: "DET"
+  label   : "det"
+}
+token: {
+  word    : "disorder"
+  start   : 21
+  end     : 28
+  head    : 3
+  tag     : "NN"
+  category: "NOUN"
+  label   : "dobj"
+}
+token: {
+  word    : "in"
+  start   : 30
+  end     : 31
+  head    : 5
+  tag     : "IN"
+  category: "ADP"
+  label   : "prep"
+}
+token: {
+  word    : "currency"
+  start   : 33
+  end     : 40
+  head    : 8
+  tag     : "NN"
+  category: "NOUN"
+  label   : "nn"
+}
+token: {
+  word    : "markets"
+  start   : 42
+  end     : 48
+  head    : 6
+  tag     : "NNS"
+  category: "NOUN"
+  label   : "pobj"
+}
+token: {
+  word    : "since"
+  start   : 50
+  end     : 54
+  head    : 14
+  tag     : "IN"
+  category: "ADP"
+  label   : "mark"
+}
+token: {
+  word    : "the"
+  start   : 56
+  end     : 58
+  head    : 12
+  tag     : "DT"
+  category: "DET"
+  label   : "det"
+}
+token: {
+  word    : "1974"
+  start   : 60
+  end     : 63
+  head    : 12
+  tag     : "CD"
+  category: "NUM"
+  label   : "num"
+}
+token: {
+  word    : "guidelines"
+  start   : 65
+  end     : 74
+  head    : 14
+  tag     : "NNS"
+  category: "NOUN"
+  label   : "nsubjpass"
+}
+token: {
+  word    : "were"
+  start   : 76
+  end     : 79
+  head    : 14
+  tag     : "VBD"
+  category: "VERB"
+  label   : "auxpass"
+}
+token: {
+  word    : "adopted"
+  start   : 81
+  end     : 87
+  head    : 3
+  tag     : "VBN"
+  category: "VERB"
+  label   : "advcl"
+}
+token: {
+  word    : "."
+  start   : 89
+  end     : 89
+  head    : 3
+  tag     : "."
+  category: "."
+  label   : "p"
+}

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 1017 - 0
syntaxnet/syntaxnet/testdata/mini-training-set


+ 399 - 0
syntaxnet/syntaxnet/text_formats.cc

@@ -0,0 +1,399 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "syntaxnet/document_format.h"
+#include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+#include "tensorflow/core/platform/regexp.h"
+
+namespace syntaxnet {
+
+// CoNLL document format reader for dependency annotated corpora.
+// The expected format is described e.g. at http://ilk.uvt.nl/conll/#dataformat
+//
+// Data should adhere to the following rules:
+//   - Data files contain sentences separated by a blank line.
+//   - A sentence consists of one or tokens, each one starting on a new line.
+//   - A token consists of ten fields described in the table below.
+//   - Fields are separated by a single tab character.
+//   - All data files will contains these ten fields, although only the ID
+//     column is required to contain non-dummy (i.e. non-underscore) values.
+// Data files should be UTF-8 encoded (Unicode).
+//
+// Fields:
+// 1  ID:      Token counter, starting at 1 for each new sentence and increasing
+//             by 1 for every new token.
+// 2  FORM:    Word form or punctuation symbol.
+// 3  LEMMA:   Lemma or stem.
+// 4  CPOSTAG: Coarse-grained part-of-speech tag or category.
+// 5  POSTAG:  Fine-grained part-of-speech tag. Note that the same POS tag
+//             cannot appear with multiple coarse-grained POS tags.
+// 6  FEATS:   Unordered set of syntactic and/or morphological features.
+// 7  HEAD:    Head of the current token, which is either a value of ID or '0'.
+// 8  DEPREL:  Dependency relation to the HEAD.
+// 9  PHEAD:   Projective head of current token.
+// 10 PDEPREL: Dependency relation to the PHEAD.
+//
+// This CoNLL reader is compatible with the CoNLL-U format described at
+//   http://universaldependencies.org/format.html
+// Note that this reader skips CoNLL-U multiword tokens and ignores the last two
+// fields of every line, which are PHEAD and PDEPREL in CoNLL format, but are
+// replaced by DEPS and MISC in CoNLL-U.
+//
+class CoNLLSyntaxFormat : public DocumentFormat {
+ public:
+  CoNLLSyntaxFormat() {}
+
+  // Reads up to the first empty line and returns false end of file is reached.
+  bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+                  string *record) override {
+    string line;
+    record->clear();
+    tensorflow::Status status = buffer->ReadLine(&line);
+    while (!line.empty() && status.ok()) {
+      tensorflow::strings::StrAppend(record, line, "\n");
+      status = buffer->ReadLine(&line);
+    }
+    return status.ok() || !record->empty();
+  }
+
+  void ConvertFromString(const string &key, const string &value,
+                         vector<Sentence *> *sentences) override {
+    // Create new sentence.
+    Sentence *sentence = new Sentence();
+
+    // Each line corresponds to one token.
+    string text;
+    vector<string> lines = utils::Split(value, '\n');
+
+    // Add each token to the sentence.
+    vector<string> fields;
+    int expected_id = 1;
+    for (size_t i = 0; i < lines.size(); ++i) {
+      // Split line into tab-separated fields.
+      fields.clear();
+      fields = utils::Split(lines[i], '\t');
+      if (fields.size() == 0) continue;
+
+      // Skip comment lines.
+      if (fields[0][0] == '#') continue;
+
+      // Skip CoNLLU lines for multiword tokens which are indicated by
+      // hyphenated line numbers, e.g., "2-4".
+      // http://universaldependencies.github.io/docs/format.html
+      if (RE2::FullMatch(fields[0], "[0-9]+-[0-9]+")) continue;
+
+      // Clear all optional fields equal to '_'.
+      for (size_t j = 2; j < fields.size(); ++j) {
+        if (fields[j].length() == 1 && fields[j][0] == '_') fields[j].clear();
+      }
+
+      // Check that the line is valid.
+      CHECK_GE(fields.size(), 8)
+          << "Every line has to have at least 8 tab separated fields.";
+
+      // Check that the ids follow the expected format.
+      const int id = utils::ParseUsing<int>(fields[0], 0, utils::ParseInt32);
+      CHECK_EQ(expected_id++, id)
+          << "Token ids start at 1 for each new sentence and increase by 1 "
+          << "on each new token. Sentences are separated by an empty line.";
+
+      // Get relevant fields.
+      const string &word = fields[1];
+      const string &cpostag = fields[3];
+      const string &tag = fields[4];
+      const int head = utils::ParseUsing<int>(fields[6], 0, utils::ParseInt32);
+      const string &label = fields[7];
+
+      // Add token to sentence text.
+      if (!text.empty()) text.append(" ");
+      const int start = text.size();
+      const int end = start + word.size() - 1;
+      text.append(word);
+
+      // Add token to sentence.
+      Token *token = sentence->add_token();
+      token->set_word(word);
+      token->set_start(start);
+      token->set_end(end);
+      if (head > 0) token->set_head(head - 1);
+      if (!tag.empty()) token->set_tag(tag);
+      if (!cpostag.empty()) token->set_category(cpostag);
+      if (!label.empty()) token->set_label(label);
+    }
+
+    if (sentence->token_size() > 0) {
+      sentence->set_docid(key);
+      sentence->set_text(text);
+      sentences->push_back(sentence);
+    } else {
+      // If the sentence was empty (e.g., blank lines at the beginning of a
+      // file), then don't save it.
+      delete sentence;
+    }
+  }
+
+  // Converts a sentence to a key/value pair.
+  void ConvertToString(const Sentence &sentence, string *key,
+                       string *value) override {
+    *key = sentence.docid();
+    vector<string> lines;
+    for (int i = 0; i < sentence.token_size(); ++i) {
+      vector<string> fields(10);
+      fields[0] = tensorflow::strings::Printf("%d", i + 1);
+      fields[1] = sentence.token(i).word();
+      fields[2] = "_";
+      fields[3] = sentence.token(i).category();
+      fields[4] = sentence.token(i).tag();
+      fields[5] = "_";
+      fields[6] =
+          tensorflow::strings::Printf("%d", sentence.token(i).head() + 1);
+      fields[7] = sentence.token(i).label();
+      fields[8] = "_";
+      fields[9] = "_";
+      lines.push_back(utils::Join(fields, "\t"));
+    }
+    *value = tensorflow::strings::StrCat(utils::Join(lines, "\n"), "\n\n");
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(CoNLLSyntaxFormat);
+};
+
+REGISTER_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat);
+
+// Reader for tokenized text. This reader expects every sentence to be on a
+// single line and tokens on that line to be separated by single spaces.
+//
+class TokenizedTextFormat : public DocumentFormat {
+ public:
+  TokenizedTextFormat() {}
+
+  // Reads a line and returns false if end of file is reached.
+  bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+                  string *record) override {
+    return buffer->ReadLine(record).ok();
+  }
+
+  void ConvertFromString(const string &key, const string &value,
+                         vector<Sentence *> *sentences) override {
+    Sentence *sentence = new Sentence();
+    string text;
+    for (const string &word : utils::Split(value, ' ')) {
+      if (word.empty()) continue;
+      const int start = text.size();
+      const int end = start + word.size() - 1;
+      if (!text.empty()) text.append(" ");
+      text.append(word);
+      Token *token = sentence->add_token();
+      token->set_word(word);
+      token->set_start(start);
+      token->set_end(end);
+    }
+
+    if (sentence->token_size() > 0) {
+      sentence->set_docid(key);
+      sentence->set_text(text);
+      sentences->push_back(sentence);
+    } else {
+      // If the sentence was empty (e.g., blank lines at the beginning of a
+      // file), then don't save it.
+      delete sentence;
+    }
+  }
+
+  void ConvertToString(const Sentence &sentence, string *key,
+                       string *value) override {
+    *key = sentence.docid();
+    value->clear();
+    for (const Token &token : sentence.token()) {
+      if (!value->empty()) value->append(" ");
+      value->append(token.word());
+      if (token.has_tag()) {
+        value->append("_");
+        value->append(token.tag());
+      }
+      if (token.has_head()) {
+        value->append("_");
+        value->append(tensorflow::strings::StrCat(token.head()));
+      }
+    }
+    value->append("\n");
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(TokenizedTextFormat);
+};
+
+REGISTER_DOCUMENT_FORMAT("tokenized-text", TokenizedTextFormat);
+
+// Text reader that attmpts to perform Penn Treebank tokenization on arbitrary
+// raw text. Adapted from https://www.cis.upenn.edu/~treebank/tokenizer.sed
+// by Robert MacIntyre, University of Pennsylvania, late 1995.
+// Expected input: raw text with one sentence per line.
+//
+class EnglishTextFormat : public TokenizedTextFormat {
+ public:
+  EnglishTextFormat() {}
+
+  void ConvertFromString(const string &key, const string &value,
+                         vector<Sentence *> *sentences) override {
+    vector<pair<string, string>> preproc_rules = {
+        // Punctuation.
+        {"’", "'"},
+        {"…", "..."},
+        {"---", "--"},
+        {"—", "--"},
+        {"–", "--"},
+        {",", ","},
+        {"。", "."},
+        {"!", "!"},
+        {"?", "?"},
+        {":", ":"},
+        {";", ";"},
+        {"&", "&"},
+
+        // Brackets.
+        {"\\[", "("},
+        {"]", ")"},
+        {"{", "("},
+        {"}", ")"},
+        {"【", "("},
+        {"】", ")"},
+        {"(", "("},
+        {")", ")"},
+
+        // Quotation marks.
+        {"\"", "\""},
+        {"″", "\""},
+        {"“", "\""},
+        {"„", "\""},
+        {"‵‵", "\""},
+        {"”", "\""},
+        {"’", "\""},
+        {"‘", "\""},
+        {"′′", "\""},
+        {"‹", "\""},
+        {"›", "\""},
+        {"«", "\""},
+        {"»", "\""},
+
+        // Discarded punctuation that breaks sentences.
+        {"|", ""},
+        {"·", ""},
+        {"•", ""},
+        {"●", ""},
+        {"▪", ""},
+        {"■", ""},
+        {"□", ""},
+        {"❑", ""},
+        {"◆", ""},
+        {"★", ""},
+        {"*", ""},
+        {"♦", ""},
+    };
+
+    vector<pair<string, string>> rules = {
+        // attempt to get correct directional quotes
+        {R"re(^")re", "`` "},
+        {R"re(([ \([{<])")re", "\\1 `` "},
+        // close quotes handled at end
+
+        {R"re(\.\.\.)re", " ... "},
+        {"[,;:@#$%&]", " \\0 "},
+
+        // Assume sentence tokenization has been done first, so split FINAL
+        // periods only.
+        {R"re(([^.])(\.)([\]\)}>"']*)[ ]*$)re", "\\1 \\2\\3 "},
+        // however, we may as well split ALL question marks and exclamation
+        // points, since they shouldn't have the abbrev.-marker ambiguity
+        // problem
+        {"[?!]", " \\0 "},
+
+        // parentheses, brackets, etc.
+        {R"re([\]\[\(\){}<>])re", " \\0 "},
+
+        // Like Adwait Ratnaparkhi's MXPOST, we use the parsed-file version of
+        // these symbols.
+        {"\\(", "-LRB-"},
+        {"\\)", "-RRB-"},
+        {"\\]", "-LSB-"},
+        {"\\]", "-RSB-"},
+        {"{", "-LCB-"},
+        {"}", "-RCB-"},
+
+        {"--", " -- "},
+
+        // First off, add a space to the beginning and end of each line, to
+        // reduce necessary number of regexps.
+        {"$", " "},
+        {"^", " "},
+
+        {"\"", " '' "},
+        // possessive or close-single-quote
+        {"([^'])' ", "\\1 ' "},
+        // as in it's, I'm, we'd
+        {"'([sSmMdD]) ", " '\\1 "},
+        {"'ll ", " 'll "},
+        {"'re ", " 're "},
+        {"'ve ", " 've "},
+        {"n't ", " n't "},
+        {"'LL ", " 'LL "},
+        {"'RE ", " 'RE "},
+        {"'VE ", " 'VE "},
+        {"N'T ", " N'T "},
+
+        {" ([Cc])annot ", " \\1an not "},
+        {" ([Dd])'ye ", " \\1' ye "},
+        {" ([Gg])imme ", " \\1im me "},
+        {" ([Gg])onna ", " \\1on na "},
+        {" ([Gg])otta ", " \\1ot ta "},
+        {" ([Ll])emme ", " \\1em me "},
+        {" ([Mm])ore'n ", " \\1ore 'n "},
+        {" '([Tt])is ", " '\\1 is "},
+        {" '([Tt])was ", " '\\1 was "},
+        {" ([Ww])anna ", " \\1an na "},
+        {" ([Ww])haddya ", " \\1ha dd ya "},
+        {" ([Ww])hatcha ", " \\1ha t cha "},
+
+        // clean out extra spaces
+        {"  *", " "},
+        {"^ *", ""},
+    };
+
+    string rewritten = value;
+    for (const pair<string, string> &rule : preproc_rules) {
+      RE2::GlobalReplace(&rewritten, rule.first, rule.second);
+    }
+    for (const pair<string, string> &rule : rules) {
+      RE2::GlobalReplace(&rewritten, rule.first, rule.second);
+    }
+    TokenizedTextFormat::ConvertFromString(key, rewritten, sentences);
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EnglishTextFormat);
+};
+
+REGISTER_DOCUMENT_FORMAT("english-text", EnglishTextFormat);
+
+}  // namespace syntaxnet

+ 108 - 0
syntaxnet/syntaxnet/text_formats_test.py

@@ -0,0 +1,108 @@
+# coding=utf-8
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for english_tokenizer."""
+
+
+# disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
+import os.path
+
+import tensorflow as tf
+
+import syntaxnet.load_parser_ops
+
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import logging
+
+from syntaxnet import sentence_pb2
+from syntaxnet import task_spec_pb2
+from syntaxnet.ops import gen_parser_ops
+
+FLAGS = tf.app.flags.FLAGS
+
+
+class TextFormatsTest(test_util.TensorFlowTestCase):
+
+  def setUp(self):
+    if not hasattr(FLAGS, 'test_srcdir'):
+      FLAGS.test_srcdir = ''
+    if not hasattr(FLAGS, 'test_tmpdir'):
+      FLAGS.test_tmpdir = tf.test.get_temp_dir()
+    self.corpus_file = os.path.join(FLAGS.test_tmpdir, 'documents.conll')
+    self.context_file = os.path.join(FLAGS.test_tmpdir, 'context.pbtxt')
+
+  def AddInput(self, name, file_pattern, record_format, context):
+    inp = context.input.add()
+    inp.name = name
+    inp.record_format.append(record_format)
+    inp.part.add().file_pattern = file_pattern
+
+  def WriteContext(self, corpus_format):
+    context = task_spec_pb2.TaskSpec()
+    self.AddInput('documents', self.corpus_file, corpus_format, context)
+    for name in ('word-map', 'lcword-map', 'tag-map',
+                 'category-map', 'label-map', 'prefix-table',
+                 'suffix-table', 'tag-to-category'):
+      self.AddInput(name, os.path.join(FLAGS.test_tmpdir, name), '', context)
+    logging.info('Writing context to: %s', self.context_file)
+    with open(self.context_file, 'w') as f:
+      f.write(str(context))
+
+  def ReadNextDocument(self, sess, sentence):
+    sentence_str, = sess.run([sentence])
+    if sentence_str:
+      sentence_doc = sentence_pb2.Sentence()
+      sentence_doc.ParseFromString(sentence_str[0])
+    else:
+      sentence_doc = None
+    return sentence_doc
+
+  def CheckTokenization(self, sentence, tokenization):
+    self.WriteContext('english-text')
+    logging.info('Writing text file to: %s', self.corpus_file)
+    with open(self.corpus_file, 'w') as f:
+      f.write(sentence)
+    sentence, _ = gen_parser_ops.document_source(
+        self.context_file, batch_size=1)
+    with self.test_session() as sess:
+      sentence_doc = self.ReadNextDocument(sess, sentence)
+      self.assertEqual(' '.join([t.word for t in sentence_doc.token]),
+                       tokenization)
+
+  def testSimple(self):
+    self.CheckTokenization('Hello, world!', 'Hello , world !')
+    self.CheckTokenization('"Hello"', "`` Hello ''")
+    self.CheckTokenization('{"Hello@#$', '-LRB- `` Hello @ # $')
+    self.CheckTokenization('"Hello..."', "`` Hello ... ''")
+    self.CheckTokenization('()[]{}<>',
+                           '-LRB- -RRB- -LRB- -RRB- -LRB- -RRB- < >')
+    self.CheckTokenization('Hello--world', 'Hello -- world')
+    self.CheckTokenization("Isn't", "Is n't")
+    self.CheckTokenization("n't", "n't")
+    self.CheckTokenization('Hello Mr. Smith.', 'Hello Mr. Smith .')
+    self.CheckTokenization("It's Mr. Smith's.", "It 's Mr. Smith 's .")
+    self.CheckTokenization("It's the Smiths'.", "It 's the Smiths ' .")
+    self.CheckTokenization('Gotta go', 'Got ta go')
+    self.CheckTokenization('50-year-old', '50-year-old')
+
+  def testUrl(self):
+    self.CheckTokenization('http://www.google.com/news is down',
+                           'http : //www.google.com/news is down')
+
+
+if __name__ == '__main__':
+  googletest.main()

+ 111 - 0
syntaxnet/syntaxnet/unpack_sparse_features.cc

@@ -0,0 +1,111 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include <memory>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+#include "syntaxnet/sparse.pb.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/types.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_STRING;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::errors::InvalidArgument;
+
+namespace syntaxnet {
+
+// Operator to unpack ids and weights stored in SparseFeatures proto.
+class UnpackSparseFeatures : public OpKernel {
+ public:
+  explicit UnpackSparseFeatures(OpKernelConstruction *context)
+      : OpKernel(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature(
+                                {DT_STRING}, {DT_INT32, DT_INT64, DT_FLOAT}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    const Tensor &input = context->input(0);
+    OP_REQUIRES(context, IsLegacyVector(input.shape()),
+                InvalidArgument("input should be a vector."));
+
+    const int64 n = input.NumElements();
+    const auto input_vec = input.flat<string>();
+    SparseFeatures sf;
+    int output_size = 0;
+    std::vector<std::pair<int64, float> > id_and_weight;
+
+    // Guess that we'll be averaging a handful of ids per SparseFeatures record.
+    id_and_weight.reserve(n * 4);
+    std::vector<int> num_ids(n);
+    for (int64 i = 0; i < n; ++i) {
+      OP_REQUIRES(context, sf.ParseFromString(input_vec(i)),
+                  InvalidArgument("Couldn't parse as SparseFeature"));
+      OP_REQUIRES(context,
+                  sf.weight_size() == 0 || sf.weight_size() == sf.id_size(),
+                  InvalidArgument(tensorflow::strings::StrCat(
+                      "Incorrect number of weights", sf.DebugString())));
+      int n_ids = sf.id_size();
+      num_ids[i] = n_ids;
+      output_size += n_ids;
+      for (int j = 0; j < n_ids; j++) {
+        float w = (sf.weight_size() > 0) ? sf.weight(j) : 1.0f;
+        id_and_weight.push_back(std::make_pair(sf.id(j), w));
+      }
+    }
+
+    Tensor *indices_t;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                0, TensorShape({output_size}), &indices_t));
+    Tensor *ids_t;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                1, TensorShape({output_size}), &ids_t));
+    Tensor *weights_t;
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                2, TensorShape({output_size}), &weights_t));
+
+    auto indices = indices_t->vec<int32>();
+    auto ids = ids_t->vec<int64>();
+    auto weights = weights_t->vec<float>();
+    int c = 0;
+    for (int64 i = 0; i < n; ++i) {
+      for (int j = 0; j < num_ids[i]; ++j) {
+        indices(c) = i;
+        ids(c) = id_and_weight[c].first;
+        weights(c) = id_and_weight[c].second;
+        ++c;
+      }
+    }
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("UnpackSparseFeatures").Device(DEVICE_CPU),
+                        UnpackSparseFeatures);
+
+}  // namespace syntaxnet

+ 260 - 0
syntaxnet/syntaxnet/utils.cc

@@ -0,0 +1,260 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/platform/macros.h"
+
+namespace syntaxnet {
+namespace utils {
+
+bool ParseInt32(const char *c_str, int *value) {
+  char *temp;
+  *value = strtol(c_str, &temp, 0);  // NOLINT
+  return (*temp == '\0');
+}
+
+bool ParseInt64(const char *c_str, int64 *value) {
+  char *temp;
+  *value = strtol(c_str, &temp, 0);  // NOLINT
+  return (*temp == '\0');
+}
+
+bool ParseDouble(const char *c_str, double *value) {
+  char *temp;
+  *value = strtod(c_str, &temp);
+  return (*temp == '\0');
+}
+
+static char hex_char[] = "0123456789abcdef";
+
+string CEscape(const string &src) {
+  string dest;
+
+  for (unsigned char c : src) {
+    switch (c) {
+      case '\n':
+        dest.append("\\n");
+        break;
+      case '\r':
+        dest.append("\\r");
+        break;
+      case '\t':
+        dest.append("\\t");
+        break;
+      case '\"':
+        dest.append("\\\"");
+        break;
+      case '\'':
+        dest.append("\\'");
+        break;
+      case '\\':
+        dest.append("\\\\");
+        break;
+      default:
+        // Note that if we emit \xNN and the src character after that is a hex
+        // digit then that digit must be escaped too to prevent it being
+        // interpreted as part of the character code by C.
+        if ((c >= 0x80) || !isprint(c)) {
+          dest.append("\\");
+          dest.push_back(hex_char[c / 64]);
+          dest.push_back(hex_char[(c % 64) / 8]);
+          dest.push_back(hex_char[c % 8]);
+        } else {
+          dest.push_back(c);
+          break;
+        }
+    }
+  }
+
+  return dest;
+}
+
+std::vector<string> Split(const string &text, char delim) {
+  std::vector<string> result;
+  int token_start = 0;
+  if (!text.empty()) {
+    for (size_t i = 0; i < text.size() + 1; i++) {
+      if ((i == text.size()) || (text[i] == delim)) {
+        result.push_back(string(text.data() + token_start, i - token_start));
+        token_start = i + 1;
+      }
+    }
+  }
+  return result;
+}
+
+bool IsAbsolutePath(tensorflow::StringPiece path) {
+  return !path.empty() && path[0] == '/';
+}
+
+// For an array of paths of length count, append them all together,
+// ensuring that the proper path separators are inserted between them.
+string JoinPath(std::initializer_list<tensorflow::StringPiece> paths) {
+  string result;
+
+  for (tensorflow::StringPiece path : paths) {
+    if (path.empty()) {
+      continue;
+    }
+
+    if (result.empty()) {
+      result = path.ToString();
+      continue;
+    }
+
+    if (result[result.size() - 1] == '/') {
+      if (IsAbsolutePath(path)) {
+        tensorflow::strings::StrAppend(&result, path.substr(1));
+      } else {
+        tensorflow::strings::StrAppend(&result, path);
+      }
+    } else {
+      if (IsAbsolutePath(path)) {
+        tensorflow::strings::StrAppend(&result, path);
+      } else {
+        tensorflow::strings::StrAppend(&result, "/", path);
+      }
+    }
+  }
+
+  return result;
+}
+
+size_t RemoveLeadingWhitespace(tensorflow::StringPiece *text) {
+  size_t count = 0;
+  const char *ptr = text->data();
+  while (count < text->size() && isspace(*ptr)) {
+    count++;
+    ptr++;
+  }
+  text->remove_prefix(count);
+  return count;
+}
+
+size_t RemoveTrailingWhitespace(tensorflow::StringPiece *text) {
+  size_t count = 0;
+  const char *ptr = text->data() + text->size() - 1;
+  while (count < text->size() && isspace(*ptr)) {
+    ++count;
+    --ptr;
+  }
+  text->remove_suffix(count);
+  return count;
+}
+
+size_t RemoveWhitespaceContext(tensorflow::StringPiece *text) {
+  // use RemoveLeadingWhitespace() and RemoveTrailingWhitespace() to do the job
+  return RemoveLeadingWhitespace(text) + RemoveTrailingWhitespace(text);
+}
+
+namespace {
+// Lower-level versions of Get... that read directly from a character buffer
+// without any bounds checking.
+inline uint32 DecodeFixed32(const char *ptr) {
+  return ((static_cast<uint32>(static_cast<unsigned char>(ptr[0]))) |
+          (static_cast<uint32>(static_cast<unsigned char>(ptr[1])) << 8) |
+          (static_cast<uint32>(static_cast<unsigned char>(ptr[2])) << 16) |
+          (static_cast<uint32>(static_cast<unsigned char>(ptr[3])) << 24));
+}
+
+// 0xff is in case char is signed.
+static inline uint32 ByteAs32(char c) { return static_cast<uint32>(c) & 0xff; }
+}  // namespace
+
+uint32 Hash32(const char *data, size_t n, uint32 seed) {
+  // 'm' and 'r' are mixing constants generated offline.
+  // They're not really 'magic', they just happen to work well.
+  const uint32 m = 0x5bd1e995;
+  const int r = 24;
+
+  // Initialize the hash to a 'random' value
+  uint32 h = seed ^ n;
+
+  // Mix 4 bytes at a time into the hash
+  while (n >= 4) {
+    uint32 k = DecodeFixed32(data);
+    k *= m;
+    k ^= k >> r;
+    k *= m;
+    h *= m;
+    h ^= k;
+    data += 4;
+    n -= 4;
+  }
+
+  // Handle the last few bytes of the input array
+  switch (n) {
+    case 3:
+      h ^= ByteAs32(data[2]) << 16;
+      TF_FALLTHROUGH_INTENDED;
+    case 2:
+      h ^= ByteAs32(data[1]) << 8;
+      TF_FALLTHROUGH_INTENDED;
+    case 1:
+      h ^= ByteAs32(data[0]);
+      h *= m;
+  }
+
+  // Do a few final mixes of the hash to ensure the last few
+  // bytes are well-incorporated.
+  h ^= h >> 13;
+  h *= m;
+  h ^= h >> 15;
+  return h;
+}
+
+string Lowercase(tensorflow::StringPiece s) {
+  string result(s.data(), s.size());
+  for (char &c : result) {
+    c = tolower(c);
+  }
+  return result;
+}
+
+PunctuationUtil::CharacterRange PunctuationUtil::kPunctuation[] = {
+    {33, 35},       {37, 42},       {44, 47},       {58, 59},
+    {63, 64},       {91, 93},       {95, 95},       {123, 123},
+    {125, 125},     {161, 161},     {171, 171},     {183, 183},
+    {187, 187},     {191, 191},     {894, 894},     {903, 903},
+    {1370, 1375},   {1417, 1418},   {1470, 1470},   {1472, 1472},
+    {1475, 1475},   {1478, 1478},   {1523, 1524},   {1548, 1549},
+    {1563, 1563},   {1566, 1567},   {1642, 1645},   {1748, 1748},
+    {1792, 1805},   {2404, 2405},   {2416, 2416},   {3572, 3572},
+    {3663, 3663},   {3674, 3675},   {3844, 3858},   {3898, 3901},
+    {3973, 3973},   {4048, 4049},   {4170, 4175},   {4347, 4347},
+    {4961, 4968},   {5741, 5742},   {5787, 5788},   {5867, 5869},
+    {5941, 5942},   {6100, 6102},   {6104, 6106},   {6144, 6154},
+    {6468, 6469},   {6622, 6623},   {6686, 6687},   {8208, 8231},
+    {8240, 8259},   {8261, 8273},   {8275, 8286},   {8317, 8318},
+    {8333, 8334},   {9001, 9002},   {9140, 9142},   {10088, 10101},
+    {10181, 10182}, {10214, 10219}, {10627, 10648}, {10712, 10715},
+    {10748, 10749}, {11513, 11516}, {11518, 11519}, {11776, 11799},
+    {11804, 11805}, {12289, 12291}, {12296, 12305}, {12308, 12319},
+    {12336, 12336}, {12349, 12349}, {12448, 12448}, {12539, 12539},
+    {64830, 64831}, {65040, 65049}, {65072, 65106}, {65108, 65121},
+    {65123, 65123}, {65128, 65128}, {65130, 65131}, {65281, 65283},
+    {65285, 65290}, {65292, 65295}, {65306, 65307}, {65311, 65312},
+    {65339, 65341}, {65343, 65343}, {65371, 65371}, {65373, 65373},
+    {65375, 65381}, {65792, 65793}, {66463, 66463}, {68176, 68184},
+    {-1, -1}};
+
+void NormalizeDigits(string *form) {
+  for (size_t i = 0; i < form->size(); ++i) {
+    if ((*form)[i] >= '0' && (*form)[i] <= '9') (*form)[i] = '9';
+  }
+}
+
+}  // namespace utils
+}  // namespace syntaxnet

+ 171 - 0
syntaxnet/syntaxnet/utils.h

@@ -0,0 +1,171 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef $TARGETDIR_UTILS_H_
+#define $TARGETDIR_UTILS_H_
+
+#include <functional>
+#include <string>
+#include <vector>
+#include <unordered_set>
+#include "syntaxnet/base.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/default/integral_types.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "util/utf8/unicodetext.h"
+
+namespace syntaxnet {
+namespace utils {
+
+bool ParseInt32(const char *c_str, int *value);
+bool ParseInt64(const char *c_str, int64 *value);
+bool ParseDouble(const char *c_str, double *value);
+
+template <typename T>
+T ParseUsing(const string &str, std::function<bool(const char *, T *)> func) {
+  T value;
+  CHECK(func(str.c_str(), &value)) << "Failed to convert: " << str;
+  return value;
+}
+
+template <typename T>
+T ParseUsing(const string &str, T defval,
+             std::function<bool(const char *, T *)> func) {
+  return str.empty() ? defval : ParseUsing<T>(str, func);
+}
+
+string CEscape(const string &src);
+
+std::vector<string> Split(const string &text, char delim);
+
+template <typename T>
+string Join(const std::vector<T> &s, const char *sep) {
+  string result;
+  bool first = true;
+  for (const auto &x : s) {
+    tensorflow::strings::StrAppend(&result, (first ? "" : sep), x);
+    first = false;
+  }
+  return result;
+}
+
+string JoinPath(std::initializer_list<StringPiece> paths);
+
+size_t RemoveLeadingWhitespace(tensorflow::StringPiece *text);
+
+size_t RemoveTrailingWhitespace(tensorflow::StringPiece *text);
+
+size_t RemoveWhitespaceContext(tensorflow::StringPiece *text);
+
+uint32 Hash32(const char *data, size_t n, uint32 seed);
+
+// Deletes all the elements in an STL container and clears the container. This
+// function is suitable for use with a vector, set, hash_set, or any other STL
+// container which defines sensible begin(), end(), and clear() methods.
+// If container is NULL, this function is a no-op.
+template <typename T>
+void STLDeleteElements(T *container) {
+  if (!container) return;
+  auto it = container->begin();
+  while (it != container->end()) {
+    auto temp = it;
+    ++it;
+    delete *temp;
+  }
+  container->clear();
+}
+
+// Returns lower-cased version of s.
+string Lowercase(tensorflow::StringPiece s);
+
+class PunctuationUtil {
+ public:
+  // Unicode character ranges for punctuation characters according to CoNLL.
+  struct CharacterRange {
+    int first;
+    int last;
+  };
+  static CharacterRange kPunctuation[];
+
+  // Returns true if Unicode character is a punctuation character.
+  static bool IsPunctuation(int u) {
+    int i = 0;
+    while (kPunctuation[i].first > 0) {
+      if (u < kPunctuation[i].first) return false;
+      if (u <= kPunctuation[i].last) return true;
+      ++i;
+    }
+    return false;
+  }
+
+  // Determine if tag is a punctuation tag.
+  static bool IsPunctuationTag(const string &tag) {
+    for (size_t i = 0; i < tag.length(); ++i) {
+      int c = tag[i];
+      if (c != ',' && c != ':' && c != '.' && c != '\'' && c != '`') {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // Returns true if word consists of punctuation characters.
+  static bool IsPunctuationToken(const string &word) {
+    UnicodeText text;
+    text.PointToUTF8(word.c_str(), word.length());
+    UnicodeText::const_iterator it;
+    for (it = text.begin(); it != text.end(); ++it) {
+      if (!IsPunctuation(*it)) return false;
+    }
+    return true;
+  }
+
+  // Returns true if tag is non-empty and has only punctuation or parens
+  // symbols.
+  static bool IsPunctuationTagOrParens(const string &tag) {
+    if (tag.empty()) return false;
+    for (size_t i = 0; i < tag.length(); ++i) {
+      int c = tag[i];
+      if (c != '(' && c != ')' && c != ',' && c != ':' && c != '.' &&
+          c != '\'' && c != '`') {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // Decides whether to score a token, given the word, the POS tag and
+  // and the scoring type.
+  static bool ScoreToken(const string &word, const string &tag,
+                         const string &scoring_type) {
+    if (scoring_type == "default") {
+      return tag.empty() || !IsPunctuationTag(tag);
+    } else if (scoring_type == "conllx") {
+      return !IsPunctuationToken(word);
+    } else if (scoring_type == "ignore_parens") {
+      return !IsPunctuationTagOrParens(tag);
+    }
+    CHECK(scoring_type.empty()) << "Unknown scoring strategy " << scoring_type;
+    return true;
+  }
+};
+
+void NormalizeDigits(string *form);
+
+}  // namespace utils
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_UTILS_H_

+ 50 - 0
syntaxnet/syntaxnet/workspace.cc

@@ -0,0 +1,50 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "syntaxnet/workspace.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+
+namespace syntaxnet {
+
+string WorkspaceRegistry::DebugString() const {
+  string str;
+  for (auto &it : workspace_names_) {
+    const string &type_name = workspace_types_.at(it.first);
+    for (size_t index = 0; index < it.second.size(); ++index) {
+      const string &workspace_name = it.second[index];
+      tensorflow::strings::StrAppend(&str, "\n  ", type_name, " :: ",
+                                     workspace_name);
+    }
+  }
+  return str;
+}
+
+VectorIntWorkspace::VectorIntWorkspace(int size) : elements_(size) {}
+
+VectorIntWorkspace::VectorIntWorkspace(int size, int value)
+    : elements_(size, value) {}
+
+VectorIntWorkspace::VectorIntWorkspace(const vector<int> &elements)
+    : elements_(elements) {}
+
+string VectorIntWorkspace::TypeName() { return "Vector"; }
+
+VectorVectorIntWorkspace::VectorVectorIntWorkspace(int size)
+    : elements_(size) {}
+
+string VectorVectorIntWorkspace::TypeName() { return "VectorVector"; }
+
+}  // namespace syntaxnet

+ 215 - 0
syntaxnet/syntaxnet/workspace.h

@@ -0,0 +1,215 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Notes on thread-safety: All of the classes here are thread-compatible.  More
+// specifically, the registry machinery is thread-safe, as long as each thread
+// performs feature extraction on a different Sentence object.
+
+#ifndef $TARGETDIR_WORKSPACE_H_
+#define $TARGETDIR_WORKSPACE_H_
+
+#include <string>
+#include <typeindex>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+#include "syntaxnet/utils.h"
+
+namespace syntaxnet {
+
+// A base class for shared workspaces. Derived classes implement a static member
+// function TypeName() which returns a human readable string name for the class.
+class Workspace {
+ public:
+  // Polymorphic destructor.
+  virtual ~Workspace() {}
+
+ protected:
+  // Create an empty workspace.
+  Workspace() {}
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(Workspace);
+};
+
+// A registry that keeps track of workspaces.
+class WorkspaceRegistry {
+ public:
+  // Create an empty registry.
+  WorkspaceRegistry() {}
+
+  // Returns the index of a named workspace, adding it to the registry first
+  // if necessary.
+  template <class W>
+  int Request(const string &name) {
+    const std::type_index id = std::type_index(typeid(W));
+    workspace_types_[id] = W::TypeName();
+    vector<string> &names = workspace_names_[id];
+    for (int i = 0; i < names.size(); ++i) {
+      if (names[i] == name) return i;
+    }
+    names.push_back(name);
+    return names.size() - 1;
+  }
+
+  const std::unordered_map<std::type_index, vector<string> > &WorkspaceNames()
+      const {
+    return workspace_names_;
+  }
+
+  // Returns a string describing the registered workspaces.
+  string DebugString() const;
+
+ private:
+  // Workspace type names, indexed as workspace_types_[typeid].
+  std::unordered_map<std::type_index, string> workspace_types_;
+
+  // Workspace names, indexed as workspace_names_[typeid][workspace].
+  std::unordered_map<std::type_index, vector<string> > workspace_names_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
+};
+
+// A typed collected of workspaces. The workspaces are indexed according to an
+// external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
+// also immutable.
+class WorkspaceSet {
+ public:
+  ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
+
+  // Returns true if a workspace has been set.
+  template <class W>
+  bool Has(int index) const {
+    const std::type_index id = std::type_index(typeid(W));
+    DCHECK(workspaces_.find(id) != workspaces_.end());
+    DCHECK_LT(index, workspaces_.find(id)->second.size());
+    return workspaces_.find(id)->second[index] != nullptr;
+  }
+
+  // Returns an indexed workspace; the workspace must have been set.
+  template <class W>
+  const W &Get(int index) const {
+    DCHECK(Has<W>(index));
+    const Workspace *w =
+        workspaces_.find(std::type_index(typeid(W)))->second[index];
+    return reinterpret_cast<const W &>(*w);
+  }
+
+  // Sets an indexed workspace; this takes ownership of the workspace, which
+  // must have been new-allocated.  It is an error to set a workspace twice.
+  template <class W>
+  void Set(int index, W *workspace) {
+    const std::type_index id = std::type_index(typeid(W));
+    DCHECK(workspaces_.find(id) != workspaces_.end());
+    DCHECK_LT(index, workspaces_[id].size());
+    DCHECK(workspaces_[id][index] == nullptr);
+    DCHECK(workspace != nullptr);
+    workspaces_[id][index] = workspace;
+  }
+
+  void Reset(const WorkspaceRegistry &registry) {
+    // Deallocate current workspaces.
+    for (auto &it : workspaces_) {
+      for (size_t index = 0; index < it.second.size(); ++index) {
+        delete it.second[index];
+      }
+    }
+    workspaces_.clear();
+
+    // Allocate space for new workspaces.
+    for (auto &it : registry.WorkspaceNames()) {
+      workspaces_[it.first].resize(it.second.size());
+    }
+  }
+
+ private:
+  // The set of workspaces, indexed as workspaces_[typeid][index].
+  std::unordered_map<std::type_index, vector<Workspace *> > workspaces_;
+};
+
+// A workspace that wraps around a single int.
+class SingletonIntWorkspace : public Workspace {
+ public:
+  // Default-initializes the int value.
+  SingletonIntWorkspace() {}
+
+  // Initializes the int with the given value.
+  explicit SingletonIntWorkspace(int value) : value_(value) {}
+
+  // Returns the name of this type of workspace.
+  static string TypeName() { return "SingletonInt"; }
+
+  // Returns the int value.
+  int get() const { return value_; }
+
+  // Sets the int value.
+  void set(int value) { value_ = value; }
+
+ private:
+  // The enclosed int.
+  int value_ = 0;
+};
+
+// A workspace that wraps around a vector of int.
+class VectorIntWorkspace : public Workspace {
+ public:
+  // Creates a vector of the given size.
+  explicit VectorIntWorkspace(int size);
+
+  // Creates a vector initialized with the given array.
+  explicit VectorIntWorkspace(const vector<int> &elements);
+
+  // Creates a vector of the given size, with each element initialized to the
+  // given value.
+  VectorIntWorkspace(int size, int value);
+
+  // Returns the name of this type of workspace.
+  static string TypeName();
+
+  // Returns the i'th element.
+  int element(int i) const { return elements_[i]; }
+
+  // Sets the i'th element.
+  void set_element(int i, int value) { elements_[i] = value; }
+
+ private:
+  // The enclosed vector.
+  vector<int> elements_;
+};
+
+// A workspace that wraps around a vector of vector of int.
+class VectorVectorIntWorkspace : public Workspace {
+ public:
+  // Creates a vector of empty vectors of the given size.
+  explicit VectorVectorIntWorkspace(int size);
+
+  // Returns the name of this type of workspace.
+  static string TypeName();
+
+  // Returns the i'th vector of elements.
+  const vector<int> &elements(int i) const { return elements_[i]; }
+
+  // Mutable access to the i'th vector of elements.
+  vector<int> *mutable_elements(int i) { return &(elements_[i]); }
+
+ private:
+  // The enclosed vector of vector of elements.
+  vector<vector<int> > elements_;
+};
+
+}  // namespace syntaxnet
+
+#endif  // $TARGETDIR_WORKSPACE_H_

+ 1 - 0
syntaxnet/tensorflow

@@ -0,0 +1 @@
+Subproject commit 3402f51ecd11a26d0c071b1d06b4edab1b0ef351

+ 34 - 0
syntaxnet/third_party/utf/BUILD

@@ -0,0 +1,34 @@
+licenses(["notice"])
+
+cc_library(
+    name = "utf",
+    srcs = [
+        "rune.c",
+        "runestrcat.c",
+        "runestrchr.c",
+        "runestrcmp.c",
+        "runestrcpy.c",
+        "runestrdup.c",
+        "runestrecpy.c",
+        "runestrlen.c",
+        "runestrncat.c",
+        "runestrncmp.c",
+        "runestrncpy.c",
+        "runestrrchr.c",
+        "runestrstr.c",
+        "runetype.c",
+        "utfecpy.c",
+        "utflen.c",
+        "utfnlen.c",
+        "utfrrune.c",
+        "utfrune.c",
+        "utfutf.c",
+    ],
+    hdrs = [
+        "runetypebody.c",
+        "utf.h",
+        "utfdef.h",
+    ],
+    includes = ["."],
+    visibility = ["//visibility:public"],
+)

+ 13 - 0
syntaxnet/third_party/utf/README

@@ -0,0 +1,13 @@
+/*
+ * The authors of this software are Rob Pike and Ken Thompson.
+ *              Copyright (c) 1998-2002 by Lucent Technologies.
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose without fee is hereby granted, provided that this entire notice
+ * is included in all copies of any software which is or includes a copy
+ * or modification of this software and in all copies of the supporting
+ * documentation for such software.
+ * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
+ * WARRANTY.  IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
+ * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
+ * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
+ */

+ 357 - 0
syntaxnet/third_party/utf/rune.c

@@ -0,0 +1,357 @@
+/*
+ * The authors of this software are Rob Pike and Ken Thompson.
+ *              Copyright (c) 2002 by Lucent Technologies.
+ * Permission to use, copy, modify, and distribute this software for any
+ * purpose without fee is hereby granted, provided that this entire notice
+ * is included in all copies of any software which is or includes a copy
+ * or modification of this software and in all copies of the supporting
+ * documentation for such software.
+ * THIS SOFTWARE IS BEING PROVIDED "AS IS", WITHOUT ANY EXPRESS OR IMPLIED
+ * WARRANTY.  IN PARTICULAR, NEITHER THE AUTHORS NOR LUCENT TECHNOLOGIES MAKE ANY
+ * REPRESENTATION OR WARRANTY OF ANY KIND CONCERNING THE MERCHANTABILITY
+ * OF THIS SOFTWARE OR ITS FITNESS FOR ANY PARTICULAR PURPOSE.
+ */
+#include <stdarg.h>
+#include <string.h>
+#include "third_party/utf/utf.h"
+#include "third_party/utf/utfdef.h"
+
+enum
+{
+	Bit1	= 7,
+	Bitx	= 6,
+	Bit2	= 5,
+	Bit3	= 4,
+	Bit4	= 3,
+	Bit5	= 2, 
+
+	T1	= ((1<<(Bit1+1))-1) ^ 0xFF,	/* 0000 0000 */
+	Tx	= ((1<<(Bitx+1))-1) ^ 0xFF,	/* 1000 0000 */
+	T2	= ((1<<(Bit2+1))-1) ^ 0xFF,	/* 1100 0000 */
+	T3	= ((1<<(Bit3+1))-1) ^ 0xFF,	/* 1110 0000 */
+	T4	= ((1<<(Bit4+1))-1) ^ 0xFF,	/* 1111 0000 */
+	T5	= ((1<<(Bit5+1))-1) ^ 0xFF,	/* 1111 1000 */
+
+	Rune1	= (1<<(Bit1+0*Bitx))-1,		/* 0000 0000 0111 1111 */
+	Rune2	= (1<<(Bit2+1*Bitx))-1,		/* 0000 0111 1111 1111 */
+	Rune3	= (1<<(Bit3+2*Bitx))-1,		/* 1111 1111 1111 1111 */
+	Rune4	= (1<<(Bit4+3*Bitx))-1,
+                                        /* 0001 1111 1111 1111 1111 1111 */
+
+	Maskx	= (1<<Bitx)-1,			/* 0011 1111 */
+	Testx	= Maskx ^ 0xFF,			/* 1100 0000 */
+
+	Bad	= Runeerror,
+};
+
+/*
+ * Modified by Wei-Hwa Huang, Google Inc., on 2004-09-24
+ * This is a slower but "safe" version of the old chartorune 
+ * that works on strings that are not necessarily null-terminated.
+ * 
+ * If you know for sure that your string is null-terminated,
+ * chartorune will be a bit faster.
+ *
+ * It is guaranteed not to attempt to access "length"
+ * past the incoming pointer.  This is to avoid
+ * possible access violations.  If the string appears to be
+ * well-formed but incomplete (i.e., to get the whole Rune
+ * we'd need to read past str+length) then we'll set the Rune
+ * to Bad and return 0.
+ *
+ * Note that if we have decoding problems for other
+ * reasons, we return 1 instead of 0.
+ */
+int
+charntorune(Rune *rune, const char *str, int length)
+{
+	int c, c1, c2, c3;
+	long l;
+
+	/* When we're not allowed to read anything */
+	if(length <= 0) {
+		goto badlen;
+	}
+
+	/*
+	 * one character sequence (7-bit value)
+	 *	00000-0007F => T1
+	 */
+	c = *(uchar*)str;
+	if(c < Tx) {
+		*rune = c;
+		return 1;
+	}
+
+	// If we can't read more than one character we must stop
+	if(length <= 1) {
+		goto badlen;
+	}
+
+	/*
+	 * two character sequence (11-bit value)
+	 *	0080-07FF => T2 Tx
+	 */
+	c1 = *(uchar*)(str+1) ^ Tx;
+	if(c1 & Testx)
+		goto bad;
+	if(c < T3) {
+		if(c < T2)
+			goto bad;
+		l = ((c << Bitx) | c1) & Rune2;
+		if(l <= Rune1)
+			goto bad;
+		*rune = l;
+		return 2;
+	}
+
+	// If we can't read more than two characters we must stop
+	if(length <= 2) {
+		goto badlen;
+	}
+
+	/*
+	 * three character sequence (16-bit value)
+	 *	0800-FFFF => T3 Tx Tx
+	 */
+	c2 = *(uchar*)(str+2) ^ Tx;
+	if(c2 & Testx)
+		goto bad;
+	if(c < T4) {
+		l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
+		if(l <= Rune2)
+			goto bad;
+		*rune = l;
+		return 3;
+	}
+
+	if (length <= 3)
+		goto badlen;
+
+	/*
+	 * four character sequence (21-bit value)
+	 *	10000-1FFFFF => T4 Tx Tx Tx
+	 */
+	c3 = *(uchar*)(str+3) ^ Tx;
+	if (c3 & Testx)
+		goto bad;
+	if (c < T5) {
+		l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
+		if (l <= Rune3)
+			goto bad;
+		if (l > Runemax)
+			goto bad;
+		*rune = l;
+		return 4;
+	}
+
+	// Support for 5-byte or longer UTF-8 would go here, but
+	// since we don't have that, we'll just fall through to bad.
+
+	/*
+	 * bad decoding
+	 */
+bad:
+	*rune = Bad;
+	return 1;
+badlen:
+	*rune = Bad;
+	return 0;
+
+}
+
+
+/*
+ * This is the older "unsafe" version, which works fine on 
+ * null-terminated strings.
+ */
+int
+chartorune(Rune *rune, const char *str)
+{
+	int c, c1, c2, c3;
+	long l;
+
+	/*
+	 * one character sequence
+	 *	00000-0007F => T1
+	 */
+	c = *(uchar*)str;
+	if(c < Tx) {
+		*rune = c;
+		return 1;
+	}
+
+	/*
+	 * two character sequence
+	 *	0080-07FF => T2 Tx
+	 */
+	c1 = *(uchar*)(str+1) ^ Tx;
+	if(c1 & Testx)
+		goto bad;
+	if(c < T3) {
+		if(c < T2)
+			goto bad;
+		l = ((c << Bitx) | c1) & Rune2;
+		if(l <= Rune1)
+			goto bad;
+		*rune = l;
+		return 2;
+	}
+
+	/*
+	 * three character sequence
+	 *	0800-FFFF => T3 Tx Tx
+	 */
+	c2 = *(uchar*)(str+2) ^ Tx;
+	if(c2 & Testx)
+		goto bad;
+	if(c < T4) {
+		l = ((((c << Bitx) | c1) << Bitx) | c2) & Rune3;
+		if(l <= Rune2)
+			goto bad;
+		*rune = l;
+		return 3;
+	}
+
+	/*
+	 * four character sequence (21-bit value)
+	 *	10000-1FFFFF => T4 Tx Tx Tx
+	 */
+	c3 = *(uchar*)(str+3) ^ Tx;
+	if (c3 & Testx)
+		goto bad;
+	if (c < T5) {
+		l = ((((((c << Bitx) | c1) << Bitx) | c2) << Bitx) | c3) & Rune4;
+		if (l <= Rune3)
+			goto bad;
+		if (l > Runemax)
+			goto bad;
+		*rune = l;
+		return 4;
+	}
+
+	/*
+	 * Support for 5-byte or longer UTF-8 would go here, but
+	 * since we don't have that, we'll just fall through to bad.
+	 */
+
+	/*
+	 * bad decoding
+	 */
+bad:
+	*rune = Bad;
+	return 1;
+}
+
+int
+isvalidcharntorune(const char* str, int length, Rune* rune, int* consumed) {
+	*consumed = charntorune(rune, str, length);
+	return *rune != Runeerror || *consumed == 3;
+}
+    
+int
+runetochar(char *str, const Rune *rune)
+{
+	/* Runes are signed, so convert to unsigned for range check. */
+	unsigned long c;
+
+	/*
+	 * one character sequence
+	 *	00000-0007F => 00-7F
+	 */
+	c = *rune;
+	if(c <= Rune1) {
+		str[0] = c;
+		return 1;
+	}
+
+	/*
+	 * two character sequence
+	 *	0080-07FF => T2 Tx
+	 */
+	if(c <= Rune2) {
+		str[0] = T2 | (c >> 1*Bitx);
+		str[1] = Tx | (c & Maskx);
+		return 2;
+	}
+
+	/*
+	 * If the Rune is out of range, convert it to the error rune.
+	 * Do this test here because the error rune encodes to three bytes.
+	 * Doing it earlier would duplicate work, since an out of range
+	 * Rune wouldn't have fit in one or two bytes.
+	 */
+	if (c > Runemax)
+		c = Runeerror;
+
+	/*
+	 * three character sequence
+	 *	0800-FFFF => T3 Tx Tx
+	 */
+	if (c <= Rune3) {
+		str[0] = T3 |  (c >> 2*Bitx);
+		str[1] = Tx | ((c >> 1*Bitx) & Maskx);
+		str[2] = Tx |  (c & Maskx);
+		return 3;
+	}
+
+	/*
+	 * four character sequence (21-bit value)
+	 *     10000-1FFFFF => T4 Tx Tx Tx
+	 */
+	str[0] = T4 | (c >> 3*Bitx);
+	str[1] = Tx | ((c >> 2*Bitx) & Maskx);
+	str[2] = Tx | ((c >> 1*Bitx) & Maskx);
+	str[3] = Tx | (c & Maskx);
+	return 4;
+}
+
+int
+runelen(Rune rune)
+{
+	char str[10];
+
+	return runetochar(str, &rune);
+}
+
+int
+runenlen(const Rune *r, int nrune)
+{
+	int nb;
+	ulong c;	/* Rune is signed, so use unsigned for range check. */
+
+	nb = 0;
+	while(nrune--) {
+		c = *r++;
+		if (c <= Rune1)
+			nb++;
+		else if (c <= Rune2)
+			nb += 2;
+		else if (c <= Rune3)
+			nb += 3;
+		else if (c <= Runemax)
+			nb += 4;
+		else
+			nb += 3;	/* Runeerror = 0xFFFD, see runetochar */
+	}
+	return nb;
+}
+
+int
+fullrune(const char *str, int n)
+{
+	if (n > 0) {
+		int c = *(uchar*)str;
+		if (c < Tx)
+			return 1;
+		if (n > 1) {
+			if (c < T3)
+				return 1;
+			if (n > 2) {
+				if (c < T4 || n > 3)
+					return 1;
+			}
+		}
+	}
+	return 0;
+}

+ 0 - 0
syntaxnet/third_party/utf/runestrcat.c


Một số tệp đã không được hiển thị bởi vì quá nhiều tập tin thay đổi trong này khác