segmenter_utils_test.cc 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/segmenter_utils.h"
  13. #include <string>
  14. #include <vector>
  15. #include "syntaxnet/char_properties.h"
  16. #include "syntaxnet/sentence.pb.h"
  17. #include <gmock/gmock.h>
  18. #include "tensorflow/core/lib/strings/strcat.h"
  19. namespace syntaxnet {
  20. // Creates a Korean senence and also initializes the token field.
  21. static Sentence GetKoSentence() {
  22. Sentence sentence;
  23. string text = "서울시는 2012년부터";
  24. // Add tokens.
  25. sentence.set_text(text);
  26. Token *tok = sentence.add_token();
  27. tok->set_word("서울시");
  28. tok->set_start(0);
  29. tok->set_end(8);
  30. tok = sentence.add_token();
  31. tok->set_word("는");
  32. tok->set_start(9);
  33. tok->set_end(11);
  34. tok = sentence.add_token();
  35. tok->set_word("2012");
  36. tok->set_start(13);
  37. tok->set_end(16);
  38. tok = sentence.add_token();
  39. tok->set_word("년");
  40. tok->set_start(17);
  41. tok->set_end(19);
  42. tok = sentence.add_token();
  43. tok->set_word("부터");
  44. tok->set_start(20);
  45. tok->set_end(25);
  46. return sentence;
  47. }
  48. // Gets the start end bytes of the given chars in the given text.
  49. static void GetStartEndBytes(const string &text,
  50. const vector<tensorflow::StringPiece> &chars,
  51. vector<int> *starts,
  52. vector<int> *ends) {
  53. SegmenterUtils segment_utils;
  54. for (const tensorflow::StringPiece &c : chars) {
  55. int start; int end;
  56. segment_utils.GetCharStartEndBytes(text, c, &start, &end);
  57. starts->push_back(start);
  58. ends->push_back(end);
  59. }
  60. }
  61. // Test the GetChars function.
  62. TEST(SegmenterUtilsTest, GetCharsTest) {
  63. // Create test sentence.
  64. const Sentence sentence = GetKoSentence();
  65. vector<tensorflow::StringPiece> chars;
  66. SegmenterUtils::GetUTF8Chars(sentence.text(), &chars);
  67. // Check the number of characters is correct.
  68. CHECK_EQ(chars.size(), 12);
  69. vector<int> starts;
  70. vector<int> ends;
  71. GetStartEndBytes(sentence.text(), chars, &starts, &ends);
  72. // Check start positions.
  73. CHECK_EQ(starts[0], 0);
  74. CHECK_EQ(starts[1], 3);
  75. CHECK_EQ(starts[2], 6);
  76. CHECK_EQ(starts[3], 9);
  77. CHECK_EQ(starts[4], 12);
  78. CHECK_EQ(starts[5], 13);
  79. CHECK_EQ(starts[6], 14);
  80. CHECK_EQ(starts[7], 15);
  81. CHECK_EQ(starts[8], 16);
  82. CHECK_EQ(starts[9], 17);
  83. CHECK_EQ(starts[10], 20);
  84. CHECK_EQ(starts[11], 23);
  85. // Check end positions.
  86. CHECK_EQ(ends[0], 2);
  87. CHECK_EQ(ends[1], 5);
  88. CHECK_EQ(ends[2], 8);
  89. CHECK_EQ(ends[3], 11);
  90. CHECK_EQ(ends[4], 12);
  91. CHECK_EQ(ends[5], 13);
  92. CHECK_EQ(ends[6], 14);
  93. CHECK_EQ(ends[7], 15);
  94. CHECK_EQ(ends[8], 16);
  95. CHECK_EQ(ends[9], 19);
  96. CHECK_EQ(ends[10], 22);
  97. CHECK_EQ(ends[11], 25);
  98. }
  99. // Test the SetCharsAsTokens function.
  100. TEST(SegmenterUtilsTest, SetCharsAsTokensTest) {
  101. // Create test sentence.
  102. const Sentence sentence = GetKoSentence();
  103. vector<tensorflow::StringPiece> chars;
  104. SegmenterUtils segment_utils;
  105. segment_utils.GetUTF8Chars(sentence.text(), &chars);
  106. vector<int> starts;
  107. vector<int> ends;
  108. GetStartEndBytes(sentence.text(), chars, &starts, &ends);
  109. // Check that the new docs word, start and end positions are properly set.
  110. Sentence new_sentence;
  111. segment_utils.SetCharsAsTokens(sentence.text(), chars, &new_sentence);
  112. CHECK_EQ(new_sentence.token_size(), chars.size());
  113. for (int t = 0; t < sentence.token_size(); ++t) {
  114. CHECK_EQ(new_sentence.token(t).word(), chars[t]);
  115. CHECK_EQ(new_sentence.token(t).start(), starts[t]);
  116. CHECK_EQ(new_sentence.token(t).end(), ends[t]);
  117. }
  118. // Re-running should remove the old tokens.
  119. segment_utils.SetCharsAsTokens(sentence.text(), chars, &new_sentence);
  120. CHECK_EQ(new_sentence.token_size(), chars.size());
  121. for (int t = 0; t < sentence.token_size(); ++t) {
  122. CHECK_EQ(new_sentence.token(t).word(), chars[t]);
  123. CHECK_EQ(new_sentence.token(t).start(), starts[t]);
  124. CHECK_EQ(new_sentence.token(t).end(), ends[t]);
  125. }
  126. }
  127. } // namespace syntaxnet