Jelajahi Sumber

Release DRAGNN (#1177)

* Release DRAGNN

* Update CoNLL evaluation table & evaluator.py
Ivan Bogatyy 8 tahun lalu
induk
melakukan
7d30a017fe
100 mengubah file dengan 65140 tambahan dan 479 penghapusan
  1. 4 0
      syntaxnet/.dockerignore
  2. 79 21
      syntaxnet/Dockerfile
  3. 120 456
      syntaxnet/README.md
  4. 15 2
      syntaxnet/WORKSPACE
  5. 27 0
      syntaxnet/docker-devel/build_devel.sh
  6. 35 0
      syntaxnet/docker-devel/build_wheels.sh
  7. 5 0
      syntaxnet/dragnn/BUILD
  8. 116 0
      syntaxnet/dragnn/components/syntaxnet/BUILD
  9. 779 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
  10. 183 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
  11. 1174 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
  12. 49 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
  13. 55 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
  14. 63 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
  15. 85 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
  16. 144 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
  17. 276 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
  18. 48 0
      syntaxnet/dragnn/components/syntaxnet/testdata/master_spec.textproto
  19. 47 0
      syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.label-map
  20. 65 0
      syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.master-spec
  21. 50 0
      syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.tag-map
  22. 4 0
      syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.word-map
  23. 9 0
      syntaxnet/dragnn/components/util/BUILD
  24. 95 0
      syntaxnet/dragnn/components/util/bulk_feature_extractor.h
  25. 9 0
      syntaxnet/dragnn/conll2017/BUILD
  26. 40 0
      syntaxnet/dragnn/conll2017/conll_parser_trainer.sh
  27. 105 0
      syntaxnet/dragnn/conll2017/make_parser_spec.py
  28. 16 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/category-map
  29. 3518 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/char-map
  30. 16126 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/char-ngram-map
  31. 43 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/label-map
  32. 16263 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/lcword-map
  33. TEMPAT SAMPAH
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/prefix-table
  34. TEMPAT SAMPAH
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/suffix-table
  35. 43 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/tag-map
  36. 42 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/tag-to-category
  37. 16269 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/word-map
  38. TEMPAT SAMPAH
      syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.data-00000-of-00001
  39. TEMPAT SAMPAH
      syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.index
  40. TEMPAT SAMPAH
      syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.meta
  41. 187 0
      syntaxnet/dragnn/conll2017/sample/zh-segmenter.master_spec
  42. 340 0
      syntaxnet/dragnn/core/BUILD
  43. 347 0
      syntaxnet/dragnn/core/beam.h
  44. 773 0
      syntaxnet/dragnn/core/beam_test.cc
  45. 8 0
      syntaxnet/dragnn/core/component_registry.cc
  46. 14 0
      syntaxnet/dragnn/core/component_registry.h
  47. 120 0
      syntaxnet/dragnn/core/compute_session.h
  48. 384 0
      syntaxnet/dragnn/core/compute_session_impl.cc
  49. 142 0
      syntaxnet/dragnn/core/compute_session_impl.h
  50. 1157 0
      syntaxnet/dragnn/core/compute_session_impl_test.cc
  51. 89 0
      syntaxnet/dragnn/core/compute_session_pool.cc
  52. 87 0
      syntaxnet/dragnn/core/compute_session_pool.h
  53. 211 0
      syntaxnet/dragnn/core/compute_session_pool_test.cc
  54. 67 0
      syntaxnet/dragnn/core/index_translator.cc
  55. 68 0
      syntaxnet/dragnn/core/index_translator.h
  56. 180 0
      syntaxnet/dragnn/core/index_translator_test.cc
  57. 78 0
      syntaxnet/dragnn/core/input_batch_cache.h
  58. 107 0
      syntaxnet/dragnn/core/input_batch_cache_test.cc
  59. 37 0
      syntaxnet/dragnn/core/interfaces/BUILD
  60. 52 0
      syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
  61. 126 0
      syntaxnet/dragnn/core/interfaces/component.h
  62. 30 0
      syntaxnet/dragnn/core/interfaces/input_batch.h
  63. 53 0
      syntaxnet/dragnn/core/interfaces/transition_state.h
  64. 113 0
      syntaxnet/dragnn/core/interfaces/transition_state_starter_test.cc
  65. 70 0
      syntaxnet/dragnn/core/ops/compute_session_op.cc
  66. 54 0
      syntaxnet/dragnn/core/ops/compute_session_op.h
  67. 396 0
      syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
  68. 588 0
      syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
  69. 115 0
      syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
  70. 631 0
      syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
  71. 851 0
      syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
  72. 239 0
      syntaxnet/dragnn/core/ops/dragnn_ops.cc
  73. 36 0
      syntaxnet/dragnn/core/resource_container.h
  74. 49 0
      syntaxnet/dragnn/core/resource_container_test.cc
  75. 57 0
      syntaxnet/dragnn/core/test/BUILD
  76. 21 0
      syntaxnet/dragnn/core/test/generic.cc
  77. 25 0
      syntaxnet/dragnn/core/test/generic.h
  78. 63 0
      syntaxnet/dragnn/core/test/mock_component.h
  79. 61 0
      syntaxnet/dragnn/core/test/mock_compute_session.h
  80. 30 0
      syntaxnet/dragnn/core/test/mock_transition_state.h
  81. TEMPAT SAMPAH
      syntaxnet/dragnn/core/testdata/brain-parser-model
  82. 86 0
      syntaxnet/dragnn/core/testdata/master_spec_link.textproto
  83. TEMPAT SAMPAH
      syntaxnet/dragnn/core/testdata/repository
  84. TEMPAT SAMPAH
      syntaxnet/dragnn/core/testdata/simple-tagger.brain-parser-model
  85. TEMPAT SAMPAH
      syntaxnet/dragnn/core/testdata/simple-tagger.repository
  86. 46 0
      syntaxnet/dragnn/core/testdata/simple-tagger.tag-map
  87. 59 0
      syntaxnet/dragnn/core/testdata/simple_parser_master_spec.textproto
  88. 52 0
      syntaxnet/dragnn/core/testdata/simple_tagger_lstm_master_spec.textproto
  89. 63 0
      syntaxnet/dragnn/core/testdata/simple_tagger_master_spec.textproto
  90. 65 0
      syntaxnet/dragnn/core/testdata/simple_tagger_wrapped_lstm_master_spec.textproto
  91. 111 0
      syntaxnet/dragnn/core/testdata/split_tagger_master_spec.textproto
  92. 47 0
      syntaxnet/dragnn/core/testdata/syntaxnet_tagger.label-map
  93. 50 0
      syntaxnet/dragnn/core/testdata/syntaxnet_tagger.tag-map
  94. 4 0
      syntaxnet/dragnn/core/testdata/syntaxnet_tagger.word-map
  95. 185 0
      syntaxnet/dragnn/core/testdata/tagger_parser_master_spec.textproto
  96. 213 0
      syntaxnet/dragnn/core/testdata/ud-hungarian.master-spec
  97. 34 0
      syntaxnet/dragnn/io/BUILD
  98. 31 0
      syntaxnet/dragnn/io/sentence_input_batch.cc
  99. 37 0
      syntaxnet/dragnn/io/sentence_input_batch.h
  100. 0 0
      syntaxnet/dragnn/io/sentence_input_batch_test.cc

+ 4 - 0
syntaxnet/.dockerignore

@@ -0,0 +1,4 @@
+.git
+bazel/
+Dockerfile*
+tensorflow/.git

+ 79 - 21
syntaxnet/Dockerfile

@@ -1,33 +1,91 @@
+# Java baseimage, for Bazel.
 FROM java:8
 FROM java:8
 
 
 ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
 ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
 
 
+# Install system packages. This doesn't include everything the TensorFlow
+# dockerfile specifies, so if anything goes awry, maybe install more packages
+# from there. Also, running apt-get clean before further commands will make the
+# Docker images smaller.
 RUN mkdir -p $SYNTAXNETDIR \
 RUN mkdir -p $SYNTAXNETDIR \
     && cd $SYNTAXNETDIR \
     && cd $SYNTAXNETDIR \
     && apt-get update \
     && apt-get update \
-    && apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip python-mock -y \
-    && pip install --upgrade pip \
-    && pip install -U protobuf==3.0.0b2 \
-    && pip install asciitree \
-    && pip install numpy \
-    && wget https://github.com/bazelbuild/bazel/releases/download/0.4.3/bazel-0.4.3-installer-linux-x86_64.sh \
+    && apt-get install -y \
+          file \
+          git \
+          graphviz \
+          libcurl3-dev \
+          libfreetype6-dev \
+          libgraphviz-dev \
+          liblapack-dev \
+          libopenblas-dev \
+          libpng12-dev \
+          libxft-dev \
+          python-dev \
+          python-mock \
+          python-pip \
+          python2.7 \
+          swig \
+          vim \
+          zlib1g-dev \
+    && apt-get clean \
+    && (rm -f /var/cache/apt/archives/*.deb \
+        /var/cache/apt/archives/partial/*.deb /var/cache/apt/*.bin || true)
+
+# Install common Python dependencies. Similar to above, remove caches
+# afterwards to help keep Docker images smaller.
+RUN pip install --ignore-installed pip \
+    && python -m pip install numpy \
+    && rm -rf /root/.cache/pip /tmp/pip*
+RUN python -m pip install \
+          asciitree \
+          ipykernel \
+          jupyter \
+          matplotlib \
+          pandas \
+          protobuf \
+          scipy \
+          sklearn \
+    && python -m ipykernel.kernelspec \
+    && python -m pip install pygraphviz \
+          --install-option="--include-path=/usr/include/graphviz" \
+          --install-option="--library-path=/usr/lib/graphviz/" \
+    && rm -rf /root/.cache/pip /tmp/pip*
+
+# Installs the latest version of Bazel.
+RUN wget --quiet https://github.com/bazelbuild/bazel/releases/download/0.4.3/bazel-0.4.3-installer-linux-x86_64.sh \
     && chmod +x bazel-0.4.3-installer-linux-x86_64.sh \
     && chmod +x bazel-0.4.3-installer-linux-x86_64.sh \
-    && ./bazel-0.4.3-installer-linux-x86_64.sh --user \
-    && git clone --recursive https://github.com/tensorflow/models.git \
-    && cd $SYNTAXNETDIR/models/syntaxnet/tensorflow \
-    && echo -e "\n\n\n\n\n\n\n\n\n" | ./configure \
-    && apt-get autoremove -y \
-    && apt-get clean
+    && ./bazel-0.4.3-installer-linux-x86_64.sh \
+    && rm ./bazel-0.4.3-installer-linux-x86_64.sh
+
+COPY WORKSPACE $SYNTAXNETDIR/syntaxnet/WORKSPACE
+COPY tools/bazel.rc $SYNTAXNETDIR/syntaxnet/tools/bazel.rc
+COPY tensorflow $SYNTAXNETDIR/syntaxnet/tensorflow
+
+# Compile common TensorFlow targets, which don't depend on DRAGNN / SyntaxNet
+# source. This makes it more convenient to re-compile DRAGNN / SyntaxNet for
+# development (though not as convenient as the docker-devel scripts).
+RUN cd $SYNTAXNETDIR/syntaxnet/tensorflow \
+    && tensorflow/tools/ci_build/builds/configured CPU \
+    && cd $SYNTAXNETDIR/syntaxnet \
+    && bazel build -c opt @org_tensorflow//tensorflow:tensorflow_py
 
 
-RUN cd $SYNTAXNETDIR/models/syntaxnet \
-    && bazel test --genrule_strategy=standalone syntaxnet/... util/utf8/...
+# Build the codez.
+WORKDIR $SYNTAXNETDIR/syntaxnet
+COPY dragnn $SYNTAXNETDIR/syntaxnet/dragnn
+COPY syntaxnet $SYNTAXNETDIR/syntaxnet/syntaxnet
+COPY third_party $SYNTAXNETDIR/syntaxnet/third_party
+COPY util/utf8 $SYNTAXNETDIR/syntaxnet/util/utf8
+RUN bazel build -c opt //dragnn/python:all //dragnn/tools:all
 
 
-WORKDIR $SYNTAXNETDIR/models/syntaxnet
+# This makes the IP exposed actually "*"; we'll do host restrictions by passing
+# a hostname to the `docker run` command.
+COPY tensorflow/tensorflow/tools/docker/jupyter_notebook_config.py /root/.jupyter/
+EXPOSE 8888
 
 
-CMD [ "sh", "-c", "echo 'Bob brought the pizza to Alice.' | syntaxnet/demo.sh" ]
+# This does not need to be compiled, only copied.
+COPY examples $SYNTAXNETDIR/syntaxnet/examples
+# Todo: Move this earlier in the file (don't want to invalidate caches for now).
+RUN jupyter nbextension enable --py --sys-prefix widgetsnbextension
 
 
-# COMMANDS to build and run
-# ===============================
-# mkdir build && cp Dockerfile build/ && cd build
-# docker build -t syntaxnet .
-# docker run syntaxnet
+CMD /bin/bash -c "bazel-bin/dragnn/tools/oss_notebook_launcher notebook --debug --notebook-dir=/opt/tensorflow/syntaxnet/examples"

+ 120 - 456
syntaxnet/README.md

@@ -1,90 +1,71 @@
 # SyntaxNet: Neural Models of Syntax.
 # SyntaxNet: Neural Models of Syntax.
 
 
-*A TensorFlow implementation of the models described in [Andor et al. (2016)]
-(http://arxiv.org/abs/1603.06042).*
+*A TensorFlow toolkit for deep learning powered natural language understanding
+(NLU).*
 
 
-**Update**: Parsey models are now [available](universal.md) for 40 languages
-trained on Universal Dependencies datasets, with support for text segmentation
-and morphological analysis.
+**CoNLL**: See [here](g3doc/conll2017/README.md) for instructions for using the
+SyntaxNet/DRAGNN baseline for the CoNLL2017 Shared Task.
 
 
 At Google, we spend a lot of time thinking about how computer systems can read
 At Google, we spend a lot of time thinking about how computer systems can read
 and understand human language in order to process it in intelligent ways. We are
 and understand human language in order to process it in intelligent ways. We are
 excited to share the fruits of our research with the broader community by
 excited to share the fruits of our research with the broader community by
-releasing SyntaxNet, an open-source neural network framework for [TensorFlow]
-(http://www.tensorflow.org) that provides a foundation for Natural Language
-Understanding (NLU) systems. Our release includes all the code needed to train
-new SyntaxNet models on your own data, as well as *Parsey McParseface*, an
-English parser that we have trained for you, and that you can use to analyze
-English text.
-
-So, how accurate is Parsey McParseface? For this release, we tried to balance a
-model that runs fast enough to be useful on a single machine (e.g. ~600
-words/second on a modern desktop) and that is also the most accurate parser
-available. Here's how Parsey McParseface compares to the academic literature on
-several different English domains: (all numbers are % correct head assignments
-in the tree, or unlabelled attachment score)
-
-Model                                                                                                           | News  | Web   | Questions
---------------------------------------------------------------------------------------------------------------- | :---: | :---: | :-------:
-[Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/)                                                | 93.10 | 88.23 | 94.21
-[Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf)                                  | 93.32 | 88.65 | 93.37
-[Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17
-[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                                                         | 94.44 | 90.17 | 95.40
-Parsey McParseface                                                                                              | 94.15 | 89.08 | 94.77
-
-We see that Parsey McParseface is state-of-the-art; more importantly, with
-SyntaxNet you can train larger networks with more hidden units and bigger beam
-sizes if you want to push the accuracy even further: [Andor et al. (2016)]
-(http://arxiv.org/abs/1603.06042)* is simply a SyntaxNet model with a
-larger beam and network. For futher information on the datasets, see that paper
-under the section "Treebank Union".
+releasing SyntaxNet, an open-source neural network framework for
+[TensorFlow](http://www.tensorflow.org) that provides a foundation for Natural
+Language Understanding (NLU) systems. Our release includes all the code needed
+to train new SyntaxNet models on your own data, as well as a suite of models
+that we have trained for you, and that you can use to analyze text in over 40
+languages.
+
+This repository is largely divided into two sub-packages:
+
+1.  **DRAGNN:
+    [code](https://github.com/tensorflow/models/tree/master/syntaxnet/dragnn),
+    [documentation](g3doc/DRAGNN.md)** implements Dynamic Recurrent Acyclic
+    Graphical Neural Networks (DRAGNN), a framework for building multi-task,
+    fully dynamic constructed computation graphs. Practically, we use DRAGNN to
+    extend our prior work from [Andor et al.
+    (2016)](http://arxiv.org/abs/1603.06042) with end-to-end, deep recurrent
+    models and to provide a much easier to use interface to SyntaxNet.
+1.  **SyntaxNet:
+    [code](https://github.com/tensorflow/models/tree/master/syntaxnet/syntaxnet),
+    [documentation](g3doc/syntaxnet-tutorial.md)** is a transition-based
+    framework for natural language processing, with core functionality for
+    feature extraction, representing annotated data, and evaluation. As of the
+    DRAGNN release, it is recommended to train and deploy SyntaxNet models using
+    the DRAGNN framework.
+
+## How to use this library
+
+There are three ways to use SyntaxNet:
+
+*   See [here](g3doc/conll2017/README.md) for instructions for using the
+    SyntaxNet/DRAGNN baseline for the CoNLL2017 Shared Task, and running the
+    ParseySaurus models.
+*   You can use DRAGNN to train your NLP models for other tasks and dataset. See
+    "Getting started with DRAGNN below."
+*   You can continue to use the Parsey McParseface family of pre-trained
+    SyntaxNet models. See "Pre-trained NLP models" below.
 
 
-Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging
-(numbers below are per-token accuracy):
+## Installation
 
 
-Model                                                                      | News  | Web   | Questions
--------------------------------------------------------------------------- | :---: | :---: | :-------:
-[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.44 | 94.03 | 96.18
-[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                    | 97.77 | 94.80 | 96.86
-Parsey McParseface                                                         | 97.52 | 94.24 | 96.45
+### Docker installation
 
 
-The first part of this tutorial describes how to install the necessary tools and
-use the already trained models provided in this release. In the second part of
-the tutorial we provide more background about the models, as well as
-instructions for training models on other datasets.
-
-## Contents
-* [Installation](#installation)
-* [Getting Started](#getting-started)
-    * [Parsing from Standard Input](#parsing-from-standard-input)
-    * [Annotating a Corpus](#annotating-a-corpus)
-    * [Configuring the Python Scripts](#configuring-the-python-scripts)
-    * [Next Steps](#next-steps)
-* [Detailed Tutorial: Building an NLP Pipeline with SyntaxNet](#detailed-tutorial-building-an-nlp-pipeline-with-syntaxnet)
-    * [Obtaining Data](#obtaining-data)
-    * [Part-of-Speech Tagging](#part-of-speech-tagging)
-    * [Training the SyntaxNet POS Tagger](#training-the-syntaxnet-pos-tagger)
-    * [Preprocessing with the Tagger](#preprocessing-with-the-tagger)
-    * [Dependency Parsing: Transition-Based Parsing](#dependency-parsing-transition-based-parsing)
-    * [Training a Parser Step 1: Local Pretraining](#training-a-parser-step-1-local-pretraining)
-    * [Training a Parser Step 2: Global Training](#training-a-parser-step-2-global-training)
-* [Contact](#contact)
-* [Credits](#credits)
+The simplest way to get started with DRAGNN is by loading our Docker container.
+[Here](g3doc/CLOUD.md) is a tutorial for running the DRAGNN container on
+[GCP](https://cloud.google.com) (just as applicable to your own computer).
 
 
-## Installation
+### Manual installation
 
 
-Running and training SyntaxNet models requires building this package from
+Running and training SyntaxNet/DRAGNN models requires building this package from
 source. You'll need to install:
 source. You'll need to install:
 
 
 *   python 2.7:
 *   python 2.7:
-    * python 3 support is not available yet
+    *   Python 3 support is not available yet
 *   bazel:
 *   bazel:
-    *   **version 0.4.3**
-    *   follow the instructions [here](http://bazel.build/docs/install.html)
-    *   Alternately, Download bazel (0.4.3) <.deb> from
-        [https://github.com/bazelbuild/bazel/releases]
-        (https://github.com/bazelbuild/bazel/releases) for your system
-        configuration.
+    *   Follow the instructions [here](http://bazel.build/docs/install.html)
+    *   Alternately, Download bazel <.deb> from
+        [https://github.com/bazelbuild/bazel/releases](https://github.com/bazelbuild/bazel/releases)
+        for your system configuration.
     *   Install it using the command: sudo dpkg -i <.deb file>
     *   Install it using the command: sudo dpkg -i <.deb file>
     *   Check for the bazel version by typing: bazel version
     *   Check for the bazel version by typing: bazel version
 *   swig:
 *   swig:
@@ -99,6 +80,11 @@ source. You'll need to install:
     *   `pip install asciitree`
     *   `pip install asciitree`
 *   numpy, package for scientific computing:
 *   numpy, package for scientific computing:
     *   `pip install numpy`
     *   `pip install numpy`
+*   pygraphviz to visualize traces and parse trees:
+    *   `apt-get install -y graphviz libgraphviz-dev`
+    *   `pip install pygraphviz
+        --install-option="--include-path=/usr/include/graphviz"
+        --install-option="--library-path=/usr/lib/graphviz/"`
 
 
 Once you completed the above steps, you can build and test SyntaxNet with the
 Once you completed the above steps, you can build and test SyntaxNet with the
 following commands:
 following commands:
@@ -108,17 +94,14 @@ following commands:
   cd models/syntaxnet/tensorflow
   cd models/syntaxnet/tensorflow
   ./configure
   ./configure
   cd ..
   cd ..
-  bazel test syntaxnet/... util/utf8/...
+  bazel test ...
   # On Mac, run the following:
   # On Mac, run the following:
   bazel test --linkopt=-headerpad_max_install_names \
   bazel test --linkopt=-headerpad_max_install_names \
-    syntaxnet/... util/utf8/...
+    dragnn/... syntaxnet/... util/utf8/...
 ```
 ```
 
 
 Bazel should complete reporting all tests passed.
 Bazel should complete reporting all tests passed.
 
 
-You can also compile SyntaxNet in a [Docker](https://www.docker.com/what-docker)
-container using this [Dockerfile](Dockerfile).
-
 To build SyntaxNet with GPU support please refer to the instructions in
 To build SyntaxNet with GPU support please refer to the instructions in
 [issues/248](https://github.com/tensorflow/models/issues/248).
 [issues/248](https://github.com/tensorflow/models/issues/248).
 
 
@@ -127,12 +110,64 @@ memory allocated for your Docker VM.
 
 
 ## Getting Started
 ## Getting Started
 
 
+We have a few guides on this README, as well as more extensive
+[documentation](g3doc/).
+
+### Learning the DRAGNN framework
+
+![DRAGNN](g3doc/unrolled-dragnn.png)
+
+An easy and visual way to get started with DRAGNN is to run [our Jupyter
+Notebook](examples/dragnn/basic_parser_tutorial.ipynb). Our tutorial
+[here](g3doc/CLOUD.md) explains how to start it up from the Docker container.
+
+### Using the Pre-trained NLP models
+
+We are happy to release *Parsey McParseface*, an English parser that we have
+trained for you, and that you can use to analyze English text, along with
+[trained models for 40 languages](g3doc/universal.md) and support for text
+segmentation and morphological analysis.
+
 Once you have successfully built SyntaxNet, you can start parsing text right
 Once you have successfully built SyntaxNet, you can start parsing text right
 away with Parsey McParseface, located under `syntaxnet/models`. The easiest
 away with Parsey McParseface, located under `syntaxnet/models`. The easiest
 thing is to use or modify the included script `syntaxnet/demo.sh`, which shows a
 thing is to use or modify the included script `syntaxnet/demo.sh`, which shows a
 basic setup to parse English taking plain text as input.
 basic setup to parse English taking plain text as input.
 
 
-### Parsing from Standard Input
+You can also skip right away to the [detailed SyntaxNet
+tutorial](g3doc/syntaxnet-tutorial.md).
+
+How accurate is Parsey McParseface? For the initial release, we tried to balance
+a model that runs fast enough to be useful on a single machine (e.g. ~600
+words/second on a modern desktop) and that is also the most accurate parser
+available. Here's how Parsey McParseface compares to the academic literature on
+several different English domains: (all numbers are % correct head assignments
+in the tree, or unlabelled attachment score)
+
+Model                                                                                                           | News  | Web   | Questions
+--------------------------------------------------------------------------------------------------------------- | :---: | :---: | :-------:
+[Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/)                                                | 93.10 | 88.23 | 94.21
+[Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf)                                  | 93.32 | 88.65 | 93.37
+[Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17
+[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                                                         | 94.44 | 90.17 | 95.40
+Parsey McParseface                                                                                              | 94.15 | 89.08 | 94.77
+
+We see that Parsey McParseface is state-of-the-art; more importantly, with
+SyntaxNet you can train larger networks with more hidden units and bigger beam
+sizes if you want to push the accuracy even further: [Andor et al.
+(2016)](http://arxiv.org/abs/1603.06042)* is simply a SyntaxNet model with a
+larger beam and network. For futher information on the datasets, see that paper
+under the section "Treebank Union".
+
+Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging
+(numbers below are per-token accuracy):
+
+Model                                                                      | News  | Web   | Questions
+-------------------------------------------------------------------------- | :---: | :---: | :-------:
+[Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.44 | 94.03 | 96.18
+[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                    | 97.77 | 94.80 | 96.86
+Parsey McParseface                                                         | 97.52 | 94.24 | 96.45
+
+#### Parsing from Standard Input
 
 
 Simply pass one sentence per line of text into the script at
 Simply pass one sentence per line of text into the script at
 `syntaxnet/demo.sh`. The script will break the text into words, run the POS
 `syntaxnet/demo.sh`. The script will break the text into words, run the POS
@@ -160,7 +195,7 @@ visualized in our tutorial graphs. In this example, we see that the verb
 If you want to feed in tokenized, CONLL-formatted text, you can run `demo.sh
 If you want to feed in tokenized, CONLL-formatted text, you can run `demo.sh
 --conll`.
 --conll`.
 
 
-### Annotating a Corpus
+#### Annotating a Corpus
 
 
 To change the pipeline to read and write to specific files (as opposed to piping
 To change the pipeline to read and write to specific files (as opposed to piping
 through stdin and stdout), we have to modify the `demo.sh` to point to the files
 through stdin and stdout), we have to modify the `demo.sh` to point to the files
@@ -200,7 +235,7 @@ input {
 Then we can use `--input=wsj-data --output=wsj-data-tagged` on the command line
 Then we can use `--input=wsj-data --output=wsj-data-tagged` on the command line
 to specify reading and writing to these files.
 to specify reading and writing to these files.
 
 
-### Configuring the Python Scripts
+#### Configuring the Python Scripts
 
 
 As mentioned above, the python scripts are configured in two ways:
 As mentioned above, the python scripts are configured in two ways:
 
 
@@ -234,386 +269,13 @@ There are many ways to extend this framework, e.g. adding new features, changing
 the model structure, training on other languages, etc. We suggest reading the
 the model structure, training on other languages, etc. We suggest reading the
 detailed tutorial below to get a handle on the rest of the framework.
 detailed tutorial below to get a handle on the rest of the framework.
 
 
-## Detailed Tutorial: Building an NLP Pipeline with SyntaxNet
-
-In this tutorial, we'll go over how to train new models, and explain in a bit
-more technical detail the NLP side of the models. Our goal here is to explain
-the NLP pipeline produced by this package.
-
-### Obtaining Data
-
-The included English parser, Parsey McParseface, was trained on the the standard
-corpora of the [Penn Treebank](https://catalog.ldc.upenn.edu/LDC99T42) and
-[OntoNotes](https://catalog.ldc.upenn.edu/LDC2013T19), as well as the [English
-Web Treebank](https://catalog.ldc.upenn.edu/LDC2012T13), but these are
-unfortunately not freely available.
-
-However, the [Universal Dependencies](http://universaldependencies.org/) project
-provides freely available treebank data in a number of languages. SyntaxNet can
-be trained and evaluated on any of these corpora.
-
-### Part-of-Speech Tagging
-
-Consider the following sentence, which exhibits several ambiguities that affect
-its interpretation:
-
-> I saw the man with glasses.
-
-This sentence is composed of words: strings of characters that are segmented
-into groups (e.g. "I", "saw", etc.) Each word in the sentence has a *grammatical
-function* that can be useful for understanding the meaning of language. For
-example, "saw" in this example is a past tense of the verb "to see". But any
-given word might have different meanings in different contexts: "saw" could just
-as well be a noun (e.g., a saw used for cutting) or a present tense verb (using
-a saw to cut something).
-
-A logical first step in understanding language is figuring out these roles for
-each word in the sentence. This process is called *Part-of-Speech (POS)
-Tagging*. The roles are called POS tags. Although a given word might have
-multiple possible tags depending on the context, given any one interpretation of
-a sentence each word will generally only have one tag.
-
-One interesting challenge of POS tagging is that the problem of defining a
-vocabulary of POS tags for a given language is quite involved. While the concept
-of nouns and verbs is pretty common, it has been traditionally difficult to
-agree on a standard set of roles across all languages. The [Universal
-Dependencies](http://www.universaldependencies.org) project aims to solve this
-problem.
-
-### Training the SyntaxNet POS Tagger
-
-In general, determining the correct POS tag requires understanding the entire
-sentence and the context in which it is uttered. In practice, we can do very
-well just by considering a small window of words around the word of interest.
-For example, words that follow the word ‘the’ tend to be adjectives or nouns,
-rather than verbs.
-
-To predict POS tags, we use a simple setup. We process the sentences
-left-to-right. For any given word, we extract features of that word and a window
-around it, and use these as inputs to a feed-forward neural network classifier,
-which predicts a probability distribution over POS tags. Because we make
-decisions in left-to-right order, we also use prior decisions as features in
-subsequent ones (e.g. "the previous predicted tag was a noun.").
-
-All the models in this package use a flexible markup language to define
-features. For example, the features in the POS tagger are found in the
-`brain_pos_features` parameter in the `TaskSpec`, and look like this (modulo
-spacing):
-
-```
-stack(3).word stack(2).word stack(1).word stack.word input.word input(1).word input(2).word input(3).word;
-input.digit input.hyphen;
-stack.suffix(length=2) input.suffix(length=2) input(1).suffix(length=2);
-stack.prefix(length=2) input.prefix(length=2) input(1).prefix(length=2)
-```
-
-Note that `stack` here means "words we have already tagged." Thus, this feature
-spec uses three types of features: words, suffixes, and prefixes. The features
-are grouped into blocks that share an embedding matrix, concatenated together,
-and fed into a chain of hidden layers. This structure is based upon the model
-proposed by [Chen and Manning (2014)]
-(http://cs.stanford.edu/people/danqi/papers/emnlp2014.pdf).
-
-We show this layout in the schematic below: the state of the system (a stack and
-a buffer, visualized below for both the POS and the dependency parsing task) is
-used to extract sparse features, which are fed into the network in groups. We
-show only a small subset of the features to simplify the presentation in the
-schematic:
-
-![Schematic](ff_nn_schematic.png "Feed-forward Network Structure")
-
-In the configuration above, each block gets its own embedding matrix and the
-blocks in the configuration above are delineated with a semi-colon. The
-dimensions of each block are controlled in the `brain_pos_embedding_dims`
-parameter. **Important note:** unlike many simple NLP models, this is *not* a
-bag of words model. Remember that although certain features share embedding
-matrices, the above features will be concatenated, so the interpretation of
-`input.word` will be quite different from `input(1).word`. This also means that
-adding features increases the dimension of the `concat` layer of the model as
-well as the number of parameters for the first hidden layer.
-
-To train the model, first edit `syntaxnet/context.pbtxt` so that the inputs
-`training-corpus`, `tuning-corpus`, and `dev-corpus` point to the location of
-your training data. You can then train a part-of-speech tagger with:
-
-```shell
-bazel-bin/syntaxnet/parser_trainer \
-  --task_context=syntaxnet/context.pbtxt \
-  --arg_prefix=brain_pos \  # read from POS configuration
-  --compute_lexicon \       # required for first stage of pipeline
-  --graph_builder=greedy \  # no beam search
-  --training_corpus=training-corpus \  # names of training/tuning set
-  --tuning_corpus=tuning-corpus \
-  --output_path=models \  # where to save new resources
-  --batch_size=32 \       # Hyper-parameters
-  --decay_steps=3600 \
-  --hidden_layer_sizes=128 \
-  --learning_rate=0.08 \
-  --momentum=0.9 \
-  --seed=0 \
-  --params=128-0.08-3600-0.9-0  # name for these parameters
-```
-
-This will read in the data, construct a lexicon, build a tensorflow graph for
-the model with the specific hyperparameters, and train the model. Every so often
-the model will be evaluated on the tuning set, and only the checkpoint with the
-highest accuracy on this set will be saved. **Note that you should never use a
-corpus you intend to test your model on as your tuning set, as you will inflate
-your test set results.**
-
-For best results, you should repeat this command with at least 3 different
-seeds, and possibly with a few different values for `--learning_rate` and
-`--decay_steps`. Good values for `--learning_rate` are usually close to 0.1, and
-you usually want `--decay_steps` to correspond to about one tenth of your
-corpus. The `--params` flag is only a human readable identifier for the model
-being trained, used to construct the full output path, so that you don't need to
-worry about clobbering old models by accident.
-
-The `--arg_prefix` flag controls which parameters should be read from the task
-context file `context.pbtxt`. In this case `arg_prefix` is set to `brain_pos`,
-so the paramters being used in this training run are
-`brain_pos_transition_system`, `brain_pos_embedding_dims`, `brain_pos_features`
-and, `brain_pos_embedding_names`. To train the dependency parser later
-`arg_prefix` will be set to `brain_parser`.
-
-### Preprocessing with the Tagger
-
-Now that we have a trained POS tagging model, we want to use the output of this
-model as features in the parser. Thus the next step is to run the trained model
-over our training, tuning, and dev (evaluation) sets. We can use the
-parser_eval.py` script for this.
-
-For example, the model `128-0.08-3600-0.9-0` trained above can be run over the
-training, tuning, and dev sets with the following command:
-
-```shell
-PARAMS=128-0.08-3600-0.9-0
-for SET in training tuning dev; do
-  bazel-bin/syntaxnet/parser_eval \
-    --task_context=models/brain_pos/greedy/$PARAMS/context \
-    --hidden_layer_sizes=128 \
-    --input=$SET-corpus \
-    --output=tagged-$SET-corpus \
-    --arg_prefix=brain_pos \
-    --graph_builder=greedy \
-    --model_path=models/brain_pos/greedy/$PARAMS/model
-done
-```
-
-**Important note:** This command only works because we have created entries for
-you in `context.pbtxt` that correspond to `tagged-training-corpus`,
-`tagged-dev-corpus`, and `tagged-tuning-corpus`. From these default settings,
-the above will write tagged versions of the training, tuning, and dev set to the
-directory `models/brain_pos/greedy/$PARAMS/`. This location is chosen because
-the `input` entries do not have `file_pattern` set: instead, they have `creator:
-brain_pos/greedy`, which means that `parser_trainer.py` will construct *new*
-files when called with `--arg_prefix=brain_pos --graph_builder=greedy` using the
-`--model_path` flag to determine the location.
-
-For convenience, `parser_eval.py` also logs POS tagging accuracy after the
-output tagged datasets have been written.
-
-### Dependency Parsing: Transition-Based Parsing
-
-Now that we have a prediction for the grammatical role of the words, we want to
-understand how the words in the sentence relate to each other. This parser is
-built around the *head-modifier* construction: for each word, we choose a
-*syntactic head* that it modifies according to some grammatical role.
-
-An example for the above sentence is as follows:
-
-![Figure](sawman.png)
-
-Below each word in the sentence we see both a fine-grained part-of-speech
-(*PRP*, *VBD*, *DT*, *NN* etc.), and a coarse-grained part-of-speech (*PRON*,
-*VERB*, *DET*, *NOUN*, etc.). Coarse-grained POS tags encode basic grammatical
-categories, while the fine-grained POS tags make further distinctions: for
-example *NN* is a singular noun (as opposed, for example, to *NNS*, which is a
-plural noun), and *VBD* is a past-tense verb. For more discussion see [Petrov et
-al. (2012)](http://www.lrec-conf.org/proceedings/lrec2012/pdf/274_Paper.pdf).
-
-Crucially, we also see directed arcs signifying grammatical relationships
-between different words in the sentence. For example *I* is the subject of
-*saw*, as signified by the directed arc labeled *nsubj* between these words;
-*man* is the direct object (dobj) of *saw*; the preposition *with* modifies
-*man* with a prep relation, signifiying modification by a prepositional phrase;
-and so on. In addition the verb *saw* is identified as the *root* of the entire
-sentence.
-
-Whenever we have a directed arc between two words, we refer to the word at the
-start of the arc as the *head*, and the word at the end of the arc as the
-*modifier*. For example we have one arc where the head is *saw* and the modifier
-is *I*, another where the head is *saw* and the modifier is *man*, and so on.
-
-The grammatical relationships encoded in dependency structures are directly
-related to the underlying meaning of the sentence in question. They allow us to
-easily recover the answers to various questions, for example *whom did I see?*,
-*who saw the man with glasses?*, and so on.
-
-SyntaxNet is a **transition-based** dependency parser [Nivre (2007)]
-(http://www.mitpressjournals.org/doi/pdfplus/10.1162/coli.07-056-R1-07-027) that
-constructs a parse incrementally. Like the tagger, it processes words
-left-to-right. The words all start as unprocessed input, called the *buffer*. As
-words are encountered they are put onto a *stack*. At each step, the parser can
-do one of three things:
-
-1.  **SHIFT:** Push another word onto the top of the stack, i.e. shifting one
-    token from the buffer to the stack.
-1.  **LEFT_ARC:** Pop the top two words from the stack. Attach the second to the
-    first, creating an arc pointing to the **left**. Push the **first** word
-    back on the stack.
-1.  **RIGHT_ARC:** Pop the top two words from the stack. Attach the second to
-    the first, creating an arc point to the **right**. Push the **second** word
-    back on the stack.
-
-At each step, we call the combination of the stack and the buffer the
-*configuration* of the parser. For the left and right actions, we also assign a
-dependency relation label to that arc. This process is visualized in the
-following animation for a short sentence:
-
-![Animation](looping-parser.gif "Parsing in Action")
-
-Note that this parser is following a sequence of actions, called a
-**derivation**, to produce a "gold" tree labeled by a linguist. We can use this
-sequence of decisions to learn a classifier that takes a configuration and
-predicts the next action to take.
-
-### Training a Parser Step 1: Local Pretraining
-
-As described in our [paper](http://arxiv.org/abs/1603.06042), the first
-step in training the model is to *pre-train* using *local* decisions. In this
-phase, we use the gold dependency to guide the parser, and train a softmax layer
-to predict the correct action given these gold dependencies. This can be
-performed very efficiently, since the parser's decisions are all independent in
-this setting.
-
-Once the tagged datasets are available, a locally normalized dependency parsing
-model can be trained with the following command:
-
-```shell
-bazel-bin/syntaxnet/parser_trainer \
-  --arg_prefix=brain_parser \
-  --batch_size=32 \
-  --projectivize_training_set \
-  --decay_steps=4400 \
-  --graph_builder=greedy \
-  --hidden_layer_sizes=200,200 \
-  --learning_rate=0.08 \
-  --momentum=0.85 \
-  --output_path=models \
-  --task_context=models/brain_pos/greedy/$PARAMS/context \
-  --seed=4 \
-  --training_corpus=tagged-training-corpus \
-  --tuning_corpus=tagged-tuning-corpus \
-  --params=200x200-0.08-4400-0.85-4
-```
-
-Note that we point the trainer to the context corresponding to the POS tagger
-that we picked previously. This allows the parser to reuse the lexicons and the
-tagged datasets that were created in the previous steps. Processing data can be
-done similarly to how tagging was done above. For example if in this case we
-picked parameters `200x200-0.08-4400-0.85-4`, the training, tuning and dev sets
-can be parsed with the following command:
-
-```shell
-PARAMS=200x200-0.08-4400-0.85-4
-for SET in training tuning dev; do
-  bazel-bin/syntaxnet/parser_eval \
-    --task_context=models/brain_parser/greedy/$PARAMS/context \
-    --hidden_layer_sizes=200,200 \
-    --input=tagged-$SET-corpus \
-    --output=parsed-$SET-corpus \
-    --arg_prefix=brain_parser \
-    --graph_builder=greedy \
-    --model_path=models/brain_parser/greedy/$PARAMS/model
-done
-```
-
-### Training a Parser Step 2: Global Training
-
-As we describe in the paper, there are several problems with the locally
-normalized models we just trained. The most important is the *label-bias*
-problem: the model doesn't learn what a good parse looks like, only what action
-to take given a history of gold decisions. This is because the scores are
-normalized *locally* using a softmax for each decision.
-
-In the paper, we show how we can achieve much better results using a *globally*
-normalized model: in this model, the softmax scores are summed in log space, and
-the scores are not normalized until we reach a final decision. When the parser
-stops, the scores of each hypothesis are normalized against a small set of
-possible parses (in the case of this model, a beam size of 8). When training, we
-force the parser to stop during parsing when the gold derivation falls off the
-beam (a strategy known as early-updates).
-
-We give a simplified view of how this training works for a [garden path
-sentence](https://en.wikipedia.org/wiki/Garden_path_sentence), where it is
-important to maintain multiple hypotheses. A single mistake early on in parsing
-leads to a completely incorrect parse; after training, the model learns to
-prefer the second (correct) parse.
-
-![Beam search training](beam_search_training.png)
-
-Parsey McParseface correctly parses this sentence. Even though the correct parse
-is initially ranked 4th out of multiple hypotheses, when the end of the garden
-path is reached, Parsey McParseface can recover due to the beam; using a larger
-beam will get a more accurate model, but it will be slower (we used beam 32 for
-the models in the paper).
-
-Once you have the pre-trained locally normalized model, a globally normalized
-parsing model can now be trained with the following command:
-
-```shell
-bazel-bin/syntaxnet/parser_trainer \
-  --arg_prefix=brain_parser \
-  --batch_size=8 \
-  --decay_steps=100 \
-  --graph_builder=structured \
-  --hidden_layer_sizes=200,200 \
-  --learning_rate=0.02 \
-  --momentum=0.9 \
-  --output_path=models \
-  --task_context=models/brain_parser/greedy/$PARAMS/context \
-  --seed=0 \
-  --training_corpus=projectivized-training-corpus \
-  --tuning_corpus=tagged-tuning-corpus \
-  --params=200x200-0.02-100-0.9-0 \
-  --pretrained_params=models/brain_parser/greedy/$PARAMS/model \
-  --pretrained_params_names=\
-embedding_matrix_0,embedding_matrix_1,embedding_matrix_2,\
-bias_0,weights_0,bias_1,weights_1
-```
-
-Training a beam model with the structured builder will take a lot longer than
-the greedy training runs above, perhaps 3 or 4 times longer. Note once again
-that multiple restarts of training will yield the most reliable results.
-Evaluation can again be done with `parser_eval.py`. In this case we use
-parameters `200x200-0.02-100-0.9-0` to evaluate on the training, tuning and dev
-sets with the following command:
-
-```shell
-PARAMS=200x200-0.02-100-0.9-0
-for SET in training tuning dev; do
-  bazel-bin/syntaxnet/parser_eval \
-    --task_context=models/brain_parser/structured/$PARAMS/context \
-    --hidden_layer_sizes=200,200 \
-    --input=tagged-$SET-corpus \
-    --output=beam-parsed-$SET-corpus \
-    --arg_prefix=brain_parser \
-    --graph_builder=structured \
-    --model_path=models/brain_parser/structured/$PARAMS/model
-done
-```
-
-Hooray! You now have your very own cousin of Parsey McParseface, ready to go out
-and parse text in the wild.
-
 ## Contact
 ## Contact
 
 
 To ask questions or report issues please post on Stack Overflow with the tag
 To ask questions or report issues please post on Stack Overflow with the tag
-[syntaxnet](http://stackoverflow.com/questions/tagged/syntaxnet)
-or open an issue on the tensorflow/models
-[issues tracker](https://github.com/tensorflow/models/issues).
-Please assign SyntaxNet issues to @calberti or @andorardo.
+[syntaxnet](http://stackoverflow.com/questions/tagged/syntaxnet) or open an
+issue on the tensorflow/models [issues
+tracker](https://github.com/tensorflow/models/issues). Please assign SyntaxNet
+issues to @calberti or @andorardo.
 
 
 ## Credits
 ## Credits
 
 
@@ -633,6 +295,7 @@ Original authors of the code in this package include (in alphabetical order):
 *   Keith Hall
 *   Keith Hall
 *   Kuzman Ganchev
 *   Kuzman Ganchev
 *   Livio Baldini Soares
 *   Livio Baldini Soares
+*   Mark Omernick
 *   Michael Collins
 *   Michael Collins
 *   Michael Ringgaard
 *   Michael Ringgaard
 *   Ryan McDonald
 *   Ryan McDonald
@@ -640,3 +303,4 @@ Original authors of the code in this package include (in alphabetical order):
 *   Stefan Istrate
 *   Stefan Istrate
 *   Terry Koo
 *   Terry Koo
 *   Tim Credo
 *   Tim Credo
+*   Zora Tung

+ 15 - 2
syntaxnet/WORKSPACE

@@ -3,10 +3,23 @@ local_repository(
   path = "tensorflow",
   path = "tensorflow",
 )
 )
 
 
+# We need to pull in @io_bazel_rules_closure for TensorFlow. Bazel design
+# documentation states that this verbosity is intentional, to prevent
+# TensorFlow/SyntaxNet from depending on different versions of
+# @io_bazel_rules_closure.
+http_archive(
+    name = "io_bazel_rules_closure",
+    sha256 = "60fc6977908f999b23ca65698c2bb70213403824a84f7904310b6000d78be9ce",
+    strip_prefix = "rules_closure-5ca1dab6df9ad02050f7ba4e816407f88690cf7d",
+    urls = [
+        "http://bazel-mirror.storage.googleapis.com/github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",  # 2017-02-03
+        "https://github.com/bazelbuild/rules_closure/archive/5ca1dab6df9ad02050f7ba4e816407f88690cf7d.tar.gz",
+    ],
+)
+
 load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
 load("@org_tensorflow//tensorflow:workspace.bzl", "tf_workspace")
 tf_workspace(path_prefix="", tf_repo_name="org_tensorflow")
 tf_workspace(path_prefix="", tf_repo_name="org_tensorflow")
 
 
 # Test that Bazel is up-to-date.
 # Test that Bazel is up-to-date.
 load("@org_tensorflow//tensorflow:workspace.bzl", "check_version")
 load("@org_tensorflow//tensorflow:workspace.bzl", "check_version")
-check_version("0.4.3")
-
+check_version("0.4.2")

+ 27 - 0
syntaxnet/docker-devel/build_devel.sh

@@ -0,0 +1,27 @@
+#!/bin/bash
+#
+# This file puts you in a Docker sub-shell where you can build SyntaxNet
+# targets. It is intended for development, as the Dockerfile (build file) does
+# not actually build any of SyntaxNet, but instead mounts it in a volume.
+
+script_path="$(readlink -f "$0")"
+root_path="$(dirname "$(dirname "${script_path}")")"
+set -e
+
+if [[ -z "$(docker images -q dragnn-oss)" ]]; then
+  docker build -t dragnn-oss .
+else
+  echo "NOTE: dragnn-oss image already exists, not re-building." >&2
+  echo "Please run \`docker build -t dragnn-oss .\` if you need." >&2
+fi
+
+echo -e "\n\nRun bazel commands like \`bazel test syntaxnet/...\`"
+
+# NOTE: Unfortunately, we need to mount /tensorflow over /syntaxnet/tensorflow
+# (which happens via devel_entrypoint.sh). This requires privileged mode.
+syntaxnet_base="/opt/tensorflow/syntaxnet"
+docker run --rm -ti \
+  -v "${root_path}"/syntaxnet:"${syntaxnet_base}"/syntaxnet \
+  -v "${root_path}"/dragnn:"${syntaxnet_base}"/dragnn \
+  -p 127.0.0.1:8888:8888 \
+  dragnn-oss "$@"

+ 35 - 0
syntaxnet/docker-devel/build_wheels.sh

@@ -0,0 +1,35 @@
+#!/bin/bash
+#
+# Convenience script to build wheel files in Docker, and copy them out of the
+# container.
+#
+# Usage: docker-devel/build_wheels.sh (takes no arguments; run it from the base
+# directory).
+set -e
+docker build -t dragnn-oss .
+
+# Start building the wheels.
+script="bazel run //dragnn/tools:build_pip_package \
+  -- --output-dir=/opt/tensorflow/syntaxnet; \
+  bazel run //dragnn/tools:build_pip_package \
+  -- --output-dir=/opt/tensorflow/syntaxnet --include-tensorflow"
+container_id="$(docker run -d dragnn-oss /bin/bash -c "${script}")"
+
+echo "Waiting for container ${container_id} to finish building the wheel ..."
+if [[ "$(docker wait "${container_id}")" != 0 ]]; then
+  echo "Container failed! Please run \`docker logs <id>\` to see errors." >&2
+  exit 1
+fi
+
+# The build_pip_package.py script prints lines like "Wrote x.whl". The wheel
+# names are prefixed by architecture and such, so don't guess them.
+wheels=(
+  $(docker logs "${container_id}" 2>/dev/null | grep Wrote | awk '{print $2;}'))
+for wheel in "${wheels[@]}"; do
+  output=./"$(basename "${wheel}")"
+  docker cp "${container_id}:${wheel}" "${output}"
+  echo "Wrote ${output} ($(du -h "${output}" | awk '{print $1;}'))"
+done
+
+echo "Removing ${container_id} ..."
+docker rm "${container_id}" >/dev/null

+ 5 - 0
syntaxnet/dragnn/BUILD

@@ -0,0 +1,5 @@
+package_group(
+    name = "dragnn_visibility",
+    packages = [
+    ],
+)

+ 116 - 0
syntaxnet/dragnn/components/syntaxnet/BUILD

@@ -0,0 +1,116 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "syntaxnet_component",
+    srcs = ["syntaxnet_component.cc"],
+    hdrs = ["syntaxnet_component.h"],
+    deps = [
+        ":syntaxnet_link_feature_extractor",
+        ":syntaxnet_transition_state",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core:beam",
+        "//dragnn/core:component_registry",
+        "//dragnn/core:input_batch_cache",
+        "//dragnn/core/interfaces:component",
+        "//dragnn/core/interfaces:transition_state",
+        "//dragnn/io:sentence_input_batch",
+        "//dragnn/io:syntaxnet_sentence",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "//dragnn/protos:trace_proto",
+        "//syntaxnet:base",
+        "//syntaxnet:parser_transitions",
+        "//syntaxnet:registry",
+        "//syntaxnet:sparse_proto",
+        "//syntaxnet:task_context",
+        "//syntaxnet:task_spec_proto",
+        "//syntaxnet:utils",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "syntaxnet_link_feature_extractor",
+    srcs = ["syntaxnet_link_feature_extractor.cc"],
+    hdrs = ["syntaxnet_link_feature_extractor.h"],
+    deps = [
+        "//dragnn/protos:spec_proto",
+        "//syntaxnet:embedding_feature_extractor",
+        "//syntaxnet:parser_transitions",
+        "//syntaxnet:task_context",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+cc_library(
+    name = "syntaxnet_transition_state",
+    srcs = ["syntaxnet_transition_state.cc"],
+    hdrs = ["syntaxnet_transition_state.h"],
+    deps = [
+        "//dragnn/core/interfaces:cloneable_transition_state",
+        "//dragnn/core/interfaces:transition_state",
+        "//dragnn/io:syntaxnet_sentence",
+        "//dragnn/protos:trace_proto",
+        "//syntaxnet:base",
+        "//syntaxnet:parser_transitions",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+# Test data.
+filegroup(
+    name = "testdata",
+    data = glob(["testdata/**"]),
+)
+
+# Tests.
+cc_test(
+    name = "syntaxnet_component_test",
+    srcs = ["syntaxnet_component_test.cc"],
+    data = [":testdata"],
+    deps = [
+        ":syntaxnet_component",
+        "//dragnn/core:input_batch_cache",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_transition_state",
+        "//dragnn/io:sentence_input_batch",
+        "//syntaxnet:sentence_proto",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "syntaxnet_link_feature_extractor_test",
+    srcs = ["syntaxnet_link_feature_extractor_test.cc"],
+    deps = [
+        ":syntaxnet_link_feature_extractor",
+        "//dragnn/core/test:generic",
+        "//dragnn/protos:spec_proto",
+        "//syntaxnet:task_context",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+        "@org_tensorflow//tensorflow/core:testlib",
+    ],
+)
+
+cc_test(
+    name = "syntaxnet_transition_state_test",
+    srcs = ["syntaxnet_transition_state_test.cc"],
+    data = [":testdata"],
+    deps = [
+        ":syntaxnet_component",
+        ":syntaxnet_transition_state",
+        "//dragnn/core:input_batch_cache",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_transition_state",
+        "//dragnn/io:sentence_input_batch",
+        "//dragnn/protos:spec_proto",
+        "//syntaxnet:sentence_proto",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:test",
+        "@org_tensorflow//tensorflow/core:testlib",
+    ],
+)

+ 779 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc

@@ -0,0 +1,779 @@
+#include "dragnn/components/syntaxnet/syntaxnet_component.h"
+
+#include <vector>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "dragnn/io/syntaxnet_sentence.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/sparse.pb.h"
+#include "syntaxnet/task_spec.pb.h"
+#include "syntaxnet/utils.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::strings::StrCat;
+
+namespace {
+
+// Returns a new step in a trace based on a ComponentSpec.
+ComponentStepTrace GetNewStepTrace(const ComponentSpec &spec,
+                                   const TransitionState &state) {
+  ComponentStepTrace step;
+  for (auto &linked_spec : spec.linked_feature()) {
+    auto &channel_trace = *step.add_linked_feature_trace();
+    channel_trace.set_name(linked_spec.name());
+    channel_trace.set_source_component(linked_spec.source_component());
+    channel_trace.set_source_translator(linked_spec.source_translator());
+    channel_trace.set_source_layer(linked_spec.source_layer());
+  }
+  for (auto &fixed_spec : spec.fixed_feature()) {
+    step.add_fixed_feature_trace()->set_name(fixed_spec.name());
+  }
+  step.set_html_representation(state.HTMLRepresentation());
+  return step;
+}
+
+// Returns the last step in the trace.
+ComponentStepTrace *GetLastStepInTrace(ComponentTrace *trace) {
+  CHECK_GT(trace->step_trace_size(), 0) << "Trace has no steps added yet";
+  return trace->mutable_step_trace(trace->step_trace_size() - 1);
+}
+
+}  // anonymous namespace
+
+SyntaxNetComponent::SyntaxNetComponent()
+    : feature_extractor_("brain_parser"),
+      rewrite_root_labels_(false),
+      max_beam_size_(1),
+      input_data_(nullptr) {}
+
+void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
+  // Save off the passed spec for future reference.
+  spec_ = spec;
+
+  // Create and populate a TaskContext for the underlying parser.
+  TaskContext context;
+
+  // Add the specified resources.
+  for (const Resource &resource : spec_.resource()) {
+    auto *input = context.GetInput(resource.name());
+    for (const Part &part : resource.part()) {
+      auto *input_part = input->add_part();
+      input_part->set_file_pattern(part.file_pattern());
+      input_part->set_file_format(part.file_format());
+      input_part->set_record_format(part.record_format());
+    }
+  }
+
+  // Add the specified task args to the transition system.
+  for (const auto &param : spec_.transition_system().parameters()) {
+    context.SetParameter(param.first, param.second);
+  }
+
+  // Set the arguments for the feature extractor.
+  std::vector<string> names;
+  std::vector<string> dims;
+  std::vector<string> fml;
+  std::vector<string> predicate_maps;
+
+  for (const FixedFeatureChannel &channel : spec.fixed_feature()) {
+    names.push_back(channel.name());
+    fml.push_back(channel.fml());
+    predicate_maps.push_back(channel.predicate_map());
+    dims.push_back(StrCat(channel.embedding_dim()));
+  }
+
+  context.SetParameter("neurosis_feature_syntax_version", "2");
+  context.SetParameter("brain_parser_embedding_dims", utils::Join(dims, ";"));
+  context.SetParameter("brain_parser_predicate_maps",
+                       utils::Join(predicate_maps, ";"));
+  context.SetParameter("brain_parser_features", utils::Join(fml, ";"));
+  context.SetParameter("brain_parser_embedding_names", utils::Join(names, ";"));
+
+  names.clear();
+  dims.clear();
+  fml.clear();
+  predicate_maps.clear();
+
+  std::vector<string> source_components;
+  std::vector<string> source_layers;
+  std::vector<string> source_translators;
+
+  for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
+    names.push_back(channel.name());
+    fml.push_back(channel.fml());
+    dims.push_back(StrCat(channel.embedding_dim()));
+    source_components.push_back(channel.source_component());
+    source_layers.push_back(channel.source_layer());
+    source_translators.push_back(channel.source_translator());
+    predicate_maps.push_back("none");
+  }
+
+  context.SetParameter("link_embedding_dims", utils::Join(dims, ";"));
+  context.SetParameter("link_predicate_maps", utils::Join(predicate_maps, ";"));
+  context.SetParameter("link_features", utils::Join(fml, ";"));
+  context.SetParameter("link_embedding_names", utils::Join(names, ";"));
+  context.SetParameter("link_source_layers", utils::Join(source_layers, ";"));
+  context.SetParameter("link_source_translators",
+                       utils::Join(source_translators, ";"));
+  context.SetParameter("link_source_components",
+                       utils::Join(source_components, ";"));
+
+  context.SetParameter("parser_transition_system",
+                       spec.transition_system().registered_name());
+
+  // Set up the fixed feature extractor.
+  feature_extractor_.Setup(&context);
+  feature_extractor_.Init(&context);
+  feature_extractor_.RequestWorkspaces(&workspace_registry_);
+
+  // Set up the underlying transition system.
+  transition_system_.reset(ParserTransitionSystem::Create(
+      context.Get("parser_transition_system", "arc-standard")));
+  transition_system_->Setup(&context);
+  transition_system_->Init(&context);
+
+  // Create label map.
+  string path = TaskContext::InputFile(*context.GetInput("label-map"));
+  label_map_ =
+      SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
+
+  // Set up link feature extractors.
+  if (spec.linked_feature_size() > 0) {
+    link_feature_extractor_.Setup(&context);
+    link_feature_extractor_.Init(&context);
+    link_feature_extractor_.RequestWorkspaces(&workspace_registry_);
+  }
+
+  // Get the legacy flag for simulating old parser processor behavior. If the
+  // flag is not set, default to 'false'.
+  rewrite_root_labels_ = context.Get("rewrite_root_labels", false);
+}
+
+std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
+    int max_size) {
+  std::unique_ptr<Beam<SyntaxNetTransitionState>> beam(
+      new Beam<SyntaxNetTransitionState>(max_size));
+  auto permission_function = [this](SyntaxNetTransitionState *state,
+                                    int action) {
+    VLOG(3) << "permission_function action:" << action
+            << " is_allowed:" << this->IsAllowed(state, action);
+    return this->IsAllowed(state, action);
+  };
+  auto finality_function = [this](SyntaxNetTransitionState *state) {
+    VLOG(2) << "finality_function is_final:" << this->IsFinal(state);
+    return this->IsFinal(state);
+  };
+  auto oracle_function = [this](SyntaxNetTransitionState *state) {
+    VLOG(2) << "oracle_function action:" << this->GetOracleLabel(state);
+    return this->GetOracleLabel(state);
+  };
+  auto beam_ptr = beam.get();
+  auto advance_function = [this, beam_ptr](SyntaxNetTransitionState *state,
+                                           int action) {
+    VLOG(2) << "advance_function beam ptr:" << beam_ptr << " action:" << action;
+    this->Advance(state, action, beam_ptr);
+  };
+  beam->SetFunctions(permission_function, finality_function, advance_function,
+                     oracle_function);
+
+  return beam;
+}
+
+void SyntaxNetComponent::InitializeData(
+    const std::vector<std::vector<const TransitionState *>> &parent_states,
+    int max_beam_size, InputBatchCache *input_data) {
+  // Save off the input data object.
+  input_data_ = input_data;
+
+  // If beam size has changed, change all beam sizes for existing beams.
+  if (max_beam_size_ != max_beam_size) {
+    CHECK_GT(max_beam_size, 0)
+        << "Requested max beam size must be greater than 0.";
+    VLOG(2) << "Adjusting max beam size from " << max_beam_size_ << " to "
+            << max_beam_size;
+    max_beam_size_ = max_beam_size;
+    for (auto &beam : batch_) {
+      beam->SetMaxSize(max_beam_size_);
+    }
+  }
+
+  SentenceInputBatch *sentences = input_data->GetAs<SentenceInputBatch>();
+
+  // Expect that the sentence data is the same size as the input states batch.
+  if (!parent_states.empty()) {
+    CHECK_EQ(parent_states.size(), sentences->data()->size());
+  }
+
+  // Adjust the beam vector so that it is the correct size for this batch.
+  if (batch_.size() < sentences->data()->size()) {
+    VLOG(1) << "Batch size is increased to " << sentences->data()->size()
+            << " from " << batch_.size();
+    for (int i = batch_.size(); i < sentences->data()->size(); ++i) {
+      batch_.push_back(CreateBeam(max_beam_size));
+    }
+  } else if (batch_.size() > sentences->data()->size()) {
+    VLOG(1) << "Batch size is decreased to " << sentences->data()->size()
+            << " from " << batch_.size();
+    batch_.erase(batch_.begin() + sentences->data()->size(), batch_.end());
+
+  } else {
+    VLOG(1) << "Batch size is constant at " << sentences->data()->size();
+  }
+  CHECK_EQ(batch_.size(), sentences->data()->size());
+
+  // Fill the beams with the relevant data for that batch.
+  for (int batch_index = 0; batch_index < sentences->data()->size();
+       ++batch_index) {
+    // Create a vector of states for this component's beam.
+    std::vector<std::unique_ptr<SyntaxNetTransitionState>> initial_states;
+    if (parent_states.empty()) {
+      // If no states have been passed in, create a single state to seed the
+      // beam.
+      initial_states.push_back(
+          CreateState(&(sentences->data()->at(batch_index))));
+    } else {
+      // If states have been passed in, seed the beam with them up to the max
+      // beam size.
+      int num_states =
+          std::min(batch_.at(batch_index)->max_size(),
+                   static_cast<int>(parent_states.at(batch_index).size()));
+      VLOG(2) << "Creating a beam using " << num_states << " initial states";
+      for (int i = 0; i < num_states; ++i) {
+        std::unique_ptr<SyntaxNetTransitionState> state(
+            CreateState(&(sentences->data()->at(batch_index))));
+        state->Init(*parent_states.at(batch_index).at(i));
+        initial_states.push_back(std::move(state));
+      }
+    }
+    batch_.at(batch_index)->Init(std::move(initial_states));
+  }
+}
+
+bool SyntaxNetComponent::IsReady() const { return input_data_ != nullptr; }
+
+string SyntaxNetComponent::Name() const {
+  return "SyntaxNet-backed beam parser";
+}
+
+int SyntaxNetComponent::BatchSize() const { return batch_.size(); }
+
+int SyntaxNetComponent::BeamSize() const { return max_beam_size_; }
+
+int SyntaxNetComponent::StepsTaken(int batch_index) const {
+  return batch_.at(batch_index)->num_steps();
+}
+
+int SyntaxNetComponent::GetBeamIndexAtStep(int step, int current_index,
+                                           int batch) const {
+  return batch_.at(batch)->FindPreviousIndex(current_index, step);
+}
+
+int SyntaxNetComponent::GetSourceBeamIndex(int current_index, int batch) const {
+  return batch_.at(batch)->FindPreviousIndex(current_index, 0);
+}
+
+std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
+    const string &method) {
+  if (method == "shift-reduce-step") {
+    // TODO(googleuser): Describe this function.
+    return [this](int batch_index, int beam_index, int value) {
+      SyntaxNetTransitionState *state =
+          batch_.at(batch_index)->beam_state(beam_index);
+      return state->step_for_token(value);
+    };
+  } else if (method == "reduce-step") {
+    // TODO(googleuser): Describe this function.
+    return [this](int batch_index, int beam_index, int value) {
+      SyntaxNetTransitionState *state =
+          batch_.at(batch_index)->beam_state(beam_index);
+      return state->parent_step_for_token(value);
+    };
+  } else if (method == "parent-shift-reduce-step") {
+    // TODO(googleuser): Describe this function.
+    return [this](int batch_index, int beam_index, int value) {
+      SyntaxNetTransitionState *state =
+          batch_.at(batch_index)->beam_state(beam_index);
+      return state->step_for_token(state->parent_step_for_token(value));
+    };
+  } else if (method == "reverse-token") {
+    // TODO(googleuser): Describe this function.
+    return [this](int batch_index, int beam_index, int value) {
+      SyntaxNetTransitionState *state =
+          batch_.at(batch_index)->beam_state(beam_index);
+      int result = state->sentence()->sentence()->token_size() - value - 1;
+      if (result >= 0 && result < state->sentence()->sentence()->token_size()) {
+        return result;
+      } else {
+        return -1;
+      }
+    };
+  } else {
+    LOG(FATAL) << "Unable to find step lookup function " << method;
+  }
+}
+
+void SyntaxNetComponent::AdvanceFromPrediction(const float transition_matrix[],
+                                               int transition_matrix_length) {
+  VLOG(2) << "Advancing from prediction.";
+  int matrix_index = 0;
+  int num_labels = transition_system_->NumActions(label_map_->Size());
+  for (int i = 0; i < batch_.size(); ++i) {
+    int max_beam_size = batch_.at(i)->max_size();
+    int matrix_size = num_labels * max_beam_size;
+    CHECK_LE(matrix_index + matrix_size, transition_matrix_length);
+    if (!batch_.at(i)->IsTerminal()) {
+      batch_.at(i)->AdvanceFromPrediction(&transition_matrix[matrix_index],
+                                          matrix_size, num_labels);
+    }
+    matrix_index += num_labels * max_beam_size;
+  }
+}
+
+void SyntaxNetComponent::AdvanceFromOracle() {
+  VLOG(2) << "Advancing from oracle.";
+  for (auto &beam : batch_) {
+    beam->AdvanceFromOracle();
+  }
+}
+
+bool SyntaxNetComponent::IsTerminal() const {
+  VLOG(2) << "Checking terminal status.";
+  for (const auto &beam : batch_) {
+    if (!beam->IsTerminal()) {
+      return false;
+    }
+  }
+  return true;
+}
+
+std::vector<std::vector<const TransitionState *>>
+SyntaxNetComponent::GetBeam() {
+  std::vector<std::vector<const TransitionState *>> state_beam;
+  for (auto &beam : batch_) {
+    // Because this component only finalizes the data of the highest ranked
+    // component in each beam, the next component should only be initialized
+    // from the highest ranked component in that beam.
+    state_beam.push_back({beam->beam().at(0)});
+  }
+  return state_beam;
+}
+
+int SyntaxNetComponent::GetFixedFeatures(
+    std::function<int32 *(int)> allocate_indices,
+    std::function<int64 *(int)> allocate_ids,
+    std::function<float *(int)> allocate_weights, int channel_id) const {
+  std::vector<SparseFeatures> features;
+
+  const int channel_size = spec_.fixed_feature(channel_id).size();
+
+  // For every beam in the batch...
+  for (const auto &beam : batch_) {
+    // For every element in the beam...
+    for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+      // Get the SparseFeatures from the feature extractor.
+      auto state = beam->beam_state(beam_idx);
+      const std::vector<std::vector<SparseFeatures>> sparse_features =
+          feature_extractor_.ExtractSparseFeatures(
+              *(state->sentence()->workspace()), *(state->parser_state()));
+
+      // Hold the SparseFeatures for later processing.
+      for (const SparseFeatures &f : sparse_features[channel_id]) {
+        features.emplace_back(f);
+        if (do_tracing_) {
+          FixedFeatures fixed_features;
+          for (const string &name : f.description()) {
+            fixed_features.add_value_name(name);
+          }
+          fixed_features.set_feature_name("");
+          auto *trace = GetLastStepInTrace(state->mutable_trace());
+          auto *fixed_trace = trace->mutable_fixed_feature_trace(channel_id);
+          *fixed_trace->add_value_trace() = fixed_features;
+        }
+      }
+    }
+    const int pad_amount = max_beam_size_ - beam->size();
+    features.resize(features.size() + pad_amount * channel_size);
+  }
+
+  int feature_count = 0;
+  for (const auto &feature : features) {
+    feature_count += feature.id_size();
+  }
+
+  VLOG(2) << "Feature count is " << feature_count;
+  int32 *indices_tensor = allocate_indices(feature_count);
+  int64 *ids_tensor = allocate_ids(feature_count);
+  float *weights_tensor = allocate_weights(feature_count);
+
+  int array_index = 0;
+  for (int feature_index = 0; feature_index < features.size();
+       ++feature_index) {
+    VLOG(2) << "Extracting for feature_index " << feature_index;
+    const auto feature = features[feature_index];
+    for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
+      indices_tensor[array_index] = feature_index;
+      ids_tensor[array_index] = feature.id(sub_idx);
+      if (sub_idx < feature.weight_size()) {
+        weights_tensor[array_index] = feature.weight(sub_idx);
+      } else {
+        weights_tensor[array_index] = 1.0;
+      }
+      VLOG(2) << "Feature index: " << indices_tensor[array_index]
+              << " id: " << ids_tensor[array_index]
+              << " weight: " << weights_tensor[array_index];
+
+      ++array_index;
+    }
+  }
+  return feature_count;
+}
+
+int SyntaxNetComponent::BulkGetFixedFeatures(
+    const BulkFeatureExtractor &extractor) {
+  // Allocate a vector of SparseFeatures per channel.
+  const int num_channels = spec_.fixed_feature_size();
+  std::vector<int> channel_size(num_channels);
+  for (int i = 0; i < num_channels; ++i) {
+    channel_size[i] = spec_.fixed_feature(i).size();
+  }
+  std::vector<std::vector<SparseFeatures>> features(num_channels);
+  std::vector<std::vector<int>> feature_indices(num_channels);
+  std::vector<std::vector<int>> step_indices(num_channels);
+  std::vector<std::vector<int>> element_indices(num_channels);
+  std::vector<int> feature_counts(num_channels);
+  int step_count = 0;
+
+  while (!IsTerminal()) {
+    int current_element = 0;
+
+    // For every beam in the batch...
+    for (const auto &beam : batch_) {
+      // For every element in the beam...
+      for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+        // Get the SparseFeatures from the parser.
+        auto state = beam->beam_state(beam_idx);
+        const std::vector<std::vector<SparseFeatures>> sparse_features =
+            feature_extractor_.ExtractSparseFeatures(
+                *(state->sentence()->workspace()), *(state->parser_state()));
+
+        for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
+          int feature_count = 0;
+          for (const SparseFeatures &f : sparse_features[channel_id]) {
+            // Trace, if requested.
+            if (do_tracing_) {
+              FixedFeatures fixed_features;
+              for (const string &name : f.description()) {
+                fixed_features.add_value_name(name);
+              }
+              fixed_features.set_feature_name("");
+              auto *trace = GetLastStepInTrace(state->mutable_trace());
+              auto *fixed_trace =
+                  trace->mutable_fixed_feature_trace(channel_id);
+              *fixed_trace->add_value_trace() = fixed_features;
+            }
+
+            // Hold the SparseFeatures for later processing.
+            features[channel_id].emplace_back(f);
+            element_indices[channel_id].emplace_back(current_element);
+            step_indices[channel_id].emplace_back(step_count);
+            feature_indices[channel_id].emplace_back(feature_count);
+            feature_counts[channel_id] += f.id_size();
+            ++feature_count;
+          }
+        }
+        ++current_element;
+      }
+
+      // Advance the current element to skip unused beam slots.
+      // Pad the beam out to max_beam_size.
+      int pad_amount = max_beam_size_ - beam->size();
+      current_element += pad_amount;
+    }
+    AdvanceFromOracle();
+    ++step_count;
+  }
+
+  const int total_steps = step_count;
+  const int num_elements = batch_.size() * max_beam_size_;
+
+  // This would be a good place to add threading.
+  for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
+    int feature_count = feature_counts[channel_id];
+    LOG(INFO) << "Feature count is " << feature_count << " for channel "
+              << channel_id;
+    int32 *indices_tensor =
+        extractor.AllocateIndexMemory(channel_id, feature_count);
+    int64 *ids_tensor = extractor.AllocateIdMemory(channel_id, feature_count);
+    float *weights_tensor =
+        extractor.AllocateWeightMemory(channel_id, feature_count);
+    int array_index = 0;
+    for (int feat_idx = 0; feat_idx < features[channel_id].size(); ++feat_idx) {
+      const auto &feature = features[channel_id][feat_idx];
+      int element_index = element_indices[channel_id][feat_idx];
+      int step_index = step_indices[channel_id][feat_idx];
+      int feature_index = feature_indices[channel_id][feat_idx];
+      for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
+        indices_tensor[array_index] =
+            extractor.GetIndex(total_steps, num_elements, feature_index,
+                               element_index, step_index);
+        ids_tensor[array_index] = feature.id(sub_idx);
+        if (sub_idx < feature.weight_size()) {
+          weights_tensor[array_index] = feature.weight(sub_idx);
+        } else {
+          weights_tensor[array_index] = 1.0;
+        }
+        ++array_index;
+      }
+    }
+  }
+  return step_count;
+}
+
+std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
+    int channel_id) const {
+  std::vector<LinkFeatures> features;
+  const int channel_size = spec_.linked_feature(channel_id).size();
+  std::unique_ptr<std::vector<string>> feature_names;
+  if (do_tracing_) {
+    feature_names.reset(new std::vector<string>);
+    *feature_names = utils::Split(spec_.linked_feature(channel_id).fml(), ' ');
+  }
+
+  // For every beam in the batch...
+  for (int batch_idx = 0; batch_idx < batch_.size(); ++batch_idx) {
+    // For every element in the beam...
+    const auto &beam = batch_[batch_idx];
+    for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+      // Get the raw link features from the linked feature extractor.
+      auto state = beam->beam_state(beam_idx);
+      std::vector<FeatureVector> raw_features(
+          link_feature_extractor_.NumEmbeddings());
+      link_feature_extractor_.ExtractFeatures(*(state->sentence()->workspace()),
+                                              *(state->parser_state()),
+                                              &raw_features);
+
+      // Add the raw feature values to the LinkFeatures proto.
+      CHECK_LT(channel_id, raw_features.size());
+      for (int i = 0; i < raw_features[channel_id].size(); ++i) {
+        features.emplace_back();
+        features.back().set_feature_value(raw_features[channel_id].value(i));
+        features.back().set_batch_idx(batch_idx);
+        features.back().set_beam_idx(beam_idx);
+        if (do_tracing_) {
+          features.back().set_feature_name(feature_names->at(i));
+        }
+      }
+    }
+
+    // Pad the beam out to max_beam_size.
+    int pad_amount = max_beam_size_ - beam->size();
+    features.resize(features.size() + pad_amount * channel_size);
+  }
+
+  return features;
+}
+
+std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
+  std::vector<std::vector<int>> oracle_labels;
+  for (const auto &beam : batch_) {
+    oracle_labels.emplace_back();
+    for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+      // Get the raw link features from the linked feature extractor.
+      auto state = beam->beam_state(beam_idx);
+      oracle_labels.back().push_back(GetOracleLabel(state));
+    }
+  }
+  return oracle_labels;
+}
+
+void SyntaxNetComponent::FinalizeData() {
+  // This chooses the top-scoring member of the beam to annotate the underlying
+  // document.
+  VLOG(2) << "Finalizing data.";
+  for (auto &beam : batch_) {
+    if (beam->size() != 0) {
+      auto top_state = beam->beam_state(0);
+      VLOG(3) << "Finalizing for sentence: "
+              << top_state->sentence()->sentence()->ShortDebugString();
+      top_state->parser_state()->AddParseToDocument(
+          top_state->sentence()->sentence(), rewrite_root_labels_);
+      VLOG(3) << "Sentence is now: "
+              << top_state->sentence()->sentence()->ShortDebugString();
+    } else {
+      LOG(WARNING) << "Attempting to finalize an empty beam for component "
+                   << spec_.name();
+    }
+  }
+}
+
+void SyntaxNetComponent::ResetComponent() {
+  for (auto &beam : batch_) {
+    beam->Reset();
+  }
+  input_data_ = nullptr;
+  max_beam_size_ = 0;
+}
+
+std::unique_ptr<SyntaxNetTransitionState> SyntaxNetComponent::CreateState(
+    SyntaxNetSentence *sentence) {
+  VLOG(3) << "Creating state for sentence "
+          << sentence->sentence()->DebugString();
+  std::unique_ptr<ParserState> parser_state(new ParserState(
+      sentence->sentence(), transition_system_->NewTransitionState(false),
+      label_map_));
+  sentence->workspace()->Reset(workspace_registry_);
+  feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
+  link_feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
+  std::unique_ptr<SyntaxNetTransitionState> transition_state(
+      new SyntaxNetTransitionState(std::move(parser_state), sentence));
+  return transition_state;
+}
+
+bool SyntaxNetComponent::IsAllowed(SyntaxNetTransitionState *state,
+                                   int action) const {
+  return transition_system_->IsAllowedAction(action, *(state->parser_state()));
+}
+
+bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
+  return transition_system_->IsFinalState(*(state->parser_state()));
+}
+
+int SyntaxNetComponent::GetOracleLabel(SyntaxNetTransitionState *state) const {
+  if (IsFinal(state)) {
+    // It is not permitted to request an oracle label from a sentence that is
+    // in a final state.
+    return -1;
+  } else {
+    return transition_system_->GetNextGoldAction(*(state->parser_state()));
+  }
+}
+
+void SyntaxNetComponent::Advance(SyntaxNetTransitionState *state, int action,
+                                 Beam<SyntaxNetTransitionState> *beam) {
+  auto parser_state = state->parser_state();
+  auto sentence_size = state->sentence()->sentence()->token_size();
+  const int num_steps = beam->num_steps();
+
+  if (transition_system_->SupportsActionMetaData()) {
+    const int parent_idx =
+        transition_system_->ParentIndex(*parser_state, action);
+    constexpr int kShiftAction = -1;
+    if (parent_idx == kShiftAction) {
+      if (parser_state->Next() < sentence_size && parser_state->Next() >= 0) {
+        // if we have already consumed all the input then it is not a shift
+        // action. We just skip it.
+        state->set_step_for_token(parser_state->Next(), num_steps);
+      }
+    } else if (parent_idx >= 0) {
+      VLOG(2) << spec_.name() << ": Updating pointer: " << parent_idx << " -> "
+              << num_steps;
+      state->set_step_for_token(parent_idx, num_steps);
+      const int child_idx =
+          transition_system_->ChildIndex(*parser_state, action);
+      assert(child_idx >= 0 && child_idx < sentence_size);
+      state->set_parent_for_token(child_idx, parent_idx);
+
+      VLOG(2) << spec_.name() << ": Updating parent for child: " << parent_idx
+              << " -> " << child_idx;
+      state->set_parent_step_for_token(child_idx, num_steps);
+    } else {
+      VLOG(2) << spec_.name() << ": Invalid parent index: " << parent_idx;
+    }
+  }
+  if (do_tracing_) {
+    auto *trace = state->mutable_trace();
+    auto *last_step = GetLastStepInTrace(trace);
+
+    // Add action to the prior step.
+    last_step->set_caption(
+        transition_system_->ActionAsString(action, *parser_state));
+    last_step->set_step_finished(true);
+  }
+
+  transition_system_->PerformAction(action, parser_state);
+
+  if (do_tracing_) {
+    // Add info for the next step.
+    *state->mutable_trace()->add_step_trace() = GetNewStepTrace(spec_, *state);
+  }
+}
+
+void SyntaxNetComponent::InitializeTracing() {
+  do_tracing_ = true;
+  CHECK(IsReady()) << "Cannot initialize trace before InitializeData().";
+
+  // Initialize each element of the beam with a new trace.
+  for (auto &beam : batch_) {
+    for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+      SyntaxNetTransitionState *state = beam->beam_state(beam_idx);
+      std::unique_ptr<ComponentTrace> trace(new ComponentTrace());
+      trace->set_name(spec_.name());
+      *trace->add_step_trace() = GetNewStepTrace(spec_, *state);
+      state->set_trace(std::move(trace));
+    }
+  }
+
+  feature_extractor_.set_add_strings(true);
+}
+
+void SyntaxNetComponent::DisableTracing() {
+  do_tracing_ = false;
+  feature_extractor_.set_add_strings(false);
+}
+
+void SyntaxNetComponent::AddTranslatedLinkFeaturesToTrace(
+    const std::vector<LinkFeatures> &features, int channel_id) {
+  CHECK(do_tracing_) << "Tracing is not enabled.";
+  int linear_idx = 0;
+  const int channel_size = spec_.linked_feature(channel_id).size();
+
+  // For every beam in the batch...
+  for (const auto &beam : batch_) {
+    // For every element in the beam...
+    for (int beam_idx = 0; beam_idx < max_beam_size_; ++beam_idx) {
+      for (int feature_idx = 0; feature_idx < channel_size; ++feature_idx) {
+        if (beam_idx < beam->size()) {
+          auto state = beam->beam_state(beam_idx);
+          auto *trace = GetLastStepInTrace(state->mutable_trace());
+          auto *link_trace = trace->mutable_linked_feature_trace(channel_id);
+          if (features[linear_idx].feature_value() >= 0 &&
+              features[linear_idx].step_idx() >= 0) {
+            *link_trace->add_value_trace() = features[linear_idx];
+          }
+        }
+        ++linear_idx;
+      }
+    }
+  }
+}
+
+std::vector<std::vector<ComponentTrace>> SyntaxNetComponent::GetTraceProtos()
+    const {
+  std::vector<std::vector<ComponentTrace>> traces;
+
+  // For every beam in the batch...
+  for (const auto &beam : batch_) {
+    std::vector<ComponentTrace> beam_trace;
+
+    // For every element in the beam...
+    for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
+      auto state = beam->beam_state(beam_idx);
+      beam_trace.push_back(*state->mutable_trace());
+    }
+    traces.push_back(beam_trace);
+  }
+  return traces;
+};
+
+REGISTER_DRAGNN_COMPONENT(SyntaxNetComponent);
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 183 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h

@@ -0,0 +1,183 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
+
+#include <vector>
+
+#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
+#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/beam.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "dragnn/protos/trace.pb.h"
+#include "syntaxnet/base.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/registry.h"
+#include "syntaxnet/task_context.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class SyntaxNetComponent : public Component {
+ public:
+  // Create a SyntaxNet-backed DRAGNN component.
+  SyntaxNetComponent();
+
+  // Initializes this component from the spec.
+  void InitializeComponent(const ComponentSpec &spec) override;
+
+  // Provides the previous beam to the component.
+  void InitializeData(
+      const std::vector<std::vector<const TransitionState *>> &states,
+      int max_beam_size, InputBatchCache *input_data) override;
+
+  // Returns true if the component has had InitializeData called on it since
+  // the last time it was reset.
+  bool IsReady() const override;
+
+  // Returns the string name of this component.
+  string Name() const override;
+
+  // Returns the number of steps taken by the given batch in this component.
+  int StepsTaken(int batch_index) const override;
+
+  // Returns the current batch size of the component's underlying data.
+  int BatchSize() const override;
+
+  // Returns the maximum beam size of this component.
+  int BeamSize() const override;
+
+  // Return the beam index of the item which is currently at index
+  // 'index', when the beam was at step 'step', for batch element 'batch'.
+  int GetBeamIndexAtStep(int step, int current_index, int batch) const override;
+
+  // Return the source index of the item which is currently at index 'index'
+  // for batch element 'batch'. This index is into the final beam of the
+  // Component that this Component was initialized from.
+  int GetSourceBeamIndex(int current_index, int batch) const override;
+
+  // Request a translation function based on the given method string.
+  // The translation function will be called with arguments (batch, beam, value)
+  // and should return the step index corresponding to the given value, for the
+  // data in the given beam and batch.
+  std::function<int(int, int, int)> GetStepLookupFunction(
+      const string &method) override;
+
+  // Advances this component from the given transition matrix.
+  void AdvanceFromPrediction(const float transition_matrix[],
+                             int transition_matrix_length) override;
+
+  // Advances this component from the state oracles.
+  void AdvanceFromOracle() override;
+
+  // Returns true if all states within this component are terminal.
+  bool IsTerminal() const override;
+
+  // Returns the current batch of beams for this component.
+  std::vector<std::vector<const TransitionState *>> GetBeam() override;
+
+  // Extracts and populates the vector of FixedFeatures for the specified
+  // channel.
+  int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
+                       std::function<int64 *(int)> allocate_ids,
+                       std::function<float *(int)> allocate_weights,
+                       int channel_id) const override;
+
+  // Extracts and populates all FixedFeatures for all channels, advancing this
+  // component via the oracle until it is terminal.
+  int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override;
+
+  // Extracts and returns the vector of LinkFeatures for the specified
+  // channel. Note: these are NOT translated.
+  std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override;
+
+  // Returns a vector of oracle labels for each element in the beam and
+  // batch.
+  std::vector<std::vector<int>> GetOracleLabels() const override;
+
+  // Annotate the underlying data object with the results of this Component's
+  // calculation.
+  void FinalizeData() override;
+
+  // Reset this component.
+  void ResetComponent() override;
+
+  // Initializes the component for tracing execution. This will typically have
+  // the side effect of slowing down all subsequent Component calculations
+  // and storing a trace in memory that can be returned by GetTraceProtos().
+  void InitializeTracing() override;
+
+  // Disables tracing, freeing any additional memory and avoiding triggering
+  // additional computation in the future.
+  void DisableTracing() override;
+
+  std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override;
+
+  void AddTranslatedLinkFeaturesToTrace(
+      const std::vector<LinkFeatures> &features, int channel_id) override;
+
+ private:
+  friend class SyntaxNetComponentTest;
+  friend class SyntaxNetTransitionStateTest;
+
+  // Permission function for this component.
+  bool IsAllowed(SyntaxNetTransitionState *state, int action) const;
+
+  // Returns true if this state is final
+  bool IsFinal(SyntaxNetTransitionState *state) const;
+
+  // Oracle function for this component.
+  int GetOracleLabel(SyntaxNetTransitionState *state) const;
+
+  // State advance function for this component.
+  void Advance(SyntaxNetTransitionState *state, int action,
+               Beam<SyntaxNetTransitionState> *beam);
+
+  // Creates a new state for the given nlp_saft::SentenceExample.
+  std::unique_ptr<SyntaxNetTransitionState> CreateState(
+      SyntaxNetSentence *example);
+
+  // Creates a newly initialized Beam.
+  std::unique_ptr<Beam<SyntaxNetTransitionState>> CreateBeam(int max_size);
+
+  // Transition system.
+  std::unique_ptr<ParserTransitionSystem> transition_system_;
+
+  // Label map for transition system.
+  const TermFrequencyMap *label_map_;
+
+  // Extractor for fixed features
+  ParserEmbeddingFeatureExtractor feature_extractor_;
+
+  // Extractor for linked features.
+  SyntaxNetLinkFeatureExtractor link_feature_extractor_;
+
+  // Internal workspace registry for use in feature extraction.
+  WorkspaceRegistry workspace_registry_;
+
+  // Switch for simulating legacy parser behaviour.
+  bool rewrite_root_labels_;
+
+  // The ComponentSpec used to initialize this component.
+  ComponentSpec spec_;
+
+  // State search beams
+  std::vector<std::unique_ptr<Beam<SyntaxNetTransitionState>>> batch_;
+
+  // Current max beam size.
+  int max_beam_size_;
+
+  // Underlying input data.
+  InputBatchCache *input_data_;
+
+  // Whether or not to trace for each batch and beam element.
+  bool do_tracing_ = false;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_

File diff ditekan karena terlalu besar
+ 1174 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc


+ 49 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc

@@ -0,0 +1,49 @@
+#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+void SyntaxNetLinkFeatureExtractor::Setup(TaskContext *context) {
+  ParserEmbeddingFeatureExtractor::Setup(context);
+
+  if (NumEmbeddings() > 0) {
+    channel_sources_ = utils::Split(
+        context->Get(
+            tensorflow::strings::StrCat(ArgPrefix(), "_", "source_components"),
+            ""),
+        ';');
+    channel_layers_ = utils::Split(
+        context->Get(
+            tensorflow::strings::StrCat(ArgPrefix(), "_", "source_layers"), ""),
+        ';');
+    channel_translators_ = utils::Split(
+        context->Get(
+            tensorflow::strings::StrCat(ArgPrefix(), "_", "source_translators"),
+            ""),
+        ';');
+  }
+
+  CHECK_EQ(channel_sources_.size(), NumEmbeddings());
+  CHECK_EQ(channel_layers_.size(), NumEmbeddings());
+  CHECK_EQ(channel_translators_.size(), NumEmbeddings());
+}
+
+void SyntaxNetLinkFeatureExtractor::AddLinkedFeatureChannelProtos(
+    ComponentSpec *spec) const {
+  for (int embedding_idx = 0; embedding_idx < NumEmbeddings();
+       ++embedding_idx) {
+    LinkedFeatureChannel *channel = spec->add_linked_feature();
+    channel->set_name(embedding_name(embedding_idx));
+    channel->set_fml(embedding_fml()[embedding_idx]);
+    channel->set_embedding_dim(EmbeddingDims(embedding_idx));
+    channel->set_size(FeatureSize(embedding_idx));
+    channel->set_source_layer(channel_layers_[embedding_idx]);
+    channel->set_source_component(channel_sources_[embedding_idx]);
+    channel->set_source_translator(channel_translators_[embedding_idx]);
+  }
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 55 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h

@@ -0,0 +1,55 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
+
+#include <string>
+#include <vector>
+
+#include "dragnn/protos/spec.pb.h"
+#include "syntaxnet/embedding_feature_extractor.h"
+#include "syntaxnet/parser_state.h"
+#include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/task_context.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// Provides feature extraction for linked features in the
+// WrapperParserComponent. This re-ues the EmbeddingFeatureExtractor
+// architecture to get another set of feature extractors. Note that we should
+// ignore predicate maps here, and we don't care about the vocabulary size
+// because all the feature values will be used for translation, but this means
+// we can configure the extractor from the GCL using the standard
+// neurosis-lib.wf syntax.
+//
+// Because it uses a different prefix, it can be executed in the same wf.stage
+// as the regular fixed extractor.
+class SyntaxNetLinkFeatureExtractor : public ParserEmbeddingFeatureExtractor {
+ public:
+  SyntaxNetLinkFeatureExtractor() : ParserEmbeddingFeatureExtractor("link") {}
+  ~SyntaxNetLinkFeatureExtractor() override {}
+
+  const string ArgPrefix() const override { return "link"; }
+
+  // Parses the TaskContext to get additional information like target layers,
+  // etc.
+  void Setup(TaskContext *context) override;
+
+  // Called during InitComponentProtoTask to add the specification from the
+  // wrapped feature extractor as LinkedFeatureChannel protos.
+  void AddLinkedFeatureChannelProtos(ComponentSpec *spec) const;
+
+ private:
+  // Source component names for each channel.
+  std::vector<string> channel_sources_;
+
+  // Source layer names for each channel.
+  std::vector<string> channel_layers_;
+
+  // Source translator name for each channel.
+  std::vector<string> channel_translators_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_

+ 63 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc

@@ -0,0 +1,63 @@
+#include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
+
+#include <string>
+
+#include "dragnn/core/test/generic.h"
+#include "dragnn/protos/spec.pb.h"
+#include "syntaxnet/task_context.h"
+#include "tensorflow/core/platform/test.h"
+
+using syntaxnet::test::EqualsProto;
+
+namespace syntaxnet {
+namespace dragnn {
+
+class ExportSpecTest : public ::testing::Test {
+ public:
+};
+
+TEST_F(ExportSpecTest, WritesChannelSpec) {
+  TaskContext context;
+
+  context.SetParameter("neurosis_feature_syntax_version", "2");
+  context.SetParameter("link_features", "input.focus;stack.focus");
+  context.SetParameter("link_embedding_names", "tagger;parser");
+  context.SetParameter("link_predicate_maps", "none;none");
+  context.SetParameter("link_embedding_dims", "16;16");
+  context.SetParameter("link_source_components", "tagger;parser");
+  context.SetParameter("link_source_layers", "hidden0;lstm");
+  context.SetParameter("link_source_translators", "token;last_action");
+
+  SyntaxNetLinkFeatureExtractor link_features;
+  link_features.Setup(&context);
+  link_features.Init(&context);
+
+  ComponentSpec spec;
+  link_features.AddLinkedFeatureChannelProtos(&spec);
+  const string expected_spec_str = R"(
+    linked_feature {
+      name: "tagger"
+      fml: "input.focus"
+      embedding_dim: 16
+      size: 1
+      source_component: "tagger"
+      source_translator: "token"
+      source_layer: "hidden0"
+    }
+    linked_feature {
+      name: "parser"
+      fml: "stack.focus"
+      embedding_dim: 16
+      size: 1
+      source_component: "parser"
+      source_translator: "last_action"
+      source_layer: "lstm"
+    }
+  )";
+  ComponentSpec expected_spec;
+  TextFormat::ParseFromString(expected_spec_str, &expected_spec);
+  EXPECT_THAT(spec, EqualsProto(expected_spec));
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 85 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc

@@ -0,0 +1,85 @@
+#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
+
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+SyntaxNetTransitionState::SyntaxNetTransitionState(
+    std::unique_ptr<ParserState> parser_state, SyntaxNetSentence *sentence)
+    : parser_state_(std::move(parser_state)), sentence_(sentence) {
+  score_ = 0;
+  current_beam_index_ = -1;
+  parent_beam_index_ = 0;
+  step_for_token_.resize(sentence->sentence()->token_size(), -1);
+  parent_for_token_.resize(sentence->sentence()->token_size(), -1);
+  parent_step_for_token_.resize(sentence->sentence()->token_size(), -1);
+}
+
+void SyntaxNetTransitionState::Init(const TransitionState &parent) {
+  score_ = parent.GetScore();
+  parent_beam_index_ = parent.GetBeamIndex();
+}
+
+std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
+    const {
+  // Create a new state from a clone of the underlying parser state.
+  std::unique_ptr<ParserState> cloned_state(parser_state_->Clone());
+  std::unique_ptr<SyntaxNetTransitionState> new_state(
+      new SyntaxNetTransitionState(std::move(cloned_state), sentence_));
+
+  // Copy relevant data members and set non-copied ones to flag values.
+  new_state->score_ = score_;
+  new_state->current_beam_index_ = current_beam_index_;
+  new_state->parent_beam_index_ = parent_beam_index_;
+  new_state->step_for_token_ = step_for_token_;
+  new_state->parent_step_for_token_ = parent_step_for_token_;
+  new_state->parent_for_token_ = parent_for_token_;
+
+  // Copy trace if it exists.
+  if (trace_) {
+    new_state->trace_.reset(new ComponentTrace(*trace_));
+  }
+
+  return new_state;
+}
+
+const int SyntaxNetTransitionState::ParentBeamIndex() const {
+  return parent_beam_index_;
+}
+
+const int SyntaxNetTransitionState::GetBeamIndex() const {
+  return current_beam_index_;
+}
+
+void SyntaxNetTransitionState::SetBeamIndex(const int index) {
+  current_beam_index_ = index;
+}
+
+const float SyntaxNetTransitionState::GetScore() const { return score_; }
+
+void SyntaxNetTransitionState::SetScore(const float score) { score_ = score; }
+
+string SyntaxNetTransitionState::HTMLRepresentation() const {
+  // Crude HTML string showing the stack and the word on the input.
+  string html = "Stack: ";
+  for (int i = parser_state_->StackSize() - 1; i >= 0; --i) {
+    const int word_idx = parser_state_->Stack(i);
+    if (word_idx >= 0) {
+      tensorflow::strings::StrAppend(
+          &html, parser_state_->GetToken(word_idx).word(), " ");
+    }
+  }
+  tensorflow::strings::StrAppend(&html, "| Input: ");
+  const int word_idx = parser_state_->Input(0);
+  if (word_idx >= 0) {
+    tensorflow::strings::StrAppend(
+        &html, parser_state_->GetToken(word_idx).word(), " ");
+  }
+
+  return html;
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 144 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h

@@ -0,0 +1,144 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
+
+#include <vector>
+
+#include "dragnn/core/interfaces/cloneable_transition_state.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/io/syntaxnet_sentence.h"
+#include "dragnn/protos/trace.pb.h"
+#include "syntaxnet/base.h"
+#include "syntaxnet/parser_state.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class SyntaxNetTransitionState
+    : public CloneableTransitionState<SyntaxNetTransitionState> {
+ public:
+  // Create a SyntaxNetTransitionState to wrap this nlp_saft::ParserState.
+  SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state,
+                           SyntaxNetSentence *sentence);
+
+  // Initialize this TransitionState from a previous TransitionState. The
+  // ParentBeamIndex is the location of that previous TransitionState in the
+  // provided beam.
+  void Init(const TransitionState &parent) override;
+
+  // Produces a new state with the same backing data as this state.
+  std::unique_ptr<SyntaxNetTransitionState> Clone() const override;
+
+  // Return the beam index of the state passed into the initializer of this
+  // TransitionState.
+  const int ParentBeamIndex() const override;
+
+  // Get the current beam index for this state.
+  const int GetBeamIndex() const override;
+
+  // Set the current beam index for this state.
+  void SetBeamIndex(const int index) override;
+
+  // Get the score associated with this transition state.
+  const float GetScore() const override;
+
+  // Set the score associated with this transition state.
+  void SetScore(const float score) override;
+
+  // Depicts this state as an HTML-language string.
+  string HTMLRepresentation() const override;
+
+  // **** END INHERITED INTERFACE ****
+
+  // TODO(googleuser): Make these comments actually mean something.
+  // Data accessor.
+  int step_for_token(int token) {
+    if (token < 0 || token >= step_for_token_.size()) {
+      return -1;
+    } else {
+      return step_for_token_.at(token);
+    }
+  }
+
+  // Data setter.
+  void set_step_for_token(int token, int step) {
+    step_for_token_.insert(step_for_token_.begin() + token, step);
+  }
+
+  // Data accessor.
+  int parent_step_for_token(int token) {
+    if (token < 0 || token >= step_for_token_.size()) {
+      return -1;
+    } else {
+      return parent_step_for_token_.at(token);
+    }
+  }
+
+  // Data setter.
+  void set_parent_step_for_token(int token, int parent_step) {
+    parent_step_for_token_.insert(parent_step_for_token_.begin() + token,
+                                  parent_step);
+  }
+
+  // Data accessor.
+  int parent_for_token(int token) {
+    if (token < 0 || token >= step_for_token_.size()) {
+      return -1;
+    } else {
+      return parent_for_token_.at(token);
+    }
+  }
+
+  // Data setter.
+  void set_parent_for_token(int token, int parent) {
+    parent_for_token_.insert(parent_for_token_.begin() + token, parent);
+  }
+
+  // Accessor for the underlying nlp_saft::ParserState.
+  ParserState *parser_state() { return parser_state_.get(); }
+
+  // Accessor for the underlying sentence object.
+  SyntaxNetSentence *sentence() { return sentence_; }
+
+  ComponentTrace *mutable_trace() {
+    CHECK(trace_) << "Trace is not initialized";
+    return trace_.get();
+  }
+  void set_trace(std::unique_ptr<ComponentTrace> trace) {
+    trace_ = std::move(trace);
+  }
+
+ private:
+  // Underlying ParserState object that is being wrapped.
+  std::unique_ptr<ParserState> parser_state_;
+
+  // Sentence object that is being examined with this state.
+  SyntaxNetSentence *sentence_;
+
+  // The current score of this state.
+  float score_;
+
+  // The current beam index of this state.
+  int current_beam_index_;
+
+  // The parent beam index for this state.
+  int parent_beam_index_;
+
+  // Maintains a list of which steps in the history correspond to
+  // representations for each of the tokens on the stack.
+  std::vector<int> step_for_token_;
+
+  // Maintains a list of which steps in the history correspond to the actions
+  // that assigned a parent for tokens when reduced.
+  std::vector<int> parent_step_for_token_;
+
+  // Maintain the parent index of a token in the system.
+  std::vector<int> parent_for_token_;
+
+  // Trace of the history to produce this state.
+  std::unique_ptr<ComponentTrace> trace_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_

+ 276 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc

@@ -0,0 +1,276 @@
+#include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
+
+#include "dragnn/components/syntaxnet/syntaxnet_component.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/test/generic.h"
+#include "dragnn/core/test/mock_transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "dragnn/protos/spec.pb.h"
+#include "syntaxnet/sentence.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+// This test suite is intended to validate the contracts that the DRAGNN
+// system expects from all transition state subclasses. Developers creating
+// new TransitionStates should copy this test and modify it as necessary,
+// using it to ensure their state conforms to DRAGNN expectations.
+
+namespace syntaxnet {
+namespace dragnn {
+
+namespace {
+
+const char kSentence0[] = R"(
+token {
+  word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+  break_level: NO_BREAK
+}
+token {
+  word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
+  break_level: NO_BREAK
+}
+)";
+
+}  // namespace
+
+using testing::Return;
+
+class SyntaxNetTransitionStateTest : public ::testing::Test {
+ public:
+  std::unique_ptr<SyntaxNetTransitionState> CreateState() {
+    // Get the master spec proto from the test data directory.
+    MasterSpec master_spec;
+    string file_name = tensorflow::io::JoinPath(
+        test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
+        "master_spec.textproto");
+    TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
+                                          &master_spec));
+
+    // Get all the resource protos from the test data directory.
+    for (Resource &resource :
+         *(master_spec.mutable_component(0)->mutable_resource())) {
+      resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
+          test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
+          resource.part(0).file_pattern()));
+    }
+
+    // Create an empty input batch and beam vector to initialize the parser.
+    Sentence sentence_0;
+    TextFormat::ParseFromString(kSentence0, &sentence_0);
+    string sentence_0_str;
+    sentence_0.SerializeToString(&sentence_0_str);
+    data_.reset(new InputBatchCache(sentence_0_str));
+    SentenceInputBatch *sentences = data_->GetAs<SentenceInputBatch>();
+
+    // Create a parser comoponent that will generate a parser state for this
+    // test.
+    SyntaxNetComponent component;
+    component.InitializeComponent(*(master_spec.mutable_component(0)));
+    std::vector<std::vector<const TransitionState *>> states;
+    constexpr int kBeamSize = 1;
+    component.InitializeData(states, kBeamSize, data_.get());
+
+    // Get a transition state from the component.
+    std::unique_ptr<SyntaxNetTransitionState> test_state =
+        component.CreateState(&(sentences->data()->at(0)));
+    return test_state;
+  }
+
+  std::unique_ptr<InputBatchCache> data_;
+};
+
+// Validates the consistency of the beam index setter and getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetBeamIndex) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr int kOldBeamIndex = 12;
+  test_state->SetBeamIndex(kOldBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
+
+  constexpr int kNewBeamIndex = 7;
+  test_state->SetBeamIndex(kNewBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kNewBeamIndex);
+}
+
+// Validates the consistency of the score setter and getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr float kOldScore = 12.1;
+  test_state->SetScore(kOldScore);
+  EXPECT_EQ(test_state->GetScore(), kOldScore);
+
+  constexpr float kNewScore = 7.2;
+  test_state->SetScore(kNewScore);
+  EXPECT_EQ(test_state->GetScore(), kNewScore);
+}
+
+// This test ensures that the initializing state's current index is saved
+// as the parent beam index of the state being initialized.
+TEST_F(SyntaxNetTransitionStateTest, ReportsParentBeamIndex) {
+  // Create a mock transition state that wil report a specific current index.
+  // This index should become the parent state index for the test state.
+  MockTransitionState mock_state;
+  constexpr int kParentBeamIndex = 1138;
+  EXPECT_CALL(mock_state, GetBeamIndex())
+      .WillRepeatedly(Return(kParentBeamIndex));
+
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+  EXPECT_EQ(test_state->ParentBeamIndex(), kParentBeamIndex);
+}
+
+// This test ensures that the initializing state's current score is saved
+// as the current score of the state being initialized.
+TEST_F(SyntaxNetTransitionStateTest, InitializationCopiesParentScore) {
+  // Create a mock transition state that wil report a specific current index.
+  // This index should become the parent state index for the test state.
+  MockTransitionState mock_state;
+  constexpr float kParentScore = 24.12;
+  EXPECT_CALL(mock_state, GetScore()).WillRepeatedly(Return(kParentScore));
+
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+  EXPECT_EQ(test_state->GetScore(), kParentScore);
+}
+
+// This test ensures that calling Clone maintains the state data (parent beam
+// index, beam index, score, etc.) of the state that was cloned.
+TEST_F(SyntaxNetTransitionStateTest, CloningMaintainsState) {
+  // Create and initialize the state->
+  MockTransitionState mock_state;
+  constexpr int kParentBeamIndex = 1138;
+  EXPECT_CALL(mock_state, GetBeamIndex())
+      .WillRepeatedly(Return(kParentBeamIndex));
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  // Validate the internal state of the test state.
+  constexpr float kOldScore = 20.0;
+  test_state->SetScore(kOldScore);
+  EXPECT_EQ(test_state->GetScore(), kOldScore);
+  constexpr int kOldBeamIndex = 12;
+  test_state->SetBeamIndex(kOldBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
+
+  auto clone = test_state->Clone();
+
+  // The clone should have identical state to the old state.
+  EXPECT_EQ(clone->ParentBeamIndex(), kParentBeamIndex);
+  EXPECT_EQ(clone->GetScore(), kOldScore);
+  EXPECT_EQ(clone->GetBeamIndex(), kOldBeamIndex);
+}
+
+// Validates the consistency of the step_for_token setter and getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetStepForToken) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr int kStepForTokenZero = 12;
+  constexpr int kStepForTokenTwo = 34;
+  test_state->set_step_for_token(0, kStepForTokenZero);
+  test_state->set_step_for_token(2, kStepForTokenTwo);
+
+  // Expect that the set tokens return values and the unset steps return the
+  // default.
+  constexpr int kDefaultValue = -1;
+  EXPECT_EQ(kStepForTokenZero, test_state->step_for_token(0));
+  EXPECT_EQ(kDefaultValue, test_state->step_for_token(1));
+  EXPECT_EQ(kStepForTokenTwo, test_state->step_for_token(2));
+
+  // Expect that out of bound accesses will return the default. (There are only
+  // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
+  EXPECT_EQ(kDefaultValue, test_state->step_for_token(-1));
+  EXPECT_EQ(kDefaultValue, test_state->step_for_token(3));
+}
+
+// Validates the consistency of the parent_step_for_token setter and getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentStepForToken) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr int kStepForTokenZero = 12;
+  constexpr int kStepForTokenTwo = 34;
+  test_state->set_parent_step_for_token(0, kStepForTokenZero);
+  test_state->set_parent_step_for_token(2, kStepForTokenTwo);
+
+  // Expect that the set tokens return values and the unset steps return the
+  // default.
+  constexpr int kDefaultValue = -1;
+  EXPECT_EQ(kStepForTokenZero, test_state->parent_step_for_token(0));
+  EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(1));
+  EXPECT_EQ(kStepForTokenTwo, test_state->parent_step_for_token(2));
+
+  // Expect that out of bound accesses will return the default. (There are only
+  // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
+  EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(-1));
+  EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(3));
+}
+
+// Validates the consistency of the parent_for_token setter and getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentForToken) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr int kParentForTokenZero = 12;
+  constexpr int kParentForTokenTwo = 34;
+  test_state->set_parent_for_token(0, kParentForTokenZero);
+  test_state->set_parent_for_token(2, kParentForTokenTwo);
+
+  // Expect that the set tokens return values and the unset steps return the
+  // default.
+  constexpr int kDefaultValue = -1;
+  EXPECT_EQ(kParentForTokenZero, test_state->parent_for_token(0));
+  EXPECT_EQ(kDefaultValue, test_state->parent_for_token(1));
+  EXPECT_EQ(kParentForTokenTwo, test_state->parent_for_token(2));
+
+  // Expect that out of bound accesses will return the default. (There are only
+  // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
+  EXPECT_EQ(kDefaultValue, test_state->parent_for_token(-1));
+  EXPECT_EQ(kDefaultValue, test_state->parent_for_token(3));
+}
+
+// Validates the consistency of trace proto setter/getter.
+TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetTrace) {
+  // Create and initialize a test state.
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  const string kTestComponentName = "test";
+  std::unique_ptr<ComponentTrace> trace;
+  trace.reset(new ComponentTrace());
+  trace->set_name(kTestComponentName);
+  test_state->set_trace(std::move(trace));
+
+  EXPECT_EQ(trace.get(), nullptr);
+  EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
+
+  // Should be preserved when cloing.
+  auto cloned_state = test_state->Clone();
+  EXPECT_EQ(cloned_state->mutable_trace()->name(), kTestComponentName);
+  EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 48 - 0
syntaxnet/dragnn/components/syntaxnet/testdata/master_spec.textproto

@@ -0,0 +1,48 @@
+component {
+  name: "parser"
+  transition_system {
+    registered_name: "arc-standard"
+  }
+  resource {
+    name: 'label-map'
+    part {
+      file_pattern: 'syntaxnet-tagger.label-map'
+      file_format: 'text'
+    }
+  }
+  resource {
+    name: 'tag-map'
+    part {
+      file_pattern: 'syntaxnet-tagger.tag-map'
+      file_format: 'text'
+    }
+  }
+  fixed_feature {
+    name: "tags"
+    fml: "input.tag input(1).tag"
+    embedding_dim: 32
+    vocabulary_size: 46
+    size: 2
+    predicate_map: "hashed"
+  }
+  fixed_feature {
+    name: "tags"
+    fml: "input(-1).tag input.tag input(1).tag"
+    embedding_dim: 32
+    vocabulary_size: 46
+    size: 3
+    predicate_map: "hashed"
+  }
+  linked_feature {
+    name: "recurrent_stack"
+    fml: "stack.focus stack(1).focus"
+    embedding_dim: 32
+    size: 2
+    source_component: "parser"
+    source_translator: "identity"
+    source_layer: "hidden_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+}

+ 47 - 0
syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.label-map

@@ -0,0 +1,47 @@
+46
+punct 243160
+prep 194627
+pobj 186958
+det 170592
+nsubj 144821
+nn 144800
+amod 117242
+ROOT 90592
+dobj 88551
+aux 76523
+advmod 72893
+conj 59384
+cc 57532
+num 36350
+poss 35117
+dep 34986
+ccomp 29470
+cop 25991
+mark 25141
+xcomp 25111
+rcmod 16234
+auxpass 15740
+advcl 14996
+possessive 14866
+nsubjpass 14133
+pcomp 12488
+appos 11112
+partmod 11106
+neg 11090
+number 10658
+prt 7123
+quantmod 6653
+tmod 5418
+infmod 5134
+npadvmod 3213
+parataxis 3012
+mwe 2793
+expl 2712
+iobj 1642
+acomp 1632
+discourse 1381
+csubj 1225
+predet 1160
+preconj 749
+goeswith 146
+csubjpass 41

+ 65 - 0
syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.master-spec

@@ -0,0 +1,65 @@
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet-tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet-tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet-tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 5
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "tagger"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  network_unit {
+    registered_name: 'feed-forward'
+    parameters {
+      key: 'hidden_layer_sizes'
+      value: '64'
+    }
+  }
+}

+ 50 - 0
syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.tag-map

@@ -0,0 +1,50 @@
+49
+NN 285194
+IN 228165
+DT 179147
+NNP 175147
+JJ 125667
+NNS 115732
+, 97481
+. 85938
+RB 78513
+VB 63952
+CC 57554
+VBD 56635
+CD 55674
+PRP 55244
+VBZ 48126
+VBN 44458
+VBG 34524
+VBP 33669
+TO 28772
+MD 22364
+PRP$ 20706
+HYPH 18526
+POS 14905
+`` 12193
+'' 12154
+WDT 10267
+: 8713
+$ 7993
+WP 7336
+RP 7335
+WRB 6634
+JJR 6295
+NNPS 5917
+-RRB- 3904
+-LRB- 3840
+JJS 3596
+RBR 3186
+EX 2733
+UH 1521
+RBS 1467
+PDT 1271
+FW 928
+NFP 844
+SYM 652
+ADD 476
+LS 392
+WP$ 332
+GW 184
+AFX 42

+ 4 - 0
syntaxnet/dragnn/components/syntaxnet/testdata/syntaxnet-tagger.word-map

@@ -0,0 +1,4 @@
+3
+Sentence 4
+. 3
+0 2

+ 9 - 0
syntaxnet/dragnn/components/util/BUILD

@@ -0,0 +1,9 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "bulk_feature_extractor",
+    hdrs = ["bulk_feature_extractor.h"],
+    deps = [
+        "@org_tensorflow//tensorflow/core:lib",
+    ],
+)

+ 95 - 0
syntaxnet/dragnn/components/util/bulk_feature_extractor.h

@@ -0,0 +1,95 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
+
+#include <functional>
+#include <utility>
+#include "tensorflow/core/platform/types.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// Provides a wrapper for allocator functions and padding data for the Bulk
+// ExtractFixedFeatures operation.
+class BulkFeatureExtractor {
+ public:
+  // Create a BulkFeatureExtractor with the given allocator functions and
+  // padding. The allocator functions should take a channel and an element
+  // count and return a contigous block of memory that is associated with that
+  // channel (the caller can decide what that means). If use_padding is true,
+  // the provided pad_to_step and pad_to_element will be used to calculate
+  // the ID size.
+  BulkFeatureExtractor(
+      std::function<tensorflow::int32 *(int channel, int num_elements)>
+          allocate_indices_by_channel,
+      std::function<tensorflow::int64 *(int channel, int num_elements)>
+          allocate_ids_by_channel,
+      std::function<float *(int channel, int num_elements)>
+          allocate_weights_by_channel,
+      bool use_padding, int pad_to_step, int pad_to_element)
+      : use_padding_(use_padding),
+        pad_to_step_(pad_to_step),
+        pad_to_element_(pad_to_element),
+        allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
+        allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
+        allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
+
+  // Create a BulkFeatureExtractor with allocator functions as above, but with
+  // use_padding set to False. Useful when you know your caller will never
+  // need to pad.
+  BulkFeatureExtractor(
+      std::function<tensorflow::int32 *(int channel, int num_elements)>
+          allocate_indices_by_channel,
+      std::function<tensorflow::int64 *(int channel, int num_elements)>
+          allocate_ids_by_channel,
+      std::function<float *(int channel, int num_elements)>
+          allocate_weights_by_channel)
+      : use_padding_(false),
+        pad_to_step_(-1),
+        pad_to_element_(-1),
+        allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
+        allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
+        allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
+
+  // Invoke the index memory allocator.
+  tensorflow::int32 *AllocateIndexMemory(int channel, int num_elements) const {
+    return allocate_indices_by_channel_(channel, num_elements);
+  }
+
+  // Invoke the ID memory allocator.
+  tensorflow::int64 *AllocateIdMemory(int channel, int num_elements) const {
+    return allocate_ids_by_channel_(channel, num_elements);
+  }
+
+  // Invoke the weight memory allocator.
+  float *AllocateWeightMemory(int channel, int num_elements) const {
+    return allocate_weights_by_channel_(channel, num_elements);
+  }
+
+  // Given the total number of steps and total number of elements for a given
+  // feature, calculate the index (not ID) of that feature. Based on how the
+  // BulkFeatureExtractor was constructed, it may use the given number of steps
+  // and number of elements, or it may use the passed padded number.
+  int GetIndex(int total_steps, int num_elements, int feature_idx,
+               int element_idx, int step_idx) const {
+    const int steps = (use_padding_) ? pad_to_step_ : total_steps;
+    const int elements = (use_padding_) ? pad_to_element_ : num_elements;
+    const int feature_offset = elements * steps;
+    const int element_offset = steps;
+    return (feature_idx * feature_offset) + (element_idx * element_offset) +
+           step_idx;
+  }
+
+ private:
+  const bool use_padding_;
+  const int pad_to_step_;
+  const int pad_to_element_;
+  const std::function<tensorflow::int32 *(int, int)>
+      allocate_indices_by_channel_;
+  const std::function<tensorflow::int64 *(int, int)> allocate_ids_by_channel_;
+  const std::function<float *(int, int)> allocate_weights_by_channel_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_

+ 9 - 0
syntaxnet/dragnn/conll2017/BUILD

@@ -0,0 +1,9 @@
+py_binary(
+    name = "make_parser_spec",
+    srcs = ["make_parser_spec.py"],
+    deps = [
+        "//dragnn/protos:spec_py_pb2",
+        "//dragnn/python:spec_builder",
+        "@org_tensorflow//tensorflow:tensorflow_py",
+    ],
+)

+ 40 - 0
syntaxnet/dragnn/conll2017/conll_parser_trainer.sh

@@ -0,0 +1,40 @@
+#!/bin/sh
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# A script to train the CONLL2017 baseline.
+set -e
+
+language=English
+output_dir=./trained-"$language"
+
+training_corpus=$1
+dev_corpus=$2
+
+bazel build -c opt //dragnn/tools:trainer //dragnn/conll2017:make_parser_spec
+
+mkdir -p $output_dir
+bazel-bin/dragnn/conll2017/make_parser_spec \
+  --spec_file="$output_dir/parser_spec.textproto"
+
+bazel-bin/dragnn/tools/trainer \
+  --logtostderr \
+  --compute_lexicon \
+  --dragnn_spec="$output_dir/parser_spec.textproto" \
+  --resource_path="$output_dir/resources" \
+  --training_corpus_path="$training_corpus" \
+  --tune_corpus_path="$dev_corpus" \
+  --tensorboard_dir="$output_dir/tensorboard" \
+  --checkpoint_filename="$output_dir/checkpoint.model"

+ 105 - 0
syntaxnet/dragnn/conll2017/make_parser_spec.py

@@ -0,0 +1,105 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Construct the spec for the CONLL2017 Parser baseline."""
+
+import tensorflow as tf
+
+from tensorflow.python.platform import gfile
+
+from dragnn.protos import spec_pb2
+from dragnn.python import spec_builder
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('spec_file', 'parser_spec.textproto',
+                    'Filename to save the spec to.')
+
+
+def main(unused_argv):
+  # Left-to-right, character-based LSTM.
+  char2word = spec_builder.ComponentSpecBuilder('char_lstm')
+  char2word.set_network_unit(
+      name='wrapped_units.LayerNormBasicLSTMNetwork',
+      hidden_layer_sizes='256')
+  char2word.set_transition_system(name='char-shift-only', left_to_right='true')
+  char2word.add_fixed_feature(name='chars', fml='char-input.text-char',
+                              embedding_dim=16)
+
+  # Lookahead LSTM reads right-to-left to represent the rightmost context of the
+  # words. It gets word embeddings from the char model.
+  lookahead = spec_builder.ComponentSpecBuilder('lookahead')
+  lookahead.set_network_unit(
+      name='wrapped_units.LayerNormBasicLSTMNetwork',
+      hidden_layer_sizes='256')
+  lookahead.set_transition_system(name='shift-only', left_to_right='false')
+  lookahead.add_link(source=char2word, fml='input.last-char-focus',
+                     embedding_dim=64)
+
+  # Construct the tagger. This is a simple left-to-right LSTM sequence tagger.
+  tagger = spec_builder.ComponentSpecBuilder('tagger')
+  tagger.set_network_unit(
+      name='wrapped_units.LayerNormBasicLSTMNetwork',
+      hidden_layer_sizes='256')
+  tagger.set_transition_system(name='tagger')
+  tagger.add_token_link(source=lookahead, fml='input.focus', embedding_dim=64)
+
+  # Construct the parser.
+  parser = spec_builder.ComponentSpecBuilder('parser')
+  parser.set_network_unit(name='FeedForwardNetwork', hidden_layer_sizes='256',
+                          layer_norm_hidden='true')
+  parser.set_transition_system(name='arc-standard')
+  parser.add_token_link(source=lookahead, fml='input.focus', embedding_dim=64)
+  parser.add_token_link(
+      source=tagger, fml='input.focus stack.focus stack(1).focus',
+      embedding_dim=64)
+
+  # Add discrete features of the predicted parse tree so far, like in Parsey
+  # McParseface.
+  parser.add_fixed_feature(name='labels', embedding_dim=16,
+                           fml=' '.join([
+                               'stack.child(1).label',
+                               'stack.child(1).sibling(-1).label',
+                               'stack.child(-1).label',
+                               'stack.child(-1).sibling(1).label',
+                               'stack(1).child(1).label',
+                               'stack(1).child(1).sibling(-1).label',
+                               'stack(1).child(-1).label',
+                               'stack(1).child(-1).sibling(1).label',
+                               'stack.child(2).label',
+                               'stack.child(-2).label',
+                               'stack(1).child(2).label',
+                               'stack(1).child(-2).label']))
+
+  # Recurrent connection for the arc-standard parser. For both tokens on the
+  # stack, we connect to the last time step to either SHIFT or REDUCE that
+  # token. This allows the parser to build up compositional representations of
+  # phrases.
+  parser.add_link(
+      source=parser,  # recurrent connection
+      name='rnn-stack',  # unique identifier
+      fml='stack.focus stack(1).focus',  # look for both stack tokens
+      source_translator='shift-reduce-step',  # maps token indices -> step
+      embedding_dim=64)  # project down to 64 dims
+
+  master_spec = spec_pb2.MasterSpec()
+  master_spec.component.extend(
+      [char2word.spec, lookahead.spec, tagger.spec, parser.spec])
+
+  with gfile.FastGFile(FLAGS.spec_file, 'w') as f:
+    f.write(str(master_spec).encode('utf-8'))
+
+if __name__ == '__main__':
+  tf.app.run()

+ 16 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/category-map

@@ -0,0 +1,16 @@
+15
+NOUN 25758
+VERB 14242
+PUNCT 12945
+PART 9977
+PROPN 8280
+NUM 5082
+ADV 4323
+ADP 4165
+ADJ 2318
+AUX 2024
+PRON 1343
+CCONJ 1329
+DET 994
+X 948
+SYM 25

File diff ditekan karena terlalu besar
+ 3518 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/char-map


File diff ditekan karena terlalu besar
+ 16126 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/char-ngram-map


+ 43 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/label-map

@@ -0,0 +1,43 @@
+42
+punct 12965
+nmod 11147
+nsubj 7134
+obj 6016
+nummod 4732
+case:suff 4179
+acl 4163
+root 3797
+mark 3445
+det 3434
+advmod 2962
+case 2739
+case:dec 2517
+conj 2421
+obl 2000
+dep 1998
+mark:relcl 1833
+clf 1722
+ccomp 1655
+amod 1525
+xcomp 1382
+acl:relcl 1356
+cop 1349
+cc 1334
+nmod:tmod 1199
+appos 1089
+case:aspect 718
+aux 675
+case:pref 569
+aux:pass 324
+csubj 280
+flat:foreign 250
+nsubj:pass 211
+discourse 151
+aux:caus 149
+advcl 125
+mark:advb 79
+iobj 61
+dislocated 45
+mark:comp 17
+csubj:pass 5
+vocative 1

File diff ditekan karena terlalu besar
+ 16263 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/lcword-map


TEMPAT SAMPAH
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/prefix-table


TEMPAT SAMPAH
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/suffix-table


+ 43 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/tag-map

@@ -0,0 +1,43 @@
+42
+NN 21794
+VV 13177
+NNP 8280
+, 5824
+CD 5082
+DEC 4350
+RB 4323
+SFN 4229
+IN 4165
+NNB 3963
+. 3807
+JJ 2318
+VC 1935
+CC 1329
+PRP 996
+DT 994
+EC 942
+FW 778
+AS 718
+MD 681
+( 641
+) 641
+PFA 555
+BB 472
+'' 331
+`` 329
+PRD 324
+/ 202
+: 165
+UH 150
+DEV 96
+HYPH 76
+WP 23
+SFV 18
+XX 17
+ADD 8
+SFA 7
+... 4
+PFN 4
+LS 3
+" 1
+VERB 1

+ 42 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/tag-to-category

@@ -0,0 +1,42 @@
+"	PUNCT
+''	PUNCT
+(	PUNCT
+)	PUNCT
+,	PUNCT
+.	PUNCT
+...	PUNCT
+/	PUNCT
+:	PUNCT
+ADD	NOUN
+AS	PART
+BB	VERB
+CC	CCONJ
+CD	NUM
+DEC	PART
+DEV	PART
+DT	DET
+EC	PUNCT
+FW	X
+HYPH	PUNCT
+IN	ADP
+JJ	ADJ
+LS	X
+MD	AUX
+NN	NOUN
+NNB	NOUN
+NNP	PROPN
+PFA	PART
+PFN	PART
+PRD	PRON
+PRP	PRON
+RB	ADV
+SFA	PART
+SFN	PART
+SFV	PART
+UH	X
+VC	VERB
+VERB	VERB
+VV	AUX
+WP	PRON
+XX	X
+``	PUNCT

File diff ditekan karena terlalu besar
+ 16269 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter-resource/word-map


TEMPAT SAMPAH
syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.data-00000-of-00001


TEMPAT SAMPAH
syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.index


TEMPAT SAMPAH
syntaxnet/dragnn/conll2017/sample/zh-segmenter.checkpoint.meta


+ 187 - 0
syntaxnet/dragnn/conll2017/sample/zh-segmenter.master_spec

@@ -0,0 +1,187 @@
+component {
+  name: "lookahead"
+  transition_system {
+    registered_name: "shift-only"
+    parameters {
+      key: "left_to_right"
+      value: "false"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "word-map"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "tag-map"
+    }
+  }
+  resource {
+    name: "tag-to-category"
+    part {
+      file_pattern: "tag-to-category"
+    }
+  }
+  resource {
+    name: "lcword-map"
+    part {
+      file_pattern: "lcword-map"
+    }
+  }
+  resource {
+    name: "category-map"
+    part {
+      file_pattern: "category-map"
+    }
+  }
+  resource {
+    name: "char-map"
+    part {
+      file_pattern: "char-map"
+    }
+  }
+  resource {
+    name: "char-ngram-map"
+    part {
+      file_pattern: "char-ngram-map"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "label-map"
+    }
+  }
+  resource {
+    name: "prefix-table"
+    part {
+      file_pattern: "prefix-table"
+    }
+  }
+  resource {
+    name: "suffix-table"
+    part {
+      file_pattern: "suffix-table"
+    }
+  }
+  fixed_feature {
+    name: "char"
+    fml: "input(-1).char input.char input(1).char"
+    embedding_dim: 32
+    vocabulary_size: 3521
+    size: 3
+  }
+  fixed_feature {
+    name: "char-bigram"
+    fml: "input.char-bigram"
+    embedding_dim: 32
+    vocabulary_size: 6579
+    size: 1
+  }
+  network_unit {
+    registered_name: "wrapped_units.LayerNormBasicLSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "256"
+    }
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  num_actions: 1
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+}
+component {
+  name: "segmenter"
+  transition_system {
+    registered_name: "binary-segment-transitions"
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "word-map"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "tag-map"
+    }
+  }
+  resource {
+    name: "tag-to-category"
+    part {
+      file_pattern: "tag-to-category"
+    }
+  }
+  resource {
+    name: "lcword-map"
+    part {
+      file_pattern: "lcword-map"
+    }
+  }
+  resource {
+    name: "category-map"
+    part {
+      file_pattern: "category-map"
+    }
+  }
+  resource {
+    name: "char-map"
+    part {
+      file_pattern: "char-map"
+    }
+  }
+  resource {
+    name: "char-ngram-map"
+    part {
+      file_pattern: "char-ngram-map"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "label-map"
+    }
+  }
+  resource {
+    name: "prefix-table"
+    part {
+      file_pattern: "prefix-table"
+    }
+  }
+  resource {
+    name: "suffix-table"
+    part {
+      file_pattern: "suffix-table"
+    }
+  }
+  linked_feature {
+    name: "lookahead"
+    fml: "input.focus stack.focus"
+    embedding_dim: 32
+    size: 2
+    source_component: "lookahead"
+    source_translator: "reverse-token"
+    source_layer: "state_h_0"
+  }
+  network_unit {
+    registered_name: "wrapped_units.LayerNormBasicLSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "128"
+    }
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  num_actions: 2
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+}

+ 340 - 0
syntaxnet/dragnn/core/BUILD

@@ -0,0 +1,340 @@
+package(default_visibility = ["//visibility:public"])
+
+# Test data.
+filegroup(
+    name = "testdata",
+    data = glob(["testdata/**"]),
+)
+
+cc_library(
+    name = "beam",
+    hdrs = ["beam.h"],
+    deps = [
+        "//dragnn/core/interfaces:cloneable_transition_state",
+        "//dragnn/core/interfaces:transition_state",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+cc_library(
+    name = "component_registry",
+    srcs = ["component_registry.cc"],
+    hdrs = ["component_registry.h"],
+    deps = [
+        "//dragnn/core/interfaces:component",
+        "//syntaxnet:registry",
+    ],
+)
+
+cc_library(
+    name = "compute_session",
+    hdrs = ["compute_session.h"],
+    deps = [
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core:index_translator",
+        "//dragnn/core/interfaces:component",
+        "//dragnn/protos:spec_proto",
+        "//dragnn/protos:trace_proto",
+    ],
+)
+
+cc_library(
+    name = "compute_session_impl",
+    srcs = ["compute_session_impl.cc"],
+    hdrs = ["compute_session_impl.h"],
+    deps = [
+        ":compute_session",
+        ":index_translator",
+        ":input_batch_cache",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "//dragnn/protos:trace_proto",
+        "//syntaxnet:registry",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+cc_library(
+    name = "compute_session_pool",
+    srcs = ["compute_session_pool.cc"],
+    hdrs = ["compute_session_pool.h"],
+    deps = [
+        ":component_registry",
+        ":compute_session",
+        ":compute_session_impl",
+        "//dragnn/protos:spec_proto",
+        "@org_tensorflow//tensorflow/core:lib",
+    ],
+)
+
+cc_library(
+    name = "index_translator",
+    srcs = ["index_translator.cc"],
+    hdrs = ["index_translator.h"],
+    deps = [
+        "//dragnn/core/interfaces:component",
+        "//dragnn/core/interfaces:transition_state",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+cc_library(
+    name = "input_batch_cache",
+    hdrs = ["input_batch_cache.h"],
+    deps = [
+        "//dragnn/core/interfaces:input_batch",
+        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+    ],
+)
+
+cc_library(
+    name = "resource_container",
+    hdrs = ["resource_container.h"],
+    deps = [
+        "//syntaxnet:base",
+        "@org_tensorflow//tensorflow/core:framework",
+    ],
+)
+
+# Tests
+
+cc_test(
+    name = "beam_test",
+    srcs = ["beam_test.cc"],
+    deps = [
+        ":beam",
+        "//dragnn/core/interfaces:cloneable_transition_state",
+        "//dragnn/core/interfaces:transition_state",
+        "//dragnn/core/test:mock_transition_state",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "compute_session_impl_test",
+    srcs = ["compute_session_impl_test.cc"],
+    deps = [
+        ":component_registry",
+        ":compute_session",
+        ":compute_session_impl",
+        ":compute_session_pool",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core/interfaces:component",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_component",
+        "//dragnn/core/test:mock_transition_state",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "compute_session_pool_test",
+    srcs = ["compute_session_pool_test.cc"],
+    deps = [
+        ":compute_session",
+        ":compute_session_pool",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_component",
+        "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "index_translator_test",
+    srcs = ["index_translator_test.cc"],
+    deps = [
+        ":index_translator",
+        "//dragnn/core/test:mock_component",
+        "//dragnn/core/test:mock_transition_state",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "input_batch_cache_test",
+    srcs = ["input_batch_cache_test.cc"],
+    deps = [
+        ":input_batch_cache",
+        "//dragnn/core/interfaces:input_batch",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_test(
+    name = "resource_container_test",
+    srcs = ["resource_container_test.cc"],
+    deps = [
+        ":resource_container",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+# Tensorflow op kernel BUILD rules.
+
+load(
+    "//dragnn:tensorflow_ops.bzl",
+    "tf_gen_op_libs",
+    "tf_gen_op_wrapper_py",
+    "tf_kernel_library",
+)
+
+tf_gen_op_libs(
+    op_lib_names = ["dragnn_ops"],
+)
+
+tf_gen_op_wrapper_py(
+    name = "dragnn_ops",
+    deps = [":dragnn_ops_op_lib"],
+)
+
+tf_gen_op_libs(
+    op_lib_names = ["dragnn_bulk_ops"],
+)
+
+tf_gen_op_wrapper_py(
+    name = "dragnn_bulk_ops",
+    deps = [":dragnn_bulk_ops_op_lib"],
+)
+
+cc_library(
+    name = "compute_session_op",
+    srcs = [
+        "ops/compute_session_op.cc",
+    ],
+    hdrs = ["ops/compute_session_op.h"],
+    deps = [
+        ":compute_session",
+        ":resource_container",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//third_party/eigen3",
+    ],
+)
+
+cc_library(
+    name = "dragnn_ops_cc",
+    srcs = [
+        "ops/dragnn_op_kernels.cc",
+        "ops/dragnn_ops.cc",
+    ],
+    deps = [
+        ":compute_session",
+        ":compute_session_op",
+        ":compute_session_pool",
+        ":resource_container",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//third_party/eigen3",
+    ],
+    alwayslink = 1,
+)
+
+cc_library(
+    name = "dragnn_bulk_ops_cc",
+    srcs = [
+        "ops/dragnn_bulk_op_kernels.cc",
+        "ops/dragnn_bulk_ops.cc",
+    ],
+    deps = [
+        ":compute_session_op",
+        ":resource_container",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:protos_all_cc",
+        "@org_tensorflow//third_party/eigen3",
+    ],
+)
+
+# Tensorflow kernel libraries, for use with unit tests.
+
+tf_kernel_library(
+    name = "dragnn_op_kernels",
+    srcs = [
+        "ops/dragnn_op_kernels.cc",
+        "ops/dragnn_ops.cc",
+    ],
+    hdrs = [
+    ],
+    deps = [
+        ":compute_session",
+        ":compute_session_op",
+        ":compute_session_pool",
+        ":resource_container",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//third_party/eigen3",
+    ],
+)
+
+tf_kernel_library(
+    name = "dragnn_bulk_op_kernels",
+    srcs = [
+        "ops/dragnn_bulk_op_kernels.cc",
+        "ops/dragnn_bulk_ops.cc",
+    ],
+    hdrs = [
+    ],
+    deps = [
+        ":compute_session",
+        ":compute_session_op",
+        ":compute_session_pool",
+        ":resource_container",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/protos:spec_proto",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:protos_all_cc",
+        "@org_tensorflow//third_party/eigen3",
+    ],
+)
+
+# Tensorflow kernel tests.
+
+cc_test(
+    name = "dragnn_op_kernels_test",
+    srcs = ["ops/dragnn_op_kernels_test.cc"],
+    deps = [
+        ":compute_session",
+        ":compute_session_pool",
+        ":dragnn_op_kernels",
+        ":resource_container",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:protos_all_cc",
+        "@org_tensorflow//tensorflow/core:test",
+        "@org_tensorflow//tensorflow/core:testlib",
+        "@org_tensorflow//tensorflow/core/kernels:ops_testutil",
+        "@org_tensorflow//tensorflow/core/kernels:ops_util",
+        "@org_tensorflow//tensorflow/core/kernels:quantized_ops",
+    ],
+)
+
+cc_test(
+    name = "dragnn_bulk_op_kernels_test",
+    srcs = ["ops/dragnn_bulk_op_kernels_test.cc"],
+    deps = [
+        ":compute_session_pool",
+        ":dragnn_bulk_op_kernels",
+        ":resource_container",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:framework",
+        "@org_tensorflow//tensorflow/core:testlib",
+        "@org_tensorflow//tensorflow/core/kernels:ops_testutil",
+        "@org_tensorflow//tensorflow/core/kernels:quantized_ops",
+    ],
+)

+ 347 - 0
syntaxnet/dragnn/core/beam.h

@@ -0,0 +1,347 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
+
+#include <algorithm>
+#include <memory>
+#include <vector>
+
+#include "dragnn/core/interfaces/cloneable_transition_state.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// The Beam class wraps the logic necessary to advance a set of transition
+// states for an arbitrary Component. Because the Beam class is generic, it
+// doesn't know how to act on the states it is provided - the instantiating
+// Component is expected to provide it the three functions it needs to interact
+// with that Component's TransitionState subclasses.
+
+template <typename T>
+class Beam {
+ public:
+  // Creates a new Beam which can grow up to max_size elements.
+  explicit Beam(int max_size) : max_size_(max_size), num_steps_(0) {
+    VLOG(2) << "Creating beam with max size " << max_size_;
+    static_assert(
+        std::is_base_of<CloneableTransitionState<T>, T>::value,
+        "This class must be instantiated to use a CloneableTransitionState");
+  }
+
+  // Sets the Beam functions, as follows:
+  // bool is_allowed(TransitionState *, int): Return true if transition 'int' is
+  //   allowed for transition state 'TransitionState *'.
+  // void perform_transition(TransitionState *, int): Performs transition 'int'
+  //   on transition state 'TransitionState *'.
+  // int oracle_function(TransitionState *): Returns the oracle-specified action
+  //   for transition state 'TransitionState *'.
+  void SetFunctions(std::function<bool(T *, int)> is_allowed,
+                    std::function<bool(T *)> is_final,
+                    std::function<void(T *, int)> perform_transition,
+                    std::function<int(T *)> oracle_function) {
+    is_allowed_ = is_allowed;
+    is_final_ = is_final;
+    perform_transition_ = perform_transition;
+    oracle_function_ = oracle_function;
+  }
+
+  // Resets the Beam and initializes it with the given set of states. The Beam
+  // takes ownership of these TransitionStates.
+  void Init(std::vector<std::unique_ptr<T>> initial_states) {
+    VLOG(2) << "Initializing beam. Beam max size is " << max_size_;
+    CHECK_LE(initial_states.size(), max_size_)
+        << "Attempted to initialize a beam with more states ("
+        << initial_states.size() << ") than the max size " << max_size_;
+    beam_ = std::move(initial_states);
+    std::vector<int> previous_beam_indices(max_size_, -1);
+    for (int i = 0; i < beam_.size(); ++i) {
+      previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
+      beam_[i]->SetBeamIndex(i);
+    }
+    beam_index_history_.emplace_back(previous_beam_indices);
+  }
+
+  // Advances the Beam from the given transition matrix.
+  void AdvanceFromPrediction(const float transition_matrix[], int matrix_length,
+                             int num_actions) {
+    // Ensure that the transition matrix is the correct size. All underlying
+    // states should have the same transition profile, so using the one at 0
+    // should be safe.
+    CHECK_EQ(matrix_length, max_size_ * num_actions)
+        << "Transition matrix size does not match max beam size * number of "
+           "state transitions!";
+
+    if (max_size_ == 1) {
+      // In the case where beam size is 1, we can advance by simply finding the
+      // highest score and advancing the beam state in place.
+      VLOG(2) << "Beam size is 1. Using fast beam path.";
+      int best_action = -1;
+      float best_score = -INFINITY;
+      auto &state = beam_[0];
+      for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
+        if (is_allowed_(state.get(), action_idx) &&
+            transition_matrix[action_idx] > best_score) {
+          best_score = transition_matrix[action_idx];
+          best_action = action_idx;
+        }
+      }
+      CHECK_GE(best_action, 0) << "Num actions: " << num_actions
+                               << " score[0]: " << transition_matrix[0];
+      perform_transition_(state.get(), best_action);
+      const float new_score = state->GetScore() + best_score;
+      state->SetScore(new_score);
+      state->SetBeamIndex(0);
+    } else {
+      // Create the vector of all possible transitions, along with their scores.
+      std::vector<Transition> candidates;
+
+      // Iterate through all beams, examining all actions for each beam.
+      for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
+        const auto &state = beam_[beam_idx];
+        for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
+          // If the action is allowed, calculate the proposed new score and add
+          // the candidate action to the vector of all actions at this state.
+          if (is_allowed_(state.get(), action_idx)) {
+            Transition candidate;
+
+            // The matrix is laid out by beam index, with a linear set of
+            // actions for that index - so beam N's actions start at [nr. of
+            // actions]*[N].
+            const int matrix_idx = action_idx + beam_idx * num_actions;
+            CHECK_LT(matrix_idx, matrix_length)
+                << "Matrix index out of bounds!";
+            const double score_delta = transition_matrix[matrix_idx];
+            CHECK(!isnan(score_delta));
+            candidate.source_idx = beam_idx;
+            candidate.action = action_idx;
+            candidate.resulting_score = state->GetScore() + score_delta;
+            candidates.emplace_back(candidate);
+          }
+        }
+      }
+
+      // Sort the vector of all possible transitions and scores.
+      const auto comparator = [](const Transition &a, const Transition &b) {
+        return a.resulting_score > b.resulting_score;
+      };
+      std::sort(candidates.begin(), candidates.end(), comparator);
+
+      // Apply the top transitions, up to a maximum of 'max_size_'.
+      std::vector<std::unique_ptr<T>> new_beam;
+      std::vector<int> previous_beam_indices(max_size_, -1);
+      const int beam_size =
+          std::min(max_size_, static_cast<int>(candidates.size()));
+      VLOG(2) << "Previous beam size = " << beam_.size();
+      VLOG(2) << "New beam size = " << beam_size;
+      VLOG(2) << "Maximum beam size = " << max_size_;
+      for (int i = 0; i < beam_size; ++i) {
+        // Get the source of the i'th transition.
+        const auto &transition = candidates[i];
+        VLOG(2) << "Taking transition with score: "
+                << transition.resulting_score
+                << " and action: " << transition.action;
+        VLOG(2) << "transition.source_idx = " << transition.source_idx;
+        const auto &source = beam_[transition.source_idx];
+
+        // Put the new transition on the new state beam.
+        auto new_state = source->Clone();
+        perform_transition_(new_state.get(), transition.action);
+        new_state->SetScore(transition.resulting_score);
+        new_state->SetBeamIndex(i);
+        previous_beam_indices.at(i) = transition.source_idx;
+        new_beam.emplace_back(std::move(new_state));
+      }
+
+      beam_ = std::move(new_beam);
+      beam_index_history_.emplace_back(previous_beam_indices);
+    }
+
+    ++num_steps_;
+  }
+
+  // Advances the Beam from the state oracles.
+  void AdvanceFromOracle() {
+    std::vector<int> previous_beam_indices(max_size_, -1);
+    for (int i = 0; i < beam_.size(); ++i) {
+      previous_beam_indices.at(i) = i;
+      if (is_final_(beam_[i].get())) continue;
+      const auto oracle_label = oracle_function_(beam_[i].get());
+      VLOG(2) << "AdvanceFromOracle beam_index:" << i
+              << " oracle_label:" << oracle_label;
+      perform_transition_(beam_[i].get(), oracle_label);
+      beam_[i]->SetScore(0.0);
+      beam_[i]->SetBeamIndex(i);
+    }
+    if (max_size_ > 1) {
+      beam_index_history_.emplace_back(previous_beam_indices);
+    }
+    num_steps_++;
+  }
+
+  // Returns true if all states in the beam are final.
+  bool IsTerminal() {
+    for (auto &state : beam_) {
+      if (!is_final_(state.get())) {
+        return false;
+      }
+    }
+    return true;
+  }
+
+  // Destroys the states held by this beam and resets its history.
+  void Reset() {
+    beam_.clear();
+    beam_index_history_.clear();
+    num_steps_ = 0;
+  }
+
+  // Given an index into the current beam, determine the index of the item's
+  // parent at beam step "step", which should be less than the total number
+  // of steps taken by this beam.
+  int FindPreviousIndex(int current_index, int step) const {
+    VLOG(2) << "FindPreviousIndex requested for current_index:" << current_index
+            << " at step:" << step;
+    if (VLOG_IS_ON(2)) {
+      int step_index = 0;
+      for (const auto &step : beam_index_history_) {
+        string row =
+            "Step " + std::to_string(step_index) + " element source slot: ";
+        for (const auto &index : step) {
+          if (index == -1) {
+            row += "  X";
+          } else {
+            row += "  " + std::to_string(index);
+          }
+        }
+        VLOG(2) << row;
+        ++step_index;
+      }
+    }
+
+    // If the max size of the beam is 1, make sure the steps are in sync with
+    // the size.
+    if (max_size_ > 1) {
+      CHECK(num_steps_ == beam_index_history_.size() - 1);
+    }
+
+    // Check if the step is too far into the past or future.
+    if (step < 0 || step > num_steps_) {
+      return -1;
+    }
+
+    // Check that the index is within the beam.
+    if (current_index < 0 || current_index >= max_size_) {
+      return -1;
+    }
+
+    // If the max size of the beam is 1, always return 0.
+    if (max_size_ == 1) {
+      return 0;
+    }
+
+    // Check that the start index isn't -1; -1 means that we don't have an
+    // actual transition state in that beam slot.
+    if (beam_index_history_.back().at(current_index) == -1) {
+      return -1;
+    }
+
+    int beam_index = current_index;
+    for (int i = beam_index_history_.size() - 1; i >= step; --i) {
+      beam_index = beam_index_history_.at(i).at(beam_index);
+    }
+    CHECK_GE(beam_index, 0);
+    VLOG(2) << "Index is " << beam_index;
+    return beam_index;
+  }
+
+  // Returns the current state of the beam.
+  std::vector<const TransitionState *> beam() const {
+    std::vector<const TransitionState *> state_ptrs;
+    for (const auto &beam_state : beam_) {
+      state_ptrs.emplace_back(beam_state.get());
+    }
+    return state_ptrs;
+  }
+
+  // Returns the beam at the current state index.
+  T *beam_state(int beam_index) { return beam_.at(beam_index).get(); }
+
+  // Returns the raw history vectors for this beam.
+  const std::vector<std::vector<int>> &history() {
+    if (max_size_ == 1) {
+      // If max size is 1, we haven't been keeping track of the beam. Quick
+      // create it.
+      beam_index_history_.clear();
+      beam_index_history_.push_back({beam_[0]->ParentBeamIndex()});
+      for (int i = 0; i < num_steps_; ++i) {
+        beam_index_history_.push_back({0});
+      }
+    }
+    return beam_index_history_;
+  }
+
+  // Sets the max size of the beam.
+  void SetMaxSize(int max_size) {
+    max_size_ = max_size;
+    Reset();
+  }
+
+  // Returns the number of steps taken so far.
+  const int num_steps() const { return num_steps_; }
+
+  // Returns the max size of this beam.
+  const int max_size() const { return max_size_; }
+
+  // Returns the current size of the beam.
+  const int size() const { return beam_.size(); }
+
+ private:
+  // Associates an action taken on an index into current_state_ with a score.
+  struct Transition {
+    // The index of the source item.
+    int source_idx;
+
+    // The index of the action being taken.
+    int action;
+
+    // The score of the full derivation.
+    double resulting_score;
+  };
+
+  // The maximum beam size.
+  int max_size_;
+
+  // The current beam.
+  std::vector<std::unique_ptr<T>> beam_;
+
+  // Function to check if a transition is allowed for a given state.
+  std::function<bool(T *, int)> is_allowed_;
+
+  // Function to check if a state is final.
+  std::function<int(T *)> is_final_;
+
+  // Function to perform a transition on a given state.
+  std::function<void(T *, int)> perform_transition_;
+
+  // Function to provide the oracle action for a given state.
+  std::function<int(T *)> oracle_function_;
+
+  // The history of the states in this beam. The vector indexes across steps.
+  // For every step, there is a vector in the vector. This inner vector denotes
+  // the state of the beam at that step, and contains the beam index that
+  // was transitioned to create the transition state at that index (so,
+  // if at step 2 the transition state at beam index 4 was created by applying
+  // a transition to the state in beam index 3 during step 1, the query would
+  // be "beam_index_history_.at(2).at(4)" and the value would be 3. Empty beam
+  // states will return -1.
+  std::vector<std::vector<int>> beam_index_history_;
+
+  // The number of steps taken so far.
+  int num_steps_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_

+ 773 - 0
syntaxnet/dragnn/core/beam_test.cc

@@ -0,0 +1,773 @@
+#include "dragnn/core/beam.h"
+
+#include "dragnn/core/interfaces/cloneable_transition_state.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/core/test/mock_transition_state.h"
+#include <gmock/gmock.h>
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using testing::MockFunction;
+using testing::Return;
+using testing::Ne;
+using testing::_;
+
+namespace {
+
+// *****************************************************************************
+// Test-internal class definitions.
+// *****************************************************************************
+
+// Create a very basic transition state to test the beam. All it does is keep
+// track of its current beam index and score, as well as providing a field
+// for the transition function to write in what transition occurred.
+// Note that this class does not fulfill the entire TransitionState contract,
+// since it is only used in this particular test.
+class TestTransitionState
+    : public CloneableTransitionState<TestTransitionState> {
+ public:
+  TestTransitionState() {}
+
+  void Init(const TransitionState &parent) override {}
+
+  std::unique_ptr<TestTransitionState> Clone() const override {
+    std::unique_ptr<TestTransitionState> ptr(new TestTransitionState());
+    return ptr;
+  }
+
+  const int ParentBeamIndex() const override { return parent_beam_index_; }
+
+  // Get the current beam index for this state.
+  const int GetBeamIndex() const override { return beam_index_; }
+
+  // Set the current beam index for this state.
+  void SetBeamIndex(const int index) override { beam_index_ = index; }
+
+  // Get the score associated with this transition state.
+  const float GetScore() const override { return score_; }
+
+  // Set the score associated with this transition state.
+  void SetScore(const float score) override { score_ = score; }
+
+  // Depicts this state as an HTML-language string.
+  string HTMLRepresentation() const override { return ""; }
+
+  int parent_beam_index_;
+
+  int beam_index_;
+
+  float score_;
+
+  int transition_action_;
+};
+
+// This transition function annotates a TestTransitionState with the action that
+// was chosen for the transition.
+auto transition_function = [](TestTransitionState *state, int action) {
+  TestTransitionState *cast_state = dynamic_cast<TestTransitionState *>(state);
+  cast_state->transition_action_ = action;
+};
+
+// Create oracle and permission functions that do nothing.
+auto null_oracle = [](TestTransitionState *) { return 0; };
+auto null_permissions = [](TestTransitionState *, int) { return true; };
+auto null_finality = [](TestTransitionState *) { return false; };
+
+// Create a unique_ptr with a test transition state in it and set its initial
+// score.
+std::unique_ptr<TestTransitionState> CreateState(float score) {
+  std::unique_ptr<TestTransitionState> state;
+  state.reset(new TestTransitionState());
+  state->SetScore(score);
+  return state;
+}
+
+// Convenience accessor for the action field in TestTransitionState.
+int GetTransition(const TransitionState *state) {
+  return (dynamic_cast<const TestTransitionState *>(state))->transition_action_;
+}
+
+// Convenience accessor for the parent_beam_index_ field in TestTransitionState.
+void SetParentBeamIndex(TransitionState *state, int index) {
+  (dynamic_cast<TestTransitionState *>(state))->parent_beam_index_ = index;
+}
+
+}  // namespace
+
+// *****************************************************************************
+// Tests begin here.
+// *****************************************************************************
+TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
+  // Create a matrix of transitions.
+  constexpr int kNumTransitions = 4;
+  constexpr int kMatrixSize = kNumTransitions;
+  constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
+  constexpr int kBestTransition = 2;
+  constexpr float kOldScore = 3.0;
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScore));
+  constexpr int kBeamSize = 1;
+  Beam<TestTransitionState> beam(kBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), kBeamSize);
+
+  // Make sure the state has performed the expected transition.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), kBestTransition);
+
+  // Make sure the state has had its score updated properly.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[kBestTransition]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
+
+  // Make sure that the beam_state accessor actually accesses the beam.
+  EXPECT_EQ(beam.beam().at(0), beam.beam_state(0));
+
+  // Validate the beam history field.
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 0);
+}
+
+TEST(BeamTest, AdvancingCreatesNewTransitions) {
+  // Create a matrix of transitions.
+  constexpr int kMaxBeamSize = 8;
+  constexpr int kNumTransitions = 4;
+  constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
+  constexpr float matrix[kMatrixSize] = {
+      30.0, 20.0, 40.0, 10.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
+  constexpr float kOldScore = 4.0;
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScore));
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 4);
+
+  // Make sure the state has performed the expected transition.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(1)), 0);
+  EXPECT_EQ(GetTransition(beam.beam().at(2)), 1);
+  EXPECT_EQ(GetTransition(beam.beam().at(3)), 3);
+
+  // Make sure the state has had its score updated properly.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
+  EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[0]);
+  EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[1]);
+  EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScore + matrix[3]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  // In this case, we expect the top 4 results to have come from state 0 and
+  // the remaining 4 slots to be empty (-1).
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 0);
+  EXPECT_EQ(history.at(1).at(1), 0);
+  EXPECT_EQ(history.at(1).at(2), 0);
+  EXPECT_EQ(history.at(1).at(3), 0);
+  EXPECT_EQ(history.at(1).at(4), -1);
+  EXPECT_EQ(history.at(1).at(5), -1);
+  EXPECT_EQ(history.at(1).at(6), -1);
+  EXPECT_EQ(history.at(1).at(7), -1);
+}
+
+TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
+  // Create a matrix of transitions.
+  constexpr int kMaxBeamSize = 8;
+  constexpr int kNumTransitions = 4;
+  constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
+
+  constexpr float matrix[kMatrixSize] = {
+      30.0, 20.0, 40.0, 10.0,  // State 0
+      31.0, 21.0, 41.0, 11.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
+
+  constexpr float kOldScores[] = {5.0, 7.0};
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScores[0]));
+  states.push_back(CreateState(kOldScores[1]));
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 8);
+
+  // Make sure the state has performed the expected transition.
+  // Note that the transition index is not the index into the matrix, but rather
+  // the index into the matrix 'row' for that state.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
+  EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
+  EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
+  EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
+  EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
+  EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
+
+  // Make sure the state has had its score updated properly.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
+  EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
+  EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
+  EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
+  EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
+  EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
+  EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
+  EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  // Validate the history at this step.
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 1);
+  EXPECT_EQ(history.at(1).at(1), 0);
+  EXPECT_EQ(history.at(1).at(2), 1);
+  EXPECT_EQ(history.at(1).at(3), 0);
+  EXPECT_EQ(history.at(1).at(4), 1);
+  EXPECT_EQ(history.at(1).at(5), 0);
+  EXPECT_EQ(history.at(1).at(6), 1);
+  EXPECT_EQ(history.at(1).at(7), 0);
+}
+
+TEST(BeamTest, AdvancingDropsLowValuePredictions) {
+  // Create a matrix of transitions.
+  constexpr int kNumTransitions = 4;
+  constexpr int kMaxBeamSize = 8;
+  constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
+  constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0,   // State 0
+                                         31.0, 21.0, 41.0, 11.0,   // State 1
+                                         32.0, 22.0, 42.0, 12.0,   // State 2
+                                         33.0, 23.0, 43.0, 13.0,   // State 3
+                                         34.0, 24.0, 44.0, 14.0,   // State 4
+                                         35.0, 25.0, 45.0, 15.0,   // State 5
+                                         36.0, 26.0, 46.0, 16.0,   // State 6
+                                         37.0, 27.0, 47.0, 17.0};  // State 7
+  constexpr float kOldScores[] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScores[0]));
+  states.push_back(CreateState(kOldScores[1]));
+  states.push_back(CreateState(kOldScores[2]));
+  states.push_back(CreateState(kOldScores[3]));
+  states.push_back(CreateState(kOldScores[4]));
+  states.push_back(CreateState(kOldScores[5]));
+  states.push_back(CreateState(kOldScores[6]));
+  states.push_back(CreateState(kOldScores[7]));
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 8);
+
+  // Make sure the state has performed the expected transition.
+  // In this case, every state will perform transition 2.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(2)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(3)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(4)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(5)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(6)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(7)), 2);
+
+  // Make sure the state has had its score updated properly. (Note that row
+  // 0 had the smallest transition score, so it ends up on the bottom of the
+  // beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
+  // correct state row and we add 2 since that was the transition index.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(),
+            kOldScores[7] + matrix[7 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(1)->GetScore(),
+            kOldScores[6] + matrix[6 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(2)->GetScore(),
+            kOldScores[5] + matrix[5 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(3)->GetScore(),
+            kOldScores[4] + matrix[4 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(4)->GetScore(),
+            kOldScores[3] + matrix[3 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(5)->GetScore(),
+            kOldScores[2] + matrix[2 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(6)->GetScore(),
+            kOldScores[1] + matrix[1 * kNumTransitions + 2]);
+  EXPECT_EQ(beam.beam().at(7)->GetScore(),
+            kOldScores[0] + matrix[0 * kNumTransitions + 2]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 7);
+  EXPECT_EQ(history.at(1).at(1), 6);
+  EXPECT_EQ(history.at(1).at(2), 5);
+  EXPECT_EQ(history.at(1).at(3), 4);
+  EXPECT_EQ(history.at(1).at(4), 3);
+  EXPECT_EQ(history.at(1).at(5), 2);
+  EXPECT_EQ(history.at(1).at(6), 1);
+  EXPECT_EQ(history.at(1).at(7), 0);
+}
+
+TEST(BeamTest, AdvancesFromOracleWithSingleBeam) {
+  // Create an oracle function for this state.
+  constexpr int kOracleLabel = 3;
+  auto oracle_function = [](TransitionState *) { return kOracleLabel; };
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(0.0));
+  constexpr int kBeamSize = 1;
+  Beam<TestTransitionState> beam(kBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    oracle_function);
+  beam.Init(std::move(states));
+  beam.AdvanceFromOracle();
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), kBeamSize);
+
+  // Make sure the state has performed the expected transition.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), kOracleLabel);
+
+  // Make sure the state has had its score held to 0.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), 0.0);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
+
+  // Validate the beam history field.
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 0);
+}
+
+TEST(BeamTest, AdvancesFromOracleWithMultipleStates) {
+  constexpr int kMaxBeamSize = 8;
+
+  // Create a beam with 8 transition states.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    // This is nonzero to test the oracle holding scores to 0.
+    states.push_back(CreateState(10.0));
+  }
+
+  std::vector<int> expected_actions;
+
+  // Create an oracle function for this state. Use mocks for finer control.
+  testing::MockFunction<int(TestTransitionState *)> mock_oracle_function;
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    // We expect each state to be queried for its oracle label,
+    // and then to be transitioned in place with its oracle label.
+    int oracle_label = i % 3;  // 3 is arbitrary.
+    EXPECT_CALL(mock_oracle_function, Call(states.at(i).get()))
+        .WillOnce(Return(oracle_label));
+    expected_actions.push_back(oracle_label);
+  }
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    mock_oracle_function.AsStdFunction());
+  beam.Init(std::move(states));
+  beam.AdvanceFromOracle();
+
+  // Make sure the state has performed the expected transition, has had its
+  // score held to 0, and is self consistent.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(GetTransition(beam.beam().at(i)), expected_actions.at(i));
+    EXPECT_EQ(beam.beam().at(i)->GetScore(), 0.0);
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  auto history = beam.history();
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(history.at(1).at(i), i);
+  }
+}
+
+TEST(BeamTest, ReportsNonFinality) {
+  constexpr int kMaxBeamSize = 8;
+
+  // Create a beam with 8 transition states.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    // This is nonzero to test the oracle holding scores to 0.
+    states.push_back(CreateState(10.0));
+  }
+
+  std::vector<int> expected_actions;
+
+  // Create a finality function for this state. Use mocks for finer control.
+  testing::MockFunction<int(TestTransitionState *)> mock_finality_function;
+
+  // Make precisely one call return false, which should cause IsFinal
+  // to report false.
+  constexpr int incomplete_state = 3;
+  EXPECT_CALL(mock_finality_function, Call(states.at(incomplete_state).get()))
+      .WillOnce(Return(false));
+  EXPECT_CALL(mock_finality_function,
+              Call(Ne(states.at(incomplete_state).get())))
+      .WillRepeatedly(Return(true));
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
+                    transition_function, null_oracle);
+  beam.Init(std::move(states));
+
+  EXPECT_FALSE(beam.IsTerminal());
+}
+
+TEST(BeamTest, ReportsFinality) {
+  constexpr int kMaxBeamSize = 8;
+
+  // Create a beam with 8 transition states.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    // This is nonzero to test the oracle holding scores to 0.
+    states.push_back(CreateState(10.0));
+  }
+
+  std::vector<int> expected_actions;
+
+  // Create a finality function for this state. Use mocks for finer control.
+  testing::MockFunction<int(TransitionState *)> mock_finality_function;
+
+  // All calls will return true, so IsFinal should return true.
+  EXPECT_CALL(mock_finality_function, Call(_)).WillRepeatedly(Return(true));
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
+                    transition_function, null_oracle);
+  beam.Init(std::move(states));
+
+  EXPECT_TRUE(beam.IsTerminal());
+}
+
+TEST(BeamTest, IgnoresForbiddenTransitionActions) {
+  // Create a matrix of transitions.
+  constexpr int kMaxBeamSize = 4;
+  constexpr int kNumTransitions = 4;
+  constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
+  constexpr float matrix[kMatrixSize] = {
+      10.0, 1000.0, 40.0, 30.0, 00.0, 0000.0, 00.0, 00.0,
+      00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0};
+  constexpr float kOldScore = 4.0;
+
+  // Create the beam.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScore));
+
+  // Forbid the second transition (index 1).
+  testing::MockFunction<int(TestTransitionState *, int)>
+      mock_permission_function;
+  EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 0))
+      .WillOnce(Return(true));
+  EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 1))
+      .WillOnce(Return(false));
+  EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 2))
+      .WillOnce(Return(true));
+  EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 3))
+      .WillOnce(Return(true));
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(mock_permission_function.AsStdFunction(), null_finality,
+                    transition_function, null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 3);
+
+  // Make sure the state has performed the expected transition.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(1)), 3);
+  EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
+
+  // Make sure the state has had its score updated properly.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
+  EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[3]);
+  EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[0]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  // In this case, we expect the top 3 results to have come from state 0 and
+  // the remaining 3 slots to be empty (-1).
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 0);
+  EXPECT_EQ(history.at(1).at(1), 0);
+  EXPECT_EQ(history.at(1).at(2), 0);
+  EXPECT_EQ(history.at(1).at(3), -1);
+}
+
+TEST(BeamTest, BadlySizedMatrixDies) {
+  // Create a matrix of transitions.
+  constexpr int kNumTransitions = 4;
+  constexpr int kMatrixSize = 4;  // We have a max beam size of 4; should be 16.
+  constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(0.0));
+  states.push_back(CreateState(0.0));
+  constexpr int kMaxBeamSize = 8;
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+
+  // This matrix should have 8 elements, not 4, so this should die.
+  EXPECT_DEATH(beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions),
+               "Transition matrix size does not match max beam size \\* number "
+               "of state transitions");
+}
+
+TEST(BeamTest, BadlySizedBeamInitializationDies) {
+  // Create an initialization beam too large for the max beam size.
+  constexpr int kMaxBeamSize = 4;
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  for (int i = 0; i < kMaxBeamSize + 1; ++i) {
+    states.push_back(CreateState(0.0));
+  }
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+
+  // Try to initialize the beam; this should die.
+  EXPECT_DEATH(beam.Init(std::move(states)),
+               "Attempted to initialize a beam with more states");
+}
+
+TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) {
+  // Create a standard beam.
+  constexpr int kMaxBeamSize = 4;
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    states.push_back(CreateState(0.0));
+  }
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+
+  beam.Init(std::move(states));
+
+  // Verify that all beam indices have been initialized.
+  for (int i = 0; i < kMaxBeamSize; ++i) {
+    EXPECT_EQ(i, beam.beam_state(i)->GetBeamIndex());
+  }
+}
+
+TEST(BeamTest, FindPreviousIndexTracesHistory) {
+  // Create a matrix of transitions.
+  constexpr int kNumTransitions = 4;
+  constexpr int kMaxBeamSize = 8;
+  constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
+  constexpr float matrix[kMatrixSize] = {
+      30.0, 20.0, 40.0, 10.0,  // State 0
+      31.0, 21.0, 41.0, 11.0,  // State 1
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
+      00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
+  constexpr float kOldScores[] = {5.0, 7.0};
+  constexpr int kParentBeamIndices[] = {1138, 42};
+
+  // Create the beam and transition it.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(kOldScores[0]));
+  states.push_back(CreateState(kOldScores[1]));
+
+  // Set parent beam indices.
+  SetParentBeamIndex(states.at(0).get(), kParentBeamIndices[0]);
+  SetParentBeamIndex(states.at(1).get(), kParentBeamIndices[1]);
+
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+  beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 8);
+
+  // Make sure the state has performed the expected transition.
+  // Note that the transition index is not the index into the matrix, but rather
+  // the index into the matrix 'row' for that state.
+  EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
+  EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
+  EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
+  EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
+  EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
+  EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
+  EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
+
+  // Make sure the state has had its score updated properly.
+  EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
+  EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
+  EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
+  EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
+  EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
+  EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
+  EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
+  EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
+
+  // Make sure that the beam index field is consistent with the actual beam idx.
+  for (int i = 0; i < beam.beam().size(); ++i) {
+    EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
+  }
+
+  // Validate the history at this step.
+  auto history = beam.history();
+  EXPECT_EQ(history.at(1).at(0), 1);
+  EXPECT_EQ(history.at(1).at(1), 0);
+  EXPECT_EQ(history.at(1).at(2), 1);
+  EXPECT_EQ(history.at(1).at(3), 0);
+  EXPECT_EQ(history.at(1).at(4), 1);
+  EXPECT_EQ(history.at(1).at(5), 0);
+  EXPECT_EQ(history.at(1).at(6), 1);
+  EXPECT_EQ(history.at(1).at(7), 0);
+
+  EXPECT_EQ(history.at(0).at(0), kParentBeamIndices[0]);
+  EXPECT_EQ(history.at(0).at(1), kParentBeamIndices[1]);
+  EXPECT_EQ(history.at(0).at(2), -1);
+  EXPECT_EQ(history.at(0).at(3), -1);
+  EXPECT_EQ(history.at(0).at(4), -1);
+  EXPECT_EQ(history.at(0).at(5), -1);
+  EXPECT_EQ(history.at(0).at(6), -1);
+  EXPECT_EQ(history.at(0).at(7), -1);
+
+  // Make sure that FindPreviousIndex can read through the history from step 1
+  // to step 0.
+  constexpr int kDesiredIndex = 0;
+  constexpr int kCurrentIndexOne = 4;
+  EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexOne, kDesiredIndex),
+            kParentBeamIndices[1]);
+
+  constexpr int kCurrentIndexTwo = 7;
+  EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexTwo, kDesiredIndex),
+            kParentBeamIndices[0]);
+}
+
+TEST(BeamTest, FindPreviousIndexReturnsInError) {
+  // Create the beam. This now has only one history state, 0.
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(0.0));
+  constexpr int kMaxBeamSize = 8;
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+
+  // If the requested step is greater than the number of steps taken, expect -1.
+  EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
+
+  // If the requested step is less than 0, expect -1.
+  EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
+
+  // If the requested index does not have a state, expect -1.
+  EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
+
+  // If the requested index is less than 0, expect -1.
+  EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
+
+  // If the requested index is larger than the maximum beam size -1, expect -1.
+  EXPECT_EQ(beam.FindPreviousIndex(0, kMaxBeamSize), -1);
+}
+
+TEST(BeamTest, ResetClearsBeamState) {
+  // Create the beam
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(1.0));
+  constexpr int kMaxBeamSize = 8;
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+
+  // Validate the new beam.
+  EXPECT_EQ(beam.beam().size(), 1);
+
+  // Reset the beam.
+  beam.Reset();
+
+  // Validate the now-reset beam, which should be empty.
+  EXPECT_EQ(beam.beam().size(), 0);
+}
+
+TEST(BeamTest, ResetClearsBeamHistory) {
+  // Create the beam
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(1.0));
+  constexpr int kMaxBeamSize = 8;
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+
+  // Validate the new beam history.
+  EXPECT_EQ(beam.history().size(), 1);
+
+  // Reset the beam.
+  beam.Reset();
+
+  // Validate the now-reset beam history, which should be empty.
+  EXPECT_EQ(beam.history().size(), 0);
+}
+
+TEST(BeamTest, SettingMaxSizeResetsBeam) {
+  // Create the beam
+  std::vector<std::unique_ptr<TestTransitionState>> states;
+  states.push_back(CreateState(1.0));
+  constexpr int kMaxBeamSize = 8;
+  Beam<TestTransitionState> beam(kMaxBeamSize);
+  beam.SetFunctions(null_permissions, null_finality, transition_function,
+                    null_oracle);
+  beam.Init(std::move(states));
+
+  // Validate the new beam history.
+  EXPECT_EQ(beam.history().size(), 1);
+
+  // Reset the beam.
+  constexpr int kNewMaxBeamSize = 4;
+  beam.SetMaxSize(kNewMaxBeamSize);
+  EXPECT_EQ(beam.max_size(), kNewMaxBeamSize);
+
+  // Validate the now-reset beam history, which should be empty.
+  EXPECT_EQ(beam.history().size(), 0);
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 8 - 0
syntaxnet/dragnn/core/component_registry.cc

@@ -0,0 +1,8 @@
+#include "dragnn/core/component_registry.h"
+
+namespace syntaxnet {
+
+// Class registry for DRAGNN components.
+REGISTER_SYNTAXNET_CLASS_REGISTRY("DRAGNN Component", dragnn::Component);
+
+}  // namespace syntaxnet

+ 14 - 0
syntaxnet/dragnn/core/component_registry.h

@@ -0,0 +1,14 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
+
+#include "dragnn/core/interfaces/component.h"
+#include "syntaxnet/registry.h"
+
+// Macro to add a component to the registry. This macro associates a class with
+// its class name as a string, so FooComponent would be associated with the
+// string "FooComponent".
+#define REGISTER_DRAGNN_COMPONENT(component)                                   \
+  REGISTER_SYNTAXNET_CLASS_COMPONENT(syntaxnet::dragnn::Component, #component, \
+                                     component)
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_

+ 120 - 0
syntaxnet/dragnn/core/compute_session.h

@@ -0,0 +1,120 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
+
+#include <string>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/index_translator.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/protos/spec.pb.h"
+#include "dragnn/protos/trace.pb.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// This defines the interface for a ComputeSession object. We only ever expect
+// ComputeSessionImpl to implement the ComputeSession - this is only used
+// to provide a mocking seam.
+
+class ComputeSession {
+ public:
+  virtual ~ComputeSession() {}
+
+  // Initialize this ComputeSession to compute the graph defined in the given
+  // MasterSpec with the hyperparameters passed in the GridPoint. This should
+  // only be called once, when the ComputeSession is created.
+  virtual void Init(const MasterSpec &master_spec,
+                    const GridPoint &hyperparams) = 0;
+
+  // Initialize a component with data and a given maximum beam
+  // size. Note that attempting to initialize a component that depends on
+  // another component that has not yet finished will cause a CHECK failure.
+  virtual void InitializeComponentData(const string &component_name,
+                                       int max_beam_size) = 0;
+
+  // Return the batch size for the given component.
+  virtual int BatchSize(const string &component_name) const = 0;
+
+  // Return the beam size for the given component.
+  virtual int BeamSize(const string &component_name) const = 0;
+
+  // Returns the spec used to create this ComputeSession.
+  virtual const ComponentSpec &Spec(const string &component_name) const = 0;
+
+  // For a given component and linked feature channel, get the beam size of the
+  // component that is the source of the linked features.
+  virtual int SourceComponentBeamSize(const string &component_name,
+                                      int channel_id) = 0;
+
+  // Advance the given component using the component's oracle.
+  virtual void AdvanceFromOracle(const string &component_name) = 0;
+
+  // Advance the given component using the given score matrix.
+  virtual void AdvanceFromPrediction(const string &component_name,
+                                     const float score_matrix[],
+                                     int score_matrix_length) = 0;
+
+  // Get the input features for the given component and channel. This passes
+  // through to the relevant Component's GetFixedFeatures() call.
+  virtual int GetInputFeatures(
+      const string &component_name,
+      std::function<int32 *(int num_items)> allocate_indices,
+      std::function<int64 *(int num_items)> allocate_ids,
+      std::function<float *(int num_items)> allocate_weights,
+      int channel_id) const = 0;
+
+  // Get the input features for the given component and channel, advancing via
+  // the oracle until the state is final. This passes through to the relevant
+  // Component's BulkGetFixedFeatures() call.
+  virtual int BulkGetInputFeatures(const string &component_name,
+                                   const BulkFeatureExtractor &extractor) = 0;
+
+  // Get the input features for the given component and channel. This function
+  // can return empty LinkFeatures protos, which represent unused padding slots
+  // in the output weight tensor.
+  virtual std::vector<LinkFeatures> GetTranslatedLinkFeatures(
+      const string &component_name, int channel_id) = 0;
+
+  // Get the oracle labels for the given component.
+  virtual std::vector<std::vector<int>> EmitOracleLabels(
+      const string &component_name) = 0;
+
+  // Returns true if the given component is terminal.
+  virtual bool IsTerminal(const string &component_name) = 0;
+
+  // Force the given component to write out its predictions to the backing data.
+  virtual void FinalizeData(const string &component_name) = 0;
+
+  // Return the finalized predictions from this compute session.
+  virtual std::vector<string> GetSerializedPredictions() = 0;
+
+  // Returns the trace protos. This will CHECK fail or be empty if the
+  // SetTracing() has not been called to initialize the underlying Component
+  // traces.
+  virtual std::vector<MasterTrace> GetTraceProtos() = 0;
+
+  // Provides the ComputeSession with a batch of data to compute.
+  virtual void SetInputData(const std::vector<string> &data) = 0;
+
+  // Resets all components owned by this ComputeSession.
+  virtual void ResetSession() = 0;
+
+  // Set the tracing for this ComputeSession.
+  virtual void SetTracing(bool tracing_on) = 0;
+
+  // Returns a unique identifier for this ComputeSession.
+  virtual int Id() const = 0;
+
+  // Returns a string describing the given component.
+  virtual string GetDescription(const string &component_name) const = 0;
+
+  // Get all the translators for the given component. Should only be used to
+  // validate correct construction of translators in tests.
+  virtual const std::vector<const IndexTranslator *> Translators(
+      const string &component_name) const = 0;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_

+ 384 - 0
syntaxnet/dragnn/core/compute_session_impl.cc

@@ -0,0 +1,384 @@
+#include "dragnn/core/compute_session_impl.h"
+
+#include <algorithm>
+#include <utility>
+
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "dragnn/protos/trace.pb.h"
+#include "syntaxnet/registry.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+ComputeSessionImpl::ComputeSessionImpl(
+    int id,
+    std::function<std::unique_ptr<Component>(const string &component_name,
+                                             const string &backend_type)>
+        component_builder)
+    : component_builder_(std::move(component_builder)), id_(id) {}
+
+void ComputeSessionImpl::Init(const MasterSpec &master_spec,
+                              const GridPoint &hyperparams) {
+  spec_ = master_spec;
+  grid_point_ = hyperparams;
+
+  VLOG(2) << "Creating components.";
+  bool is_input = true;
+  Component *predecessor;
+  for (const ComponentSpec &spec : master_spec.component()) {
+    // Construct the component using the specified backend.
+    VLOG(2) << "Creating component '" << spec.name()
+            << "' with backend: " << spec.backend().registered_name();
+    auto component =
+        component_builder_(spec.name(), spec.backend().registered_name());
+
+    // Initializes the component.
+    component->InitializeComponent(spec);
+
+    // Adds a predecessor to non-input components.
+    if (!is_input) {
+      predecessors_.insert(
+          std::pair<Component *, Component *>(component.get(), predecessor));
+    }
+
+    // The current component will be the predecessor component next time around.
+    predecessor = component.get();
+
+    // All components after the first are non-input components.
+    is_input = false;
+
+    // Move into components list.
+    components_.insert(std::pair<string, std::unique_ptr<Component>>(
+        spec.name(), std::move(component)));
+  }
+  VLOG(2) << "Done creating components.";
+
+  VLOG(2) << "Adding translators.";
+  for (const ComponentSpec &spec : master_spec.component()) {
+    // First, get the component object for this spec.
+    VLOG(2) << "Examining component: " << spec.name();
+    auto map_result = components_.find(spec.name());
+    CHECK(map_result != components_.end()) << "Unable to find component.";
+    Component *start_component = map_result->second.get();
+
+    if (spec.linked_feature_size() > 0) {
+      VLOG(2) << "Adding " << spec.linked_feature_size() << " translators for "
+              << spec.name();
+
+      // Attach all the translators described in the spec.
+      std::vector<IndexTranslator *> translator_set;
+      for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
+        // For every translator, save off a non-unique ptr in the component name
+        // to translator map, then push the unique ptr onto the management
+        // vector.
+        auto translator = CreateTranslator(channel, start_component);
+        translator_set.push_back(translator.get());
+        owned_translators_.push_back(std::move(translator));
+      }
+
+      // Once all translators have been created, associate this group of
+      // translators with a component.
+      translators_.insert(std::pair<string, std::vector<IndexTranslator *>>(
+          spec.name(), std::move(translator_set)));
+    } else {
+      VLOG(2) << "No translators found for " << spec.name();
+    }
+  }
+  VLOG(2) << "Done adding translators.";
+
+  VLOG(2) << "Initialization complete.";
+}
+
+void ComputeSessionImpl::InitializeComponentData(const string &component_name,
+                                                 int max_beam_size) {
+  CHECK(input_data_ != nullptr) << "Attempted to access a component without "
+                                   "providing input data for this session.";
+  Component *component = GetComponent(component_name);
+
+  // Try and find the source component. If one exists, check that it is terminal
+  // and get its data; if not, pass in an empty vector for source data.
+  auto source_result = predecessors_.find(component);
+  if (source_result == predecessors_.end()) {
+    VLOG(1) << "Source result not found. Using empty initialization vector for "
+            << component_name;
+    component->InitializeData({}, max_beam_size, input_data_.get());
+  } else {
+    VLOG(1) << "Source result found. Using prior initialization vector for "
+            << component_name;
+    auto source = source_result->second;
+    CHECK(source->IsTerminal()) << "Source is not terminal for component '"
+                                << component_name << "'. Exiting.";
+    component->InitializeData(source->GetBeam(), max_beam_size,
+                              input_data_.get());
+  }
+  if (do_tracing_) {
+    component->InitializeTracing();
+  }
+}
+
+int ComputeSessionImpl::BatchSize(const string &component_name) const {
+  return GetReadiedComponent(component_name)->BatchSize();
+}
+
+int ComputeSessionImpl::BeamSize(const string &component_name) const {
+  return GetReadiedComponent(component_name)->BeamSize();
+}
+
+const ComponentSpec &ComputeSessionImpl::Spec(
+    const string &component_name) const {
+  for (const auto &component : spec_.component()) {
+    if (component.name() == component_name) {
+      return component;
+    }
+  }
+  LOG(FATAL) << "Missing component '" << component_name << "'. Exiting.";
+}
+
+int ComputeSessionImpl::SourceComponentBeamSize(const string &component_name,
+                                                int channel_id) {
+  const auto &translators = GetTranslators(component_name);
+  return translators.at(channel_id)->path().back()->BeamSize();
+}
+
+void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
+  GetReadiedComponent(component_name)->AdvanceFromOracle();
+}
+
+void ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
+                                               const float score_matrix[],
+                                               int score_matrix_length) {
+  GetReadiedComponent(component_name)
+      ->AdvanceFromPrediction(score_matrix, score_matrix_length);
+}
+
+int ComputeSessionImpl::GetInputFeatures(
+    const string &component_name, std::function<int32 *(int)> allocate_indices,
+    std::function<int64 *(int)> allocate_ids,
+    std::function<float *(int)> allocate_weights, int channel_id) const {
+  return GetReadiedComponent(component_name)
+      ->GetFixedFeatures(allocate_indices, allocate_ids, allocate_weights,
+                         channel_id);
+}
+
+int ComputeSessionImpl::BulkGetInputFeatures(
+    const string &component_name, const BulkFeatureExtractor &extractor) {
+  return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor);
+}
+
+std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
+    const string &component_name, int channel_id) {
+  auto *component = GetReadiedComponent(component_name);
+  auto features = component->GetRawLinkFeatures(channel_id);
+
+  IndexTranslator *translator = GetTranslators(component_name).at(channel_id);
+  for (int i = 0; i < features.size(); ++i) {
+    LinkFeatures &feature = features[i];
+    if (feature.has_feature_value()) {
+      VLOG(2) << "Raw feature[" << i << "]: " << feature.ShortDebugString();
+      IndexTranslator::Index index = translator->Translate(
+          feature.batch_idx(), feature.beam_idx(), feature.feature_value());
+      feature.set_step_idx(index.step_index);
+      feature.set_batch_idx(index.batch_index);
+      feature.set_beam_idx(index.beam_index);
+    } else {
+      VLOG(2) << "Raw feature[" << i << "]: PADDING (empty proto)";
+    }
+  }
+
+  // Add the translated link features to the component's trace.
+  if (do_tracing_) {
+    component->AddTranslatedLinkFeaturesToTrace(features, channel_id);
+  }
+
+  return features;
+}
+std::vector<std::vector<int>> ComputeSessionImpl::EmitOracleLabels(
+    const string &component_name) {
+  return GetReadiedComponent(component_name)->GetOracleLabels();
+}
+
+bool ComputeSessionImpl::IsTerminal(const string &component_name) {
+  return GetReadiedComponent(component_name)->IsTerminal();
+}
+
+void ComputeSessionImpl::SetTracing(bool tracing_on) {
+  do_tracing_ = tracing_on;
+  for (auto &component_pair : components_) {
+    if (!tracing_on) {
+      component_pair.second->DisableTracing();
+    }
+  }
+}
+
+void ComputeSessionImpl::FinalizeData(const string &component_name) {
+  VLOG(2) << "Finalizing data for " << component_name;
+  GetReadiedComponent(component_name)->FinalizeData();
+}
+
+std::vector<string> ComputeSessionImpl::GetSerializedPredictions() {
+  VLOG(2) << "Geting serialized predictions.";
+  return input_data_->SerializedData();
+}
+
+std::vector<MasterTrace> ComputeSessionImpl::GetTraceProtos() {
+  std::vector<MasterTrace> traces;
+
+  // First compute all possible traces for each component.
+  std::map<string, std::vector<std::vector<ComponentTrace>>> component_traces;
+  std::vector<string> pipeline;
+  for (auto &component_spec : spec_.component()) {
+    pipeline.push_back(component_spec.name());
+    component_traces.insert(
+        {component_spec.name(),
+         GetComponent(component_spec.name())->GetTraceProtos()});
+  }
+
+  // Only output for the actual number of states in each beam.
+  auto final_beam = GetComponent(pipeline.back())->GetBeam();
+  for (int batch_idx = 0; batch_idx < final_beam.size(); ++batch_idx) {
+    for (int beam_idx = 0; beam_idx < final_beam[batch_idx].size();
+         ++beam_idx) {
+      std::vector<int> beam_path;
+      beam_path.push_back(beam_idx);
+
+      // Trace components backwards, finding the source of each state in the
+      // prior component.
+      VLOG(2) << "Start trace: " << beam_idx;
+      for (int i = pipeline.size() - 1; i > 0; --i) {
+        const auto *component = GetComponent(pipeline[i]);
+        int source_beam_idx =
+            component->GetSourceBeamIndex(beam_path.back(), batch_idx);
+        beam_path.push_back(source_beam_idx);
+
+        VLOG(2) << "Tracing path: " << pipeline[i] << " = " << source_beam_idx;
+      }
+
+      // Trace the path from the *start* to the end.
+      std::reverse(beam_path.begin(), beam_path.end());
+      MasterTrace master_trace;
+      for (int i = 0; i < pipeline.size(); ++i) {
+        *master_trace.add_component_trace() =
+            component_traces[pipeline[i]][batch_idx][beam_path[i]];
+      }
+      traces.push_back(master_trace);
+    }
+  }
+
+  return traces;
+}
+
+void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
+  input_data_.reset(new InputBatchCache(data));
+}
+
+void ComputeSessionImpl::ResetSession() {
+  // Reset all component states.
+  for (auto &component_pair : components_) {
+    component_pair.second->ResetComponent();
+  }
+
+  // Reset the input data pointer.
+  input_data_.reset();
+}
+
+int ComputeSessionImpl::Id() const { return id_; }
+
+string ComputeSessionImpl::GetDescription(const string &component_name) const {
+  return GetComponent(component_name)->Name();
+}
+
+const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
+    const string &component_name) const {
+  auto translators = GetTranslators(component_name);
+  std::vector<const IndexTranslator *> const_translators;
+  for (const auto &translator : translators) {
+    const_translators.push_back(translator);
+  }
+  return const_translators;
+}
+
+Component *ComputeSessionImpl::GetReadiedComponent(
+    const string &component_name) const {
+  auto component = GetComponent(component_name);
+  CHECK(component->IsReady())
+      << "Attempted to access component " << component_name
+      << " without first initializing it.";
+  return component;
+}
+
+Component *ComputeSessionImpl::GetComponent(
+    const string &component_name) const {
+  auto result = components_.find(component_name);
+  if (result == components_.end()) {
+    LOG(ERROR) << "Could not find component \"" << component_name
+               << "\" in the component set. Current components are: ";
+    for (const auto &component_pair : components_) {
+      LOG(ERROR) << component_pair.first;
+    }
+    LOG(FATAL) << "Missing component. Exiting.";
+  }
+
+  auto component = result->second.get();
+  return component;
+}
+
+const std::vector<IndexTranslator *> &ComputeSessionImpl::GetTranslators(
+    const string &component_name) const {
+  auto result = translators_.find(component_name);
+  if (result == translators_.end()) {
+    LOG(ERROR) << "Could not find component " << component_name
+               << " in the translator set. Current components are: ";
+    for (const auto &component_pair : translators_) {
+      LOG(ERROR) << component_pair.first;
+    }
+    LOG(FATAL) << "Missing component. Exiting.";
+  }
+  return result->second;
+}
+
+std::unique_ptr<IndexTranslator> ComputeSessionImpl::CreateTranslator(
+    const LinkedFeatureChannel &channel, Component *start_component) {
+  const int num_components = spec_.component_size();
+  VLOG(2) << "Channel spec: " << channel.ShortDebugString();
+
+  // Find the linked feature's source component, if it exists.
+  auto source_map_result = components_.find(channel.source_component());
+  CHECK(source_map_result != components_.end())
+      << "Unable to find source component " << channel.source_component();
+  const Component *end_component = source_map_result->second.get();
+
+  // Our goal here is to iterate up the source map from the
+  // start_component to the end_component.
+  Component *current_component = start_component;
+  std::vector<Component *> path;
+  path.push_back(current_component);
+  while (current_component != end_component) {
+    // Try to find the next link upwards in the source chain.
+    auto source_result = predecessors_.find(current_component);
+
+    // If this component doesn't have a source to find, that's an error.
+    CHECK(source_result != predecessors_.end())
+        << "No link to source " << channel.source_component();
+
+    // If we jump more times than there are components in the graph, that
+    // is an error state.
+    CHECK_LT(path.size(), num_components) << "Too many jumps. Is there a "
+                                             "loop in the MasterSpec "
+                                             "component definition?";
+
+    // Add the source to the vector and repeat.
+    path.push_back(source_result->second);
+    current_component = source_result->second;
+  }
+
+  // At this point, we have the source chain for the traslator and can
+  // build it.
+  std::unique_ptr<IndexTranslator> translator(
+      new IndexTranslator(path, channel.source_translator()));
+  return translator;
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 142 - 0
syntaxnet/dragnn/core/compute_session_impl.h

@@ -0,0 +1,142 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
+
+#include <memory>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/compute_session.h"
+#include "dragnn/core/index_translator.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "dragnn/protos/trace.pb.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class ComputeSessionImpl : public ComputeSession {
+ public:
+  // Creates a ComputeSessionImpl with the provided component builder function.
+  ComputeSessionImpl(
+      int id,
+      std::function<std::unique_ptr<Component>(const string &component_name,
+                                               const string &backend_type)>
+          component_builder);
+
+  void Init(const MasterSpec &master_spec,
+            const GridPoint &hyperparams) override;
+
+  void InitializeComponentData(const string &component_name,
+                               int max_beam_size) override;
+
+  int BatchSize(const string &component_name) const override;
+
+  int BeamSize(const string &component_name) const override;
+
+  const ComponentSpec &Spec(const string &component_name) const override;
+
+  int SourceComponentBeamSize(const string &component_name,
+                              int channel_id) override;
+
+  void AdvanceFromOracle(const string &component_name) override;
+
+  void AdvanceFromPrediction(const string &component_name,
+                             const float score_matrix[],
+                             int score_matrix_length) override;
+
+  int GetInputFeatures(const string &component_name,
+                       std::function<int32 *(int)> allocate_indices,
+                       std::function<int64 *(int)> allocate_ids,
+                       std::function<float *(int)> allocate_weights,
+                       int channel_id) const override;
+
+  int BulkGetInputFeatures(const string &component_name,
+                           const BulkFeatureExtractor &extractor) override;
+
+  std::vector<LinkFeatures> GetTranslatedLinkFeatures(
+      const string &component_name, int channel_id) override;
+
+  std::vector<std::vector<int>> EmitOracleLabels(
+      const string &component_name) override;
+
+  bool IsTerminal(const string &component_name) override;
+
+  void FinalizeData(const string &component_name) override;
+
+  std::vector<string> GetSerializedPredictions() override;
+
+  std::vector<MasterTrace> GetTraceProtos() override;
+
+  void SetInputData(const std::vector<string> &data) override;
+
+  void ResetSession() override;
+
+  void SetTracing(bool tracing_on) override;
+
+  int Id() const override;
+
+  string GetDescription(const string &component_name) const override;
+
+  const std::vector<const IndexTranslator *> Translators(
+      const string &component_name) const override;
+
+ private:
+  // Get a given component. Fails if the component is not found.
+  Component *GetComponent(const string &component_name) const;
+
+  // Get a given component. CHECK-fail if the component's IsReady method
+  // returns false.
+  Component *GetReadiedComponent(const string &component_name) const;
+
+  // Get the index translators for the given component.
+  const std::vector<IndexTranslator *> &GetTranslators(
+      const string &component_name) const;
+
+  // Create an index translator.
+  std::unique_ptr<IndexTranslator> CreateTranslator(
+      const LinkedFeatureChannel &channel, Component *start_component);
+
+  // Perform initialization on the given Component.
+  void InitComponent(Component *component);
+
+  // Holds all of the components owned by this ComputeSession, associated with
+  // their names in the MasterSpec.
+  std::map<string, std::unique_ptr<Component>> components_;
+
+  // Holds a vector of translators for each component, indexed by the name
+  // of the component they belong to.
+  std::map<string, std::vector<IndexTranslator *>> translators_;
+
+  // Holds ownership of all the IndexTranslators for this compute session.
+  std::vector<std::unique_ptr<IndexTranslator>> owned_translators_;
+
+  // The predecessor component for every component.
+  // If a component is not in this map, it has no predecessor component and
+  // will have its beam initialized without any data from other components.
+  std::map<Component *, Component *> predecessors_;
+
+  // Holds the current input data for this ComputeSession.
+  std::unique_ptr<InputBatchCache> input_data_;
+
+  // Function that, given a string, will return a Component.
+  std::function<std::unique_ptr<Component>(const string &component_name,
+                                           const string &backend_type)>
+      component_builder_;
+
+  // The master spec for this compute session.
+  MasterSpec spec_;
+
+  // The hyperparameters for this compute session.
+  GridPoint grid_point_;
+
+  // Unique identifier, assigned at construction.
+  int id_;
+
+  // Whether or not to perform tracing.
+  bool do_tracing_ = false;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_

File diff ditekan karena terlalu besar
+ 1157 - 0
syntaxnet/dragnn/core/compute_session_impl_test.cc


+ 89 - 0
syntaxnet/dragnn/core/compute_session_pool.cc

@@ -0,0 +1,89 @@
+#include "dragnn/core/compute_session_pool.h"
+
+#include <utility>
+
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/compute_session_impl.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::mutex_lock;
+
+ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
+                                       const GridPoint &hyperparams)
+    : master_spec_(master_spec),
+      hyperparams_(hyperparams),
+      num_unique_sessions_(0) {
+  // Create a default component builder function. This function looks up
+  // components in the component registry and returns them.
+  component_builder_ = [](
+      const string &component_name,
+      const string &backend_type) -> std::unique_ptr<Component> {
+    VLOG(2) << "Creating component " << component_name << " with backend "
+            << backend_type;
+    std::unique_ptr<Component> component(Component::Create(backend_type));
+    return component;
+  };
+
+  // Create a default session builder function. This function returns a
+  // ComputeSessionImpl that uses the currently set component_builder_
+  // function to create its components.
+  session_builder_ = [this]() {
+    return std::unique_ptr<ComputeSession>(
+        new ComputeSessionImpl(num_unique_sessions_, this->component_builder_));
+  };
+}
+
+ComputeSessionPool::~ComputeSessionPool() {
+  LOG(INFO) << "Destroying pool: total number of sessions created = "
+            << num_unique_sessions_;
+  if (sessions_.size() < num_unique_sessions_) {
+    LOG(WARNING) << "Destroying pool: number of unreturned sessions = "
+                 << (num_unique_sessions_ - sessions_.size());
+  }
+}
+
+void ComputeSessionPool::SetComputeSessionBuilder(
+    std::function<std::unique_ptr<ComputeSession>()> session_builder) {
+  mutex_lock lock(lock_);
+  session_builder_ = std::move(session_builder);
+}
+
+void ComputeSessionPool::SetComponentBuilder(
+    std::function<std::unique_ptr<Component>(const string &component_name,
+                                             const string &backend_type)>
+        component_builder) {
+  mutex_lock lock(lock_);
+  component_builder_ = std::move(component_builder);
+}
+
+std::unique_ptr<ComputeSession> ComputeSessionPool::GetSession() {
+  mutex_lock lock(lock_);
+  std::unique_ptr<ComputeSession> session_ptr;
+  if (sessions_.empty()) {
+    // There are no available sessions, so create and initialize one.
+    VLOG(2) << "Creating new session.";
+    session_ptr = session_builder_();
+    num_unique_sessions_++;
+    session_ptr->Init(master_spec_, hyperparams_);
+  } else {
+    // Get the last free session, and remove it from the free sessions vector.
+    VLOG(2) << "Reusing session from pool of size " << sessions_.size();
+    session_ptr = std::move(sessions_.back());
+    sessions_.pop_back();
+
+    session_ptr->ResetSession();
+  }
+  return session_ptr;
+}
+
+void ComputeSessionPool::ReturnSession(
+    std::unique_ptr<ComputeSession> session) {
+  mutex_lock lock(lock_);
+  sessions_.push_back(std::move(session));
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 87 - 0
syntaxnet/dragnn/core/compute_session_pool.h

@@ -0,0 +1,87 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
+
+#include <memory>
+
+#include "dragnn/core/compute_session.h"
+#include "dragnn/protos/spec.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// This pool creates and manages the reuse of ComputeSession objects.
+
+class ComputeSessionPool {
+ public:
+  // Create a ComputeSessionPool that creates ComputeSessions for the given
+  // MasterSpec and hyperparameters.
+  ComputeSessionPool(const MasterSpec &master_spec,
+                     const GridPoint &hyperparams);
+
+  virtual ~ComputeSessionPool();
+
+  // Get a ComputeSession. This function will attempt to use an already-created
+  // ComputeSession, but if none are available a new one will be created.
+  std::unique_ptr<ComputeSession> GetSession();
+
+  // Returns a ComputeSession to the backing pool.
+  void ReturnSession(std::unique_ptr<ComputeSession> session);
+
+  // Returns the count of outstanding unique sessions.
+  int num_outstanding_sessions() {
+    tensorflow::mutex_lock lock(lock_);
+    return num_unique_sessions_ - sessions_.size();
+  }
+
+ private:
+  friend class ComputeSessionImplTestPoolAccessor;
+  friend class ComputeSessionPoolTestPoolAccessor;
+
+  // This is a creational injection setter. It should be used for tests
+  // where we want our ComputeSessionPool to prepare and return
+  // MockComputeSessions instead of actual ComputeSessionImpls.
+  void SetComputeSessionBuilder(
+      std::function<std::unique_ptr<ComputeSession>()> session_builder);
+
+  // This injector will cause ComputeSessions built in this pool to use the
+  // passed function to create Components. This is useful when you want a
+  // ComputeSession to create MockComponents instead of real ones.
+  void SetComponentBuilder(
+      std::function<std::unique_ptr<Component>(const string &component_name,
+                                               const string &backend_type)>
+          component_builder);
+
+  // The MasterSpec that will be used to initialize ComputeSessions from this
+  // pool.
+  const MasterSpec master_spec_;
+
+  // The hyperparameters that will be used to initialize ComputeSessions from
+  // this pool.
+  const GridPoint hyperparams_;
+
+  // The function that is used to create ComputeSessions.
+  std::function<std::unique_ptr<ComputeSession>()> session_builder_;
+
+  // The function passed to ComputeSessions that will be used by that session
+  // to create components.
+  std::function<std::unique_ptr<Component>(const string &component_name,
+                                           const string &backend_type)>
+      component_builder_;
+
+  // ComputeSessions that are not currently being used. These sessions are not
+  // reset until they are requested by another thread.
+  std::vector<std::unique_ptr<ComputeSession>> sessions_;
+
+  // Count of the number of unique ComputeSession objects that have been
+  // created. Used to assign IDs to new Sessions.
+  int num_unique_sessions_;
+
+  // Mutex that protects accesses to all members of this object.
+  tensorflow::mutex lock_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_

+ 211 - 0
syntaxnet/dragnn/core/compute_session_pool_test.cc

@@ -0,0 +1,211 @@
+#include "dragnn/core/compute_session_pool.h"
+
+#include <memory>
+
+#include <gmock/gmock.h>
+
+#include "dragnn/core/compute_session.h"
+#include "dragnn/core/test/generic.h"
+#include "dragnn/core/test/mock_component.h"
+#include "dragnn/core/test/mock_compute_session.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using syntaxnet::test::EqualsProto;
+using testing::Return;
+using testing::Invoke;
+using testing::MockFunction;
+
+class ComputeSessionPoolTestPoolAccessor {
+ public:
+  static void SetComponentBuilder(
+      ComputeSessionPool *pool,
+      std::function<std::unique_ptr<Component>(const string &component_name,
+                                               const string &backend_type)>
+          component_builder_function) {
+    pool->SetComponentBuilder(std::move(component_builder_function));
+  }
+
+  static void SetSessionBuilder(ComputeSessionPool *pool,
+                                std::function<std::unique_ptr<ComputeSession>()>
+                                    session_builder_function) {
+    pool->SetComputeSessionBuilder(std::move(session_builder_function));
+  }
+};
+
+TEST(ComputeSessionPoolTest, DefaultConstructorWorks) {
+  MasterSpec spec;
+  GridPoint hyperparams;
+  ComputeSessionPool pool(spec, hyperparams);
+  auto request = pool.GetSession();
+  EXPECT_NE(request, nullptr);
+}
+
+TEST(ComputeSessionPoolTest, ComponentBuilderInjectionWorks) {
+  MasterSpec spec;
+  auto component = spec.add_component();
+  component->set_name("test_component_name");
+  auto backend = component->mutable_backend();
+  backend->set_registered_name("arbitrary_component");
+  GridPoint hyperparams;
+
+  ComputeSessionPool pool(spec, hyperparams);
+
+  // Set up a mock component builder.
+  MockFunction<std::unique_ptr<Component>(const string &component_name,
+                                          const string &backend_type)>
+      mock_component_builder;
+  auto mock_creation_function = [](string, string) {
+    return std::unique_ptr<MockComponent>(new MockComponent());
+  };
+  EXPECT_CALL(mock_component_builder,
+              Call("test_component_name", "arbitrary_component"))
+      .WillOnce(Invoke(mock_creation_function));
+  ComputeSessionPoolTestPoolAccessor::SetComponentBuilder(
+      &pool, mock_component_builder.AsStdFunction());
+
+  // Now, when the session is requested, the mock component builder should see
+  // the expected call.
+  auto request = pool.GetSession();
+  EXPECT_NE(request, nullptr);
+}
+
+TEST(ComputeSessionPoolTest, CreatesNewSessionIfNoSessionsExist) {
+  // We don't need to fill these for this test.
+  MasterSpec spec;
+  GridPoint hyperparams;
+  ComputeSessionPool pool(spec, hyperparams);
+
+  // Create a function that will track calls to the session builder.
+  MockFunction<std::unique_ptr<ComputeSession>()> mock_session_builder;
+
+  // Initialize expectations for a request for a ComputeSession.
+  std::unique_ptr<MockComputeSession> session_one(new MockComputeSession());
+  MockComputeSession *session_one_ptr = session_one.get();
+  auto mock_creation_function = [&session_one]() {
+    return std::move(session_one);
+  };
+  EXPECT_CALL(mock_session_builder, Call())
+      .WillOnce(Invoke(mock_creation_function))
+      .RetiresOnSaturation();
+  EXPECT_CALL(*session_one_ptr,
+              Init(EqualsProto(spec), EqualsProto(hyperparams)));
+
+  // Initialize expectations for another request for a ComputeSession.
+  std::unique_ptr<MockComputeSession> session_two(new MockComputeSession());
+  MockComputeSession *session_two_ptr = session_two.get();
+  auto mock_creation_function_two = [&session_two]() {
+    return std::move(session_two);
+  };
+  EXPECT_CALL(mock_session_builder, Call())
+      .WillOnce(Invoke(mock_creation_function_two))
+      .RetiresOnSaturation();
+  EXPECT_CALL(*session_two_ptr,
+              Init(EqualsProto(spec), EqualsProto(hyperparams)));
+
+  // Inject the function to the pool.
+  ComputeSessionPoolTestPoolAccessor::SetSessionBuilder(
+      &pool, mock_session_builder.AsStdFunction());
+
+  // The first call will recieve the second session because of how the mocks go.
+  auto first_request = pool.GetSession();
+  EXPECT_EQ(first_request.get(), session_two_ptr);
+
+  auto second_request = pool.GetSession();
+  EXPECT_EQ(second_request.get(), session_one_ptr);
+}
+
+TEST(ComputeSessionPoolTest, ReusesAvailableSessions) {
+  // We don't need to fill these for this test.
+  MasterSpec spec;
+  GridPoint hyperparams;
+  ComputeSessionPool pool(spec, hyperparams);
+
+  // Create a function that will track calls to the session builder.
+  MockFunction<std::unique_ptr<ComputeSession>()> mock_session_builder;
+
+  // Initialize expectations for a request for a ComputeSession.
+  std::unique_ptr<MockComputeSession> session_one(new MockComputeSession());
+  MockComputeSession *session_one_ptr = session_one.get();
+  auto mock_creation_function = [&session_one]() {
+    return std::move(session_one);
+  };
+  EXPECT_CALL(mock_session_builder, Call())
+      .WillOnce(Invoke(mock_creation_function))
+      .RetiresOnSaturation();
+  EXPECT_CALL(*session_one_ptr,
+              Init(EqualsProto(spec), EqualsProto(hyperparams)));
+
+  // Initialize expectations for another request for a ComputeSession.
+  std::unique_ptr<MockComputeSession> session_two(new MockComputeSession());
+  MockComputeSession *session_two_ptr = session_two.get();
+  auto mock_creation_function_two = [&session_two]() {
+    return std::move(session_two);
+  };
+  EXPECT_CALL(mock_session_builder, Call())
+      .WillOnce(Invoke(mock_creation_function_two))
+      .RetiresOnSaturation();
+  EXPECT_CALL(*session_two_ptr,
+              Init(EqualsProto(spec), EqualsProto(hyperparams)));
+
+  // Inject the function to the pool.
+  ComputeSessionPoolTestPoolAccessor::SetSessionBuilder(
+      &pool, mock_session_builder.AsStdFunction());
+
+  // The first call will recieve the second session because of how the mocks go.
+  auto first_request = pool.GetSession();
+  EXPECT_EQ(1, pool.num_outstanding_sessions());
+  EXPECT_EQ(first_request.get(), session_two_ptr);
+
+  // Return the first pointer. After this, the second request should get that
+  // pointer.
+  EXPECT_CALL(*session_two_ptr, ResetSession());
+  pool.ReturnSession(std::move(first_request));
+  EXPECT_EQ(0, pool.num_outstanding_sessions());
+  auto second_request = pool.GetSession();
+  EXPECT_EQ(1, pool.num_outstanding_sessions());
+  EXPECT_EQ(second_request.get(), session_two_ptr);
+
+  // There are now no spare sessions, so the next session request should
+  // create a second session.
+  auto third_request = pool.GetSession();
+  EXPECT_EQ(2, pool.num_outstanding_sessions());
+  EXPECT_EQ(third_request.get(), session_one_ptr);
+}
+
+TEST(ComputeSessionPoolTest, AssignsUniqueIds) {
+  MasterSpec spec;
+  GridPoint hyperparams;
+  ComputeSessionPool pool(spec, hyperparams);
+  auto session = pool.GetSession();
+  auto session_2 = pool.GetSession();
+  EXPECT_NE(session->Id(), session_2->Id());
+}
+
+TEST(ComputeSessionPoolTest, SupportsMultithreadedAccess) {
+  MasterSpec spec;
+  GridPoint hyperparams;
+  ComputeSessionPool pool(spec, hyperparams);
+
+  std::vector<std::unique_ptr<tensorflow::Thread>> request_threads;
+  constexpr int kNumThreadsToTest = 100;
+  for (int i = 0; i < kNumThreadsToTest; ++i) {
+    request_threads.push_back(std::unique_ptr<tensorflow::Thread>(
+        tensorflow::Env::Default()->StartThread(
+            tensorflow::ThreadOptions(), "thread",
+            [this, &pool] { auto session = pool.GetSession(); })));
+  }
+
+  // Deleting a tensorflow::Thread blocks until the thread exits,
+  // so clearing the vector blocks until all threads have exited.
+  request_threads.clear();
+
+  // Make sure all the threads got their session.
+  EXPECT_EQ(kNumThreadsToTest, pool.num_outstanding_sessions());
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 67 - 0
syntaxnet/dragnn/core/index_translator.cc

@@ -0,0 +1,67 @@
+#include "dragnn/core/index_translator.h"
+
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using Index = IndexTranslator::Index;
+
+IndexTranslator::IndexTranslator(const std::vector<Component *> &path,
+                                 const string &method)
+    : path_(path), method_(method) {
+  if (method_ == "identity") {
+    // Identity lookup: Return the feature index.
+    step_lookup_ = [](int batch_index, int beam_index, int feature) {
+      return feature;
+    };
+  } else if (method_ == "history") {
+    // History lookup: Return the number of steps taken less the feature.
+    step_lookup_ = [this](int batch_index, int beam_index, int feature) {
+      if (feature > path_.back()->StepsTaken(batch_index) - 1) {
+        VLOG(2) << "Translation to outside: feature is " << feature
+                << " and steps_taken is "
+                << path_.back()->StepsTaken(batch_index);
+        return -1;
+      }
+      return ((path_.back()->StepsTaken(batch_index) - 1) - feature);
+    };
+  } else {
+    // Component defined lookup: Get the lookup function from the component.
+    // If the lookup function is not defined, this function will CHECK.
+    step_lookup_ = path_.back()->GetStepLookupFunction(method_);
+  }
+}
+
+Index IndexTranslator::Translate(int batch_index, int beam_index,
+                                 int feature_value) {
+  Index translated_index;
+  translated_index.batch_index = batch_index;
+  VLOG(2) << "Translation requested (type: " << method_ << ") for batch "
+          << batch_index << " beam " << beam_index << " feature "
+          << feature_value;
+
+  // For all save the last item in the path, get the source index for the
+  // previous component.
+  int current_beam_index = beam_index;
+  VLOG(2) << "Beam index before walk is " << current_beam_index;
+  for (int i = 0; i < path_.size() - 1; ++i) {
+    // Backtrack through previous components. For each non-final component,
+    // figure out what state in the prior component was used to initialize the
+    // state at the current beam index.
+    current_beam_index =
+        path_.at(i)->GetSourceBeamIndex(current_beam_index, batch_index);
+    VLOG(2) << "Beam index updated to " << current_beam_index;
+  }
+  VLOG(2) << "Beam index after walk is " << current_beam_index;
+  translated_index.step_index =
+      step_lookup_(batch_index, current_beam_index, feature_value);
+  VLOG(2) << "Translated step index is " << translated_index.step_index;
+  translated_index.beam_index = path_.back()->GetBeamIndexAtStep(
+      translated_index.step_index, current_beam_index, batch_index);
+  VLOG(2) << "Translated beam index is " << translated_index.beam_index;
+  return translated_index;
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 68 - 0
syntaxnet/dragnn/core/index_translator.h

@@ -0,0 +1,68 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
+
+#include <memory>
+#include <vector>
+
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// A IndexTranslator provides an interface into the data of another component.
+// It allows one component to look up a translated array index from the history
+// or state of another component.
+//
+// When it is created, it is passed a pointer to the source component (that is,
+// the component whose data it will be accessing) and a string representing the
+// type of data access it will perform. There are two universal data access
+// methods - "identity" and "history" - and components can declare more via
+// their GetStepLookupFunction function.
+
+class IndexTranslator {
+ public:
+  // Index into a TensorArray. Provides a given step, and the beam index within
+  // that step, for TensorArray access to data in the given batch.
+  struct Index {
+    int batch_index = -1;
+    int beam_index = -1;
+    int step_index = -1;
+  };
+
+  // Creates a new IndexTranslator with access method as determined by the
+  // passed string. The Translator will walk the path "path" in order, and will
+  // translate from the last Component in the path.
+  IndexTranslator(const std::vector<Component *> &path, const string &method);
+
+  // Returns an index in (step, beam, batch) index space as computed from the
+  // given feature value.
+  Index Translate(int batch_index, int beam_index, int feature_value);
+
+  // Returns the path to be walked by this translator.
+  const std::vector<Component *> &path() const { return path_; }
+
+  // Returns the method to be used by this translator.
+  const string &method() const { return method_; }
+
+ private:
+  // The ordered list of components that must be walked to get from the
+  // requesting component to the source component. This vector has the
+  // requesting component at index 0 and the source component at the end. If
+  // the requesting component is the source component, this vector has only one
+  // entry.
+  const std::vector<Component *> path_;
+
+  // The function this translator will use to look up the step in the source
+  // component. The function is invoked as:
+  // step_lookup_(batch_index, beam_index, feature).
+  std::function<int(int, int, int)> step_lookup_;
+
+  // This translator's method.
+  string method_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_

+ 180 - 0
syntaxnet/dragnn/core/index_translator_test.cc

@@ -0,0 +1,180 @@
+#include "dragnn/core/index_translator.h"
+
+#include "dragnn/core/test/mock_component.h"
+#include "dragnn/core/test/mock_transition_state.h"
+#include <gmock/gmock.h>
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using testing::MockFunction;
+using testing::Return;
+
+TEST(IndexTranslatorTest, PerformsIdentityTranslation) {
+  MockComponent mock_component;
+
+  // We are testing the Identity lookup with a single component (so, self-
+  // referencing) and thus we expect the translator to call GetBeamIndexAtStep
+  // for the step and index we pass in.
+  constexpr int kBeam = 4;
+  constexpr int kFeature = 2;
+  constexpr int kResultIndex = 3;
+  constexpr int kBatch = 99;
+  EXPECT_CALL(mock_component, GetBeamIndexAtStep(kFeature, kBeam, kBatch))
+      .WillOnce(Return(kResultIndex));
+
+  // Execute!
+  IndexTranslator translator({&mock_component}, "identity");
+  auto result = translator.Translate(kBatch, kBeam, kFeature);
+  EXPECT_EQ(kResultIndex, result.beam_index);
+  EXPECT_EQ(kFeature, result.step_index);
+  EXPECT_EQ(kBatch, result.batch_index);
+}
+
+TEST(IndexTranslatorTest, PerformsHistoryTranslation) {
+  MockComponent mock_component;
+
+  // We are testing the History lookup with a single component (so, self-
+  // referencing) and thus we expect the translator to call StepsTaken() to get
+  // the number of steps taken and GetBeamIndexAtStep with (total-desired).
+  constexpr int kBeam = 4;
+  constexpr int kFeature = 2;
+  constexpr int kTotalNumberSteps = 8;
+  constexpr int kBatch = 99;
+
+  // Here, the expected step result is two in from the final index, so
+  // (8-1) - 2, or 5.
+  constexpr int kExpectedResult = 5;
+  constexpr int kResultIndex = 3;
+  EXPECT_CALL(mock_component, StepsTaken(kBatch))
+      .WillRepeatedly(Return(kTotalNumberSteps));
+  EXPECT_CALL(mock_component,
+              GetBeamIndexAtStep(kExpectedResult, kBeam, kBatch))
+      .WillOnce(Return(kResultIndex));
+
+  // Execute!
+  IndexTranslator translator({&mock_component}, "history");
+  auto result = translator.Translate(kBatch, kBeam, kFeature);
+  EXPECT_EQ(kResultIndex, result.beam_index);
+  EXPECT_EQ(kExpectedResult, result.step_index);
+  EXPECT_EQ(kBatch, result.batch_index);
+}
+
+TEST(IndexTranslatorTest, TraversesPathToLookup) {
+  MockComponent mock_component_a;
+  MockComponent mock_component_b;
+  MockComponent mock_component_c;
+  constexpr int kBatch = 99;
+
+  // The translator should request the source index from mock component A.
+  constexpr int kBeam = 4;
+  constexpr int kSourceBIndex = 3;
+  EXPECT_CALL(mock_component_a, GetSourceBeamIndex(kBeam, kBatch))
+      .WillOnce(Return(kSourceBIndex));
+
+  // The translator should use the source index from A in a source index request
+  // to component B.
+  constexpr int kSourceCIndex = 17;
+  EXPECT_CALL(mock_component_b, GetSourceBeamIndex(kSourceBIndex, kBatch))
+      .WillOnce(Return(kSourceCIndex));
+
+  // The translator should request the beam index at the requested step in
+  // component C, using the beam index from the source index request to B.
+  constexpr int kFeature = 2;
+  constexpr int kResultIndex = 1157;
+
+  // This is testing with an identity translator, so kFeature == kStep.
+  EXPECT_CALL(mock_component_c,
+              GetBeamIndexAtStep(kFeature, kSourceCIndex, kBatch))
+      .WillOnce(Return(kResultIndex));
+
+  // Execute!
+  IndexTranslator translator(
+      {&mock_component_a, &mock_component_b, &mock_component_c}, "identity");
+  auto result = translator.Translate(kBatch, kBeam, kFeature);
+  EXPECT_EQ(kResultIndex, result.beam_index);
+  EXPECT_EQ(kFeature, result.step_index);
+  EXPECT_EQ(kBatch, result.batch_index);
+}
+
+TEST(IndexTranslatorTest, RequestsArbitraryTranslationFunction) {
+  MockComponent mock_component;
+  MockFunction<int(int, int, int)> mock_function;
+
+  // This test ensures that we can get an arbitrary translation function
+  // from the component and execute it properly.
+  constexpr int kBeam = 4;
+  constexpr int kFeature = 2;
+  constexpr int kFunctionResult = 10;
+  constexpr int kResultIndex = 3;
+  constexpr int kBatch = 99;
+
+  // The arbitrary function should be called with the desired input.
+  EXPECT_CALL(mock_function, Call(kBatch, kBeam, kFeature))
+      .WillOnce(Return(kFunctionResult));
+
+  // The translator should request the function from the component.
+  EXPECT_CALL(mock_component, GetStepLookupFunction("arbitrary_function"))
+      .WillOnce(Return(mock_function.AsStdFunction()));
+
+  // The translator should call GetBeamIndexAtStep with the result of calling
+  // the function.
+  EXPECT_CALL(mock_component,
+              GetBeamIndexAtStep(kFunctionResult, kBeam, kBatch))
+      .WillOnce(Return(kResultIndex));
+
+  // Execute!
+  IndexTranslator translator({&mock_component}, "arbitrary_function");
+  auto result = translator.Translate(kBatch, kBeam, kFeature);
+  EXPECT_EQ(kResultIndex, result.beam_index);
+  EXPECT_EQ(kFunctionResult, result.step_index);
+  EXPECT_EQ(kBatch, result.batch_index);
+}
+
+// This test ensures that the translation function is queried with the beam
+// index for that component, and that the translation function is taken from
+// the correct component.
+TEST(IndexTranslatorTest, RequestsArbitraryTranslationAcrossComponents) {
+  MockComponent mock_component_a;
+  MockComponent mock_component_b;
+  MockFunction<int(int, int, int)> mock_function;
+
+  // This test ensures that we can get an arbitrary translation function
+  // from the component and execute it properly.
+  constexpr int kFeature = 2;
+  constexpr int kFunctionResult = 10;
+  constexpr int kResultIndex = 3;
+  constexpr int kBatch = 99;
+
+  // The translator should request the source index from mock component A.
+  constexpr int kBeam = 4;
+  constexpr int kSourceBIndex = 3;
+  EXPECT_CALL(mock_component_a, GetSourceBeamIndex(kBeam, kBatch))
+      .WillOnce(Return(kSourceBIndex));
+
+  // The translator should request the function from the component.
+  EXPECT_CALL(mock_component_b, GetStepLookupFunction("arbitrary_function"))
+      .WillOnce(Return(mock_function.AsStdFunction()));
+
+  // The arbitrary function should be called with the desired input.
+  EXPECT_CALL(mock_function, Call(kBatch, kSourceBIndex, kFeature))
+      .WillOnce(Return(kFunctionResult));
+
+  // The translator should call GetBeamIndexAtStep with the result of calling
+  // the function.
+  EXPECT_CALL(mock_component_b,
+              GetBeamIndexAtStep(kFunctionResult, kSourceBIndex, kBatch))
+      .WillOnce(Return(kResultIndex));
+
+  // Execute!
+  IndexTranslator translator({&mock_component_a, &mock_component_b},
+                             "arbitrary_function");
+  auto result = translator.Translate(kBatch, kBeam, kFeature);
+  EXPECT_EQ(kResultIndex, result.beam_index);
+  EXPECT_EQ(kFunctionResult, result.step_index);
+  EXPECT_EQ(kBatch, result.batch_index);
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 78 - 0
syntaxnet/dragnn/core/input_batch_cache.h

@@ -0,0 +1,78 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
+
+#include <memory>
+#include <string>
+#include <typeindex>
+
+#include "dragnn/core/interfaces/input_batch.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// A InputBatchCache holds data converted to a DRAGNN internal representation.
+// It performs the conversion lazily via Data objects and caches the result.
+
+class InputBatchCache {
+ public:
+  // Create an empty cache..
+  InputBatchCache() : stored_type_(std::type_index(typeid(void))) {}
+
+  // Create a InputBatchCache from a single example. This copies the string.
+  explicit InputBatchCache(const string &data)
+      : stored_type_(std::type_index(typeid(void))), source_data_({data}) {}
+
+  // Create a InputBatchCache from a vector of examples. The vector is copied.
+  explicit InputBatchCache(const std::vector<string> &data)
+      : stored_type_(std::type_index(typeid(void))), source_data_(data) {}
+
+  // Adds a single string to the cache. Only useable before GetAs() has been
+  // called.
+  void AddData(const string &data) {
+    CHECK(stored_type_ == std::type_index(typeid(void)))
+        << "You may not add data to an InputBatchCache after the cache has "
+           "been converted via GetAs().";
+    source_data_.emplace_back(data);
+  }
+
+  // Convert the stored strings into protos and return them in a specific
+  // InputBatch subclass. T should always be of type InputBatch. After this
+  // method is called once, all further calls must be of the same data type.
+  template <class T>
+  T *GetAs() {
+    if (!converted_data_) {
+      stored_type_ = std::type_index(typeid(T));
+      converted_data_.reset(new T());
+      converted_data_->SetData(source_data_);
+    }
+    CHECK(std::type_index(typeid(T)) == stored_type_)
+        << "Attempted to convert to two object types! Existing object type was "
+        << stored_type_.name() << ", new object type was "
+        << std::type_index(typeid(T)).name();
+
+    return dynamic_cast<T *>(converted_data_.get());
+  }
+
+  // Return the serialized representation of the data held in the input batch
+  // object within this cache.
+  const std::vector<string> SerializedData() const {
+    CHECK(converted_data_) << "Cannot return batch without data.";
+    return converted_data_->GetSerializedData();
+  }
+
+ private:
+  // The typeid of the stored data.
+  std::type_index stored_type_;
+
+  // The raw data.
+  std::vector<string> source_data_;
+
+  // The converted data, contained in an InputBatch object.
+  std::unique_ptr<InputBatch> converted_data_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_

+ 107 - 0
syntaxnet/dragnn/core/input_batch_cache_test.cc

@@ -0,0 +1,107 @@
+#include "dragnn/core/input_batch_cache.h"
+
+#include "dragnn/core/interfaces/input_batch.h"
+#include <gmock/gmock.h>
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class StringData : public InputBatch {
+ public:
+  StringData() {}
+
+  void SetData(const std::vector<string> &data) override {
+    for (const auto &element : data) {
+      data_.push_back(element + "_converted");
+    }
+  }
+
+  const std::vector<string> GetSerializedData() const override { return data_; }
+
+  std::vector<string> *data() { return &data_; }
+
+ private:
+  std::vector<string> data_;
+};
+
+class DifferentStringData : public InputBatch {
+ public:
+  DifferentStringData() {}
+
+  void SetData(const std::vector<string> &data) override {
+    for (const auto &element : data) {
+      data_.push_back(element + "_also_converted");
+    }
+  }
+
+  const std::vector<string> GetSerializedData() const override { return data_; }
+
+  std::vector<string> *data() { return &data_; }
+
+ private:
+  std::vector<string> data_;
+};
+
+TEST(InputBatchCacheTest, ConvertsSingleInput) {
+  string test_string = "Foo";
+  InputBatchCache generic_set(test_string);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_EQ(data->data()->size(), 1);
+  EXPECT_EQ(data->data()->at(0), "Foo_converted");
+}
+
+TEST(InputBatchCacheTest, ConvertsAddedInput) {
+  string test_string = "Foo";
+  InputBatchCache generic_set;
+  generic_set.AddData(test_string);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_EQ(data->data()->size(), 1);
+  EXPECT_EQ(data->data()->at(0), "Foo_converted");
+}
+
+TEST(InputBatchCacheTest, ConvertsVectorOfInputs) {
+  std::vector<string> test_inputs;
+  test_inputs.push_back("Foo");
+  test_inputs.push_back("Bar");
+  test_inputs.push_back("Baz");
+  InputBatchCache generic_set(test_inputs);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_EQ(data->data()->size(), test_inputs.size());
+  EXPECT_EQ(data->data()->at(0), "Foo_converted");
+  EXPECT_EQ(data->data()->at(1), "Bar_converted");
+  EXPECT_EQ(data->data()->at(2), "Baz_converted");
+}
+
+TEST(InputBatchCacheTest, ConvertingMultipleDataTypesCausesCheck) {
+  string test_string = "Foo";
+  InputBatchCache generic_set(test_string);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_EQ(data->data()->at(0), "Foo_converted");
+  ASSERT_DEATH(generic_set.GetAs<DifferentStringData>(),
+               "Attempted to convert to two object types!.*");
+}
+
+TEST(InputBatchCacheTest, ReturnsSingleInput) {
+  string test_string = "Foo";
+  InputBatchCache generic_set(test_string);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_NE(nullptr, data);
+  auto returned = generic_set.SerializedData();
+  EXPECT_EQ(returned.size(), 1);
+  EXPECT_EQ(returned.at(0), "Foo_converted");
+}
+
+TEST(InputBatchCacheTest, ConvertsAddedInputDiesAfterGetAs) {
+  string test_string = "Foo";
+  InputBatchCache generic_set;
+  generic_set.AddData(test_string);
+  auto data = generic_set.GetAs<StringData>();
+  EXPECT_EQ(data->data()->size(), 1);
+  EXPECT_EQ(data->data()->at(0), "Foo_converted");
+  EXPECT_DEATH(generic_set.AddData("YOU MAY NOT DO THIS AND IT WILL DIE."),
+               "after the cache has been converted");
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 37 - 0
syntaxnet/dragnn/core/interfaces/BUILD

@@ -0,0 +1,37 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "cloneable_transition_state",
+    hdrs = ["cloneable_transition_state.h"],
+    deps = [":transition_state"],
+)
+
+cc_library(
+    name = "component",
+    hdrs = ["component.h"],
+    deps = [
+        ":transition_state",
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core:input_batch_cache",
+        "//dragnn/protos:spec_proto",
+        "//dragnn/protos:trace_proto",
+        "//syntaxnet:base",
+        "//syntaxnet:registry",
+    ],
+)
+
+cc_library(
+    name = "input_batch",
+    hdrs = ["input_batch.h"],
+    deps = [
+        "//syntaxnet:base",
+    ],
+)
+
+cc_library(
+    name = "transition_state",
+    hdrs = ["transition_state.h"],
+    deps = [
+        "//syntaxnet:base",
+    ],
+)

+ 52 - 0
syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h

@@ -0,0 +1,52 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
+
+#include <memory>
+#include <vector>
+
+#include "dragnn/core/interfaces/transition_state.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// This defines a TransitionState object that can be used with the Beam class.
+// Any class designed to be used with the Beam must inherit from
+// CloneableTransitionState<T>, not TransitionState.
+
+template <class T>
+class CloneableTransitionState : public TransitionState {
+ public:
+  ~CloneableTransitionState<T>() override {}
+
+  // Initialize this TransitionState from a previous TransitionState. The
+  // ParentBeamIndex is the location of that previous TransitionState in the
+  // provided beam.
+  void Init(const TransitionState &parent) override = 0;
+
+  // Return the beam index of the state passed into the initializer of this
+  // TransitionState.
+  const int ParentBeamIndex() const override = 0;
+
+  // Get the current beam index for this state.
+  const int GetBeamIndex() const override = 0;
+
+  // Set the current beam index for this state.
+  void SetBeamIndex(const int index) override = 0;
+
+  // Get the score associated with this transition state.
+  const float GetScore() const override = 0;
+
+  // Set the score associated with this transition state.
+  void SetScore(const float score) override = 0;
+
+  // Depicts this state as an HTML-language string.
+  string HTMLRepresentation() const override = 0;
+
+  // Produces a new state with the same backing data as this state.
+  virtual std::unique_ptr<T> Clone() const = 0;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_

+ 126 - 0
syntaxnet/dragnn/core/interfaces/component.h

@@ -0,0 +1,126 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
+
+#include <vector>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/protos/spec.pb.h"
+#include "dragnn/protos/trace.pb.h"
+#include "syntaxnet/registry.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class Component : public RegisterableClass<Component> {
+ public:
+  virtual ~Component() {}
+
+  // Initializes this component from the spec.
+  virtual void InitializeComponent(const ComponentSpec &spec) = 0;
+
+  // Provides the previous beam to the component.
+  virtual void InitializeData(
+      const std::vector<std::vector<const TransitionState *>> &states,
+      int max_beam_size, InputBatchCache *input_data) = 0;
+
+  // Returns true if the component has had InitializeData called on it since
+  // the last time it was reset.
+  virtual bool IsReady() const = 0;
+
+  // Initializes the component for tracing execution, resetting any existing
+  // traces. This will typically have the side effect of slowing down all
+  // subsequent Component calculations and storing a trace in memory that can be
+  // returned by GetTraceProtos().
+  virtual void InitializeTracing() = 0;
+
+  // Disables tracing, freeing any associated traces and avoiding triggering
+  // additional computation in the future.
+  virtual void DisableTracing() = 0;
+
+  // Returns the string name of this component.
+  virtual string Name() const = 0;
+
+  // Returns the current batch size of the component's underlying data.
+  virtual int BatchSize() const = 0;
+
+  // Returns the maximum beam size of this component.
+  virtual int BeamSize() const = 0;
+
+  // Returns the number of steps taken by this component so far.
+  virtual int StepsTaken(int batch_index) const = 0;
+
+  // Return the beam index of the item which is currently at index
+  // 'index', when the beam was at step 'step', for batch element 'batch'.
+  virtual int GetBeamIndexAtStep(int step, int current_index,
+                                 int batch) const = 0;
+
+  // Return the source index of the item which is currently at index 'index'
+  // for batch element 'batch'. This index is into the final beam of the
+  // Component that this Component was initialized from.
+  virtual int GetSourceBeamIndex(int current_index, int batch) const = 0;
+
+  // Request a translation function based on the given method string.
+  // The translation function will be called with arguments (beam, batch, value)
+  // and should return the step index corresponding to the given value, for the
+  // data in the given beam and batch.
+  virtual std::function<int(int, int, int)> GetStepLookupFunction(
+      const string &method) = 0;
+
+  // Advances this component from the given transition matrix.
+  virtual void AdvanceFromPrediction(const float transition_matrix[],
+                                     int transition_matrix_length) = 0;
+
+  // Advances this component from the state oracles.
+  virtual void AdvanceFromOracle() = 0;
+
+  // Returns true if all states within this component are terminal.
+  virtual bool IsTerminal() const = 0;
+
+  // Returns the current batch of beams for this component.
+  virtual std::vector<std::vector<const TransitionState *>> GetBeam() = 0;
+
+  // Extracts and populates the vector of FixedFeatures for the specified
+  // channel. Each functor allocates storage space for the indices, the IDs, and
+  // the weights (respectively).
+  virtual int GetFixedFeatures(
+      std::function<int32 *(int num_elements)> allocate_indices,
+      std::function<int64 *(int num_elements)> allocate_ids,
+      std::function<float *(int num_elements)> allocate_weights,
+      int channel_id) const = 0;
+
+  // Extracts and populates all FixedFeatures for all channels, advancing this
+  // component via the oracle until it is terminal. This call uses a
+  // BulkFeatureExtractor object to contain the functors and other information.
+  virtual int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) = 0;
+
+  // Extracts and returns the vector of LinkFeatures for the specified
+  // channel. Note: these are NOT translated.
+  virtual std::vector<LinkFeatures> GetRawLinkFeatures(
+      int channel_id) const = 0;
+
+  // Returns a vector of oracle labels for each element in the beam and
+  // batch.
+  virtual std::vector<std::vector<int>> GetOracleLabels() const = 0;
+
+  // Annotate the underlying data object with the results of this Component's
+  // calculation.
+  virtual void FinalizeData() = 0;
+
+  // Reset this component.
+  virtual void ResetComponent() = 0;
+
+  // Get a vector of all traces managed by this component.
+  virtual std::vector<std::vector<ComponentTrace>> GetTraceProtos() const = 0;
+
+  // Add the translated link features (done outside the component) to the traces
+  // managed by this component.
+  virtual void AddTranslatedLinkFeaturesToTrace(
+      const std::vector<LinkFeatures> &features, int channel_id) = 0;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_

+ 30 - 0
syntaxnet/dragnn/core/interfaces/input_batch.h

@@ -0,0 +1,30 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
+
+#include <string>
+#include <vector>
+
+#include "syntaxnet/base.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// An InputBatch object converts strings into a given data type. It is used to
+// abstract DRAGNN internal data typing. Each internal DRAGNN data type should
+// subclass InputBatch, with a public accessor to the type in question.
+
+class InputBatch {
+ public:
+  virtual ~InputBatch() {}
+
+  // Set the data to translate to the subclass' data type.
+  virtual void SetData(const std::vector<string> &data) = 0;
+
+  // Translate the underlying data back to a vector of strings, as appropriate.
+  virtual const std::vector<string> GetSerializedData() const = 0;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_

+ 53 - 0
syntaxnet/dragnn/core/interfaces/transition_state.h

@@ -0,0 +1,53 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
+
+#include <memory>
+#include <vector>
+
+#include "syntaxnet/base.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// TransitionState defines the minimal interface required to pass data between
+// Component objects. It is used to initialize one Component from the output of
+// another, and every backend should define one. Note that inheriting from
+// TransitionState directly is not sufficient to use the Beam class, which
+// requires extra functionality given by inheriting from the
+// ClonableTransitionState interface. (ClonableTransitionState is a subclass
+// of TransitionState, so inheriting from ClonableTransitionState is sufficient
+// to allow Components to pass your backing states.)
+
+class TransitionState {
+ public:
+  virtual ~TransitionState() {}
+
+  // Initialize this TransitionState from a previous TransitionState. The
+  // ParentBeamIndex is the location of that previous TransitionState in the
+  // provided beam.
+  virtual void Init(const TransitionState &parent) = 0;
+
+  // Return the beam index of the state passed into the initializer of this
+  // TransitionState.
+  virtual const int ParentBeamIndex() const = 0;
+
+  // Get the current beam index for this state.
+  virtual const int GetBeamIndex() const = 0;
+
+  // Set the current beam index for this state.
+  virtual void SetBeamIndex(const int index) = 0;
+
+  // Get the score associated with this transition state.
+  virtual const float GetScore() const = 0;
+
+  // Set the score associated with this transition state.
+  virtual void SetScore(const float score) = 0;
+
+  // Depicts this state as an HTML-language string.
+  virtual string HTMLRepresentation() const = 0;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_

+ 113 - 0
syntaxnet/dragnn/core/interfaces/transition_state_starter_test.cc

@@ -0,0 +1,113 @@
+#include "dragnn/core/test/mock_transition_state.h"
+#include <gmock/gmock.h>
+#include "testing/base/public/googletest.h"
+#include "testing/base/public/gunit.h"
+
+// This test suite is intended to validate the contracts that the DRAGNN
+// system expects from all transition state subclasses. Developers creating
+// new TransitionState subclasses should copy this test and modify it as needed,
+// using it to ensure their state conforms to DRAGNN expectations.
+
+namespace syntaxnet {
+namespace dragnn {
+
+using testing::Return;
+
+// When this test is instantiated, this function should be changed to
+// instantiate a TransitionState subclass of the appropriate type instead
+// of Transitionstate->
+std::unique_ptr<TransitionState> CreateState() {
+  std::unique_ptr<TransitionState> test_state(new TransitionState());
+  return test_state;
+}
+
+// Validates the consistency of the beam index setter and getter.
+TEST(TransitionStateInterfaceTest, CanSetAndGetBeamIndex) {
+  // Create and initialize a test state->
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr int kOldBeamIndex = 12;
+  test_state->SetBeamIndex(kOldBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
+
+  constexpr int kNewBeamIndex = 7;
+  test_state->SetBeamIndex(kNewBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kNewBeamIndex);
+}
+
+// Validates the consistency of the score setter and getter.
+TEST(TransitionStateInterfaceTest, CanSetAndGetScore) {
+  // Create and initialize a test state->
+  MockTransitionState mock_state;
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  constexpr float kOldScore = 12.1;
+  test_state->SetScore(kOldScore);
+  EXPECT_EQ(test_state->GetScore(), kOldScore);
+
+  constexpr float kNewScore = 7.2;
+  test_state->SetScore(kNewScore);
+  EXPECT_EQ(test_state->GetScore(), kNewScore);
+}
+
+// This test ensures that the initializing state's current index is saved
+// as the parent beam index of the state being initialized.
+TEST(TransitionStateInterfaceTest, ReportsParentBeamIndex) {
+  // Create a mock transition state that wil report a specific current index.
+  // This index should become the parent state index for the test state->
+  MockTransitionState mock_state;
+  constexpr int kParentBeamIndex = 1138;
+  EXPECT_CALL(mock_state, GetBeamIndex())
+      .WillRepeatedly(Return(kParentBeamIndex));
+
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+  EXPECT_EQ(test_state->ParentBeamIndex(), kParentBeamIndex);
+}
+
+// This test ensures that the initializing state's current score is saved
+// as the current score of the state being initialized.
+TEST(TransitionStateInterfaceTest, InitializationCopiesParentScore) {
+  // Create a mock transition state that wil report a specific current index.
+  // This index should become the parent state index for the test state->
+  MockTransitionState mock_state;
+  constexpr float kParentScore = 24.12;
+  EXPECT_CALL(mock_state, GetScore()).WillRepeatedly(Return(kParentScore));
+
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+  EXPECT_EQ(test_state->GetScore(), kParentScore);
+}
+
+// This test ensures that calling Clone maintains the state data (parent beam
+// index, beam index, score, etc.) of the state that was cloned.
+TEST(TransitionStateInterfaceTest, CloningMaintainsState) {
+  // Create and initialize the state->
+  MockTransitionState mock_state;
+  constexpr int kParentBeamIndex = 1138;
+  EXPECT_CALL(mock_state, GetBeamIndex())
+      .WillRepeatedly(Return(kParentBeamIndex));
+  auto test_state = CreateState();
+  test_state->Init(mock_state);
+
+  // Validate the internal state of the test state.
+  constexpr float kOldScore = 20.0;
+  test_state->SetScore(kOldScore);
+  EXPECT_EQ(test_state->GetScore(), kOldScore);
+  constexpr int kOldBeamIndex = 12;
+  test_state->SetBeamIndex(kOldBeamIndex);
+  EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
+
+  auto clone = test_state->Clone();
+
+  // The clone should have identical state to the old state.
+  EXPECT_EQ(clone->ParentBeamIndex(), kParentBeamIndex);
+  EXPECT_EQ(clone->GetScore(), kOldScore);
+  EXPECT_EQ(clone->GetBeamIndex(), kOldBeamIndex);
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 70 - 0
syntaxnet/dragnn/core/ops/compute_session_op.cc

@@ -0,0 +1,70 @@
+#include "dragnn/core/ops/compute_session_op.h"
+
+#include "dragnn/core/compute_session.h"
+#include "dragnn/core/resource_container.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::errors::InvalidArgument;
+
+typedef ResourceContainer<ComputeSession> ComputeSessionResource;
+
+ComputeSessionOp::ComputeSessionOp(OpKernelConstruction *context)
+    : OpKernel(context) {
+  OP_REQUIRES(context, context->num_inputs() > 0,
+              InvalidArgument("Must declare at least one input of type string "
+                              "for the ComputeSession handle."));
+  OP_REQUIRES(context, context->input_type(0) == tensorflow::DT_STRING,
+              InvalidArgument("Must declare at least one input of type string "
+                              "for the ComputeSession handle."));
+  OP_REQUIRES_OK(context, context->GetAttr("component", &component_name_));
+}
+
+// Computes extracts the state from the resource manager and calls
+// ComputeWithState(). If OutputsHandle() is true, also outputs the handle for
+// subsequent ops.
+void ComputeSessionOp::Compute(OpKernelContext *context) {
+  // Validates the input/output tensors and the op attrs.
+  if (RequiresComponentName()) {
+    OP_REQUIRES(context, !component_name_.empty(),
+                InvalidArgument("Required \"component\" attribute is empty."));
+  }
+  if (OutputsHandle()) {
+    OP_REQUIRES(context, context->num_outputs() > 0,
+                InvalidArgument(
+                    "Must declare at least one output of type string "
+                    "for the ComputeSession handle if OutputsHandle is true."));
+    OP_REQUIRES(context,
+                context->expected_output_dtype(0) == tensorflow::DT_STRING,
+                InvalidArgument(
+                    "Must declare at least one output of type string "
+                    "for the ComputeSession handle if OutputsHandle is true."));
+  }
+
+  // Gets the relevant ComputeSessionResource and computes with it.
+  auto handle = context->input(0).vec<string>();
+  ComputeSessionResource *session_resource;
+  OP_REQUIRES_OK(context,
+                 context->resource_manager()->Lookup<ComputeSessionResource>(
+                     handle(0), handle(1), &session_resource));
+  ComputeWithState(context, session_resource->get());
+
+  // Outputs the passed handle, if necessary, allowing op dependency chains.
+  if (OutputsHandle()) {
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({2}), &output));
+    output->vec<string>() = handle;
+  }
+  session_resource->Unref();
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 54 - 0
syntaxnet/dragnn/core/ops/compute_session_op.h

@@ -0,0 +1,54 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
+
+#include <string>
+
+#include "dragnn/core/compute_session.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// Abstract base class: Given a MasterState and a component name, runs some op
+// on the component state. The first input is always the handle. If
+// OutputsHandle() is true in the derived class, then the first output will also
+// be the handle.
+class ComputeSessionOp : public tensorflow::OpKernel {
+ public:
+  explicit ComputeSessionOp(tensorflow::OpKernelConstruction *context);
+
+  // Virtual Compute()-like function that assumes the state has been extracted
+  // from the handle.
+  virtual void ComputeWithState(tensorflow::OpKernelContext *context,
+                                ComputeSession *compute_session) = 0;
+
+  // Compute extracts the state from the resource manager and calls
+  // ComputeWithState(). If OutputsHandle() is true, also outputs the handle for
+  // subsequent ops.
+  void Compute(tensorflow::OpKernelContext *context) override;
+
+ protected:
+  // If true, then the handle will be the first output of this op.
+  virtual bool OutputsHandle() const = 0;
+
+  // If true, then the constructor will check that the "component_name"
+  // attribute is set.
+  virtual bool RequiresComponentName() const = 0;
+
+  // Returns the component name.
+  string component_name() const {
+    CHECK(RequiresComponentName());
+    return component_name_;
+  }
+
+ private:
+  // Name of the component used by this op.
+  string component_name_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_

+ 396 - 0
syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc

@@ -0,0 +1,396 @@
+#include <math.h>
+#include <algorithm>
+#include <utility>
+#include <vector>
+
+#include "dragnn/core/ops/compute_session_op.h"
+#include "dragnn/core/resource_container.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+using std::vector;
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_STRING;
+using tensorflow::DataType;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::quint8;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+using tensorflow::uint8;
+
+namespace syntaxnet {
+namespace dragnn {
+
+namespace {
+
+// Helper struct for resource manager.
+struct VectorTriple {
+  std::unique_ptr<std::vector<std::unique_ptr<std::vector<int32>>>>
+      index_vectors;
+  std::unique_ptr<std::vector<std::unique_ptr<std::vector<int64>>>> id_vectors;
+  std::unique_ptr<std::vector<std::unique_ptr<std::vector<float>>>>
+      weight_vectors;
+};
+
+}  // namespace
+
+typedef ResourceContainer<VectorTriple> VectorTripleResource;
+
+// See docstring in dragnn_bulk_ops.cc.
+class BulkFixedFeatures : public ComputeSessionOp {
+ public:
+  explicit BulkFixedFeatures(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("num_channels", &num_channels_));
+
+    // Input: state handle.
+    vector<DataType> input_types(1, DT_STRING);
+
+    // Output: indices, ids and weights for every fixed feature channel.
+    vector<DataType> output_types;
+    output_types.push_back(DT_STRING);
+    for (int c = 0; c < num_channels_; ++c) output_types.push_back(DT_INT32);
+    for (int c = 0; c < num_channels_; ++c) output_types.push_back(DT_INT64);
+    for (int c = 0; c < num_channels_; ++c) output_types.push_back(DT_FLOAT);
+    output_types.push_back(DT_INT32);
+    OP_REQUIRES_OK(context, context->MatchSignature(input_types, output_types));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    constexpr int kTensorOffset = 1;
+    auto indices_allocator = [context, kTensorOffset](int channel,
+                                                      int num_elements) {
+      Tensor *output;
+      CHECK(context
+                ->allocate_output(channel + kTensorOffset,
+                                  TensorShape({num_elements}), &output)
+                .ok());
+      return output->vec<int32>().data();
+    };
+
+    const int num_channels = num_channels_;
+    auto ids_allocator = [context, num_channels, kTensorOffset](
+                             int channel, int num_elements) {
+      Tensor *output;
+      CHECK(context
+                ->allocate_output(num_channels + channel + kTensorOffset,
+                                  TensorShape({num_elements}), &output)
+                .ok());
+      return output->vec<int64>().data();
+    };
+    auto weights_allocator = [context, num_channels, kTensorOffset](
+                                 int channel, int num_elements) {
+      Tensor *output;
+      CHECK(context
+                ->allocate_output(2 * num_channels + channel + kTensorOffset,
+                                  TensorShape({num_elements}), &output)
+                .ok());
+      return output->vec<float>().data();
+    };
+
+    BulkFeatureExtractor extractor(indices_allocator, ids_allocator,
+                                   weights_allocator);
+
+    int num_steps = session->BulkGetInputFeatures(component_name(), extractor);
+    VLOG(2) << "Extracted " << num_steps;
+    Tensor *num_steps_tensor;
+    OP_REQUIRES_OK(
+        context, context->allocate_output(3 * num_channels_ + 1,
+                                          TensorShape({}), &num_steps_tensor));
+    num_steps_tensor->scalar<int32>()() = num_steps;
+  }
+
+ private:
+  // Number of fixed feature channels.
+  int num_channels_;
+};
+
+REGISTER_KERNEL_BUILDER(Name("BulkFixedFeatures").Device(DEVICE_CPU),
+                        BulkFixedFeatures);
+
+// See docstring in dragnn_bulk_ops.cc.
+class BulkFixedEmbeddings : public ComputeSessionOp {
+ public:
+  explicit BulkFixedEmbeddings(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("num_channels", &num_channels_));
+
+    // Input: state handle.
+    vector<DataType> input_types;
+    input_types.push_back(DT_STRING);
+    for (int c = 0; c < num_channels_; ++c) input_types.push_back(DT_FLOAT);
+    const vector<DataType> output_types = {DT_STRING, DT_FLOAT, DT_INT32};
+    OP_REQUIRES_OK(context, context->MatchSignature(input_types, output_types));
+    OP_REQUIRES_OK(context, context->GetAttr("pad_to_batch", &pad_to_batch_));
+    OP_REQUIRES_OK(context, context->GetAttr("pad_to_steps", &pad_to_steps_));
+    use_padding_ = (pad_to_steps_ != -1) || (pad_to_batch_ != -1);
+    VLOG(2) << "Created a BulkFixedEmbeddings with use_padding = "
+            << use_padding_;
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    const int batch_size = session->BatchSize(component_name());
+    tensorflow::ResourceMgr *rmgr = context->resource_manager();
+
+    // Create the pool for this container, or re-use one that was allocated in a
+    // previous call.
+    auto create = [this](VectorTripleResource **resource) {
+      LOG(INFO) << "Creating new VectorTripleResource";
+      std::unique_ptr<VectorTriple> triple(new VectorTriple());
+      *resource = new VectorTripleResource(std::move(triple));
+      (*resource)->get()->index_vectors.reset(
+          new std::vector<std::unique_ptr<std::vector<int32>>>(num_channels_));
+      (*resource)->get()->id_vectors.reset(
+          new std::vector<std::unique_ptr<std::vector<int64>>>(num_channels_));
+      (*resource)->get()->weight_vectors.reset(
+          new std::vector<std::unique_ptr<std::vector<float>>>(num_channels_));
+      for (int i = 0; i < num_channels_; ++i) {
+        (*resource)->get()->index_vectors->at(i).reset(
+            new std::vector<int32>());
+        (*resource)->get()->id_vectors->at(i).reset(new std::vector<int64>());
+        (*resource)->get()->weight_vectors->at(i).reset(
+            new std::vector<float>());
+      }
+      return Status::OK();
+    };
+
+    VectorTripleResource *vector_triple;
+    auto handle = context->input(0).vec<string>();
+    OP_REQUIRES_OK(context, rmgr->LookupOrCreate<VectorTripleResource>(
+                                handle(0), handle(1), &vector_triple, create));
+
+    std::vector<std::unique_ptr<std::vector<int32>>> *indices =
+        vector_triple->get()->index_vectors.get();
+    std::vector<std::unique_ptr<std::vector<int64>>> *ids =
+        vector_triple->get()->id_vectors.get();
+    std::vector<std::unique_ptr<std::vector<float>>> *weights =
+        vector_triple->get()->weight_vectors.get();
+
+    auto indices_allocator = [context, &indices](int channel, int size) {
+      (*indices)[channel]->resize(size);
+      return (*indices)[channel]->data();
+    };
+    auto ids_allocator = [context, &ids](int channel, int size) {
+      (*ids)[channel]->resize(size);
+      return (*ids)[channel]->data();
+    };
+    auto weights_allocator = [context, &weights](int channel, int size) {
+      (*weights)[channel]->resize(size);
+      return (*weights)[channel]->data();
+    };
+
+    BulkFeatureExtractor extractor(indices_allocator, ids_allocator,
+                                   weights_allocator, use_padding_,
+                                   pad_to_steps_, pad_to_batch_);
+
+    int num_steps = session->BulkGetInputFeatures(component_name(), extractor);
+    VLOG(2) << "Extracted " << num_steps;
+
+    Tensor *num_steps_tensor;
+    OP_REQUIRES_OK(context, context->allocate_output(2, TensorShape({}),
+                                                     &num_steps_tensor));
+    num_steps_tensor->scalar<int32>()() = num_steps;
+
+    // Looks up and outputs embedding vectors.
+    const auto &spec = session->Spec(component_name());
+
+    int embedding_size = 0;
+    for (int channel = 0; channel < num_channels_; ++channel) {
+      embedding_size += context->input(1 + channel).shape().dim_size(1) *
+                        spec.fixed_feature(channel).size();
+    }
+
+    const int padded_batch = std::max(pad_to_batch_, batch_size);
+    const int padded_num_steps = std::max(pad_to_steps_, num_steps);
+    Tensor *embedding_vectors;
+    OP_REQUIRES_OK(
+        context,
+        context->allocate_output(
+            1, TensorShape({padded_num_steps * padded_batch, embedding_size}),
+            &embedding_vectors));
+    embedding_vectors->flat<float>().setZero();
+
+    int channel_offset = 0;
+    for (int channel = 0; channel < num_channels_; ++channel) {
+      ExtractForChannel(*(indices->at(channel)), *(ids->at(channel)),
+                        *(weights->at(channel)), channel_offset,
+                        context->input(1 + channel), embedding_vectors);
+      channel_offset += context->input(1 + channel).shape().dim_size(1) *
+                        spec.fixed_feature(channel).size();
+    }
+    vector_triple->Unref();
+  }
+
+ private:
+  void ExtractForChannel(const std::vector<int32> &indices,
+                         const std::vector<int64> &ids,
+                         const std::vector<float> &weights, int channel_base,
+                         const Tensor &embeddings, Tensor *output) {
+    // Just turn this into a feature-size matrix, then the index is just the
+    // X coordinate into it. Run up the row (known length!) and sum.
+    int num_elements = output->shape().dim_size(0);
+    int embedding_length = embeddings.shape().dim_size(1);
+    VLOG(2) << "Num elements: " << num_elements;
+    VLOG(2) << "Embedding length: " << embedding_length;
+    auto output_matrix = output->matrix<float>();
+    auto embedding_matrix = embeddings.matrix<float>();
+    VLOG(2) << "Channel base:" << channel_base;
+    for (int i = 0; i < indices.size(); ++i) {
+      VLOG(2) << "Feature: ind:" << indices[i] << ", id: " << ids[i]
+              << ", wt: " << weights[i];
+      int y_base =
+          (indices[i] / num_elements) * embedding_length + channel_base;
+      int x_base = indices[i] % num_elements;
+      VLOG(2) << "Extracting to (x,y) = (" << x_base << "," << y_base << ")";
+      for (int j = 0; j < embedding_length; ++j) {
+        output_matrix(x_base, y_base + j) +=
+            embedding_matrix(ids[i], j) * weights[i];
+      }
+    }
+  }
+
+  // Number of fixed feature channels.
+  int num_channels_;
+
+  // Will pad output to at least this many batch elements.
+  int pad_to_batch_ = -1;
+
+  // Will pad output to at least this many steps.
+  int pad_to_steps_ = -1;
+
+  // Set if either pad_to_batch or pad_to_steps is not -1.
+  bool use_padding_ = false;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(BulkFixedEmbeddings);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BulkFixedEmbeddings").Device(DEVICE_CPU),
+                        BulkFixedEmbeddings);
+
+// See docstring in dragnn_bulk_ops.cc.
+class BulkAdvanceFromOracle : public ComputeSessionOp {
+ public:
+  explicit BulkAdvanceFromOracle(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING}, {DT_STRING, DT_INT32}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  // Advances all transition states along the oracle path.
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    const int batch_size = session->BatchSize(component_name());
+    const int beam_size = session->BeamSize(component_name());
+    const int num_items = batch_size * beam_size;
+    vector<vector<vector<int32>>> gold;
+
+    int num_steps = 0;
+    while (!session->IsTerminal(component_name())) {
+      gold.emplace_back(session->EmitOracleLabels(component_name()));
+
+      // Advance the component.
+      session->AdvanceFromOracle(component_name());
+      ++num_steps;
+    }
+
+    // Fills output tensor with oracle labels where possible, or -1.
+    Tensor *gold_output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       1, TensorShape({num_items * num_steps}), &gold_output));
+    int item = 0;
+    for (int batch_ix = 0; batch_ix < batch_size; ++batch_ix) {
+      for (int beam_ix = 0; beam_ix < beam_size; ++beam_ix, ++item) {
+        for (int step = 0; step < num_steps; ++step) {
+          gold_output->vec<int32>()(item * num_steps + step) =
+              step < gold.size() ? gold[step][batch_ix][beam_ix] : -1;
+        }
+      }
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(BulkAdvanceFromOracle);
+};
+
+REGISTER_KERNEL_BUILDER(Name("BulkAdvanceFromOracle").Device(DEVICE_CPU),
+                        BulkAdvanceFromOracle);
+
+// See docstring in dragnn_bulk_ops.cc.
+template <typename T>
+class BulkAdvanceFromPrediction : public ComputeSessionOp {
+ public:
+  explicit BulkAdvanceFromPrediction(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    const DataType dt = tensorflow::DataTypeToEnum<T>::v();
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING, dt}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  // Advances all transition states as much as possible using the given scores.
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    const Tensor &scores_tensor = context->input(1);
+    const auto &scores = scores_tensor.matrix<T>();
+    const int num_items = (session->BatchSize(component_name()) *
+                           session->BeamSize(component_name()));
+    const int num_actions = scores_tensor.shape().dim_size(1);
+    const int num_steps = scores_tensor.shape().dim_size(0) / num_items;
+    vector<float> scores_per_step(num_items * num_actions);
+    for (int step = 0; step < num_steps; ++step) {
+      for (int item = 0; item < num_items; ++item) {
+        for (int action = 0; action < num_actions; ++action) {
+          scores_per_step[item * num_actions + action] =
+              scores(item * num_steps + step, action);
+        }
+      }
+      if (!session->IsTerminal(component_name())) {
+        session->AdvanceFromPrediction(component_name(), scores_per_step.data(),
+                                       scores_per_step.size());
+      }
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(BulkAdvanceFromPrediction);
+};
+
+#define REGISTER_BULK_ADVANCE(type)                         \
+  REGISTER_KERNEL_BUILDER(Name("BulkAdvanceFromPrediction") \
+                              .Device(DEVICE_CPU)           \
+                              .TypeConstraint<type>("T"),   \
+                          BulkAdvanceFromPrediction<type>)
+
+REGISTER_BULK_ADVANCE(float);
+REGISTER_BULK_ADVANCE(quint8);
+REGISTER_BULK_ADVANCE(uint8);
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 588 - 0
syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc

@@ -0,0 +1,588 @@
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/compute_session_pool.h"
+#include "dragnn/core/resource_container.h"
+#include "dragnn/core/test/mock_compute_session.h"
+
+#include <gmock/gmock.h>
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::AllocatorAttributes;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_STRING;
+using tensorflow::FrameAndIter;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpKernelContext;
+using tensorflow::ResourceMgr;
+using tensorflow::ScopedStepContainer;
+using tensorflow::Status;
+using tensorflow::TensorShape;
+using tensorflow::checkpoint::TensorSliceReaderCacheWrapper;
+using tensorflow::test::SetOutputAttrs;
+
+using testing::Return;
+using testing::_;
+
+typedef ResourceContainer<ComputeSession> ComputeSessionResource;
+typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
+
+class DragnnBulkOpKernelsTest : public tensorflow::OpsTestBase {
+ public:
+  static const int kEmbeddingSize = 2;
+  static const int kNumActions = 3;
+  static const int kNumChannels = 2;
+  static const int kNumIds = 8;
+  static const int kNumItems = 3;
+  static const int kNumSteps = 3;
+  const string kComponentName = "TESTING_COMPONENT_NAME";
+
+  MockComputeSession *GetMockSession() {
+    TF_CHECK_OK(InitOp());
+
+    // Set the input data.
+    const string container_string = "container_str";
+    const string id_string = "id_str";
+    AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+    // Reset the test context to ensure it's clean.
+    ResetOpKernelContext();
+
+    // Create a MockComputeSession and set expectations.
+    std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+    MockComputeSession *mock_session_ptr = mock_session.get();
+
+    // Wrap the ComputeSessionResource and put it into the resource manager.
+    TF_CHECK_OK(resource_mgr()->Create<ComputeSessionResource>(
+        container_string, id_string,
+        new ComputeSessionResource(std::move(mock_session))));
+    return mock_session_ptr;
+  }
+
+  void ResetOpKernelContext() {
+    params_.reset(new OpKernelContext::Params);
+    params_->device = device_.get();
+    params_->frame_iter = FrameAndIter(0, 0);
+    params_->inputs = &inputs_;
+    params_->op_kernel = kernel_.get();
+    step_container_.reset(new ScopedStepContainer(0, [](const string &) {}));
+    params_->step_container = step_container_.get();
+    attrs_.clear();
+    SetOutputAttrs(params_.get(), &attrs_);
+    TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+    params_->slice_reader_cache = &slice_reader_cache_wrapper;
+    params_->resource_manager = device_->resource_manager();
+    context_.reset(new OpKernelContext(params_.get()));
+  }
+
+  Status RunOpKernelWithContext() {
+    device_->Compute(kernel_.get(), context_.get());
+    return context_->status();
+  }
+
+  // Accessor for the underlying resource manager.
+  ResourceMgr *resource_mgr() { return params_->resource_manager; }
+  /*
+  // Returns a vector with dimensions: channel x batch x step.
+  // For each item we return features for three steps:
+  //   feature step 0: (5, 1)
+  //   feature step 1: (5, 0.5), (6, 0.7)
+  //   feature step 2: (3, 0.1), (7, [empty]) <- Default weight is 1.0.
+  void ExpectFeatures(MockComputeSession *mock_session) {
+    vector<FixedFeatures> feature_step_zero, feature_step_one, feature_step_two;
+    for (int item = 0; item < kNumItems; ++item) {
+      feature_step_zero.emplace_back();
+      feature_step_zero.back().add_id(5);
+      feature_step_zero.back().add_weight(1.0);
+      feature_step_one.emplace_back();
+      feature_step_one.back().add_id(5);
+      feature_step_one.back().add_weight(0.5);
+      feature_step_one.back().add_id(6);
+      feature_step_one.back().add_weight(0.7);
+      feature_step_two.emplace_back();
+      feature_step_two.back().add_id(3);
+      feature_step_two.back().add_weight(0.1);
+      feature_step_two.back().add_id(7);
+    }
+    for (int channel = 0; channel < kNumChannels; ++channel) {
+      EXPECT_CALL(*mock_session, GetInputFeatures(kComponentName, channel))
+          .Times(3)
+          .WillOnce(Return(feature_step_zero))
+          .WillOnce(Return(feature_step_one))
+          .WillOnce(Return(feature_step_two));
+    }
+  }
+
+  // Returns a vector with dimensions: channel x batch x step.
+  // For each item we return features for three steps with ids only:
+  //   feature step 0: id=5
+  //   feature step 1: id=6
+  //   feature step 2: id=3
+  void ExpectFeatureIds(MockComputeSession *mock_session) {
+    vector<FixedFeatures> feature_step_zero, feature_step_one, feature_step_two;
+    for (int item = 0; item < kNumItems; ++item) {
+      feature_step_zero.emplace_back();
+      feature_step_zero.back().add_id(5);
+      feature_step_one.emplace_back();
+      feature_step_one.back().add_id(6);
+      feature_step_two.emplace_back();
+      feature_step_two.back().add_id(3);
+    }
+    for (int channel = 0; channel < kNumChannels; ++channel) {
+      EXPECT_CALL(*mock_session, GetInputFeatures(kComponentName, channel))
+          .Times(3)
+          .WillOnce(Return(feature_step_zero))
+          .WillOnce(Return(feature_step_one))
+          .WillOnce(Return(feature_step_two));
+    }
+  }
+  */
+  // This needs to maintain its existence throughout the compute call.
+  std::vector<AllocatorAttributes> attrs_;
+};
+
+const int DragnnBulkOpKernelsTest::kEmbeddingSize;
+const int DragnnBulkOpKernelsTest::kNumActions;
+const int DragnnBulkOpKernelsTest::kNumChannels;
+const int DragnnBulkOpKernelsTest::kNumIds;
+const int DragnnBulkOpKernelsTest::kNumItems;
+const int DragnnBulkOpKernelsTest::kNumSteps;
+
+// The ExtractFixedFeatures op should return a set of fixed feature vectors
+// as described below.
+TEST_F(DragnnBulkOpKernelsTest, BulkFixedFeatures) {
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("BulkFixedFeatures", "BulkFixedFeatures")
+          .Attr("component", kComponentName)
+          .Attr("num_channels", kNumChannels)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+
+  MockComputeSession *mock_session = GetMockSession();
+  const std::vector<int> expected_indices({0, 2, 1, 0, 1});
+  const std::vector<int> expected_ids({5, 5, 6, 3, 7});
+  const std::vector<float> expected_weights({1.0, 0.5, 0.7, 0.1, 1.0});
+
+  // This function takes the allocator functions passed into GetBulkFF, uses
+  // them to allocate a tensor, then fills that tensor based on channel.
+  auto assigner_function = [=](string, const BulkFeatureExtractor &extractor) {
+    constexpr int kFeatureCount = 3;
+    constexpr int kTotalFeatures = 5;
+    constexpr int kNumSteps = 3;
+    for (int i = 0; i < kNumChannels; ++i) {
+      // Allocate a new tensor set for every channel.
+      int32 *indices =
+          extractor.AllocateIndexMemory(i, kTotalFeatures * kNumSteps);
+      int64 *ids = extractor.AllocateIdMemory(i, kTotalFeatures * kNumSteps);
+      float *weights =
+          extractor.AllocateWeightMemory(i, kTotalFeatures * kNumSteps);
+
+      // Fill the tensor.
+      int array_index = 0;
+      for (int step = 0; step < kNumSteps; step++) {
+        for (int j = 0; j < kTotalFeatures; ++j) {
+          int offset = i + 1;
+          indices[array_index] =
+              (expected_indices[j] + step * kFeatureCount) * offset;
+          ids[array_index] = expected_ids[j] * offset;
+          weights[array_index] = expected_weights[j] * offset;
+          ++array_index;
+        }
+      }
+    }
+    return kNumSteps;
+  };
+
+  EXPECT_CALL(*mock_session, BulkGetInputFeatures(kComponentName, _))
+      .WillOnce(testing::Invoke(assigner_function));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  // In this case, for every channel we should have:
+  //   indices = [0  , 2  , 1  , 0  , 1  ]
+  //             [3  , 5  , 4  , 3  , 4  ]
+  //             [6  , 8  , 7  , 6  , 7  ]
+  //   ids =     [5  , 5  , 6  , 3  , 7  ]
+  //             [5  , 5  , 6  , 3  , 7  ]
+  //             [5  , 5  , 6  , 3  , 7  ]
+  //   weights = [1.0, 0.5, 0.7, 0.1, 1.0]
+  //             [1.0, 0.5, 0.7, 0.1, 1.0]
+  //             [1.0, 0.5, 0.7, 0.1, 1.0]
+
+  for (int i = 0; i < kNumChannels * 3; ++i) {
+    EXPECT_EQ(expected_indices.size() * kNumSteps,
+              GetOutput(i + 1)->NumElements());
+  }
+  for (int channel = 0; channel < kNumChannels; ++channel) {
+    LOG(INFO) << "Channel " << channel;
+    for (int step = 0; step < kNumSteps; ++step) {
+      for (int i = 0; i < expected_indices.size(); ++i) {
+        const int j = i + step * expected_indices.size();
+
+        // Note that the expectation on the indices changes per step, unlike the
+        // expectation for ids and weights.
+        int offset = channel + 1;
+        EXPECT_EQ((expected_indices[i] + step * kNumItems) * offset,
+                  GetOutput(channel + 1)->vec<int32>()(j));
+        EXPECT_EQ(expected_ids[i] * offset,
+                  GetOutput(kNumChannels + channel + 1)->vec<int64>()(j));
+        EXPECT_EQ(expected_weights[i] * offset,
+                  GetOutput(2 * kNumChannels + channel + 1)->vec<float>()(j));
+      }
+    }
+  }
+  EXPECT_EQ(kNumSteps, GetOutput(3 * kNumChannels + 1)->scalar<int32>()());
+}
+
+TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddings) {
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("BulkFixedEmbeddings", "BulkFixedEmbeddings")
+          .Attr("component", kComponentName)
+          .Attr("num_channels", kNumChannels)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_FLOAT))   // Embedding matrices.
+          .Finalize(node_def()));
+  MockComputeSession *mock_session = GetMockSession();
+  ComponentSpec spec;
+  spec.set_name(kComponentName);
+  auto chan0_spec = spec.add_fixed_feature();
+  chan0_spec->set_size(2);
+  auto chan1_spec = spec.add_fixed_feature();
+  chan1_spec->set_size(1);
+  EXPECT_CALL(*mock_session, Spec(kComponentName))
+      .WillOnce(testing::ReturnRef(spec));
+
+  EXPECT_CALL(*mock_session, BatchSize(kComponentName))
+      .WillOnce(Return(kNumItems));
+
+  const std::vector<int> feature_step_1({0, 1, 2, 1, 2, 2, 1, 0, 1, 0});
+  const std::vector<int> feature_index_1({0, 0, 0, 0, 0, 1, 1, 1, 1, 1});
+  const std::vector<int> feature_ids_1({5, 6, 3, 5, 7, 5, 6, 3, 5, 7});
+  const std::vector<float> feature_weights_1(
+      {1.0, 0.7, 0.1, 0.5, 1.0, 10, 7, 1, 5, 10});
+
+  const std::vector<int> feature_step_2({0, 1, 2, 1, 2});
+  const std::vector<int> feature_index_2({0, 0, 0, 0, 0});
+  const std::vector<int> feature_ids_2({5, 6, 3, 5, 7});
+  const std::vector<float> feature_weights_2({1.0, 0.7, 0.1, 0.5, 1.0});
+
+  const std::vector<std::vector<int>> feature_steps_by_channel(
+      {feature_step_1, feature_step_2});
+  const std::vector<std::vector<int>> feature_index_by_channel(
+      {feature_index_1, feature_index_2});
+  const std::vector<std::vector<int>> feature_ids_by_channel(
+      {feature_ids_1, feature_ids_2});
+  const std::vector<std::vector<float>> feature_weights_by_channel(
+      {feature_weights_1, feature_weights_2});
+
+  // This function takes the allocator functions passed into GetBulkFF, uses
+  // them to allocate a tensor, then fills that tensor based on channel.
+  auto assigner_function = [=](string, const BulkFeatureExtractor &extractor) {
+    constexpr int kNumElements = 3;
+    constexpr int kNumSteps = 3;
+    for (int i = 0; i < kNumChannels; ++i) {
+      auto feature_step = feature_steps_by_channel.at(i);
+      auto feature_index = feature_index_by_channel.at(i);
+      auto feature_ids = feature_ids_by_channel.at(i);
+      auto feature_weights = feature_weights_by_channel.at(i);
+
+      // Allocate a new tensor set for every channel.
+      int32 *indices =
+          extractor.AllocateIndexMemory(i, kNumElements * feature_step.size());
+      int64 *ids =
+          extractor.AllocateIdMemory(i, kNumElements * feature_step.size());
+      float *weights =
+          extractor.AllocateWeightMemory(i, kNumElements * feature_step.size());
+
+      // Fill the tensor.
+      int array_index = 0;
+
+      for (int element = 0; element < kNumElements; ++element) {
+        for (int feature = 0; feature < feature_step.size(); ++feature) {
+          indices[array_index] = extractor.GetIndex(
+              kNumSteps, kNumElements, feature_index[feature], element,
+              feature_step[feature]);
+          ids[array_index] = feature_ids[feature];
+          weights[array_index] = feature_weights[feature];
+          ++array_index;
+        }
+      }
+    }
+    return kNumSteps;
+  };
+
+  EXPECT_CALL(*mock_session, BulkGetInputFeatures(kComponentName, _))
+      .WillOnce(testing::Invoke(assigner_function));
+
+  // Embedding matrices as additional inputs.
+  // For channel 0, the embeddings are [id, 0].
+  // For channel 1, the embeddings are [0, id].
+  vector<float> embedding_matrix_a;
+  vector<float> embedding_matrix_b;
+  for (int id = 0; id < kNumIds; ++id) {
+    embedding_matrix_a.push_back(id);
+    embedding_matrix_a.push_back(0);
+    embedding_matrix_b.push_back(0);
+    embedding_matrix_b.push_back(id);
+  }
+  AddInputFromArray<float>(TensorShape({8, 2}), embedding_matrix_a);
+  AddInputFromArray<float>(TensorShape({8, 2}), embedding_matrix_b);
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  // In this case we should have, for every item, these three steps:
+  const vector<vector<float>> expected_embeddings = {{5.0, 0, 73, 0, 0, 5.0},
+                                                     {6.7, 0, 67, 0, 0, 6.7},
+                                                     {7.3, 0, 50, 0, 0, 7.3}};
+  EXPECT_EQ(kNumSteps * kNumItems, GetOutput(1)->shape().dim_size(0));
+  constexpr int kNumFeatures = 3;
+  EXPECT_EQ(kNumFeatures * kEmbeddingSize, GetOutput(1)->shape().dim_size(1));
+  for (int item = 0; item < kNumItems; ++item) {
+    for (int step = 0; step < kNumSteps; ++step) {
+      for (int col = 0; col < kNumChannels * kEmbeddingSize; ++col) {
+        const int row = item * kNumSteps + step;
+        EXPECT_EQ(expected_embeddings[step][col],
+                  GetOutput(1)->matrix<float>()(row, col))
+            << "step: " << step << ", row: " << row << ", col: " << col;
+      }
+    }
+  }
+
+  EXPECT_EQ(kNumSteps, GetOutput(2)->scalar<int32>()());
+}
+
+TEST_F(DragnnBulkOpKernelsTest, BulkFixedEmbeddingsWithPadding) {
+  // Create and initialize the kernel under test.
+  constexpr int kPaddedNumSteps = 5;
+  constexpr int kPaddedBatchSize = 4;
+  TF_ASSERT_OK(
+      NodeDefBuilder("BulkFixedEmbeddings", "BulkFixedEmbeddings")
+          .Attr("component", kComponentName)
+          .Attr("num_channels", kNumChannels)
+          .Attr("pad_to_steps", kPaddedNumSteps)
+          .Attr("pad_to_batch", kPaddedBatchSize)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_FLOAT))   // Embedding matrices.
+          .Finalize(node_def()));
+  MockComputeSession *mock_session = GetMockSession();
+  ComponentSpec spec;
+  spec.set_name(kComponentName);
+  auto chan0_spec = spec.add_fixed_feature();
+  chan0_spec->set_size(2);
+  auto chan1_spec = spec.add_fixed_feature();
+  chan1_spec->set_size(1);
+  EXPECT_CALL(*mock_session, Spec(kComponentName))
+      .WillOnce(testing::ReturnRef(spec));
+
+  EXPECT_CALL(*mock_session, BatchSize(kComponentName))
+      .WillOnce(Return(kNumItems));
+
+  const std::vector<int> feature_step_1({0, 1, 2, 1, 2, 2, 1, 0, 1, 0});
+  const std::vector<int> feature_index_1({0, 0, 0, 0, 0, 1, 1, 1, 1, 1});
+  const std::vector<int> feature_ids_1({5, 6, 3, 5, 7, 5, 6, 3, 5, 7});
+  const std::vector<float> feature_weights_1(
+      {1.0, 0.7, 0.1, 0.5, 1.0, 10, 7, 1, 5, 10});
+
+  const std::vector<int> feature_step_2({0, 1, 2, 1, 2});
+  const std::vector<int> feature_index_2({0, 0, 0, 0, 0});
+  const std::vector<int> feature_ids_2({5, 6, 3, 5, 7});
+  const std::vector<float> feature_weights_2({1.0, 0.7, 0.1, 0.5, 1.0});
+
+  const std::vector<std::vector<int>> feature_steps_by_channel(
+      {feature_step_1, feature_step_2});
+  const std::vector<std::vector<int>> feature_index_by_channel(
+      {feature_index_1, feature_index_2});
+  const std::vector<std::vector<int>> feature_ids_by_channel(
+      {feature_ids_1, feature_ids_2});
+  const std::vector<std::vector<float>> feature_weights_by_channel(
+      {feature_weights_1, feature_weights_2});
+
+  // This function takes the allocator functions passed into GetBulkFF, uses
+  // them to allocate a tensor, then fills that tensor based on channel.
+  auto assigner_function = [=](string, const BulkFeatureExtractor &extractor) {
+    constexpr int kNumElements = 3;
+    constexpr int kNumSteps = 3;
+    for (int i = 0; i < kNumChannels; ++i) {
+      auto feature_step = feature_steps_by_channel.at(i);
+      auto feature_index = feature_index_by_channel.at(i);
+      auto feature_ids = feature_ids_by_channel.at(i);
+      auto feature_weights = feature_weights_by_channel.at(i);
+
+      // Allocate a new tensor set for every channel.
+      int32 *indices =
+          extractor.AllocateIndexMemory(i, kNumElements * feature_step.size());
+      int64 *ids =
+          extractor.AllocateIdMemory(i, kNumElements * feature_step.size());
+      float *weights =
+          extractor.AllocateWeightMemory(i, kNumElements * feature_step.size());
+
+      // Fill the tensor.
+      int array_index = 0;
+
+      for (int element = 0; element < kNumElements; ++element) {
+        for (int feature = 0; feature < feature_step.size(); ++feature) {
+          indices[array_index] = extractor.GetIndex(
+              kNumSteps, kNumElements, feature_index[feature], element,
+              feature_step[feature]);
+          ids[array_index] = feature_ids[feature];
+          weights[array_index] = feature_weights[feature];
+          ++array_index;
+        }
+      }
+    }
+    return kNumSteps;
+  };
+
+  EXPECT_CALL(*mock_session, BulkGetInputFeatures(kComponentName, _))
+      .WillOnce(testing::Invoke(assigner_function));
+
+  // Embedding matrices as additional inputs.
+  // For channel 0, the embeddings are [id, 0].
+  // For channel 1, the embeddings are [0, id].
+  vector<float> embedding_matrix_a;
+  vector<float> embedding_matrix_b;
+  for (int id = 0; id < kNumIds; ++id) {
+    embedding_matrix_a.push_back(id);
+    embedding_matrix_a.push_back(0);
+    embedding_matrix_b.push_back(0);
+    embedding_matrix_b.push_back(id);
+  }
+  AddInputFromArray<float>(TensorShape({8, 2}), embedding_matrix_a);
+  AddInputFromArray<float>(TensorShape({8, 2}), embedding_matrix_b);
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  // In this case we should have, for every item, these three steps:
+  const vector<vector<float>> expected_embeddings = {{5.0, 0, 73, 0, 0, 5.0},
+                                                     {6.7, 0, 67, 0, 0, 6.7},
+                                                     {7.3, 0, 50, 0, 0, 7.3}};
+  EXPECT_EQ(kPaddedNumSteps * kPaddedBatchSize,
+            GetOutput(1)->shape().dim_size(0));
+
+  constexpr int kNumFeatures = 3;
+  EXPECT_EQ(kNumFeatures * kEmbeddingSize, GetOutput(1)->shape().dim_size(1));
+  for (int item = 0; item < kNumItems; ++item) {
+    for (int step = 0; step < kNumSteps; ++step) {
+      for (int col = 0; col < kNumChannels * kEmbeddingSize; ++col) {
+        const int row = item * kPaddedNumSteps + step;
+        EXPECT_EQ(expected_embeddings[step][col],
+                  GetOutput(1)->matrix<float>()(row, col))
+            << "step: " << step << ", row: " << row << ", col: " << col;
+      }
+    }
+  }
+
+  EXPECT_EQ(kNumSteps, GetOutput(2)->scalar<int32>()());
+}
+
+TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromOracle) {
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("BulkAdvanceFromOracle", "BulkAdvanceFromOracle")
+          .Attr("component", kComponentName)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  MockComputeSession *mock_session = GetMockSession();
+  EXPECT_CALL(*mock_session, IsTerminal(kComponentName))
+      .WillOnce(Return(false))
+      .WillOnce(Return(false))
+      .WillOnce(Return(false))
+      .WillOnce(Return(true));
+  EXPECT_CALL(*mock_session, AdvanceFromOracle(kComponentName))
+      .Times(kNumSteps);
+  const vector<vector<vector<int32>>> gold = {
+      {{1}, {1}, {1}}, {{2}, {2}, {2}}, {{3}, {3}, {3}},
+  };
+  EXPECT_CALL(*mock_session, EmitOracleLabels(kComponentName))
+      .WillOnce(Return(gold[0]))
+      .WillOnce(Return(gold[1]))
+      .WillOnce(Return(gold[2]));
+  EXPECT_CALL(*mock_session, BeamSize(kComponentName)).WillOnce(Return(1));
+  EXPECT_CALL(*mock_session, BatchSize(kComponentName))
+      .WillOnce(Return(kNumItems));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  // For every item we should have:
+  const vector<int32> expected_gold = {1, 2, 3};
+  EXPECT_EQ(kNumSteps * kNumItems, GetOutput(1)->NumElements());
+  for (int item = 0; item < kNumItems; ++item) {
+    for (int step = 0; step < kNumSteps; ++step) {
+      EXPECT_EQ(expected_gold[step],
+                GetOutput(1)->vec<int32>()(step + item * kNumSteps));
+    }
+  }
+}
+
+string ArrayToString(const float *array, const int size) {
+  string str = "[ ";
+  for (int i = 0; i < size; ++i) {
+    str += tensorflow::strings::Printf("%.1f ", array[i]);
+  }
+  return str + "]";
+}
+
+MATCHER(CheckScoresAreConsecutiveIntegersDivTen, "") {
+  const int size =
+      DragnnBulkOpKernelsTest::kNumItems * DragnnBulkOpKernelsTest::kNumActions;
+  for (int i(0), score(arg[0] * 10); i < size; ++i, ++score) {
+    EXPECT_NEAR(score / 10.0f, arg[i], 1e-4)
+        << "i: " << i << ", scores: " << ArrayToString(arg, size);
+  }
+  return true;
+}
+
+TEST_F(DragnnBulkOpKernelsTest, BulkAdvanceFromPrediction) {
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("BulkAdvanceFromPrediction", "BulkAdvanceFromPrediction")
+          .Attr("component", kComponentName)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_FLOAT))   // Prediction scores for advancing.
+          .Finalize(node_def()));
+  MockComputeSession *mock_session = GetMockSession();
+
+  // Creates an input tensor such that each step will see a list of consecutive
+  // integers divided by 10 as scores.
+  vector<float> scores(kNumItems * kNumSteps * kNumActions);
+  for (int step(0), cnt(0); step < kNumSteps; ++step) {
+    for (int item = 0; item < kNumItems; ++item) {
+      for (int action = 0; action < kNumActions; ++action, ++cnt) {
+        scores[action + kNumActions * (step + item * kNumSteps)] = cnt / 10.0f;
+      }
+    }
+  }
+  AddInputFromArray<float>(TensorShape({kNumItems * kNumSteps, kNumActions}),
+                           scores);
+
+  EXPECT_CALL(*mock_session, BeamSize(kComponentName)).WillOnce(Return(1));
+  EXPECT_CALL(*mock_session, BatchSize(kComponentName))
+      .WillOnce(Return(kNumItems));
+  EXPECT_CALL(*mock_session, IsTerminal(kComponentName))
+      .Times(kNumSteps)
+      .WillRepeatedly(Return(false));
+  EXPECT_CALL(*mock_session,
+              AdvanceFromPrediction(kComponentName,
+                                    CheckScoresAreConsecutiveIntegersDivTen(),
+                                    kNumItems * kNumActions))
+      .Times(kNumSteps);
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 115 - 0
syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc

@@ -0,0 +1,115 @@
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+REGISTER_OP("BulkFixedFeatures")
+    .Input("handle: string")
+    .Output("output_handle: string")
+    .Output("indices: num_channels * int32")
+    .Output("ids: num_channels * int64")
+    .Output("weights: num_channels * float")
+    .Output("num_steps: int32")
+    .Attr("component: string")
+    .Attr("num_channels: int")
+    .Doc(R"doc(
+Given a ComputeSession and a component, outputs fixed features for all steps.
+
+This op outputs features for the entire oracle path of the component. Unlike
+ExtractFixedFeatures, this op mutates the master state, advancing all of its
+states until they are final. For every channel, indices[channel], ids[channel],
+and weights[channel] have the same length, ie. the number of predicates,
+ordered by batch, beam, step.
+
+handle: A handle to a ComputeSession.
+output_handle: A handle to the same ComputeSession after advancement.
+indices: (num_channels vectors of int32) If indices[i] = j, then
+  embedding_sum[j] += embedding_matrix[ids[i]] * weights[i].
+ids: (num_channels vectors of int64) Ids to lookup in embedding matrices.
+weights: (num_channels vectors of float) Weight for each embedding.
+num_steps: (int32 scalar) The batch was unrolled for this many steps.
+component: The name of a Component instance, matching the ComponentSpec.name.
+num_channels: The number of FixedFeature channels.
+)doc");
+
+REGISTER_OP("BulkFixedEmbeddings")
+    .Input("handle: string")
+    .Input("embedding_matrix: num_channels * T")
+    .Output("output_handle: string")
+    .Output("embedding_vectors: T")
+    .Output("num_steps: int32")
+    .Attr("component: string")
+    .Attr("num_channels: int")
+    .Attr("T: type")
+    .Attr("pad_to_batch: int=-1")
+    .Attr("pad_to_steps: int=-1")
+    .SetIsStateful()
+    .Doc(R"doc(
+This op is a more efficient version of BulkFixedFeatures.
+
+It is intended to be run with large batch sizes at inference time. The op takes
+a handle to ComputeSession and embedding matrices as tensor inputs, and directly
+outputs concatenated embedding vectors.
+
+handle: A handle to ComputeSession.
+embedding_matrix: Embedding matrices.
+output_handle: A handle to the same ComputeSession after advancement.
+embedding_vectors: (matrix of float) Concatenated embeddings,
+  shaped as (batch * beam * token) x sum_channel(embedding_dim[channel]).
+num_steps: The batch was unrolled for these many steps.
+component: The name of a Component instance, matching the ComponentSpec.name.
+num_channels: The number of FixedFeature channels.
+T: The datatype to emit.
+pad_to_batch: If set, the op will pad/truncate to this number of elements.
+pad_to_steps: If set, the op will pad/truncate to this number of steps.
+)doc");
+
+REGISTER_OP("BulkAdvanceFromOracle")
+    .Input("handle: string")
+    .Output("output_handle: string")
+    .Output("gold_labels: int32")
+    .Attr("component: string")
+    .Doc(R"doc(
+Given a ComputeSession, advances until all states are final.
+
+Note that, unlike AdvanceFromOracle, this op does mutate the master state, by
+advancing all of its states until they are final.
+
+handle: A handle to a ComputeSession.
+output_handle: A handle to the same ComputeSession, after it has advanced.
+gold_labels: [batch_size * beam_size * max_num_steps] vector of oracle actions,
+             where max_num_steps is the maximum number of steps in the oracle
+             action sequences for every state in the batch of beams.  Each
+             sub-segment of length max_num_steps provides the oracle action
+             sequence for the corresponding state in the batch of beams, padded
+             with trailing -1s.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+REGISTER_OP("BulkAdvanceFromPrediction")
+    .Input("handle: string")
+    .Input("scores: T")
+    .Output("output_handle: string")
+    .Attr("component: string")
+    .Attr("T: type")
+    .SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
+      auto scores = context->input(1);
+      TF_RETURN_IF_ERROR(context->WithRank(scores, 2, &scores));
+      return tensorflow::Status::OK();
+    })
+    .Doc(R"doc(
+Given a ComputeSession and a tensor of scores, advances the state.
+
+The state will be advanced until all scores are used up or all states are final.
+
+handle: A handle to a ComputeSession.
+scores: A tensor of scores with shape
+        {batch_size * beam_size * num_steps, num_actions}.
+output_handle: handle to the same ComputeSession after advancement.
+component: The name of a Component instance, matching the ComponentSpec.name.
+T: The datatype to emit.
+)doc");
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 631 - 0
syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc

@@ -0,0 +1,631 @@
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "dragnn/core/compute_session.h"
+#include "dragnn/core/compute_session_pool.h"
+#include "dragnn/core/ops/compute_session_op.h"
+#include "dragnn/core/resource_container.h"
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/mutex.h"
+
+using tensorflow::DEVICE_CPU;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_INT32;
+using tensorflow::DT_INT64;
+using tensorflow::DT_STRING;
+using tensorflow::DataType;
+using tensorflow::OpKernel;
+using tensorflow::OpKernelConstruction;
+using tensorflow::OpKernelContext;
+using tensorflow::ResourceMgr;
+using tensorflow::Status;
+using tensorflow::Tensor;
+using tensorflow::TensorShape;
+
+namespace syntaxnet {
+namespace dragnn {
+
+typedef ResourceContainer<ComputeSession> ComputeSessionResource;
+typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
+
+// Given a MasterSpec proto, outputs a handle to a ComputeSession.
+class GetSession : public OpKernel {
+ public:
+  explicit GetSession(OpKernelConstruction *context) : OpKernel(context) {
+    string master_spec_str;
+    string grid_point_spec_str;
+    OP_REQUIRES_OK(context, context->GetAttr("master_spec", &master_spec_str));
+    OP_REQUIRES_OK(context,
+                   context->GetAttr("grid_point", &grid_point_spec_str));
+    CHECK(master_spec_.ParseFromString(master_spec_str));
+    CHECK(grid_point_.ParseFromString(grid_point_spec_str));
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    const string container = context->input(0).scalar<string>()();
+    ResourceMgr *rmgr = context->resource_manager();
+
+    // Create the pool for this container, or re-use one that was allocated in a
+    // previous call.
+    auto create_pool = [this,
+                        &container](ComputeSessionPoolResource **resource) {
+      LOG(INFO) << "Creating new ComputeSessionPool in container handle: "
+      << container;
+      std::unique_ptr<ComputeSessionPool> pool(
+          new ComputeSessionPool(master_spec_, grid_point_));
+      *resource = new ComputeSessionPoolResource(std::move(pool));
+      return Status::OK();
+    };
+
+    ComputeSessionPoolResource *pool_resource;
+
+    // Synchronize access to the resource manager when getting or creating the
+    // ComputeSessionPool.
+    // Scoping for minimal mutex locking.
+    {
+      mutex_lock lock(lock_);
+      OP_REQUIRES_OK(context,
+                     rmgr->LookupOrCreate<ComputeSessionPoolResource>(
+                         container, "pool", &pool_resource, create_pool));
+    }
+    ComputeSessionPool *pool = pool_resource->get();
+    CHECK(pool != nullptr);
+
+    // Get a new Session for this computation from the pool.
+    std::unique_ptr<ComputeSession> session = pool->GetSession();
+    const string id = std::to_string(session->Id());
+
+    // Store it in the ResourceManager.
+    OP_REQUIRES_OK(
+        context,
+        rmgr->Create<ComputeSessionResource>(
+            container, id, new ComputeSessionResource(std::move(session))));
+
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({2}), &output));
+    output->vec<string>()(0) = container;
+    output->vec<string>()(1) = id;
+
+    // Unref the pool so it gets destroyed properly.
+    pool_resource->Unref();
+    VLOG(1) << "Returning session: " << id;
+  }
+
+ private:
+  MasterSpec master_spec_;
+  GridPoint grid_point_;
+
+  // Mutex that serializes accesses to the resource manager. (These would block
+  // in the compute session pool anyways, so there's no regression there, and
+  // we need to protect from racy multiple initialization.)
+  tensorflow::mutex lock_;
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GetSession);
+};
+
+REGISTER_KERNEL_BUILDER(Name("GetSession").Device(DEVICE_CPU), GetSession);
+
+// Given a handle to a ComputeSession, returns it to the pool. As long as we
+// start with "GetSession", DRAGNN graphs are thread-safe and there is no need
+// for explicit multi-thread logic. As long as we end with "ReleaseSession",
+// then memory usage will be constrained to the maximum number of concurrent
+// requests.
+class ReleaseSession : public OpKernel {
+ public:
+  explicit ReleaseSession(OpKernelConstruction *context) : OpKernel(context) {
+    string master_spec_str;
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {}));
+  }
+
+  void Compute(OpKernelContext *context) override {
+    auto handle = context->input(0).vec<string>();
+    const string &container = handle(0);
+    const string &id = handle(1);
+    VLOG(1) << "Releasing session: " << id;
+    ResourceMgr *rmgr = context->resource_manager();
+
+    // Get the pool for this container.
+    ComputeSessionPoolResource *pool_resource;
+    TF_CHECK_OK(rmgr->Lookup<ComputeSessionPoolResource>(container, "pool",
+                                                         &pool_resource));
+    auto *pool = pool_resource->get();
+    CHECK(pool != nullptr);
+
+    // Get the compute session.
+    ComputeSessionResource *session_resource;
+    TF_CHECK_OK(
+        rmgr->Lookup<ComputeSessionResource>(container, id, &session_resource));
+
+    // We need to release the ComputeSession from both the ResourceMgr and
+    // the ComputeSessionPool. The order of release is critical. If the
+    // resource is not first Delete()-ed from the ResourceMgr, then another
+    // thread may try to Create() the same resource, resulting in an
+    // "Already exists" error.
+    //
+    // First, delete the ResourceMgr reference so it can be used in the future.
+    TF_CHECK_OK(rmgr->Delete<ComputeSessionResource>(container, id));
+
+    // Second, return the ComputeSession to the pool.
+    pool->ReturnSession(session_resource->release());
+
+    // Unref the resources so they get destroyed properly.
+    session_resource->Unref();
+    pool_resource->Unref();
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(ReleaseSession);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ReleaseSession").Device(DEVICE_CPU),
+                        ReleaseSession);
+
+/*******************************************************************************
+ *                   ComputeSessionOps below here.
+ ******************************************************************************/
+
+// Given a handle to a BatchedBeamComponentState, advances based on the next
+// oracle (gold) action.
+class AdvanceFromOracle : public ComputeSessionOp {
+ public:
+  explicit AdvanceFromOracle(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    session->AdvanceFromOracle(component_name());
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(AdvanceFromOracle);
+};
+
+REGISTER_KERNEL_BUILDER(Name("AdvanceFromOracle").Device(DEVICE_CPU),
+                        AdvanceFromOracle);
+
+// Given a handle to a BatchedBeamComponentState and a tensor of scores,
+// advances the state. The tensor of scores has shape batch_size x beam_size
+// x num_actions.
+class AdvanceFromPrediction : public ComputeSessionOp {
+ public:
+  explicit AdvanceFromPrediction(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING, DT_FLOAT}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    const Tensor &scores = context->input(1);
+    session->AdvanceFromPrediction(component_name(),
+                                   scores.tensor<float, 2>().data(),
+                                   scores.NumElements());
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(AdvanceFromPrediction);
+};
+
+REGISTER_KERNEL_BUILDER(Name("AdvanceFromPrediction").Device(DEVICE_CPU),
+                        AdvanceFromPrediction);
+
+// Given a handle to a ComputeSession and a channel index, outputs fixed
+// features.
+// Fixed features are returned as 3 vectors or equal length:
+//   - ids: specifies which rows should be looked up in the embedding
+//   matrix,
+//   - weights: specifies a scale for each embedding vector,
+//   - indices: sorted vector that assigns the same index to embedding
+//   vectors
+//       that should be summed together.
+//
+// For example if we have 3 features, for a given channel, we might have:
+//   feature a: (5, 1)
+//   feature b: (5, 0.5), (6, 0.5)
+//   feature c: (7, 1)
+// In this case:
+//   indices should look like: [0, 1, 1, 2]
+//   ids should be [5, 5, 6, 7]
+//   weights should be [1, 0.5, 0.5, 1]
+class ExtractFixedFeatures : public ComputeSessionOp {
+ public:
+  explicit ExtractFixedFeatures(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channel_id_));
+    OP_REQUIRES_OK(context, context->MatchSignature(
+                                {DT_STRING}, {DT_INT32, DT_INT64, DT_FLOAT}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    // Allocates output tensors.
+    auto indices_allocator = [context](int num_elements) {
+      Tensor *output;
+      CHECK(context->allocate_output(0, TensorShape({num_elements}), &output)
+                .ok());
+      return output->vec<int32>().data();
+    };
+    auto ids_allocator = [context](int num_elements) {
+      Tensor *ids_tensor;
+      CHECK(
+          context->allocate_output(1, TensorShape({num_elements}), &ids_tensor)
+              .ok());
+      return ids_tensor->vec<int64>().data();
+    };
+    auto weights_allocator = [context](int num_elements) {
+      Tensor *output;
+      CHECK(context->allocate_output(2, TensorShape({num_elements}), &output)
+                .ok());
+      return output->vec<float>().data();
+    };
+    int num_features = session->GetInputFeatures(
+        component_name(), indices_allocator, ids_allocator, weights_allocator,
+        channel_id_);
+    VLOG(2) << "Extracted " << num_features;
+  }
+
+ private:
+  int channel_id_;
+  TF_DISALLOW_COPY_AND_ASSIGN(ExtractFixedFeatures);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExtractFixedFeatures").Device(DEVICE_CPU),
+                        ExtractFixedFeatures);
+
+// Given a handle to a ComputeSession and a channel index, outputs link
+// features. Link features are returned as two vectors of length batch_size *
+// beam_size * channel_size:
+//   - step_idx: specifies the element to read in a tensor array of activations,
+//   - idx: specifies the row within the tensor array element.
+class ExtractLinkFeatures : public ComputeSessionOp {
+ public:
+  explicit ExtractLinkFeatures(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channel_id_));
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING}, {DT_INT32, DT_INT32}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    auto features =
+        session->GetTranslatedLinkFeatures(component_name(), channel_id_);
+
+    // Computes output size.
+    const int64 num_indices = features.size();
+
+    // Allocates output tensors.
+    Tensor *step_idx_output;
+    Tensor *idx_output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({num_indices}),
+                                            &step_idx_output));
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                1, TensorShape({num_indices}), &idx_output));
+
+    const int source_beam_size =
+        session->SourceComponentBeamSize(component_name(), channel_id_);
+    VLOG(2) << "source_beam_size:" << source_beam_size;
+
+    // Clip step_idx for all features. If a feature is empty, set the step
+    // index to -1.
+    for (int i = 0; i < features.size(); ++i) {
+      if (!features[i].has_step_idx() || features[i].step_idx() < -1) {
+        features[i].set_step_idx(-1);
+      }
+    }
+
+    // Fills output tensors.
+    for (int i = 0; i < features.size(); ++i) {
+      // Sets the element to read from a tensor array of activations.
+      step_idx_output->vec<int32>()(i) = features[i].step_idx();
+
+      // Within the tensor array element the id has to account for beam index
+      // and batch index for this specific component state.
+      idx_output->vec<int32>()(i) =
+          features[i].step_idx() >= 0
+              ? OutputLinearIndex(features[i], source_beam_size)
+              : 0;
+
+      VLOG(2) << "features[" << i << "]: " << features[i].ShortDebugString();
+    }
+  }
+
+ private:
+  // Given the beam index and the batch index in a LinkFeatures proto, returns
+  // the corresponding linear index, assuming that the matrix we're indexing
+  // into has shape {batch_size * beam_size, activation_size}, reshaped from a
+  // tensor of shape {batch_size, beam_size, activation_size}.
+  static uint64 OutputLinearIndex(const LinkFeatures &feature,
+                                  const int beam_size) {
+    VLOG(2) << "OutputLinearIndex batch_idx:" << feature.batch_idx()
+            << " beam_size:" << beam_size << " beam_idx:" << feature.beam_idx();
+    return feature.batch_idx() * beam_size + feature.beam_idx();
+  }
+
+  int channel_id_;
+  TF_DISALLOW_COPY_AND_ASSIGN(ExtractLinkFeatures);
+};
+
+REGISTER_KERNEL_BUILDER(Name("ExtractLinkFeatures").Device(DEVICE_CPU),
+                        ExtractLinkFeatures);
+
+// Given a handle to a BatchedBeamComponentState, emits a vector of gold
+// labels.
+// The vector of gold labels has size batch_size * beam_size.
+class EmitOracleLabels : public ComputeSessionOp {
+ public:
+  explicit EmitOracleLabels(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_INT32}));
+  }
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    VLOG(2) << "state->BatchSize: " << session->BatchSize(component_name());
+    VLOG(2) << "state->BeamSize: " << session->BeamSize(component_name());
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(
+                       0,
+                       TensorShape({session->BatchSize(component_name()) *
+                                    session->BeamSize(component_name())}),
+                       &output));
+    std::vector<std::vector<int>> batched_labels =
+        session->EmitOracleLabels(component_name());
+    int raw_index = 0;
+    for (const auto &batch_vector : batched_labels) {
+      for (const auto &label : batch_vector) {
+        output->vec<int32>()(raw_index) = label;
+        ++raw_index;
+      }
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EmitOracleLabels);
+};
+
+REGISTER_KERNEL_BUILDER(Name("EmitOracleLabels").Device(DEVICE_CPU),
+                        EmitOracleLabels);
+
+// Given a handle to a ComponentState, emits a single bool indicating
+// whether all elements in the batch contain beams containing all final states.
+class EmitAllFinal : public ComputeSessionOp {
+ public:
+  explicit EmitAllFinal(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_BOOL}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({1}), &output));
+    const bool is_terminal = session->IsTerminal(component_name());
+    VLOG(2) << "EmitAllFinal: is_terminal = " << is_terminal;
+    output->vec<bool>()(0) = is_terminal;
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EmitAllFinal);
+};
+
+REGISTER_KERNEL_BUILDER(Name("EmitAllFinal").Device(DEVICE_CPU), EmitAllFinal);
+
+// Prepares the given component for computation.
+class InitComponentData : public ComputeSessionOp {
+ public:
+  explicit InitComponentData(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING, DT_INT32}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    const int beam_size = context->input(1).scalar<int32>()();
+    session->InitializeComponentData(component_name(), beam_size);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("InitComponentData").Device(DEVICE_CPU),
+                        InitComponentData);
+
+// Returns the given component's batch size.
+class BatchSize : public ComputeSessionOp {
+ public:
+  explicit BatchSize(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_INT32}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    Tensor *output;
+    OP_REQUIRES_OK(context,
+                   context->allocate_output(0, TensorShape({}), &output));
+    output->scalar<int>()() = session->BatchSize(component_name());
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("BatchSize").Device(DEVICE_CPU), BatchSize);
+
+// Attaches a data source to the master.
+class AttachDataReader : public ComputeSessionOp {
+ public:
+  explicit AttachDataReader(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(
+        context, context->MatchSignature({DT_STRING, DT_STRING}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return false; }
+
+  // Calls SetInputData() on the ComputeSession.
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    auto input_data(context->input(1).vec<string>());
+
+    std::vector<string> data;
+    for (int i = 0; i < input_data.size(); ++i) {
+      data.push_back(input_data(i));
+    }
+    session->SetInputData(data);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("AttachDataReader").Device(DEVICE_CPU),
+                        AttachDataReader);
+
+// Sets the tracing flag on the master state, which will enable or disable
+// tracing as inference / training is run.
+class SetTracing : public ComputeSessionOp {
+ public:
+  explicit SetTracing(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context,
+                   context->MatchSignature({DT_STRING, DT_BOOL}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return false; }
+
+  // Calls SetTracing() on the ComputeSession.
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    auto tracing_on = context->input(1).scalar<bool>()();
+    session->SetTracing(tracing_on);
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("SetTracing").Device(DEVICE_CPU), SetTracing);
+
+class WriteAnnotations : public ComputeSessionOp {
+ public:
+  explicit WriteAnnotations(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return true; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    session->FinalizeData(component_name());
+  }
+};
+
+REGISTER_KERNEL_BUILDER(Name("WriteAnnotations").Device(DEVICE_CPU),
+                        WriteAnnotations);
+
+// Given a handle to a ComponentState, emits a vector of strings
+// corresponding to the serialized predictions of the model.
+class EmitAnnotations : public ComputeSessionOp {
+ public:
+  explicit EmitAnnotations(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    // Get the annotations from the state.
+    auto annotations = session->GetSerializedPredictions();
+
+    // Copy annotations to the output.
+    Tensor *output;
+    const int64 output_size = annotations.size();
+    OP_REQUIRES_OK(context, context->allocate_output(
+                                0, TensorShape({output_size}), &output));
+    auto annotations_output = output->vec<string>();
+    for (int i = 0; i < annotations.size(); ++i) {
+      annotations_output(i) = annotations[i];
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(EmitAnnotations);
+};
+
+REGISTER_KERNEL_BUILDER(Name("EmitAnnotations").Device(DEVICE_CPU),
+                        EmitAnnotations);
+
+// Get the component trace.
+class GetComponentTrace : public ComputeSessionOp {
+ public:
+  explicit GetComponentTrace(OpKernelConstruction *context)
+      : ComputeSessionOp(context) {
+    OP_REQUIRES_OK(context, context->MatchSignature({DT_STRING}, {DT_STRING}));
+  }
+
+  bool OutputsHandle() const override { return false; }
+  bool RequiresComponentName() const override { return true; }
+
+  void ComputeWithState(OpKernelContext *context,
+                        ComputeSession *session) override {
+    auto traces = session->GetTraceProtos();
+
+    const int64 size = traces.size();
+    Tensor *trace_output_tensor;
+    OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape({size}),
+                                                     &trace_output_tensor));
+    auto trace_output = trace_output_tensor->vec<string>();
+    for (int i = 0; i < size; ++i) {
+      CHECK(traces[i].SerializeToString(&trace_output(i)));
+    }
+  }
+
+  TF_DISALLOW_COPY_AND_ASSIGN(GetComponentTrace);
+};
+
+REGISTER_KERNEL_BUILDER(Name("GetComponentTrace").Device(DEVICE_CPU),
+                        GetComponentTrace);
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 851 - 0
syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc

@@ -0,0 +1,851 @@
+#include <functional>
+#include <memory>
+#include <vector>
+
+#include "dragnn/core/compute_session.h"
+#include "dragnn/core/compute_session_pool.h"
+#include "dragnn/core/resource_container.h"
+#include "dragnn/core/test/generic.h"
+#include "dragnn/core/test/mock_compute_session.h"
+
+#include <gmock/gmock.h>
+
+#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/control_flow.h"
+#include "tensorflow/core/framework/fake_input.h"
+#include "tensorflow/core/framework/node_def_builder.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/kernels/ops_testutil.h"
+#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::AllocatorAttributes;
+using tensorflow::checkpoint::TensorSliceReaderCacheWrapper;
+using tensorflow::DT_BOOL;
+using tensorflow::DT_FLOAT;
+using tensorflow::DT_STRING;
+using tensorflow::DT_INT32;
+using tensorflow::FrameAndIter;
+using tensorflow::DataType;
+using tensorflow::NodeDefBuilder;
+using tensorflow::OpKernelContext;
+using tensorflow::ResourceMgr;
+using tensorflow::ScopedStepContainer;
+using tensorflow::Status;
+using tensorflow::test::SetOutputAttrs;
+using tensorflow::TensorShape;
+
+using testing::_;
+using testing::ElementsAreArray;
+using testing::Invoke;
+using testing::Pointwise;
+using testing::Return;
+
+typedef ResourceContainer<ComputeSession> ComputeSessionResource;
+typedef ResourceContainer<ComputeSessionPool> ComputeSessionPoolResource;
+
+class DragnnOpKernelsTest : public tensorflow::OpsTestBase {
+ public:
+  void ResetOpKernelContext() {
+    params_.reset(new OpKernelContext::Params);
+    params_->device = device_.get();
+    params_->frame_iter = FrameAndIter(0, 0);
+    params_->inputs = &inputs_;
+    params_->op_kernel = kernel_.get();
+    step_container_.reset(new ScopedStepContainer(0, [](const string &) {}));
+    params_->step_container = step_container_.get();
+    attrs_.clear();
+    SetOutputAttrs(params_.get(), &attrs_);
+    TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
+    params_->slice_reader_cache = &slice_reader_cache_wrapper;
+    params_->resource_manager = device_->resource_manager();
+    context_.reset(new OpKernelContext(params_.get()));
+  }
+
+  Status RunOpKernelWithContext() {
+    device_->Compute(kernel_.get(), context_.get());
+    return context_->status();
+  }
+
+  // Accessor for the underlying resource manager.
+  ResourceMgr *resource_mgr() { return params_->resource_manager; }
+
+  // This needs to maintain its existence throughout the compute call.
+  std::vector<AllocatorAttributes> attrs_;
+};
+
+// Helper function to build LinkFeatures.
+LinkFeatures MakeFeatures(int batch_index, int beam_index, int step) {
+  LinkFeatures features;
+  features.set_batch_idx(batch_index);
+  features.set_beam_idx(beam_index);
+  features.set_step_idx(step);
+  return features;
+}
+
+// The GetSessionOp should
+// 1. create a ComputeSessionPool resource and store it in the ResourceMgr,
+// 2. create a ComputeSession resource and store it in the ResourceMgr,
+// 3. return the container and id strings in its output.
+TEST_F(DragnnOpKernelsTest, GetSessionOpTest) {
+  // Create a MasterSpec and GridPoint string to pass into the attrs for this
+  // op.
+  MasterSpec spec;
+  spec.set_debug_tracing(true);
+  string master_spec_str;
+  spec.SerializeToString(&master_spec_str);
+
+  GridPoint hyperparams;
+  string hyperparams_str;
+  hyperparams.SerializeToString(&hyperparams_str);
+
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("get_session", "GetSession")
+          .Attr("master_spec", master_spec_str)
+          .Attr("grid_point", hyperparams_str)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  AddInputFromList<string>(TensorShape({1}), {container_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Expect that the 0th output contains two strings, and that the ResourceMgr
+  // contains a ComputeSessionResource associated with those two strings.
+  const string container_str = GetOutput(0)->vec<string>()(0);
+  const string id_str = GetOutput(0)->vec<string>()(1);
+  VLOG(2) << "container: " << container_str << " id: " << id_str;
+
+  // The first compute session should have id "0".
+  EXPECT_EQ("0", id_str);
+  ComputeSessionResource *session_resource;
+  TF_EXPECT_OK(resource_mgr()->Lookup<ComputeSessionResource>(
+      container_str, id_str, &session_resource));
+
+  // Expect that the ResourceMgr also contains a ComputeSessionPoolResource.
+  const string pool_id_str = "pool";
+  ComputeSessionPoolResource *pool_resource;
+  TF_EXPECT_OK(resource_mgr()->Lookup<ComputeSessionPoolResource>(
+      container_str, pool_id_str, &pool_resource));
+
+  // Unref the managed resources so they get destroyed properly.
+  session_resource->Unref();
+  pool_resource->Unref();
+}
+
+// The GetSessionOp should take a session stored in the resource manager
+// and return it to the ComputeSessionPool.
+TEST_F(DragnnOpKernelsTest, ReleaseSessionOpTest) {
+  // Create and initialize the kernel under test.
+  TF_ASSERT_OK(
+      NodeDefBuilder("release_session", "ReleaseSession")
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a ComputeSessionPool.
+  MasterSpec spec;
+  GridPoint hyperparams;
+  std::unique_ptr<ComputeSessionPool> pool(
+      new ComputeSessionPool(spec, hyperparams));
+
+  // Get an unowned pointer to the ComputeSessionPool before moving
+  // the pool to the resource manager.
+  ComputeSessionPool *pool_ptr = pool.get();
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionPoolResource>(
+      container_string, "pool",
+      new ComputeSessionPoolResource(std::move(pool))));
+
+  // Create a ComputeSession and move it to the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(pool_ptr->GetSession())));
+
+  // At this point, the pool should report that it has one outstanding session.
+  EXPECT_EQ(1, pool_ptr->num_outstanding_sessions());
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // At this point, the pool should report that it has no outstanding sessions.
+  EXPECT_EQ(0, pool_ptr->num_outstanding_sessions());
+
+  // The resource manager should no longer contain the session object.
+  ComputeSessionResource *null_resource = nullptr;
+  auto result = resource_mgr()->Lookup<ComputeSessionResource>(
+      container_string, id_string, &null_resource);
+  EXPECT_NE(Status::OK(), result);
+  EXPECT_EQ(null_resource, nullptr);
+}
+
+// The AdvanceFromOracle op should call AdvanceFromOracle on the specified
+// component name.
+TEST_F(DragnnOpKernelsTest, AdvanceFromOracleOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("advance_from_oracle", "AdvanceFromOracle")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set expectations on the mock session.
+  EXPECT_CALL(*mock_session_ptr, AdvanceFromOracle(component_name));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+// The AdvanceFromPredicton op should call AdvanceFromPrediction on the
+// specified component with the passed scores.
+TEST_F(DragnnOpKernelsTest, AdvanceFromPredictionOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("advance_from_prediction", "AdvanceFromPrediction")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_FLOAT))   // The prediction tensor.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+  const std::vector<float> weights = {1.1, 2.2, 3.3, 4.4};
+  AddInputFromArray<float>(TensorShape({2, 2}), weights);
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set expectations on the mock session.
+  auto validator_function = [weights](const string &component_name,
+                                      const float score_matrix[],
+                                      int score_matrix_length) {
+    EXPECT_EQ(weights.size(), score_matrix_length);
+    for (int i = 0; i < weights.size(); ++i) {
+      EXPECT_EQ(weights[i], score_matrix[i]);
+    }
+  };
+  EXPECT_CALL(*mock_session_ptr, AdvanceFromPrediction(component_name, _, _))
+      .WillOnce(Invoke(validator_function));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+// The ExtractFixedFeatures op should return a set of fixed feature vectors
+// as described below.
+TEST_F(DragnnOpKernelsTest, ExtractFixedFeaturesOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  constexpr int kChannelId = 78;
+  TF_ASSERT_OK(
+      NodeDefBuilder("advance_from_prediction", "ExtractFixedFeatures")
+          .Attr("component", component_name)
+          .Attr("channel_id", kChannelId)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // If we have 3 features, for a given channel, we might have:
+  //   feature a: (5, 1)
+  //   feature b: (5, 0.5), (6, 0.7)
+  //   feature c: (3, 0.1), (7, [empty]) <- Empty weights are equivalent to 1.0.
+  // In this case:
+  //   indices should look like  [0  , 1  , 1  , 2  , 2  ]
+  //   ids should be             [5  , 5  , 6  , 3  , 7  ]
+  //   weights should be         [1.0, 0.5, 0.7, 0.1, 1.0]
+  const std::vector<int32> expected_indices({0, 1, 1, 2, 2});
+  const std::vector<int64> expected_ids({5, 5, 6, 3, 7});
+  const std::vector<float> expected_weights({1.0, 0.5, 0.7, 0.1, 1.0});
+
+  auto assigner_function =
+      [=](string, std::function<int32 *(int)> indices_allocator,
+          std::function<int64 *(int)> ids_allocator,
+          std::function<float *(int)> weights_allocator, int) {
+        constexpr int kFeatureCount = 5;
+        int32 *indices = indices_allocator(kFeatureCount);
+        int64 *ids = ids_allocator(kFeatureCount);
+        float *weights = weights_allocator(kFeatureCount);
+        for (int i = 0; i < kFeatureCount; ++i) {
+          indices[i] = expected_indices[i];
+          ids[i] = expected_ids[i];
+          weights[i] = expected_weights[i];
+        }
+        return kFeatureCount;
+      };
+
+  EXPECT_CALL(*mock_session_ptr,
+              GetInputFeatures(component_name, _, _, _, kChannelId))
+      .WillOnce(testing::Invoke(assigner_function));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(expected_indices.size(), GetOutput(0)->NumElements());
+  for (int i = 0; i < expected_indices.size(); ++i) {
+    EXPECT_EQ(expected_indices[i], GetOutput(0)->vec<int32>()(i));
+  }
+  EXPECT_EQ(expected_ids.size(), GetOutput(1)->NumElements());
+  for (int i = 0; i < expected_ids.size(); ++i) {
+    EXPECT_EQ(expected_ids[i], GetOutput(1)->vec<int64>()(i));
+  }
+  EXPECT_EQ(expected_weights.size(), GetOutput(2)->NumElements());
+  for (int i = 0; i < expected_weights.size(); ++i) {
+    EXPECT_EQ(expected_weights[i], GetOutput(2)->vec<float>()(i));
+  }
+}
+
+// The ExtractLinkFeatures op should return a set of linked feature vectors
+// as described below.
+TEST_F(DragnnOpKernelsTest, ExtractLinkFeaturesOpTest) {
+  // TODO(googleuser): Is a 2-vector output the correct way to do this?
+  // Why reshape instead of passing [batch, beam, index] or just
+  // [batch,index] ?
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  constexpr int kChannelId = 3421;
+  TF_ASSERT_OK(
+      NodeDefBuilder("extract_link_features", "ExtractLinkFeatures")
+          .Attr("component", component_name)
+          .Attr("channel_id", kChannelId)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // This op will return link features in two flat arrays using batch-major
+  // ordering. So, if we have a batch of 2 and a beam of 3, with data as follows
+  // (note that the features are {batch,beam,step} and [] is 'empty')
+  // batch 1 features: {{02,03,[]},{01,00,04},{08,06,01}}
+  // batch 2 features: {{12,13,14},{11,12,-1},{18,16,20}}
+  //
+  // and a **source component** beam size of 5 should result in output tensors:
+  // step_idx  (tensor 0): {-1,  4,  1, 14, -1, 20}
+  // array_idx (tensor 1): { 0,  5, 46, 73,  0, 106}
+  // (0 [step=-1]),(5=1*5+0),(46=8*5+6),(73=12*5+13),(0 [step=-1]),(96=18*5+16)
+  constexpr int kSourceComponentBeamSize = 5;
+
+  std::vector<LinkFeatures> features;
+  features.push_back(MakeFeatures(2, 3, -1));
+  features.back().clear_step_idx();  // step_idx is now empty.
+  features.push_back(MakeFeatures(1, 0, 4));
+  features.push_back(MakeFeatures(8, 6, 1));
+  features.push_back(MakeFeatures(12, 13, 14));
+  features.push_back(MakeFeatures(11, 12, -1));
+  features.push_back(MakeFeatures(18, 16, 20));
+
+  const std::vector<int> expected_step_idx({-1, 4, 1, 14, -1, 20});
+  const std::vector<int> expected_array_idx({0, 5, 46, 73, 0, 106});
+
+  EXPECT_CALL(*mock_session_ptr,
+              SourceComponentBeamSize(component_name, kChannelId))
+      .WillRepeatedly(Return(kSourceComponentBeamSize));
+  EXPECT_CALL(*mock_session_ptr,
+              GetTranslatedLinkFeatures(component_name, kChannelId))
+      .WillOnce(Return(features));
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(expected_step_idx.size(), GetOutput(0)->NumElements());
+  for (int i = 0; i < expected_step_idx.size(); ++i) {
+    EXPECT_EQ(expected_step_idx[i], GetOutput(0)->vec<int32>()(i));
+  }
+  EXPECT_EQ(expected_array_idx.size(), GetOutput(1)->NumElements());
+  for (int i = 0; i < expected_array_idx.size(); ++i) {
+    EXPECT_EQ(expected_array_idx[i], GetOutput(1)->vec<int32>()(i));
+  }
+}
+
+// The EmitOracleLabels op should return a set of oracle labels for all
+// elements in all beams in all batches.
+TEST_F(DragnnOpKernelsTest, EmitOracleLabelsOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("emit_oracle_labels", "EmitOracleLabels")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // The op should request the batch and beam size, then request the oracle
+  // labels. They should be returned batch major, so:
+  // batch 1 oracle labels: {1, 3, 5, 7}
+  // batch 2 oracle labels: {2, 4, 6, 8}
+  // should result in an output tensor as follows:
+  // {1, 3, 5, 7, 2, 4, 6, 8}
+
+  constexpr int kBatchSize = 2;
+  constexpr int kBeamSize = 4;
+  const std::vector<std::vector<int>> oracle_labels(
+      {{1, 3, 5, 7}, {2, 4, 6, 8}});
+
+  EXPECT_CALL(*mock_session_ptr, BatchSize(component_name))
+      .WillRepeatedly(Return(kBatchSize));
+  EXPECT_CALL(*mock_session_ptr, BeamSize(component_name))
+      .WillRepeatedly(Return(kBeamSize));
+  EXPECT_CALL(*mock_session_ptr, EmitOracleLabels(component_name))
+      .WillOnce(Return(oracle_labels));
+
+  const std::vector<int> expected_labels({1, 3, 5, 7, 2, 4, 6, 8});
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(expected_labels.size(), GetOutput(0)->NumElements());
+  for (int i = 0; i < expected_labels.size(); ++i) {
+    EXPECT_EQ(expected_labels[i], GetOutput(0)->vec<int32>()(i));
+  }
+}
+
+// The EmitAllFinal op should return the result of IsTerminal(component_name).
+TEST_F(DragnnOpKernelsTest, EmitAllFinalOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("emit_all_final", "EmitAllFinal")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set up mocks.
+  constexpr bool kIsTerminal = true;
+  EXPECT_CALL(*mock_session_ptr, IsTerminal(component_name))
+      .WillOnce(Return(kIsTerminal));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(1, GetOutput(0)->NumElements());
+  EXPECT_EQ(kIsTerminal, GetOutput(0)->vec<bool>()(0));
+}
+
+// The InitComponent op should initialize the given component with the given
+// beam size.
+// TODO(googleuser): Should we just store the beam size somewhere in the
+// ComputeSession?
+TEST_F(DragnnOpKernelsTest, InitComponentDataOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("init_component_data", "InitComponentData")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_INT32))   // The beam size.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+  constexpr int32 kBeamSize = 9001;
+  AddInputFromList<int32>(TensorShape({1}), {kBeamSize});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set up mocks.
+  EXPECT_CALL(*mock_session_ptr,
+              InitializeComponentData(component_name, kBeamSize));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Output should be the input handle.
+  EXPECT_EQ(container_string, GetOutput(0)->vec<string>()(0));
+  EXPECT_EQ(id_string, GetOutput(0)->vec<string>()(1));
+}
+
+// The BatchSize op should call BatchSize on the ComputeSession with the given
+// component as argument.
+TEST_F(DragnnOpKernelsTest, BatchSizeOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("batch_size", "BatchSize")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set up mocks.
+  constexpr int kBatchSize = 8;
+  EXPECT_CALL(*mock_session_ptr, BatchSize(component_name))
+      .WillOnce(Return(kBatchSize));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Output should be the batch size.
+  EXPECT_EQ(kBatchSize, GetOutput(0)->scalar<int>()());
+}
+
+// The AttachDataReader op should push the given vector of strings into the
+// session.
+TEST_F(DragnnOpKernelsTest, AttachDataReaderOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("attach_data_reader", "AttachDataReader")
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_STRING))  // The data to pass to the session.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  const std::vector<string> data(
+      {"one string", "two string", "red string", "blue string"});
+  AddInputFromArray<string>(TensorShape({4}), data);
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set up mocks.
+  EXPECT_CALL(*mock_session_ptr, SetInputData(ElementsAreArray(data)));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+// The SetTracingOp should pass its argument through to the underlying
+// ComputeSession.
+TEST_F(DragnnOpKernelsTest, SetTracingOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("set_tracing", "SetTracing")
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Input(FakeInput(DT_BOOL))    // The boolean to set tracing to.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+  constexpr bool kSetTracing = true;
+  AddInputFromList<bool>(TensorShape({1}), {kSetTracing});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set expectations on the mock session.
+  EXPECT_CALL(*mock_session_ptr, SetTracing(kSetTracing));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+// The WriteAnnotations op should call FinalizeData on the current component.
+TEST_F(DragnnOpKernelsTest, WriteAnnotationsOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("write_annotations", "WriteAnnotations")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  // Create a MockComputeSession and set expectations.
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Set expectations on the mock session.
+  EXPECT_CALL(*mock_session_ptr, FinalizeData(component_name));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+}
+
+// The EmitAnnotations op should return a vector of annotated strings as
+// described below.
+TEST_F(DragnnOpKernelsTest, EmitAnnotationsOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("emit_annotations", "EmitAnnotations")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  constexpr int kBatchSize = 2;
+  std::vector<string> predictions({"one", "two"});
+
+  EXPECT_CALL(*mock_session_ptr, BatchSize(component_name))
+      .WillRepeatedly(Return(kBatchSize));
+  EXPECT_CALL(*mock_session_ptr, GetSerializedPredictions())
+      .WillOnce(Return(predictions));
+
+  // The output vector is batch_size.
+  const std::vector<string> expected_output({"one", "two"});
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(expected_output.size(), GetOutput(0)->NumElements());
+  for (int i = 0; i < expected_output.size(); ++i) {
+    EXPECT_EQ(expected_output[i], GetOutput(0)->vec<string>()(i));
+  }
+}
+
+// The GetComponentTrace op should return a vector of serialized trace protos.
+TEST_F(DragnnOpKernelsTest, GetComponentTraceOpTest) {
+  // Create and initialize the kernel under test.
+  const string component_name = "TESTING_COMPONENT_NAME";
+  TF_ASSERT_OK(
+      NodeDefBuilder("get_component_trace", "GetComponentTrace")
+          .Attr("component", component_name)
+          .Input(FakeInput(DT_STRING))  // The handle for the ComputeSession.
+          .Finalize(node_def()));
+  TF_ASSERT_OK(InitOp());
+
+  // Set the input data.
+  const string container_string = "container_str";
+  const string id_string = "id_str";
+  AddInputFromList<string>(TensorShape({2}), {container_string, id_string});
+
+  // Reset the test context to ensure it's clean.
+  ResetOpKernelContext();
+
+  std::unique_ptr<MockComputeSession> mock_session(new MockComputeSession());
+  MockComputeSession *mock_session_ptr = mock_session.get();
+
+  // This op will request a set of MasterTraces from GetTraceProtos(), then
+  // return them.
+
+  MasterTrace trace;
+  auto component_trace = trace.add_component_trace();
+  component_trace->set_name("arbitrary_component_name_for_html");
+  auto component_trace_2 = trace.add_component_trace();
+  component_trace_2->set_name("arbitrary_component_name_2_for_html");
+  const std::vector<MasterTrace> master_traces({trace});
+
+  EXPECT_CALL(*mock_session_ptr, GetTraceProtos())
+      .WillOnce(Return(master_traces));
+
+  // Wrap the ComputeSessionResource and put it into the resource manager.
+  TF_ASSERT_OK(resource_mgr()->Create<ComputeSessionResource>(
+      container_string, id_string,
+      new ComputeSessionResource(std::move(mock_session))));
+
+  // Run the kernel.
+  TF_EXPECT_OK(RunOpKernelWithContext());
+
+  // Validate the outputs.
+  EXPECT_EQ(master_traces.size(), GetOutput(0)->NumElements());
+  for (int i = 0; i < master_traces.size(); ++i) {
+    string expected;
+    master_traces.at(i).SerializeToString(&expected);
+    EXPECT_EQ(expected, GetOutput(0)->vec<string>()(i));
+  }
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 239 - 0
syntaxnet/dragnn/core/ops/dragnn_ops.cc

@@ -0,0 +1,239 @@
+#include "tensorflow/core/framework/op.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+REGISTER_OP("GetSession")
+    .Input("container: string")
+    .Attr("master_spec: string")
+    .Attr("grid_point: string")
+    .Output("handle: string")
+    .SetIsStateful()
+    .Doc(R"doc(
+Given MasterSpec and GridPoint protos, outputs a handle to a ComputeSession.
+
+container: A unique identifier for the ComputeSessionPool from which a
+    ComputeSession will be allocated.
+master_spec: A serialized syntaxnet.dragnn.MasterSpec proto.
+grid_point: A serialized syntaxnet.dragnn.GridPoint proto.
+handle: A string handle to a ComputeSession.
+)doc");
+
+REGISTER_OP("ReleaseSession").Input("handle: string").SetIsStateful().Doc(R"doc(
+Given a ComputeSession, return it to the ComputeSession pool.
+
+This ComputeSession will no longer be available after this op returns.
+
+handle: A handle to a ComputeSession that will be returned to the backing pool.
+)doc");
+
+REGISTER_OP("InitComponentData")
+    .Input("handle: string")
+    .Input("beam_size: int32")
+    .Attr("component: string")
+    .Output("output_handle: string")
+    .Doc(R"doc(
+Initialize a component with the given beam size for a given ComputeSession.
+
+handle: A handle to a ComputeSession.
+beam_size: The size of the beam to use on the component.
+component: The name of a Component instance, matching the ComponentSpec.name.
+output_handle: The handle to the same ComputeSession after initialization.
+)doc");
+
+REGISTER_OP("BatchSize")
+    .Input("handle: string")
+    .Attr("component: string")
+    .Output("batch_size: int32")
+    .Doc(R"doc(
+Given a ComputeSession and a component name,return the component batch size.
+
+handle: A handle to a ComputeSession.
+component: The name of a Component instance, matching the ComponentSpec.name.
+batch_size: The size of the given component's batch.
+)doc");
+
+REGISTER_OP("SetTracing")
+    .Input("handle: string")
+    .Input("tracing_on: bool")
+    .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
+    .Output("output_handle: string")
+    .Doc(R"doc(
+Given a ComputeSession, turns on or off tracing for all components.
+
+handle: A handle to a ComputeSession.
+tracing_on: Whether or not to record traces.
+output_handle: The handle to the same ComputeSession, with the tracing status changed.
+)doc");
+
+REGISTER_OP("AttachDataReader")
+    .Input("handle: string")
+    .Input("input_spec: string")
+    .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
+    .Output("output_handle: string")
+    .Doc(R"doc(
+Given a ComputeSession, attach a data source.
+
+This op is agnostic to the type of input data. The vector of input strings is
+interpreted by the backend.
+
+handle: A handle to a ComputeSession.
+input_spec: A vector of strings, where each string represents one batch item.
+output_handle: The handle to the same ComputeSession after attachment.
+)doc");
+
+REGISTER_OP("AdvanceFromOracle")
+    .Input("handle: string")
+    .Attr("component: string")
+    .Output("output_handle: string")
+    .Doc(R"doc(
+Given a ComputeSession and a Component name, advance the component via oracle.
+
+handle: A handle to a ComputeSession.
+component: The name of a Component instance, matching the ComponentSpec.name.
+output_handle: The handle to the same ComputeSession after advancement.
+)doc");
+
+REGISTER_OP("AdvanceFromPrediction")
+    .Input("handle: string")
+    .Input("scores: float")
+    .Attr("component: string")
+    .Output("output_handle: string")
+    .Doc(R"doc(
+Given a ComputeSession, a Component name, and a score tensor, advance the state.
+
+handle: A handle to a ComputeSession.
+scores: A tensor of scores, ordered by {batch_size, beam_size, num_actions}.
+component: The name of a Component instance, matching the ComponentSpec.name.
+output_handle: A handle to the same ComputeSession after advancement.
+)doc");
+
+REGISTER_OP("DragnnEmbeddingInitializer")
+    .Output("embeddings: float")
+    .Attr("embedding_input: string")
+    .Attr("vocab: string")
+    .Attr("scaling_coefficient: float = 1.0")
+    .Doc(R"doc(
+*** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
+
+Read embeddings from an an input for every key specified in a text vocab file.
+
+embeddings: A tensor containing embeddings from the specified sstable.
+embedding_input: Path to location with embedding vectors.
+vocab: Path to list of keys corresponding to the input.
+scaling_coefficient: A scaling coefficient for the embedding matrix.
+)doc");
+
+REGISTER_OP("ExtractFixedFeatures")
+    .Input("handle: string")
+    .Output("indices: int32")
+    .Output("ids: int64")
+    .Output("weights: float")
+    .Attr("component: string")
+    .Attr("channel_id: int")
+    .Doc(R"doc(
+Given a ComputeSession, Component, and channel index, output fixed features.
+
+Fixed features returned as 3 vectors, 'indices', 'ids', and 'weights' of equal
+length. 'ids' specifies which rows should be looked up in the embedding
+matrix. 'weights' specifies a scale for each embedding vector. 'indices' is a
+sorted vector that assigns the same index to embedding vectors that should be
+summed together.
+
+handle: A handle to a ComputeSession.
+indices: The row to add the feature to.
+ids: The indices into embedding matrices for each feature.
+weights: The weight for each looked up feature.
+component: The name of a Component instance, matching the ComponentSpec.name.
+channel_id: The feature channel to extract features for.
+)doc");
+
+REGISTER_OP("ExtractLinkFeatures")
+    .Input("handle: string")
+    .Output("step_idx: int32")
+    .Output("idx: int32")
+    .Attr("component: string")
+    .Attr("channel_id: int")
+    .Doc(R"doc(
+Given a ComputeSession, Component, and a channel index, outputs link features.
+
+Output indices have shape {batch_size * beam_size * channel_size}.
+
+handle: A handle to a ComputeSession.
+step_idx: The step indices to read activations from.
+idx: indices The index within a step to read the activations from.
+component: The name of a Component instance, matching the ComponentSpec.name.
+channel_id: The feature channel to extract features for.
+)doc");
+
+REGISTER_OP("EmitOracleLabels")
+    .Input("handle: string")
+    .Output("gold_labels: int32")
+    .Attr("component: string")
+    .Doc(R"doc(
+Given a ComputeSession and Component, emit a vector of gold labels.
+
+handle: A handle to a ComputeSession.
+gold_labels: A [batch_size * beam_size] vector of gold labels for the current
+             ComputeSession.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+REGISTER_OP("EmitAllFinal")
+    .Input("handle: string")
+    .Output("all_final: bool")
+    .Attr("component: string")
+    .Doc(R"doc(
+Given a ComputeSession and Component, returns whether the Component is final.
+
+A component is considered final when all elements in the batch have beams
+containing all final states.
+
+handle: A handle to a ComputeSession.
+all_final: Whether every element in the specified component is 'final'.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+REGISTER_OP("WriteAnnotations")
+    .Input("handle: string")
+    .Output("output_handle: string")
+    .Attr("component: string")
+    .Doc(R"doc(
+Given a ComputeSession, has the given component write out its annotations.
+
+The annotations are written to the underlying data objects passed in at the
+beginning of the computation.
+
+handle: A handle to a ComputeSession.
+output_handle: A handle to the same ComputeSession after writing.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+REGISTER_OP("EmitAnnotations")
+    .Input("handle: string")
+    .Output("annotations: string")
+    .Attr("component: string")
+    .Doc(R"doc(
+Given a ComputeSession, emits strings with final predictions for the model.
+
+Predictions are given for each element in the final component's batch.
+
+handle: A handle to a ComputeSession.
+annotations: A vector of strings representing the annotated data.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+REGISTER_OP("GetComponentTrace")
+    .Input("handle: string")
+    .Output("trace: string")
+    .Attr("component: string")
+    .Doc(R"doc(
+Gets the raw MasterTrace proto for each batch, state, and beam slot.
+
+handle: A handle to a ComputeSession.
+trace: A vector of MasterTrace protos.
+component: The name of a Component instance, matching the ComponentSpec.name.
+)doc");
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 36 - 0
syntaxnet/dragnn/core/resource_container.h

@@ -0,0 +1,36 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
+
+#include <memory>
+
+#include "syntaxnet/base.h"
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+using tensorflow::strings::StrCat;
+
+// Wrapper to store a data type T in the ResourceMgr. There should be one per
+// Session->Run() call that may happen concurrently.
+template <class T>
+class ResourceContainer : public tensorflow::ResourceBase {
+ public:
+  explicit ResourceContainer(std::unique_ptr<T> data)
+      : data_(std::move(data)) {}
+
+  ~ResourceContainer() override {}
+
+  T *get() { return data_.get(); }
+  std::unique_ptr<T> release() { return std::move(data_); }
+
+  string DebugString() override { return "ResourceContainer"; }
+
+ private:
+  std::unique_ptr<T> data_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_

+ 49 - 0
syntaxnet/dragnn/core/resource_container_test.cc

@@ -0,0 +1,49 @@
+// Tests the methods of ResourceContainer.
+//
+// NOTE(danielandor): For all tests: ResourceContainer is derived from
+// RefCounted, which requires the use of Unref to reduce the ref count
+// to zero and automatically delete the pointer.
+#include "dragnn/core/resource_container.h"
+
+#include <gmock/gmock.h>
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class MockDatatype {};
+
+TEST(ResourceContainerTest, Get) {
+  std::unique_ptr<MockDatatype> data(new MockDatatype());
+  MockDatatype *data_ptr = data.get();
+  auto *container = new ResourceContainer<MockDatatype>(std::move(data));
+  EXPECT_EQ(data_ptr, container->get());
+  container->Unref();
+}
+
+TEST(ResourceContainerTest, Release) {
+  std::unique_ptr<MockDatatype> data(new MockDatatype());
+  MockDatatype *data_ptr = data.get();
+  auto *container = new ResourceContainer<MockDatatype>(std::move(data));
+  std::unique_ptr<MockDatatype> data_again = container->release();
+  container->Unref();
+  EXPECT_EQ(data_ptr, data_again.get());
+}
+
+TEST(ResourceContainerTest, NullptrOnGetAfterRelease) {
+  std::unique_ptr<MockDatatype> data(new MockDatatype());
+  auto *container = new ResourceContainer<MockDatatype>(std::move(data));
+  container->release();
+  EXPECT_EQ(nullptr, container->get());
+  container->Unref();
+}
+
+TEST(ResourceContainerTest, DebugString) {
+  std::unique_ptr<MockDatatype> data(new MockDatatype());
+  auto *container = new ResourceContainer<MockDatatype>(std::move(data));
+  EXPECT_EQ("ResourceContainer", container->DebugString());
+  container->Unref();
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 57 - 0
syntaxnet/dragnn/core/test/BUILD

@@ -0,0 +1,57 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "mock_component",
+    testonly = True,
+    hdrs = ["mock_component.h"],
+    deps = [
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core:index_translator",
+        "//dragnn/core/interfaces:component",
+        "//dragnn/core/interfaces:transition_state",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "//syntaxnet:base",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_library(
+    name = "mock_compute_session",
+    testonly = True,
+    hdrs = ["mock_compute_session.h"],
+    deps = [
+        "//dragnn/components/util:bulk_feature_extractor",
+        "//dragnn/core:compute_session",
+        "//dragnn/protos:data_proto",
+        "//dragnn/protos:spec_proto",
+        "//syntaxnet:base",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_library(
+    name = "mock_transition_state",
+    testonly = True,
+    hdrs = ["mock_transition_state.h"],
+    deps = [
+        "//dragnn/core/interfaces:transition_state",
+        "//syntaxnet:base",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)
+
+cc_library(
+    name = "generic",
+    testonly = True,
+    srcs = ["generic.cc"],
+    hdrs = ["generic.h"],
+    deps = [
+        "//syntaxnet:base",
+        "@org_tensorflow//tensorflow/core:lib",
+        "@org_tensorflow//tensorflow/core:test",
+        "@org_tensorflow//tensorflow/core:testlib",
+    ],
+)

+ 21 - 0
syntaxnet/dragnn/core/test/generic.cc

@@ -0,0 +1,21 @@
+#include "dragnn/core/test/generic.h"
+
+#include "tensorflow/core/lib/io/path.h"
+
+namespace syntaxnet {
+namespace test {
+
+string GetTestDataPrefix() {
+  const char *env = getenv("TEST_SRCDIR");
+  const char *workspace = getenv("TEST_WORKSPACE");
+  if (!env || env[0] == '\0' || !workspace || workspace[0] == '\0') {
+    LOG(FATAL) << "Test directories not set up";
+  }
+  return tensorflow::io::JoinPath(
+
+      env, workspace
+      );
+}
+
+}  // namespace test
+}  // namespace syntaxnet

+ 25 - 0
syntaxnet/dragnn/core/test/generic.h

@@ -0,0 +1,25 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
+
+#include <utility>
+
+#include <gmock/gmock.h>
+
+#include "syntaxnet/base.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace test {
+
+MATCHER_P(EqualsProto, a, "Protos are not equivalent:") {
+  return a.DebugString() == arg.DebugString();
+}
+
+// Returns the prefix for where the test data is stored.
+string GetTestDataPrefix();
+
+}  // namespace test
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_

+ 63 - 0
syntaxnet/dragnn/core/test/mock_component.h

@@ -0,0 +1,63 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
+
+#include <gmock/gmock.h>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/index_translator.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "syntaxnet/base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class MockComponent : public Component {
+ public:
+  MOCK_METHOD1(InitializeComponent, void(const ComponentSpec &spec));
+  MOCK_METHOD3(
+      InitializeData,
+      void(const std::vector<std::vector<const TransitionState *>> &states,
+           int max_beam_size, InputBatchCache *input_data));
+  MOCK_CONST_METHOD0(IsReady, bool());
+  MOCK_METHOD0(InitializeTracing, void());
+  MOCK_METHOD0(DisableTracing, void());
+  MOCK_CONST_METHOD0(Name, string());
+  MOCK_CONST_METHOD0(BatchSize, int());
+  MOCK_CONST_METHOD0(BeamSize, int());
+  MOCK_CONST_METHOD1(StepsTaken, int(int batch_index));
+  MOCK_CONST_METHOD3(GetBeamIndexAtStep,
+                     int(int step, int current_index, int batch));
+  MOCK_CONST_METHOD2(GetSourceBeamIndex, int(int current_index, int batch));
+  MOCK_METHOD2(AdvanceFromPrediction,
+               void(const float transition_matrix[], int matrix_length));
+  MOCK_METHOD0(AdvanceFromOracle, void());
+  MOCK_CONST_METHOD0(IsTerminal, bool());
+  MOCK_METHOD0(GetBeam, std::vector<std::vector<const TransitionState *>>());
+  MOCK_CONST_METHOD4(GetFixedFeatures,
+                     int(std::function<int32 *(int)> allocate_indices,
+                         std::function<int64 *(int)> allocate_ids,
+                         std::function<float *(int)> allocate_weights,
+                         int channel_id));
+  MOCK_METHOD1(BulkGetFixedFeatures,
+               int(const BulkFeatureExtractor &extractor));
+  MOCK_CONST_METHOD1(GetRawLinkFeatures,
+                     std::vector<LinkFeatures>(int channel_id));
+  MOCK_CONST_METHOD0(GetOracleLabels, std::vector<std::vector<int>>());
+  MOCK_METHOD0(ResetComponent, void());
+  MOCK_METHOD1(GetStepLookupFunction,
+               std::function<int(int, int, int)>(const string &method));
+  MOCK_METHOD0(FinalizeData, void());
+  MOCK_CONST_METHOD0(GetTraceProtos,
+                     std::vector<std::vector<ComponentTrace>>());
+  MOCK_METHOD2(AddTranslatedLinkFeaturesToTrace,
+               void(const std::vector<LinkFeatures> &features, int channel_id));
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_

+ 61 - 0
syntaxnet/dragnn/core/test/mock_compute_session.h

@@ -0,0 +1,61 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
+
+#include <gmock/gmock.h>
+
+#include "dragnn/components/util/bulk_feature_extractor.h"
+#include "dragnn/core/compute_session.h"
+#include "dragnn/protos/data.pb.h"
+#include "dragnn/protos/spec.pb.h"
+#include "syntaxnet/base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class MockComputeSession : public ComputeSession {
+ public:
+  MOCK_METHOD2(Init, void(const MasterSpec &master_spec,
+                          const GridPoint &hyperparams));
+  MOCK_METHOD2(InitializeComponentData,
+               void(const string &component_name, int max_beam_size));
+  MOCK_CONST_METHOD1(BatchSize, int(const string &component_name));
+  MOCK_CONST_METHOD1(BeamSize, int(const string &component_name));
+  MOCK_CONST_METHOD1(Spec, const ComponentSpec &(const string &component_name));
+  MOCK_METHOD2(SourceComponentBeamSize,
+               int(const string &component_name, int channel_id));
+  MOCK_METHOD1(AdvanceFromOracle, void(const string &component_name));
+  MOCK_METHOD3(AdvanceFromPrediction,
+               void(const string &component_name, const float score_matrix[],
+                    int score_matrix_length));
+  MOCK_CONST_METHOD5(GetInputFeatures,
+                     int(const string &component_name,
+                         std::function<int32 *(int)> allocate_indices,
+                         std::function<int64 *(int)> allocate_ids,
+                         std::function<float *(int)> allocate_weights,
+                         int channel_id));
+  MOCK_METHOD2(BulkGetInputFeatures,
+               int(const string &component_name,
+                   const BulkFeatureExtractor &extractor));
+  MOCK_METHOD2(GetTranslatedLinkFeatures,
+               std::vector<LinkFeatures>(const string &component_name,
+                                         int channel_id));
+  MOCK_METHOD1(EmitOracleLabels,
+               std::vector<std::vector<int>>(const string &component_name));
+  MOCK_METHOD1(IsTerminal, bool(const string &component_name));
+  MOCK_METHOD1(FinalizeData, void(const string &component_name));
+  MOCK_METHOD0(GetSerializedPredictions, std::vector<string>());
+  MOCK_METHOD0(GetTraceProtos, std::vector<MasterTrace>());
+  MOCK_METHOD1(SetInputData, void(const std::vector<string> &data));
+  MOCK_METHOD0(ResetSession, void());
+  MOCK_METHOD1(SetTracing, void(bool tracing_on));
+  MOCK_CONST_METHOD0(Id, int());
+  MOCK_CONST_METHOD1(GetDescription, string(const string &component_name));
+  MOCK_CONST_METHOD1(Translators, const std::vector<const IndexTranslator *>(
+                                      const string &component_name));
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_

+ 30 - 0
syntaxnet/dragnn/core/test/mock_transition_state.h

@@ -0,0 +1,30 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
+
+#include <memory>
+
+#include <gmock/gmock.h>
+
+#include "dragnn/core/interfaces/transition_state.h"
+#include "syntaxnet/base.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+class MockTransitionState : public TransitionState {
+ public:
+  MOCK_METHOD1(Init, void(const TransitionState &parent));
+  MOCK_CONST_METHOD0(Clone, std::unique_ptr<TransitionState>());
+  MOCK_CONST_METHOD0(ParentBeamIndex, const int());
+  MOCK_METHOD1(SetBeamIndex, void(const int index));
+  MOCK_CONST_METHOD0(GetBeamIndex, const int());
+  MOCK_CONST_METHOD0(GetScore, const float());
+  MOCK_METHOD1(SetScore, void(const float score));
+  MOCK_CONST_METHOD0(HTMLRepresentation, string());
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_

TEMPAT SAMPAH
syntaxnet/dragnn/core/testdata/brain-parser-model


+ 86 - 0
syntaxnet/dragnn/core/testdata/master_spec_link.textproto

@@ -0,0 +1,86 @@
+component {
+  name: "parser"
+  num_actions: 100
+  transition_system {
+    registered_name: "arc-standard"
+    parameters {
+      key: "entity_name_tokenizer"
+      value: "pre-tokenized"
+    }
+    parameters {
+      key: "language"
+      value: "en"
+    }
+    parameters {
+      key: "neurosis_feature_syntax_version"
+      value: "2"
+    }
+    parameters {
+      key: "parser_skip_deterministic"
+      value: "false"
+    }
+    parameters {
+      key: "parser_transition_system"
+      value: "arc-standard"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "tags"
+    fml: "input.tag stack.tag stack(1).tag"
+    embedding_dim: 32
+    vocabulary_size: 51
+    size: 3
+    predicate_map: "hashed"
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input.word"
+    embedding_dim: 64
+    vocabulary_size: 39396
+    size: 1
+    predicate_map: "hashed"
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "parser"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: 'SyntaxNetComponent'
+  }
+  component_builder {
+    registered_name: 'DynamicComponentBuilder'
+  }
+  network_unit {
+    registered_name: 'FeedForwardNetwork'
+    parameters {
+      key: 'hidden_layer_sizes'
+      value: '64'
+    }
+  }
+}

TEMPAT SAMPAH
syntaxnet/dragnn/core/testdata/repository


TEMPAT SAMPAH
syntaxnet/dragnn/core/testdata/simple-tagger.brain-parser-model


TEMPAT SAMPAH
syntaxnet/dragnn/core/testdata/simple-tagger.repository


+ 46 - 0
syntaxnet/dragnn/core/testdata/simple-tagger.tag-map

@@ -0,0 +1,46 @@
+45
+NN 132998
+IN 98554
+NNP 91466
+DT 81832
+JJ 61217
+NNS 59856
+, 48727
+. 39478
+CD 36568
+RB 30907
+VBD 29889
+VB 26438
+CC 23959
+TO 22357
+VBZ 21672
+VBN 20024
+PRP 17436
+VBG 14846
+VBP 12491
+MD 9803
+POS 8701
+PRP$ 8407
+$ 7372
+`` 7092
+'' 6919
+: 4772
+WDT 4294
+JJR 3238
+NNPS 2673
+RP 2662
+WP 2363
+WRB 2143
+JJS 1947
+RBR 1768
+-RRB- 1376
+-LRB- 1366
+EX 863
+RBS 451
+PDT 368
+FW 234
+WP$ 168
+# 142
+UH 97
+SYM 58
+LS 36

+ 59 - 0
syntaxnet/dragnn/core/testdata/simple_parser_master_spec.textproto

@@ -0,0 +1,59 @@
+component {
+  name: "parser"
+  num_actions : 93
+  transition_system {
+    registered_name: "arc-standard"
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus stack(1).focus"
+    embedding_dim: 32
+    size: 2
+    source_component: "parser"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: 'DynamicComponentBuilder'
+  }
+  network_unit {
+    registered_name: 'FeedForwardNetwork'
+    parameters {
+      key: 'hidden_layer_sizes'
+      value: '64'
+    }
+  }
+  
+  inference_beam_size: 4
+}

+ 52 - 0
syntaxnet/dragnn/core/testdata/simple_tagger_lstm_master_spec.textproto

@@ -0,0 +1,52 @@
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "LSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64"
+    }
+  }
+}

+ 63 - 0
syntaxnet/dragnn/core/testdata/simple_tagger_master_spec.textproto

@@ -0,0 +1,63 @@
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "tagger"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "FeedForwardNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64"
+    }
+  }
+  training_beam_size: 1
+  inference_beam_size: 3
+}

+ 65 - 0
syntaxnet/dragnn/core/testdata/simple_tagger_wrapped_lstm_master_spec.textproto

@@ -0,0 +1,65 @@
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+    predicate_map: "hashed"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "wrapped_units.LayerNormBasicLSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64,64,64"
+    }
+    parameters {
+      key: "input_dropout_rate"
+      value: "0.9"
+    }
+    parameters {
+      key: "recurrent_dropout_rate"
+      value: "0.9"
+    }
+    parameters {
+      key: "layer_norm"
+      value: "true"
+    }
+  }
+}

+ 111 - 0
syntaxnet/dragnn/core/testdata/split_tagger_master_spec.textproto

@@ -0,0 +1,111 @@
+component {
+  name: "tagger-features"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "tagger-features"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "FeedForwardNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64"
+    }
+  }
+}
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  linked_feature {
+    name: "features"
+    fml: "input.focus"
+    embedding_dim: -1
+    size: 1
+    source_component: "tagger-features"
+    source_translator: "identity"
+    source_layer: "logits"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "IdentityNetwork"
+  }
+}

+ 47 - 0
syntaxnet/dragnn/core/testdata/syntaxnet_tagger.label-map

@@ -0,0 +1,47 @@
+46
+punct 243160
+prep 194627
+pobj 186958
+det 170592
+nsubj 144821
+nn 144800
+amod 117242
+ROOT 90592
+dobj 88551
+aux 76523
+advmod 72893
+conj 59384
+cc 57532
+num 36350
+poss 35117
+dep 34986
+ccomp 29470
+cop 25991
+mark 25141
+xcomp 25111
+rcmod 16234
+auxpass 15740
+advcl 14996
+possessive 14866
+nsubjpass 14133
+pcomp 12488
+appos 11112
+partmod 11106
+neg 11090
+number 10658
+prt 7123
+quantmod 6653
+tmod 5418
+infmod 5134
+npadvmod 3213
+parataxis 3012
+mwe 2793
+expl 2712
+iobj 1642
+acomp 1632
+discourse 1381
+csubj 1225
+predet 1160
+preconj 749
+goeswith 146
+csubjpass 41

+ 50 - 0
syntaxnet/dragnn/core/testdata/syntaxnet_tagger.tag-map

@@ -0,0 +1,50 @@
+49
+NN 285194
+IN 228165
+DT 179147
+NNP 175147
+JJ 125667
+NNS 115732
+, 97481
+. 85938
+RB 78513
+VB 63952
+CC 57554
+VBD 56635
+CD 55674
+PRP 55244
+VBZ 48126
+VBN 44458
+VBG 34524
+VBP 33669
+TO 28772
+MD 22364
+PRP$ 20706
+HYPH 18526
+POS 14905
+`` 12193
+'' 12154
+WDT 10267
+: 8713
+$ 7993
+WP 7336
+RP 7335
+WRB 6634
+JJR 6295
+NNPS 5917
+-RRB- 3904
+-LRB- 3840
+JJS 3596
+RBR 3186
+EX 2733
+UH 1521
+RBS 1467
+PDT 1271
+FW 928
+NFP 844
+SYM 652
+ADD 476
+LS 392
+WP$ 332
+GW 184
+AFX 42

+ 4 - 0
syntaxnet/dragnn/core/testdata/syntaxnet_tagger.word-map

@@ -0,0 +1,4 @@
+3
+sentence 4
+. 3
+0 2

+ 185 - 0
syntaxnet/dragnn/core/testdata/tagger_parser_master_spec.textproto

@@ -0,0 +1,185 @@
+component {
+  name: "features"
+  num_actions : 1
+  transition_system {
+    registered_name: "shift-only"
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input(-1).word input(-2).word input(-3).word input.word input(1).word input(2).word input(3).word"
+    embedding_dim: 64
+    vocabulary_size: 39397
+    size: 7
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "IdentityNetwork"
+  }
+  inference_beam_size: 1
+}
+component {
+  name: "tagger"
+  num_actions : 49
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  linked_feature {
+    name: "words"
+    fml: "input.focus"
+    embedding_dim: -1
+    size: 1
+    source_component: "features"
+    source_translator: "identity"
+    source_layer: "input_embeddings"
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "tagger"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "FeedForwardNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64"
+    }
+  }
+  inference_beam_size: 1
+}
+component {
+  name: "parser"
+  num_actions : 93
+  transition_system {
+    registered_name: "arc-standard"
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.tag-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.word-map"
+      file_format: "text"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TESTDATA/syntaxnet_tagger.label-map"
+      file_format: "text"
+    }
+  }
+  fixed_feature {
+    name: "action"
+    fml: "last-action"
+    embedding_dim: 32
+    vocabulary_size: 93
+    size: 1
+  }
+  linked_feature {
+    name: "words"
+    fml: "input.focus"
+    embedding_dim: -1
+    size: 1
+    source_component: "features"
+    source_translator: "identity"
+    source_layer: "input_embeddings"
+  }
+  linked_feature {
+    name: "tagger"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "tagger"
+    source_translator: "identity"
+    source_layer: "layer_0"
+  }
+  linked_feature {
+    name: "rnn"
+    fml: "stack.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "parser"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  component_builder {
+    registered_name: "DynamicComponentBuilder"
+  }
+  network_unit {
+    registered_name: "FeedForwardNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "64"
+    }
+  }
+  inference_beam_size: 5
+}

+ 213 - 0
syntaxnet/dragnn/core/testdata/ud-hungarian.master-spec

@@ -0,0 +1,213 @@
+component {
+  name: "rl_rnn"
+  transition_system {
+    registered_name: "shift-only"
+    parameters {
+      key: "left-to-right"
+      value: "false"
+    }
+    parameters {
+      key: "parser_skip_deterministic"
+      value: "false"
+    }
+  }
+  resource {
+    name: "char-ngram-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.char-ngram-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  resource {
+    name: "word-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.word-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.label-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  fixed_feature {
+    name: "char_ngram"
+    fml: "input.token.char-ngram"
+    embedding_dim: 16
+    vocabulary_size: 9943
+    size: 1
+  }
+  fixed_feature {
+    name: "other"
+    fml: "input.token {digit hyphen punctuation-amount quote }"
+    embedding_dim: 8
+    vocabulary_size: 5
+    size: 4
+  }
+  fixed_feature {
+    name: "words"
+    fml: "input.word"
+    embedding_dim: 64
+    vocabulary_size: 11090
+    size: 1
+  }
+  network_unit {
+    registered_name: "wrapped_units.LayerNormBasicLSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "256"
+    }
+  }
+  component_builder {
+    registered_name: 'DynamicComponentBuilder'
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  num_actions: 1
+  attention_component: ""
+}
+component {
+  name: "tagger"
+  transition_system {
+    registered_name: "tagger"
+    parameters {
+      key: "join_category_to_pos"
+      value: "true"
+    }
+    parameters {
+      key: "parser_skip_deterministic"
+      value: "false"
+    }
+  }
+  resource {
+    name: "tag-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.tag-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.label-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  fixed_feature {
+    name: "action"
+    fml: "last-action"
+    embedding_dim: 32
+    vocabulary_size: 100
+    size: 1
+  }
+  linked_feature {
+    name: "encoder"
+    fml: "input.focus"
+    embedding_dim: 64
+    size: 1
+    source_component: "rl_rnn"
+    source_translator: "reverse-token"
+    source_layer: "state_h_0"
+  }
+  network_unit {
+    registered_name: "wrapped_units.LayerNormBasicLSTMNetwork"
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "256"
+    }
+  }
+  component_builder {
+    registered_name: 'DynamicComponentBuilder'
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  num_actions: 642
+  attention_component: ""
+}
+component {
+  name: "parser"
+  transition_system {
+    registered_name: "arc-standard"
+    parameters {
+      key: "parser_skip_deterministic"
+      value: "false"
+    }
+  }
+  resource {
+    name: "label-map"
+    part {
+      file_pattern: "TOPDIR/ud-hungarian.label-map"
+      file_format: "text"
+      record_format: ""
+    }
+  }
+  fixed_feature {
+    name: "action"
+    fml: "last-action"
+    embedding_dim: 32
+    vocabulary_size: 100
+    size: 1
+  }
+  fixed_feature {
+    name: "labels"
+    fml: "stack.child(1).label stack.child(1).sibling(-1).label stack.child(-1).label stack.child(-1).sibling(1).label stack(1).child(1).label stack(1).child(1).sibling(-1).label stack(1).child(-1).label stack(1).child(-1).sibling(1).label stack.child(2).label stack.child(-2).label stack(1).child(2).label stack(1).child(-2).label"
+    embedding_dim: 16
+    vocabulary_size: 57
+    size: 12
+  }
+  linked_feature {
+    name: "encoder"
+    fml: "input.focus"
+    embedding_dim: 64
+    size: 1
+    source_component: "rl_rnn"
+    source_translator: "reverse-token"
+    source_layer: "state_h_0"
+  }
+  linked_feature {
+    name: "parser-rnn"
+    fml: "stack.focus stack(1).focus"
+    embedding_dim: 64
+    size: 2
+    source_component: "parser"
+    source_translator: "shift-reduce-step"
+    source_layer: "layer_0"
+  }
+  linked_feature {
+    name: "tagger"
+    fml: "input.focus stack.focus stack(1).focus"
+    embedding_dim: 64
+    size: 3
+    source_component: "tagger"
+    source_translator: "identity"
+    source_layer: "state_h_0"
+  }
+  network_unit {
+    registered_name: 'FeedForwardNetwork'
+    parameters {
+      key: "hidden_layer_sizes"
+      value: "256,256"
+    }
+    parameters {
+      key: "layer_norm_hidden"
+      value: "True"
+    }
+  }
+  component_builder {
+    registered_name: 'DynamicComponentBuilder'
+  }
+  backend {
+    registered_name: "SyntaxNetComponent"
+  }
+  num_actions: 109
+  attention_component: ""
+}

+ 34 - 0
syntaxnet/dragnn/io/BUILD

@@ -0,0 +1,34 @@
+package(default_visibility = ["//visibility:public"])
+
+cc_library(
+    name = "sentence_input_batch",
+    srcs = ["sentence_input_batch.cc"],
+    hdrs = ["sentence_input_batch.h"],
+    deps = [
+        ":syntaxnet_sentence",
+        "//dragnn/core/interfaces:input_batch",
+        "//syntaxnet:base",
+        "//syntaxnet:sentence_proto",
+    ],
+)
+
+cc_library(
+    name = "syntaxnet_sentence",
+    hdrs = ["syntaxnet_sentence.h"],
+    deps = [
+        "//syntaxnet:sentence_proto",
+        "//syntaxnet:workspace",
+    ],
+)
+
+cc_test(
+    name = "sentence_input_batch_test",
+    srcs = ["sentence_input_batch_test.cc"],
+    deps = [
+        ":sentence_input_batch",
+        "//dragnn/core/test:generic",
+        "//syntaxnet:sentence_proto",
+        "//syntaxnet:test_main",
+        "@org_tensorflow//tensorflow/core:test",
+    ],
+)

+ 31 - 0
syntaxnet/dragnn/io/sentence_input_batch.cc

@@ -0,0 +1,31 @@
+#include "dragnn/io/sentence_input_batch.h"
+
+#include "syntaxnet/sentence.pb.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+void SentenceInputBatch::SetData(
+    const std::vector<string> &stringified_sentence_protos) {
+  for (const auto &stringified_proto : stringified_sentence_protos) {
+    std::unique_ptr<Sentence> sentence(new Sentence);
+    std::unique_ptr<WorkspaceSet> workspace_set(new WorkspaceSet);
+    CHECK(sentence->ParseFromString(stringified_proto))
+        << "Unable to parse string input as syntaxnet.Sentence.";
+    SyntaxNetSentence aug_sentence(std::move(sentence),
+                                   std::move(workspace_set));
+    data_.push_back(std::move(aug_sentence));
+  }
+}
+
+const std::vector<string> SentenceInputBatch::GetSerializedData() const {
+  std::vector<string> output_data;
+  output_data.resize(data_.size());
+  for (int i = 0; i < data_.size(); ++i) {
+    data_[i].sentence()->SerializeToString(&(output_data[i]));
+  }
+  return output_data;
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 37 - 0
syntaxnet/dragnn/io/sentence_input_batch.h

@@ -0,0 +1,37 @@
+#ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
+#define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
+
+#include <string>
+#include <vector>
+
+#include "dragnn/core/interfaces/input_batch.h"
+#include "dragnn/io/syntaxnet_sentence.h"
+#include "syntaxnet/base.h"
+
+namespace syntaxnet {
+namespace dragnn {
+
+// Data accessor backed by a syntaxnet::Sentence object.
+class SentenceInputBatch : public InputBatch {
+ public:
+  SentenceInputBatch() {}
+
+  // Translates from a vector of stringified Sentence protos.
+  void SetData(
+      const std::vector<string> &stringified_sentence_protos) override;
+
+  // Translates to a vector of stringified Sentence protos.
+  const std::vector<string> GetSerializedData() const override;
+
+  // Get the underlying Sentences.
+  std::vector<SyntaxNetSentence> *data() { return &data_; }
+
+ private:
+  // The backing Sentence protos.
+  std::vector<SyntaxNetSentence> data_;
+};
+
+}  // namespace dragnn
+}  // namespace syntaxnet
+
+#endif  // NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_

+ 0 - 0
syntaxnet/dragnn/io/sentence_input_batch_test.cc


Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini