segmenter_utils.cc 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/segmenter_utils.h"
  13. #include "util/utf8/unicodetext.h"
  14. #include "util/utf8/unilib.h"
  15. #include "util/utf8/unilib_utf8_utils.h"
  16. namespace syntaxnet {
  17. // Separators, code Zs from http://www.unicode.org/Public/UNIDATA/PropList.txt
  18. // NB: This list is not necessarily exhaustive.
  19. const std::unordered_set<int> SegmenterUtils::kBreakChars({
  20. 0x2028, // line separator
  21. 0x2029, // paragraph separator
  22. 0x0020, // space
  23. 0x00a0, // no-break space
  24. 0x1680, // Ogham space mark
  25. 0x180e, // Mongolian vowel separator
  26. 0x202f, // narrow no-break space
  27. 0x205f, // medium mathematical space
  28. 0x3000, // ideographic space
  29. 0xe5e5, // Google addition
  30. 0x2000, 0x2001, 0x2002, 0x2003, 0x2004, 0x2005, 0x2006, 0x2007, 0x2008,
  31. 0x2009, 0x200a
  32. });
  33. bool SegmenterUtils::ConvertToCharTokenDoc(const Sentence &sentence,
  34. Sentence *char_sentence) {
  35. CHECK(char_sentence);
  36. // Extracts tokens and byte offsets.
  37. std::vector<tensorflow::StringPiece> orig_chars;
  38. GetUTF8Chars(sentence.text(), &orig_chars);
  39. // If sentence's token start/end bytes are not consistent with UTF-8
  40. // characters, then do not process the sentence.
  41. if (!DocTokensUTF8Consistent(orig_chars, sentence)) {
  42. LOG(WARNING) << "Document token boundaries not UTF8 consistent.";
  43. return false;
  44. }
  45. // Create a mapping from text byte to token index or -1.
  46. // token_ids[i] = index of the token byte i is contained in and -1 o.w.
  47. std::vector<int> token_ids;
  48. for (int t = 0; t < sentence.token_size(); ++t) {
  49. const Token &token = sentence.token(t);
  50. while (token_ids.size() < token.start()) token_ids.push_back(-1);
  51. while (token_ids.size() <= token.end()) token_ids.push_back(t);
  52. }
  53. while (token_ids.size() < sentence.text().size()) token_ids.push_back(-1);
  54. // Infer SPLIT/MERGE for each UTF-8 char.
  55. std::vector<Token::BreakLevel> break_levels;
  56. break_levels.push_back(Token::SPACE_BREAK); // first token is always a split.
  57. for (int c = 1; c < orig_chars.size(); ++c) {
  58. int char_start, char_end;
  59. GetCharStartEndBytes(sentence.text(), orig_chars[c], &char_start,
  60. &char_end);
  61. int prev_char_start, prev_char_end;
  62. GetCharStartEndBytes(sentence.text(), orig_chars[c - 1], &prev_char_start,
  63. &prev_char_end);
  64. // We split if this character is a break token (token_ids = -1) or if this
  65. // character is part of a different token from the previous.
  66. const bool is_split =
  67. token_ids[char_start] == -1 ||
  68. token_ids[char_start] != token_ids[prev_char_end];
  69. break_levels.push_back(is_split ? Token::SPACE_BREAK : Token::NO_BREAK);
  70. }
  71. // Initialize character sentence.
  72. SetCharsAsTokens(sentence.text(), orig_chars, char_sentence);
  73. CHECK_EQ(break_levels.size(), char_sentence->token_size());
  74. for (int i = 0 ; i < break_levels.size(); ++i) {
  75. char_sentence->mutable_token(i)->set_break_level(break_levels[i]);
  76. }
  77. return true;
  78. }
  79. bool SegmenterUtils::DocTokensUTF8Consistent(
  80. const std::vector<tensorflow::StringPiece> &chars,
  81. const Sentence &sentence) {
  82. std::set<int> starts;
  83. std::set<int> ends;
  84. for (const tensorflow::StringPiece c : chars) {
  85. int start_byte, end_byte;
  86. GetCharStartEndBytes(sentence.text(), c, &start_byte, &end_byte);
  87. starts.insert(start_byte);
  88. ends.insert(end_byte);
  89. }
  90. for (const Token &t : sentence.token()) {
  91. if (starts.find(t.start()) == starts.end()) return false;
  92. if (ends.find(t.end()) == ends.end()) return false;
  93. }
  94. return true;
  95. }
  96. void SegmenterUtils::GetUTF8Chars(const string &text,
  97. std::vector<tensorflow::StringPiece> *chars) {
  98. const char *start = text.c_str();
  99. const char *end = text.c_str() + text.size();
  100. while (start < end) {
  101. int char_length = UniLib::OneCharLen(start);
  102. chars->emplace_back(start, char_length);
  103. start += char_length;
  104. }
  105. }
  106. void SegmenterUtils::SetCharsAsTokens(
  107. const string &text,
  108. const std::vector<tensorflow::StringPiece> &chars,
  109. Sentence *sentence) {
  110. sentence->clear_token();
  111. sentence->set_text(text);
  112. for (int i = 0; i < chars.size(); ++i) {
  113. Token *tok = sentence->add_token();
  114. tok->set_word(chars[i].ToString()); // NOLINT
  115. int start_byte, end_byte;
  116. GetCharStartEndBytes(text, chars[i], &start_byte, &end_byte);
  117. tok->set_start(start_byte);
  118. tok->set_end(end_byte);
  119. }
  120. }
  121. bool SegmenterUtils::IsValidSegment(const Sentence &sentence,
  122. const Token &token) {
  123. // Check that the token is not empty, both by string and by bytes.
  124. if (token.word().empty()) return false;
  125. if (token.start() > token.end()) return false;
  126. // Check token boudaries inside of text.
  127. if (token.start() < 0) return false;
  128. if (token.end() >= sentence.text().size()) return false;
  129. // Check that token string is valid UTF8, by bytes.
  130. const char s = sentence.text()[token.start()];
  131. const char e = sentence.text()[token.end() + 1];
  132. if (UniLib::IsTrailByte(s)) return false;
  133. if (UniLib::IsTrailByte(e)) return false;
  134. return true;
  135. }
  136. } // namespace syntaxnet