Parcourir la source

New features, locator and text format for tokenization (#306)

* Adding:
  - offset feature locator,
  - last word feature function,
  - untokenized text format.
calberti il y a 9 ans
Parent
commit
6f6a453913

+ 4 - 0
syntaxnet/syntaxnet/BUILD

@@ -132,6 +132,7 @@ cc_library(
     srcs = ["text_formats.cc"],
     deps = [
         ":document_format",
+        ":segmenter_utils",
         ":sentence_proto",
     ],
     alwayslink = 1,
@@ -558,7 +559,9 @@ cc_test(
     deps = [
         ":parser_transitions",
         ":sentence_proto",
+        ":task_context",
         ":test_main",
+        ":workspace",
     ],
 )
 
@@ -646,6 +649,7 @@ py_binary(
         ":graph_builder",
         ":sentence_py_pb2",
         ":structured_graph_builder",
+        ":task_spec_py_pb2",
     ],
 )
 

+ 100 - 0
syntaxnet/syntaxnet/binary_segment_transitions.cc

@@ -14,8 +14,10 @@ limitations under the License.
 ==============================================================================*/
 
 #include "syntaxnet/binary_segment_state.h"
+#include "syntaxnet/parser_features.h"
 #include "syntaxnet/parser_state.h"
 #include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/term_frequency_map.h"
 
 namespace syntaxnet {
 
@@ -118,4 +120,102 @@ class BinarySegmentTransitionSystem : public ParserTransitionSystem {
 REGISTER_TRANSITION_SYSTEM("binary-segment-transitions",
                            BinarySegmentTransitionSystem);
 
+// Parser feature locator that returns the token in the sentence that is
+// argument() positions from the provided focus token.
+class OffsetFeatureLocator : public ParserIndexLocator<OffsetFeatureLocator> {
+ public:
+  // Update the current focus to a new location.  If the initial focus or new
+  // focus is outside the range of the sentence, returns -2.
+  void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
+                  int *focus) const {
+    if (*focus < -1 || *focus >= state.sentence().token_size()) {
+      *focus = -2;
+      return;
+    }
+    int new_focus = *focus + argument();
+    if (new_focus < -1 || new_focus >= state.sentence().token_size()) {
+      *focus = -2;
+      return;
+    }
+    *focus = new_focus;
+  }
+};
+
+REGISTER_PARSER_IDX_FEATURE_FUNCTION("offset", OffsetFeatureLocator);
+
+// Feature function that returns the id of the n-th most recently constructed
+// word. Note that the argument, n, should be larger than 0. When equals to 0,
+// it points to the word which is not yet completed.
+class LastWordFeatureFunction : public ParserFeatureFunction {
+ public:
+  void Setup(TaskContext *context) override {
+    input_word_map_ = context->GetInput("word-map", "text", "");
+  }
+
+  void Init(TaskContext *context) override {
+    min_freq_ = GetIntParameter("min-freq", 0);
+    max_num_terms_ = GetIntParameter("max-num-terms", 0);
+    word_map_.Load(
+        TaskContext::InputFile(*input_word_map_), min_freq_, max_num_terms_);
+    unk_id_ = word_map_.Size();
+    outside_id_ = unk_id_ + 1;
+    set_feature_type(
+        new ResourceBasedFeatureType<LastWordFeatureFunction>(
+        name(), this, {}));
+  }
+
+  int64 NumValues() const {
+    return outside_id_ + 1;
+  }
+
+  // Returns the string representation of the given feature value.
+  string GetFeatureValueName(FeatureValue value) const {
+    if (value == outside_id_) return "<OUTSIDE>";
+    if (value == unk_id_) return "<UNKNOWN>";
+    DCHECK_GE(value, 0);
+    DCHECK_LT(value, word_map_.Size());
+    return word_map_.GetTerm(value);
+  }
+
+  FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
+                       const FeatureVector *result) const override {
+    // n should be larger than 0, since the current word is still under
+    // construction.
+    const int n = argument();
+    CHECK_GT(n, 0);
+    const auto *segment_state = static_cast<const BinarySegmentState *>(
+        state.transition_state());
+    if (n >= segment_state->NumStarts(state)) {
+      return outside_id_;
+    }
+
+    const auto &sentence = state.sentence();
+    const int start = segment_state->LastStart(n, state);
+    const int end = segment_state->LastStart(n - 1, state) - 1;
+    CHECK_GE(end, start);
+
+    const int start_offset = state.GetToken(start).start();
+    const int length = state.GetToken(end).end() - start_offset + 1;
+    const auto *data = sentence.text().data() + start_offset;
+    return word_map_.LookupIndex(string(data, length), unk_id_);
+  }
+
+ private:
+  // Task input for the word to id map. Not owned.
+  TaskInput *input_word_map_ = nullptr;
+  TermFrequencyMap word_map_;
+
+  // Special ids of unknown words and out-of-range.
+  int unk_id_ = 0;
+  int outside_id_ = 0;
+
+  // Minimum frequency for term map.
+  int min_freq_;
+
+  // Maximum number of terms for term map.
+  int max_num_terms_;
+};
+
+REGISTER_PARSER_FEATURE_FUNCTION("last-word", LastWordFeatureFunction);
+
 }  // namespace syntaxnet

+ 122 - 1
syntaxnet/syntaxnet/binary_segment_transitions_test.cc

@@ -14,9 +14,12 @@ limitations under the License.
 ==============================================================================*/
 
 #include "syntaxnet/binary_segment_state.h"
+#include "syntaxnet/parser_features.h"
 #include "syntaxnet/parser_state.h"
 #include "syntaxnet/parser_transitions.h"
+#include "syntaxnet/task_context.h"
 #include "syntaxnet/term_frequency_map.h"
+#include "syntaxnet/workspace.h"
 #include "tensorflow/core/platform/test.h"
 
 namespace syntaxnet {
@@ -38,6 +41,58 @@ class SegmentationTransitionTest : public ::testing::Test {
         "token { word: '样' start: 14 end: 16 break_level: NO_BREAK } ";
     sentence_ = std::unique_ptr<Sentence>(new Sentence());
     TextFormat::ParseFromString(str_sentence, sentence_.get());
+
+    word_map_.Increment("因为");
+    word_map_.Increment("因为");
+    word_map_.Increment("有");
+    word_map_.Increment("这");
+    word_map_.Increment("这");
+    word_map_.Increment("样");
+    word_map_.Increment("样");
+    word_map_.Increment("这样");
+    word_map_.Increment("这样");
+    string filename = tensorflow::strings::StrCat(
+        tensorflow::testing::TmpDir(), "word-map");
+    word_map_.Save(filename);
+
+    // Re-load in sorted order, ignore words that only occurs once.
+    word_map_.Load(filename, 2, -1);
+
+    // Prepare task context.
+    context_ = std::unique_ptr<TaskContext>(new TaskContext());
+    AddInputToContext("word-map", filename, "text", "");
+    registry_ = std::unique_ptr<WorkspaceRegistry>( new WorkspaceRegistry());
+  }
+
+  // Adds an input to the task context.
+  void AddInputToContext(const string &name,
+                         const string &file_pattern,
+                         const string &file_format,
+                         const string &record_format) {
+    TaskInput *input = context_->GetInput(name);
+    TaskInput::Part *part = input->add_part();
+    part->set_file_pattern(file_pattern);
+    part->set_file_format(file_format);
+    part->set_record_format(record_format);
+  }
+
+  // Prepares a feature for computations.
+  void PrepareFeature(const string &feature_name, ParserState *state) {
+    feature_extractor_ = std::unique_ptr<ParserFeatureExtractor>(
+        new ParserFeatureExtractor());
+    feature_extractor_->Parse(feature_name);
+    feature_extractor_->Setup(context_.get());
+    feature_extractor_->Init(context_.get());
+    feature_extractor_->RequestWorkspaces(registry_.get());
+    workspace_.Reset(*registry_);
+    feature_extractor_->Preprocess(&workspace_, state);
+  }
+
+  // Computes the feature value for the parser state.
+  FeatureValue ComputeFeature(const ParserState &state) const {
+    FeatureVector result;
+    feature_extractor_->ExtractFeatures(workspace_, state, &result);
+    return result.size() > 0 ? result.value(0) : -1;
   }
 
   void CheckStarts(const ParserState &state, const vector<int> &target) {
@@ -48,10 +103,18 @@ class SegmentationTransitionTest : public ::testing::Test {
     }
   }
 
-  // The test document, parse tree, and sentence with tags and partial parses.
+  // The test sentence.
   std::unique_ptr<Sentence> sentence_;
+
+  // Members for testing features.
+  std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
+  std::unique_ptr<TaskContext> context_;
+  std::unique_ptr<WorkspaceRegistry> registry_;
+  WorkspaceSet workspace_;
+
   std::unique_ptr<ParserTransitionSystem> transition_system_;
   TermFrequencyMap label_map_;
+  TermFrequencyMap word_map_;
 };
 
 TEST_F(SegmentationTransitionTest, GoldNextActionTest) {
@@ -108,4 +171,62 @@ TEST_F(SegmentationTransitionTest, DefaultActionTest) {
   EXPECT_EQ(sentence_->token(4).word(), "样");
 }
 
+TEST_F(SegmentationTransitionTest, LastWordFeatureTest) {
+  const int unk_id = word_map_.Size();
+  const int outside_id = unk_id + 1;
+
+  // Prepare a parser state.
+  BinarySegmentState *segment_state = new BinarySegmentState();
+  auto state = std::unique_ptr<ParserState>(new ParserState(
+      sentence_.get(), segment_state, &label_map_));
+
+  // Test initial state which contains no words.
+  PrepareFeature("last-word(1,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+  PrepareFeature("last-word(2,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+  PrepareFeature("last-word(3,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+
+  // Test when the state contains only one start.
+  segment_state->AddStart(0, state.get());
+  PrepareFeature("last-word(1,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+  PrepareFeature("last-word(2,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+
+  // Test when the state contains two starts, which forms a complete word and
+  // the start of another new word.
+  segment_state->AddStart(2, state.get());
+  EXPECT_NE(word_map_.LookupIndex("因为", unk_id), unk_id);
+  PrepareFeature("last-word(1)", state.get());
+  EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
+
+  // The last-word still points to outside.
+  PrepareFeature("last-word(2,min-freq=2)", state.get());
+  EXPECT_EQ(outside_id, ComputeFeature(*state));
+
+  // Adding more starts that leads to the following words:
+  // 因为 ‘ ’ 有 ‘ ’
+  segment_state->AddStart(3, state.get());
+  segment_state->AddStart(4, state.get());
+
+  // Note 有 is pruned from the map since its frequency is less than 2.
+  EXPECT_EQ(word_map_.LookupIndex("有", unk_id), unk_id);
+  PrepareFeature("last-word(1,min-freq=2)", state.get());
+  EXPECT_EQ(unk_id, ComputeFeature(*state));
+
+  // Note that last-word(2) points to ' ' which is also a unk.
+  PrepareFeature("last-word(2,min-freq=2)", state.get());
+  EXPECT_EQ(unk_id, ComputeFeature(*state));
+  PrepareFeature("last-word(3,min-freq=2)", state.get());
+  EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
+
+  // Adding two words: "这" and "样".
+  segment_state->AddStart(5, state.get());
+  segment_state->AddStart(6, state.get());
+  PrepareFeature("last-word(1,min-freq=2)", state.get());
+  EXPECT_EQ(word_map_.LookupIndex("这", unk_id), ComputeFeature(*state));
+}
+
 }  // namespace syntaxnet

+ 1 - 0
syntaxnet/syntaxnet/char_properties.cc

@@ -795,6 +795,7 @@ DEFINE_CHAR_PROPERTY(separator, prop) {
 DEFINE_CHAR_PROPERTY_AS_SET(digit,
   RANGE('0', '9'),
   RANGE(0x0660, 0x0669),  // Arabic-Indic digits
+
   RANGE(0x06F0, 0x06F9),  // Eastern Arabic-Indic digits
 )
 

+ 31 - 17
syntaxnet/syntaxnet/parser_eval.py

@@ -20,13 +20,19 @@ import os
 import os.path
 import time
 
+import tempfile
 import tensorflow as tf
 
+from tensorflow.python.platform import gfile
 from tensorflow.python.platform import tf_logging as logging
+
+from google.protobuf import text_format
+
 from syntaxnet import sentence_pb2
 from syntaxnet import graph_builder
 from syntaxnet import structured_graph_builder
 from syntaxnet.ops import gen_parser_ops
+from syntaxnet import task_spec_pb2
 
 flags = tf.app.flags
 FLAGS = flags.FLAGS
@@ -35,6 +41,8 @@ FLAGS = flags.FLAGS
 flags.DEFINE_string('task_context', '',
                     'Path to a task context with inputs and parameters for '
                     'feature extractors.')
+flags.DEFINE_string('resource_dir', '',
+                    'Optional base directory for task context resources.')
 flags.DEFINE_string('model_path', '', 'Path to model parameters.')
 flags.DEFINE_string('arg_prefix', None, 'Prefix for context parameters.')
 flags.DEFINE_string('graph_builder', 'greedy',
@@ -53,16 +61,28 @@ flags.DEFINE_bool('slim_model', False,
                   'Whether to expect only averaged variables.')
 
 
-def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
-  """Builds and evaluates a network.
+def RewriteContext(task_context):
+  context = task_spec_pb2.TaskSpec()
+  with gfile.FastGFile(task_context) as fin:
+    text_format.Merge(fin.read(), context)
+  for resource in context.input:
+    for part in resource.part:
+      if part.file_pattern != '-':
+        part.file_pattern = os.path.join(FLAGS.resource_dir, part.file_pattern)
+  with tempfile.NamedTemporaryFile(delete=False) as fout:
+    fout.write(str(context))
+    return fout.name
+
+
+def Eval(sess):
+  """Builds and evaluates a network."""
+  task_context = FLAGS.task_context
+  if FLAGS.resource_dir:
+    task_context = RewriteContext(task_context)
+  feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
+      gen_parser_ops.feature_size(task_context=task_context,
+                                  arg_prefix=FLAGS.arg_prefix))
 
-  Args:
-    sess: tensorflow session to use
-    num_actions: number of possible golden actions
-    feature_sizes: size of each feature vector
-    domain_sizes: number of possible feature ids in each feature vector
-    embedding_dims: embedding dimension for each feature group
-  """
   t = time.time()
   hidden_layer_sizes = map(int, FLAGS.hidden_layer_sizes.split(','))
   logging.info('Building training network with parameters: feature_sizes: %s '
@@ -86,7 +106,6 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
         arg_prefix=FLAGS.arg_prefix,
         beam_size=FLAGS.beam_size,
         max_steps=FLAGS.max_steps)
-  task_context = FLAGS.task_context
   parser.AddEvaluation(task_context,
                        FLAGS.batch_size,
                        corpus_name=FLAGS.input,
@@ -98,7 +117,7 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
 
   sink_documents = tf.placeholder(tf.string)
   sink = gen_parser_ops.document_sink(sink_documents,
-                                      task_context=FLAGS.task_context,
+                                      task_context=task_context,
                                       corpus_name=FLAGS.output)
   t = time.time()
   num_epochs = None
@@ -136,12 +155,7 @@ def Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims):
 def main(unused_argv):
   logging.set_verbosity(logging.INFO)
   with tf.Session() as sess:
-    feature_sizes, domain_sizes, embedding_dims, num_actions = sess.run(
-        gen_parser_ops.feature_size(task_context=FLAGS.task_context,
-                                    arg_prefix=FLAGS.arg_prefix))
-
-  with tf.Session() as sess:
-    Eval(sess, num_actions, feature_sizes, domain_sizes, embedding_dims)
+    Eval(sess)
 
 
 if __name__ == '__main__':

+ 1 - 0
syntaxnet/syntaxnet/sentence_features.h

@@ -243,6 +243,7 @@ class TermFrequencyMapSetFeature : public TokenLookupSetFeature {
   void Init(TaskContext *context) override;
 
   // Number of unique values.
+
   int64 NumValues() const override { return term_map_->Size(); }
 
   // Special value for strings not in the map.

+ 52 - 10
syntaxnet/syntaxnet/text_formats.cc

@@ -19,6 +19,7 @@ limitations under the License.
 
 #include "syntaxnet/document_format.h"
 #include "syntaxnet/sentence.pb.h"
+#include "syntaxnet/segmenter_utils.h"
 #include "syntaxnet/utils.h"
 #include "tensorflow/core/lib/io/inputbuffer.h"
 #include "tensorflow/core/lib/strings/strcat.h"
@@ -172,13 +173,13 @@ class CoNLLSyntaxFormat : public DocumentFormat {
       if (add_pos_as_attribute_) RemovePosFromAttributes(&token);
       vector<string> fields(10);
       fields[0] = tensorflow::strings::Printf("%d", i + 1);
-      fields[1] = token.word();
+      fields[1] = UnderscoreIfEmpty(token.word());
       fields[2] = "_";
-      fields[3] = token.category();
-      fields[4] = token.tag();
+      fields[3] = UnderscoreIfEmpty(token.category());
+      fields[4] = UnderscoreIfEmpty(token.tag());
       fields[5] = GetMorphAttributes(token);
       fields[6] = tensorflow::strings::Printf("%d", token.head() + 1);
-      fields[7] = token.label();
+      fields[7] = UnderscoreIfEmpty(token.label());
       fields[8] = "_";
       fields[9] = "_";
       lines.push_back(utils::Join(fields, "\t"));
@@ -187,6 +188,11 @@ class CoNLLSyntaxFormat : public DocumentFormat {
   }
 
  private:
+  // Replaces empty fields with an undescore.
+  string UnderscoreIfEmpty(const string &field) {
+    return field.empty() ? "_" : field;
+  }
+
   // Creates a TokenMorphology object out of a list of attribute values of the
   // form: a1=v1|a2=v2|... or v1|v2|...
   void AddMorphAttributes(const string &attributes, Token *token) {
@@ -194,11 +200,7 @@ class CoNLLSyntaxFormat : public DocumentFormat {
         token->MutableExtension(TokenMorphology::morphology);
     vector<string> att_vals = utils::Split(attributes, '|');
     for (int i = 0; i < att_vals.size(); ++i) {
-      vector<string> att_val = utils::Split(att_vals[i], '=');
-      CHECK_LE(att_val.size(), 2)
-          << "Error parsing morphology features "
-          << "column, must be of format "
-          << "a1=v1|a2=v2|... or v1|v2|... <field>: " << attributes;
+      vector<string> att_val = utils::SplitOne(att_vals[i], '=');
 
       // Format is either:
       //   1) a1=v1|a2=v2..., e.g., Czech CoNLL data, or,
@@ -268,7 +270,8 @@ class CoNLLSyntaxFormat : public DocumentFormat {
     // Assumes the "fPOS" attribute, if present, is the last one.
     TokenMorphology *morph =
         token->MutableExtension(TokenMorphology::morphology);
-    if (morph->attribute().rbegin()->name() == "fPOS") {
+    if (morph->attribute_size() > 0 &&
+        morph->attribute().rbegin()->name() == "fPOS") {
       morph->mutable_attribute()->RemoveLast();
     }
   }
@@ -346,6 +349,45 @@ class TokenizedTextFormat : public DocumentFormat {
 
 REGISTER_DOCUMENT_FORMAT("tokenized-text", TokenizedTextFormat);
 
+// Reader for un-tokenized text. This reader expects every sentence to be on a
+// single line. For each line in the input, a sentence proto will be created,
+// where tokens are utf8 characters of that line.
+//
+class UntokenizedTextFormat : public TokenizedTextFormat {
+ public:
+  UntokenizedTextFormat() {}
+
+  void ConvertFromString(const string &key, const string &value,
+                         vector<Sentence *> *sentences) override {
+    Sentence *sentence = new Sentence();
+    vector<tensorflow::StringPiece> chars;
+    SegmenterUtils::GetUTF8Chars(value, &chars);
+    int start = 0;
+    for (auto utf8char : chars) {
+      Token *token = sentence->add_token();
+      token->set_word(utf8char.ToString());
+      token->set_start(start);
+      start += utf8char.size();
+      token->set_end(start - 1);
+    }
+
+    if (sentence->token_size() > 0) {
+      sentence->set_docid(key);
+      sentence->set_text(value);
+      sentences->push_back(sentence);
+    } else {
+      // If the sentence was empty (e.g., blank lines at the beginning of a
+      // file), then don't save it.
+      delete sentence;
+    }
+  }
+
+ private:
+  TF_DISALLOW_COPY_AND_ASSIGN(UntokenizedTextFormat);
+};
+
+REGISTER_DOCUMENT_FORMAT("untokenized-text", UntokenizedTextFormat);
+
 // Text reader that attmpts to perform Penn Treebank tokenization on arbitrary
 // raw text. Adapted from https://www.cis.upenn.edu/~treebank/tokenizer.sed
 // by Robert MacIntyre, University of Pennsylvania, late 1995.

+ 23 - 0
syntaxnet/syntaxnet/text_formats_test.py

@@ -83,6 +83,29 @@ class TextFormatsTest(test_util.TensorFlowTestCase):
       self.assertEqual(' '.join([t.word for t in sentence_doc.token]),
                        tokenization)
 
+  def CheckUntokenizedDoc(self, sentence, words, starts, ends):
+    self.WriteContext('untokenized-text')
+    logging.info('Writing text file to: %s', self.corpus_file)
+    with open(self.corpus_file, 'w') as f:
+      f.write(sentence)
+    sentence, _ = gen_parser_ops.document_source(
+        self.context_file, batch_size=1)
+    with self.test_session() as sess:
+      sentence_doc = self.ReadNextDocument(sess, sentence)
+      self.assertEqual(len(sentence_doc.token), len(words))
+      self.assertEqual(len(sentence_doc.token), len(starts))
+      self.assertEqual(len(sentence_doc.token), len(ends))
+      for i, token in enumerate(sentence_doc.token):
+        self.assertEqual(token.word.encode('utf-8'), words[i])
+        self.assertEqual(token.start, starts[i])
+        self.assertEqual(token.end, ends[i])
+
+  def testUntokenized(self):
+    self.CheckUntokenizedDoc('一个测试', ['一', '个', '测', '试'],
+                             [0, 3, 6, 9], [2, 5, 8, 11])
+    self.CheckUntokenizedDoc('Hello ', ['H', 'e', 'l', 'l', 'o', ' '],
+                             [0, 1, 2, 3, 4, 5], [0, 1, 2, 3, 4, 5])
+
   def testSimple(self):
     self.CheckTokenization('Hello, world!', 'Hello , world !')
     self.CheckTokenization('"Hello"', "`` Hello ''")

+ 10 - 0
syntaxnet/syntaxnet/utils.cc

@@ -95,6 +95,16 @@ std::vector<string> Split(const string &text, char delim) {
   return result;
 }
 
+std::vector<string> SplitOne(const string &text, char delim) {
+  std::vector<string> result;
+  size_t split = text.find_first_of(delim);
+  result.push_back(text.substr(0, split));
+  if (split != string::npos) {
+    result.push_back(text.substr(split + 1));
+  }
+  return result;
+}
+
 bool IsAbsolutePath(tensorflow::StringPiece path) {
   return !path.empty() && path[0] == '/';
 }

+ 5 - 0
syntaxnet/syntaxnet/utils.h

@@ -49,8 +49,13 @@ T ParseUsing(const string &str, T defval,
 
 string CEscape(const string &src);
 
+// Splits the given string on every occurrence of the given delimiter char.
 std::vector<string> Split(const string &text, char delim);
 
+// Splits the given string on the first occurrence of the given delimiter char,
+// or returns the given string if the given delimiter is not found.
+std::vector<string> SplitOne(const string &text, char delim);
+
 template <typename T>
 string Join(const std::vector<T> &s, const char *sep) {
   string result;