binary_segment_transitions_test.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/binary_segment_state.h"
  13. #include "syntaxnet/parser_features.h"
  14. #include "syntaxnet/parser_state.h"
  15. #include "syntaxnet/parser_transitions.h"
  16. #include "syntaxnet/task_context.h"
  17. #include "syntaxnet/term_frequency_map.h"
  18. #include "syntaxnet/workspace.h"
  19. #include "tensorflow/core/platform/test.h"
  20. namespace syntaxnet {
  21. class SegmentationTransitionTest : public ::testing::Test {
  22. protected:
  23. void SetUp() override {
  24. transition_system_ = std::unique_ptr<ParserTransitionSystem>(
  25. ParserTransitionSystem::Create("binary-segment-transitions"));
  26. // Prepare a sentence.
  27. const char *str_sentence = "text: '因为 有 这样' "
  28. "token { word: '因' start: 0 end: 2 break_level: SPACE_BREAK } "
  29. "token { word: '为' start: 3 end: 5 break_level: NO_BREAK } "
  30. "token { word: ' ' start: 6 end: 6 break_level: SPACE_BREAK } "
  31. "token { word: '有' start: 7 end: 9 break_level: SPACE_BREAK } "
  32. "token { word: ' ' start: 10 end: 10 break_level: SPACE_BREAK } "
  33. "token { word: '这' start: 11 end: 13 break_level: SPACE_BREAK } "
  34. "token { word: '样' start: 14 end: 16 break_level: NO_BREAK } ";
  35. sentence_ = std::unique_ptr<Sentence>(new Sentence());
  36. TextFormat::ParseFromString(str_sentence, sentence_.get());
  37. context_.reset(new TaskContext());
  38. word_map_.Increment("因为");
  39. word_map_.Increment("因为");
  40. word_map_.Increment("有");
  41. word_map_.Increment("这");
  42. word_map_.Increment("这");
  43. word_map_.Increment("样");
  44. word_map_.Increment("样");
  45. word_map_.Increment("这样");
  46. word_map_.Increment("这样");
  47. bigram_map_.Increment("因为");
  48. bigram_map_.Increment("因为");
  49. bigram_map_.Increment("因为");
  50. bigram_map_.Increment("为有");
  51. bigram_map_.Increment("这样");
  52. bigram_map_.Increment("这样");
  53. string filename = tensorflow::strings::StrCat(
  54. tensorflow::testing::TmpDir(), "word-map");
  55. word_map_.Save(filename);
  56. word_map_.Load(filename, 2, -1);
  57. AddInputToContext("word-map", filename, "text", "");
  58. filename = tensorflow::strings::StrCat(
  59. tensorflow::testing::TmpDir(), "char-ngram-map");
  60. bigram_map_.Save(filename);
  61. AddInputToContext("char-ngram-map", filename, "text", "");
  62. registry_ = std::unique_ptr<WorkspaceRegistry>( new WorkspaceRegistry());
  63. }
  64. // Adds an input to the task context.
  65. void AddInputToContext(const string &name,
  66. const string &file_pattern,
  67. const string &file_format,
  68. const string &record_format) {
  69. TaskInput *input = context_->GetInput(name);
  70. TaskInput::Part *part = input->add_part();
  71. part->set_file_pattern(file_pattern);
  72. part->set_file_format(file_format);
  73. part->set_record_format(record_format);
  74. }
  75. // Prepares a feature for computations.
  76. void PrepareFeature(const string &feature_name, ParserState *state) {
  77. feature_extractor_ = std::unique_ptr<ParserFeatureExtractor>(
  78. new ParserFeatureExtractor());
  79. feature_extractor_->Parse(feature_name);
  80. feature_extractor_->Setup(context_.get());
  81. feature_extractor_->Init(context_.get());
  82. feature_extractor_->RequestWorkspaces(registry_.get());
  83. workspace_.Reset(*registry_);
  84. feature_extractor_->Preprocess(&workspace_, state);
  85. }
  86. // Computes the feature value for the parser state.
  87. FeatureValue ComputeFeature(const ParserState &state) const {
  88. FeatureVector result;
  89. feature_extractor_->ExtractFeatures(workspace_, state, &result);
  90. return result.size() > 0 ? result.value(0) : -1;
  91. }
  92. void CheckStarts(const ParserState &state, const std::vector<int> &target) {
  93. ASSERT_EQ(state.StackSize(), target.size());
  94. std::vector<int> starts;
  95. for (int i = 0; i < state.StackSize(); ++i) {
  96. EXPECT_EQ(state.Stack(i), target[i]);
  97. }
  98. }
  99. // The test sentence.
  100. std::unique_ptr<Sentence> sentence_;
  101. // Members for testing features.
  102. std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
  103. std::unique_ptr<TaskContext> context_;
  104. std::unique_ptr<WorkspaceRegistry> registry_;
  105. WorkspaceSet workspace_;
  106. std::unique_ptr<ParserTransitionSystem> transition_system_;
  107. TermFrequencyMap label_map_;
  108. TermFrequencyMap word_map_;
  109. TermFrequencyMap bigram_map_;
  110. };
  111. TEST_F(SegmentationTransitionTest, GoldNextActionTest) {
  112. BinarySegmentState *segment_state = static_cast<BinarySegmentState *>(
  113. transition_system_->NewTransitionState(true));
  114. ParserState state(sentence_.get(), segment_state, &label_map_);
  115. // Do segmentation by following the gold actions.
  116. while (transition_system_->IsFinalState(state) == false) {
  117. ParserAction action = transition_system_->GetNextGoldAction(state);
  118. transition_system_->PerformActionWithoutHistory(action, &state);
  119. }
  120. // Test STARTs.
  121. CheckStarts(state, {5, 4, 3, 2, 0});
  122. // Test the annotated tokens.
  123. segment_state->AddParseToDocument(state, false, sentence_.get());
  124. ASSERT_EQ(sentence_->token_size(), 3);
  125. EXPECT_EQ(sentence_->token(0).word(), "因为");
  126. EXPECT_EQ(sentence_->token(1).word(), "有");
  127. EXPECT_EQ(sentence_->token(2).word(), "这样");
  128. // Test start/end annotation of each token.
  129. EXPECT_EQ(sentence_->token(0).start(), 0);
  130. EXPECT_EQ(sentence_->token(0).end(), 5);
  131. EXPECT_EQ(sentence_->token(1).start(), 7);
  132. EXPECT_EQ(sentence_->token(1).end(), 9);
  133. EXPECT_EQ(sentence_->token(2).start(), 11);
  134. EXPECT_EQ(sentence_->token(2).end(), 16);
  135. }
  136. TEST_F(SegmentationTransitionTest, DefaultActionTest) {
  137. BinarySegmentState *segment_state = static_cast<BinarySegmentState *>(
  138. transition_system_->NewTransitionState(true));
  139. ParserState state(sentence_.get(), segment_state, &label_map_);
  140. // Do segmentation, tagging and parsing by following the gold actions.
  141. while (transition_system_->IsFinalState(state) == false) {
  142. ParserAction action = transition_system_->GetDefaultAction(state);
  143. transition_system_->PerformActionWithoutHistory(action, &state);
  144. }
  145. // Every character should be START.
  146. CheckStarts(state, {6, 5, 4, 3, 2, 1, 0});
  147. // Every non-space character should be a word.
  148. segment_state->AddParseToDocument(state, false, sentence_.get());
  149. ASSERT_EQ(sentence_->token_size(), 5);
  150. EXPECT_EQ(sentence_->token(0).word(), "因");
  151. EXPECT_EQ(sentence_->token(1).word(), "为");
  152. EXPECT_EQ(sentence_->token(2).word(), "有");
  153. EXPECT_EQ(sentence_->token(3).word(), "这");
  154. EXPECT_EQ(sentence_->token(4).word(), "样");
  155. }
  156. TEST_F(SegmentationTransitionTest, BigramFeatureTest) {
  157. context_->SetParameter("char-bigram-min-freq", "2");
  158. const int unk_id = 2;
  159. const int outside_id = 3;
  160. // Prepare a parser state.
  161. auto state = std::unique_ptr<ParserState>(new ParserState(
  162. sentence_.get(),
  163. transition_system_->NewTransitionState(/*training_mode=*/false),
  164. &label_map_));
  165. PrepareFeature("input.char-bigram", state.get());
  166. EXPECT_EQ(0, ComputeFeature(*state));
  167. PrepareFeature("input(1).char-bigram", state.get());
  168. EXPECT_EQ(unk_id, ComputeFeature(*state));
  169. PrepareFeature("input(2).char-bigram", state.get());
  170. EXPECT_EQ(unk_id, ComputeFeature(*state));
  171. PrepareFeature("input(3).char-bigram", state.get());
  172. EXPECT_EQ(unk_id, ComputeFeature(*state));
  173. PrepareFeature("input(4).char-bigram", state.get());
  174. EXPECT_EQ(unk_id, ComputeFeature(*state));
  175. PrepareFeature("input(5).char-bigram", state.get());
  176. EXPECT_EQ(1, ComputeFeature(*state));
  177. PrepareFeature("input(6).char-bigram", state.get());
  178. EXPECT_EQ(outside_id, ComputeFeature(*state));
  179. }
  180. TEST_F(SegmentationTransitionTest, LastWordFeatureTest) {
  181. const int unk_id = word_map_.Size();
  182. const int outside_id = unk_id + 1;
  183. // Prepare a parser state.
  184. BinarySegmentState *segment_state = new BinarySegmentState();
  185. auto state = std::unique_ptr<ParserState>(new ParserState(
  186. sentence_.get(), segment_state, &label_map_));
  187. // Test initial state which contains no words.
  188. PrepareFeature("last-word(1,min-freq=2)", state.get());
  189. EXPECT_EQ(outside_id, ComputeFeature(*state));
  190. PrepareFeature("last-word(2,min-freq=2)", state.get());
  191. EXPECT_EQ(outside_id, ComputeFeature(*state));
  192. PrepareFeature("last-word(3,min-freq=2)", state.get());
  193. EXPECT_EQ(outside_id, ComputeFeature(*state));
  194. // Test when the state contains only one start.
  195. segment_state->AddStart(0, state.get());
  196. PrepareFeature("last-word(1,min-freq=2)", state.get());
  197. EXPECT_EQ(outside_id, ComputeFeature(*state));
  198. PrepareFeature("last-word(2,min-freq=2)", state.get());
  199. EXPECT_EQ(outside_id, ComputeFeature(*state));
  200. // Test when the state contains two starts, which forms a complete word and
  201. // the start of another new word.
  202. segment_state->AddStart(2, state.get());
  203. EXPECT_NE(word_map_.LookupIndex("因为", unk_id), unk_id);
  204. PrepareFeature("last-word(1)", state.get());
  205. EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
  206. // The last-word still points to outside.
  207. PrepareFeature("last-word(2,min-freq=2)", state.get());
  208. EXPECT_EQ(outside_id, ComputeFeature(*state));
  209. // Adding more starts that leads to the following words:
  210. // 因为 ‘ ’ 有 ‘ ’
  211. segment_state->AddStart(3, state.get());
  212. segment_state->AddStart(4, state.get());
  213. // Note 有 is pruned from the map since its frequency is less than 2.
  214. EXPECT_EQ(word_map_.LookupIndex("有", unk_id), unk_id);
  215. PrepareFeature("last-word(1,min-freq=2)", state.get());
  216. EXPECT_EQ(unk_id, ComputeFeature(*state));
  217. // Note that last-word(2) points to ' ' which is also a unk.
  218. PrepareFeature("last-word(2,min-freq=2)", state.get());
  219. EXPECT_EQ(unk_id, ComputeFeature(*state));
  220. PrepareFeature("last-word(3,min-freq=2)", state.get());
  221. EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
  222. // Adding two words: "这" and "样".
  223. segment_state->AddStart(5, state.get());
  224. segment_state->AddStart(6, state.get());
  225. PrepareFeature("last-word(1,min-freq=2)", state.get());
  226. EXPECT_EQ(word_map_.LookupIndex("这", unk_id), ComputeFeature(*state));
  227. }
  228. } // namespace syntaxnet