Explorar el Código

Updates to syntaxnet, including update tensorflow sub-module, bazel requirement and fix trainer crash (#479)

* syntaxnet: Cosmetic fixes recommended by python lint.

* syntaxnet: Fix crash in parser_trainer due to inconsistency between LexiconBuilder::Compute()
	   and context.pbtxt definition ('char-map' input declaration was missing).

* syntaxnet: reduce flakiness in GraphBuilderTest.

* syntaxnet: Update tensorflow submodule to version > 0.10.

* syntaxnet: Update to latest stable bazel (0.3.1).

This update comes partially to allow Tensorflow submodule to build
succesffuly. In this commit, I also update and simplify the WORKSPACE,
to avoid declaring dependencies already present in tensorflow.

* syntaxnet: Update bazel version check to require version 0.3.0

* syntaxnet: Document pip requirement, along with python mock module.
Livio Soares hace 9 años
padre
commit
51238b1b52

+ 6 - 6
syntaxnet/Dockerfile

@@ -5,17 +5,17 @@ ENV SYNTAXNETDIR=/opt/tensorflow PATH=$PATH:/root/bin
 RUN mkdir -p $SYNTAXNETDIR \
     && cd $SYNTAXNETDIR \
     && apt-get update \
-    && apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip -y \
+    && apt-get install git zlib1g-dev file swig python2.7 python-dev python-pip python-mock -y \
     && pip install --upgrade pip \
-    && pip install -U protobuf==3.0.0b2 \
+    && pip install -U protobuf==3.0.0 \
     && pip install asciitree \
     && pip install numpy \
-    && wget https://github.com/bazelbuild/bazel/releases/download/0.2.2b/bazel-0.2.2b-installer-linux-x86_64.sh \
-    && chmod +x bazel-0.2.2b-installer-linux-x86_64.sh \
-    && ./bazel-0.2.2b-installer-linux-x86_64.sh --user \
+    && wget https://github.com/bazelbuild/bazel/releases/download/0.3.1/bazel-0.3.1-installer-linux-x86_64.sh \
+    && chmod +x bazel-0.3.1-installer-linux-x86_64.sh \
+    && ./bazel-0.3.1-installer-linux-x86_64.sh --user \
     && git clone --recursive https://github.com/tensorflow/models.git \
     && cd $SYNTAXNETDIR/models/syntaxnet/tensorflow \
-    && echo "\n\n\n" | ./configure \
+    && echo "\n\n\n\n" | ./configure \
     && apt-get autoremove -y \
     && apt-get clean
 

+ 13 - 5
syntaxnet/README.md

@@ -29,7 +29,7 @@ Model
 [Martins et al. (2013)](http://www.cs.cmu.edu/~ark/TurboParser/)                                                | 93.10 | 88.23 | 94.21
 [Zhang and McDonald (2014)](http://research.google.com/pubs/archive/38148.pdf)                                  | 93.32 | 88.65 | 93.37
 [Weiss et al. (2015)](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/43800.pdf) | 93.91 | 89.29 | 94.17
-[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                                                   | 94.44 | 90.17 | 95.40
+[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                                                         | 94.44 | 90.17 | 95.40
 Parsey McParseface                                                                                              | 94.15 | 89.08 | 94.77
 
 We see that Parsey McParseface is state-of-the-art; more importantly, with
@@ -45,7 +45,7 @@ Parsey McParseface is also state-of-the-art for part-of-speech (POS) tagging
 Model                                                                      | News  | Web   | Questions
 -------------------------------------------------------------------------- | :---: | :---: | :-------:
 [Ling et al. (2015)](http://www.cs.cmu.edu/~lingwang/papers/emnlp2015.pdf) | 97.78 | 94.03 | 96.18
-[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*              | 97.77 | 94.80 | 96.86
+[Andor et al. (2016)](http://arxiv.org/abs/1603.06042)*                    | 97.77 | 94.80 | 96.86
 Parsey McParseface                                                         | 97.52 | 94.24 | 96.45
 
 The first part of this tutorial describes how to install the necessary tools and
@@ -78,10 +78,16 @@ source. You'll need to install:
 
 *   python 2.7:
     * python 3 support is not available yet
+*   pip (python package manager)
+    * `apt-get install python-pip` on Ubuntu
+    * `brew` installs pip along with python on OSX
 *   bazel:
-    *   **versions 0.2.0 - 0.2.2b, NOT 0.2.3**
+    *   **versions 0.3.0 - 0.3.1*
     *   follow the instructions [here](http://bazel.io/docs/install.html)
-    *   Alternately, Download bazel (0.2.2-0.2.2b) <.deb> from [here](https://github.com/bazelbuild/bazel/releases) for your system configuration.
+    *   Alternately, Download bazel <.deb> from
+        [https://github.com/bazelbuild/bazel/releases]
+        (https://github.com/bazelbuild/bazel/releases) for your system
+        configuration.
     *   Install it using the command: sudo dpkg -i <.deb file>
     *   Check for the bazel version by typing: bazel version
 *   swig:
@@ -94,12 +100,14 @@ source. You'll need to install:
     *   `pip install asciitree`
 *   numpy, package for scientific computing:
     *   `pip install numpy`
+*   mock, package for unit testing:
+    *   `pip install mock`
 
 Once you completed the above steps, you can build and test SyntaxNet with the
 following commands:
 
 ```shell
-  git clone --recursive --recurse-submodules https://github.com/tensorflow/models.git
+  git clone --recursive https://github.com/tensorflow/models.git
   cd models/syntaxnet/tensorflow
   ./configure
   cd ..

+ 4 - 31
syntaxnet/WORKSPACE

@@ -1,38 +1,11 @@
 local_repository(
   name = "org_tensorflow",
-  path = __workspace_dir__ + "/tensorflow",
+  path = "tensorflow",
 )
 
-load('//tensorflow/tensorflow:workspace.bzl', 'tf_workspace')
-tf_workspace("tensorflow/", "@org_tensorflow")
+load('@org_tensorflow//tensorflow:workspace.bzl', 'tf_workspace')
+tf_workspace()
 
 # Specify the minimum required Bazel version.
 load("@org_tensorflow//tensorflow:tensorflow.bzl", "check_version")
-check_version("0.2.0")
-
-# ===== gRPC dependencies =====
-
-bind(
-    name = "libssl",
-    actual = "@ts_boringssl_git//:ssl",
-)
-
-git_repository(
-    name = "ts_boringssl_git",
-    commit = "436432d849b83ab90f18773e4ae1c7a8f148f48d",
-    init_submodules = True,
-    remote = "https://github.com/mdsteele/boringssl-bazel.git",
-)
-
-bind(
-    name = "zlib",
-    actual = "@ts_zlib_archive//:zlib",
-)
-
-new_http_archive(
-    name = "ts_zlib_archive",
-    build_file = "zlib.BUILD",
-    sha256 = "879d73d8cd4d155f31c1f04838ecd567d34bebda780156f0e82a20721b3973d5",
-    strip_prefix = "zlib-1.2.8",
-    url = "http://zlib.net/zlib128.zip",
-)
+check_version("0.3.0")

+ 1 - 1
syntaxnet/syntaxnet/BUILD

@@ -78,7 +78,7 @@ cc_library(
     hdrs = ["base.h"],
     visibility = ["//visibility:public"],
     deps = [
-        "@re2//:re2",
+        "@com_googlesource_code_re2//:re2",
         "@protobuf//:protobuf",
         "@org_tensorflow//third_party/eigen3",
     ] + select({

+ 0 - 1
syntaxnet/syntaxnet/beam_reader_ops.cc

@@ -35,7 +35,6 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
 #include "tensorflow/core/platform/env.h"
 
 using tensorflow::DEVICE_CPU;

+ 0 - 1
syntaxnet/syntaxnet/beam_reader_ops_test.py

@@ -18,7 +18,6 @@
 
 import os.path
 import time
-
 import tensorflow as tf
 
 from tensorflow.python.framework import test_util

+ 8 - 5
syntaxnet/syntaxnet/conll2tree.py

@@ -40,9 +40,11 @@ flags.DEFINE_string('corpus_name', 'stdin-conll',
 
 def to_dict(sentence):
   """Builds a dictionary representing the parse tree of a sentence.
-     Note that the suffix "@id" (where 'id' is a number) is appended to each element
-     to handle the sentence that has multiple elements with identical representation.
-     Those suffix needs to be removed after the asciitree is rendered.
+
+     Note that the suffix "@id" (where 'id' is a number) is appended to each
+     element to handle the sentence that has multiple elements with identical
+     representation. Those suffix needs to be removed after the asciitree is
+     rendered.
 
   Args:
     sentence: Sentence protocol buffer to represent.
@@ -54,7 +56,8 @@ def to_dict(sentence):
   root = -1
   for i in range(0, len(sentence.token)):
     token = sentence.token[i]
-    token_str.append('%s %s %s @%d' % (token.word, token.tag, token.label, (i+1)))
+    token_str.append('%s %s %s @%d' %
+                     (token.word, token.tag, token.label, (i+1)))
     if token.head == -1:
       root = i
     else:
@@ -88,7 +91,7 @@ def main(unused_argv):
         print 'Input: %s' % sentence.text
         print 'Parse:'
         tr_str = tr(d)
-        pat = re.compile('\s*@\d+$')
+        pat = re.compile(r'\s*@\d+$')
         for tr_ln in tr_str.splitlines():
           print pat.sub('', tr_ln)
 

+ 4 - 0
syntaxnet/syntaxnet/context.pbtxt

@@ -87,6 +87,10 @@ input {
   creator: 'brain_pos/greedy'
 }
 input {
+  name: 'char-map'
+  creator: 'brain_pos/greedy'
+}
+input {
   name: 'prefix-table'
   creator: 'brain_pos/greedy'
 }

+ 2 - 2
syntaxnet/syntaxnet/document_format.h

@@ -25,7 +25,7 @@ limitations under the License.
 #include "syntaxnet/registry.h"
 #include "syntaxnet/sentence.pb.h"
 #include "syntaxnet/task_context.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
 
 namespace syntaxnet {
 
@@ -42,7 +42,7 @@ class DocumentFormat : public RegisterableClass<DocumentFormat> {
 
   // Reads a record from the given input buffer with format specific logic.
   // Returns false if no record could be read because we reached end of file.
-  virtual bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+  virtual bool ReadRecord(tensorflow::io::BufferedInputStream *buffer,
                           string *record) = 0;
 
   // Converts a key/value pair to one or more documents.

+ 0 - 1
syntaxnet/syntaxnet/feature_extractor.h

@@ -50,7 +50,6 @@ limitations under the License.
 #include "syntaxnet/workspace.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
 #include "tensorflow/core/lib/io/record_reader.h"
 #include "tensorflow/core/lib/io/record_writer.h"
 #include "tensorflow/core/lib/strings/strcat.h"

+ 2 - 2
syntaxnet/syntaxnet/graph_builder.py

@@ -256,7 +256,7 @@ class GreedyParser(object):
             self.params[name])
 
   def GetStep(self):
-    def OnesInitializer(shape, dtype=tf.float32):
+    def OnesInitializer(shape, dtype=tf.float32, partition_info=None):
       return tf.ones(shape, dtype)
     return self._AddVariable([], tf.int32, 'step', OnesInitializer)
 
@@ -475,7 +475,7 @@ class GreedyParser(object):
   def AddPretrainedEmbeddings(self, index, embeddings_path, task_context):
     """Embeddings at the given index will be set to pretrained values."""
 
-    def _Initializer(shape, dtype=tf.float32):
+    def _Initializer(shape, dtype=tf.float32, partition_info=None):
       unused_dtype = dtype
       t = gen_parser_ops.word_embedding_initializer(
           vectors=embeddings_path,

+ 2 - 3
syntaxnet/syntaxnet/graph_builder_test.py

@@ -18,7 +18,6 @@
 
 # disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
 import os.path
-
 import tensorflow as tf
 
 from tensorflow.python.framework import test_util
@@ -221,7 +220,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
     with self.test_session(graph=graph1) as sess:
       sess.run(parser.inits.values())
       metrics1 = None
-      for _ in range(500):
+      for _ in range(50):
         cost1, _ = sess.run([parser.training['cost'],
                              parser.training['train_op']])
         em1 = parser.evaluation['eval_metrics'].eval()
@@ -240,7 +239,7 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
     with self.test_session(graph=graph2) as sess:
       sess.run(parser.inits.values())
       metrics2 = None
-      for _ in range(500):
+      for _ in range(50):
         cost2, _ = sess.run([parser.training['cost'],
                              parser.training['train_op']])
         em2 = parser.evaluation['eval_metrics'].eval()

+ 0 - 1
syntaxnet/syntaxnet/lexicon_builder_test.py

@@ -19,7 +19,6 @@
 
 # disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
 import os.path
-
 import tensorflow as tf
 
 import syntaxnet.load_parser_ops

+ 0 - 1
syntaxnet/syntaxnet/parser_eval.py

@@ -19,7 +19,6 @@
 import os
 import os.path
 import time
-
 import tempfile
 import tensorflow as tf
 

+ 0 - 1
syntaxnet/syntaxnet/parser_trainer.py

@@ -20,7 +20,6 @@
 import os
 import os.path
 import time
-
 import tensorflow as tf
 
 from tensorflow.python.platform import gfile

+ 1 - 4
syntaxnet/syntaxnet/parser_trainer_test.sh

@@ -17,12 +17,9 @@
 # This test trains a parser on a small dataset, then runs it in greedy mode and
 # in structured mode with beam 1, and checks that the result is identical.
 
-
-
-
 set -eux
 
-BINDIR=$TEST_SRCDIR/syntaxnet
+BINDIR=$TEST_SRCDIR/$TEST_WORKSPACE/syntaxnet
 CONTEXT=$BINDIR/testdata/context.pbtxt
 TMP_DIR=/tmp/syntaxnet-output
 

+ 13 - 7
syntaxnet/syntaxnet/proto_io.h

@@ -32,7 +32,8 @@ limitations under the License.
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/core/status.h"
 #include "tensorflow/core/lib/core/stringpiece.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
 #include "tensorflow/core/lib/io/record_reader.h"
 #include "tensorflow/core/lib/io/record_writer.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -181,22 +182,27 @@ class TextReader {
     if (filename_ == "-") {
       static const int kInputBufferSize = 8 * 1024; /* bytes */
       file_.reset(new StdIn());
-      buffer_.reset(
-          new tensorflow::io::InputBuffer(file_.get(), kInputBufferSize));
+      stream_.reset(new tensorflow::io::RandomAccessInputStream(file_.get()));
+      buffer_.reset(new tensorflow::io::BufferedInputStream(file_.get(),
+                                                            kInputBufferSize));
     } else {
       static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
       TF_CHECK_OK(
           tensorflow::Env::Default()->NewRandomAccessFile(filename_, &file_));
-      buffer_.reset(
-          new tensorflow::io::InputBuffer(file_.get(), kInputBufferSize));
+      stream_.reset(new tensorflow::io::RandomAccessInputStream(file_.get()));
+      buffer_.reset(new tensorflow::io::BufferedInputStream(file_.get(),
+                                                            kInputBufferSize));
     }
   }
 
  private:
   string filename_;
   int sentence_count_ = 0;
-  std::unique_ptr<tensorflow::RandomAccessFile> file_;  // must outlive buffer_
-  std::unique_ptr<tensorflow::io::InputBuffer> buffer_;
+  std::unique_ptr<tensorflow::RandomAccessFile>
+      file_;  // must outlive buffer_, stream_
+  std::unique_ptr<tensorflow::io::RandomAccessInputStream>
+      stream_;  // Must outlive buffer_
+  std::unique_ptr<tensorflow::io::BufferedInputStream> buffer_;
   std::unique_ptr<DocumentFormat> format_;
 };
 

+ 0 - 1
syntaxnet/syntaxnet/reader_ops.cc

@@ -35,7 +35,6 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor.h"
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
 #include "tensorflow/core/lib/io/table.h"
 #include "tensorflow/core/lib/io/table_options.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"

+ 3 - 3
syntaxnet/syntaxnet/reader_ops_test.py

@@ -17,12 +17,10 @@
 
 
 import os.path
-
 import numpy as np
 import tensorflow as tf
 
 from tensorflow.python.framework import test_util
-from tensorflow.python.ops import control_flow_ops as cf
 from tensorflow.python.platform import googletest
 from tensorflow.python.platform import tf_logging as logging
 
@@ -164,7 +162,9 @@ class ParsingReaderOpsTest(test_util.TensorFlowTestCase):
       loop_vars = [epoch, num_actions]
 
       res = sess.run(
-          cf.While(Condition, Body, loop_vars, parallel_iterations=1))
+          tf.while_loop(Condition, Body, loop_vars,
+                        shape_invariants=[tf.TensorShape(None)] * 2,
+                        parallel_iterations=1))
       logging.info('Result: %s', res)
       self.assertEqual(res[0], 2)
 

+ 2 - 1
syntaxnet/syntaxnet/sentence_batch.h

@@ -18,6 +18,7 @@ limitations under the License.
 
 #include <memory>
 #include <string>
+#include <utility>
 #include <vector>
 
 #include "syntaxnet/embedding_feature_extractor.h"
@@ -38,7 +39,7 @@ class SentenceBatch {
  public:
   SentenceBatch(int batch_size, string input_name)
       : batch_size_(batch_size),
-        input_name_(input_name),
+        input_name_(std::move(input_name)),
         sentences_(batch_size) {}
 
   // Initializes all resources and opens the corpus file.

+ 1 - 1
syntaxnet/syntaxnet/sentence_features.h

@@ -300,7 +300,7 @@ class Word : public TermFrequencyMapFeature {
   Word() : TermFrequencyMapFeature("word-map") {}
 
   FeatureValue ComputeValue(const Token &token) const override {
-    string form = token.word();
+    const string &form = token.word();
     return term_map().LookupIndex(form, UnknownValue());
   }
 };

+ 1 - 1
syntaxnet/syntaxnet/shared_store.h

@@ -71,7 +71,7 @@ class SharedStore {
     int refcount;
 
     SharedObject(void *o, std::function<void()> d)
-        : object(o), delete_callback(d), refcount(1) {}
+        : object(o), delete_callback(std::move(d)), refcount(1) {}
   };
 
   // A map from keys to shared objects.

+ 4 - 3
syntaxnet/syntaxnet/structured_graph_builder.py

@@ -24,9 +24,9 @@ from tensorflow.python.ops import tensor_array_ops
 from syntaxnet import graph_builder
 from syntaxnet.ops import gen_parser_ops
 
-tf.NoGradient('BeamParseReader')
-tf.NoGradient('BeamParser')
-tf.NoGradient('BeamParserOutput')
+tf.NotDifferentiable('BeamParseReader')
+tf.NotDifferentiable('BeamParser')
+tf.NotDifferentiable('BeamParserOutput')
 
 
 def AddCrossEntropy(batch_size, n):
@@ -122,6 +122,7 @@ class StructuredGraphBuilder(graph_builder.GreedyParser):
         KeepGoing,
         Advance,
         [state, step, scores_array, alive, alive_steps] + list(features),
+        shape_invariants=[tf.TensorShape(None)] * (len(features) + 5),
         parallel_iterations=100)
 
     # Link to the final nodes/values of ops that have passed through While:

+ 2 - 2
syntaxnet/syntaxnet/syntaxnet.bzl

@@ -19,8 +19,8 @@ load("@protobuf//:protobuf.bzl", "py_proto_library")
 def if_cuda(if_true, if_false = []):
     """Shorthand for select()'ing on whether we're building with CUDA."""
     return select({
-        "@org_tensorflow//third_party/gpus/cuda:using_nvcc": if_true,
-        "@org_tensorflow//third_party/gpus/cuda:using_gcudacc": if_true,
+        "@local_config_cuda//cuda:using_nvcc": if_true,
+        "@local_config_cuda//cuda:using_clang": if_true,
         "//conditions:default": if_false
     })
 

+ 9 - 6
syntaxnet/syntaxnet/term_frequency_map.cc

@@ -20,7 +20,8 @@ limitations under the License.
 #include <limits>
 
 #include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
+#include "tensorflow/core/lib/io/random_inputstream.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/platform/env.h"
 
@@ -61,9 +62,10 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
   std::unique_ptr<tensorflow::RandomAccessFile> file;
   TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
   static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
-  tensorflow::io::InputBuffer input(file.get(), kInputBufferSize);
+  tensorflow::io::RandomAccessInputStream stream(file.get());
+  tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize);
   string line;
-  TF_CHECK_OK(input.ReadLine(&line));
+  TF_CHECK_OK(buffer.ReadLine(&line));
   int32 total = -1;
   CHECK(utils::ParseInt32(line.c_str(), &total));
   CHECK_GE(total, 0);
@@ -71,7 +73,7 @@ void TermFrequencyMap::Load(const string &filename, int min_frequency,
   // Read the mapping.
   int64 last_frequency = -1;
   for (int i = 0; i < total && i < max_num_terms; ++i) {
-    TF_CHECK_OK(input.ReadLine(&line));
+    TF_CHECK_OK(buffer.ReadLine(&line));
     vector<string> elements = utils::Split(line, ' ');
     CHECK_EQ(2, elements.size());
     CHECK(!elements[0].empty());
@@ -143,9 +145,10 @@ TagToCategoryMap::TagToCategoryMap(const string &filename) {
   std::unique_ptr<tensorflow::RandomAccessFile> file;
   TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
   static const int kInputBufferSize = 1 * 1024 * 1024; /* bytes */
-  tensorflow::io::InputBuffer input(file.get(), kInputBufferSize);
+  tensorflow::io::RandomAccessInputStream stream(file.get());
+  tensorflow::io::BufferedInputStream buffer(&stream, kInputBufferSize);
   string line;
-  while (input.ReadLine(&line) == tensorflow::Status::OK()) {
+  while (buffer.ReadLine(&line) == tensorflow::Status::OK()) {
     vector<string> pair = utils::Split(line, '\t');
     CHECK(line.empty() || pair.size() == 2) << line;
     tag_to_category_[pair[0]] = pair[1];

+ 120 - 4
syntaxnet/syntaxnet/text_formats.cc

@@ -18,10 +18,10 @@ limitations under the License.
 #include <vector>
 
 #include "syntaxnet/document_format.h"
-#include "syntaxnet/sentence.pb.h"
 #include "syntaxnet/segmenter_utils.h"
+#include "syntaxnet/sentence.pb.h"
 #include "syntaxnet/utils.h"
-#include "tensorflow/core/lib/io/inputbuffer.h"
+#include "tensorflow/core/lib/io/buffered_inputstream.h"
 #include "tensorflow/core/lib/strings/strcat.h"
 #include "tensorflow/core/lib/strings/stringprintf.h"
 #include "tensorflow/core/platform/regexp.h"
@@ -70,7 +70,7 @@ class CoNLLSyntaxFormat : public DocumentFormat {
   }
 
   // Reads up to the first empty line and returns false end of file is reached.
-  bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+  bool ReadRecord(tensorflow::io::BufferedInputStream *buffer,
                   string *record) override {
     string line;
     record->clear();
@@ -284,6 +284,122 @@ class CoNLLSyntaxFormat : public DocumentFormat {
 
 REGISTER_DOCUMENT_FORMAT("conll-sentence", CoNLLSyntaxFormat);
 
+// Reader for segmentation training data format. This reader assumes the input
+// format is similar to CoNLL format but with only two fileds:
+//
+// Fields:
+// 1  FORM:        Word form or punctuation symbol.
+// 2  SPACE FLAG:  Can be either 'SPACE' or 'NO_SPACE' indicates that whether
+//                 there should be a space between this word and the next one in
+//                 the raw text.
+//
+// Examples:
+// To create a training example for sentence with raw text:
+//   That's a good point.
+// and the corresponding gold segmentation:
+//   That 's a good point .
+// Then the correct input is:
+// That	NO_SPACE
+// 's	SPACE
+// a	SPACE
+// good	SPACE
+// point	NO_SPACE
+// .	NO_SPACE
+//
+// Yet another example:
+// To create a training example for sentence with raw text:
+//   这是一个测试
+// and the corresponding gold segmentation:
+//   这 是 一 个 测试
+// Then the correct input is:
+// 这	NO_SPACE
+// 是	NO_SPACE
+// 一	NO_SPACE
+// 个	NO_SPACE
+// 测试	NO_SPACE
+class SegmentationTrainingDataFormat : public CoNLLSyntaxFormat {
+ public:
+  // Converts to segmentation training data by breaking those word in the input
+  // tokens to utf8 character based tokens. Moreover, if a character is the
+  // first char of the word in the original token, then its break level is set
+  // to SPACE_BREAK to indicate that the corresponding gold transition for that
+  // character token is START. Otherwise NO_BREAK to indicate MERGE.
+  void ConvertFromString(const string &key, const string &value,
+                         vector<Sentence *> *sentences) override {
+    // Create new sentence.
+    Sentence *sentence = new Sentence();
+
+    // Each line corresponds to one token.
+    string text;
+    vector<string> lines = utils::Split(value, '\n');
+
+    // Add each token to the sentence.
+    vector<string> fields;
+    for (size_t i = 0; i < lines.size(); ++i) {
+      // Split line into tab-separated fields.
+      fields.clear();
+      fields = utils::Split(lines[i], '\t');
+      if (fields.empty()) continue;
+
+      // Skip comment lines.
+      if (fields[0][0] == '#') continue;
+
+      // Check that the line is valid.
+      CHECK_GE(fields.size(), 2)
+          << "Every line has to have at least 8 tab separated fields.";
+
+      // Get relevant fields.
+      const string &word = fields[0];
+      CHECK(fields[1] == "SPACE" || fields[1] == "NO_SPACE")
+          << "The space field can only be either 'SPACE' or 'NO_SPACE'";
+      const bool space_after = fields[1] == "SPACE";
+
+      // Add token to sentence text.
+      int start = text.size();
+      text.append(word);
+      if (space_after && i != lines.size() - 1) {
+        text.append(" ");
+      }
+
+      // Add character-based token to sentence.
+      vector<tensorflow::StringPiece> chars;
+      SegmenterUtils::GetUTF8Chars(word, &chars);
+      bool is_first_char = true;
+      for (auto utf8char : chars) {
+        Token *char_token = sentence->add_token();
+        char_token->set_word(utf8char.ToString());
+        char_token->set_start(start);
+        start += char_token->word().size();
+        char_token->set_end(start - 1);
+        char_token->set_break_level(
+            is_first_char ? Token::SPACE_BREAK : Token::NO_BREAK);
+        is_first_char = false;
+      }
+
+      // Add another space token.
+      if (space_after) {
+        Token *char_token = sentence->add_token();
+        char_token->set_word(" ");
+        char_token->set_start(start);
+        char_token->set_end(start);
+        char_token->set_break_level(Token::SPACE_BREAK);
+      }
+    }
+
+    if (sentence->token_size() > 0) {
+      sentence->set_docid(key);
+      sentence->set_text(text);
+      sentences->push_back(sentence);
+    } else {
+      // If the sentence was empty (e.g., blank lines at the beginning of a
+      // file), then don't save it.
+      delete sentence;
+    }
+  }
+};
+
+REGISTER_DOCUMENT_FORMAT("segment-train-data", SegmentationTrainingDataFormat);
+
 // Reader for tokenized text. This reader expects every sentence to be on a
 // single line and tokens on that line to be separated by single spaces.
 //
@@ -292,7 +408,7 @@ class TokenizedTextFormat : public DocumentFormat {
   TokenizedTextFormat() {}
 
   // Reads a line and returns false if end of file is reached.
-  bool ReadRecord(tensorflow::io::InputBuffer *buffer,
+  bool ReadRecord(tensorflow::io::BufferedInputStream *buffer,
                   string *record) override {
     return buffer->ReadLine(record).ok();
   }

+ 48 - 1
syntaxnet/syntaxnet/text_formats_test.py

@@ -19,7 +19,6 @@
 
 # disable=no-name-in-module,unused-import,g-bad-import-order,maybe-no-member
 import os.path
-
 import tensorflow as tf
 
 import syntaxnet.load_parser_ops
@@ -51,6 +50,11 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
     inp.record_format.append(record_format)
     inp.part.add().file_pattern = file_pattern
 
+  def AddParameter(self, name, value, context):
+    param = context.parameter.add()
+    param.name = name
+    param.value = value
+
   def WriteContext(self, corpus_format):
     context = task_spec_pb2.TaskSpec()
     self.AddInput('documents', self.corpus_file, corpus_format, context)
@@ -106,6 +110,49 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
     self.CheckUntokenizedDoc('Hello ', ['H', 'e', 'l', 'l', 'o', ' '],
                              [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])
 
+  def testSegmentationTrainingData(self):
+    doc1_lines = ['测试	NO_SPACE\n',
+                  '的	NO_SPACE\n',
+                  '句子	NO_SPACE']
+    doc1_text = '测试的句子'
+    doc1_tokens = ['测', '试', '的', '句', '子']
+    doc1_break_levles = [1, 0, 1, 1, 0]
+    doc2_lines = ['That	NO_SPACE\n',
+                  '\'s	SPACE\n',
+                  'a	SPACE\n',
+                  'good	SPACE\n',
+                  'point	NO_SPACE\n',
+                  '.	NO_SPACE']
+    doc2_text = 'That\'s a good point.'
+    doc2_tokens = ['T', 'h', 'a', 't', '\'', 's', ' ', 'a', ' ', 'g', 'o', 'o',
+                   'd', ' ', 'p', 'o', 'i', 'n', 't', '.']
+    doc2_break_levles = [1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0,
+                         0, 1]
+    self.CheckSegmentationTrainingData(doc1_lines, doc1_text, doc1_tokens,
+                                       doc1_break_levles)
+    self.CheckSegmentationTrainingData(doc2_lines, doc2_text, doc2_tokens,
+                                       doc2_break_levles)
+
+  def CheckSegmentationTrainingData(self, doc_lines, doc_text, doc_words,
+                                    break_levels):
+    # Prepare context.
+    self.WriteContext('segment-train-data')
+
+    # Prepare test sentence.
+    with open(self.corpus_file, 'w') as f:
+      f.write(''.join(doc_lines))
+
+    # Test converted sentence.
+    sentence, _ = gen_parser_ops.document_source(
+        self.context_file, batch_size=1)
+    with self.test_session() as sess:
+      sentence_doc = self.ReadNextDocument(sess, sentence)
+      self.assertEqual(doc_text.decode('utf-8'), sentence_doc.text)
+      self.assertEqual([t.decode('utf-8') for t in doc_words],
+                       [t.word for t in sentence_doc.token])
+      self.assertEqual(break_levels,
+                       [t.break_level for t in sentence_doc.token])
+
   def testSimple(self):
     self.CheckTokenization('Hello, world!', 'Hello , world !')
     self.CheckTokenization('"Hello"', "`` Hello ''")

+ 1 - 1
syntaxnet/tensorflow

@@ -1 +1 @@
-Subproject commit 861644c0bcae5d56f7b3f439696eefa6df8580ec
+Subproject commit 8ed00233c0cd530fec78cfad5b34f54b6f902e31