char_shift_transitions_test.cc 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/char_shift_transitions.h"
  13. #include <memory>
  14. #include "syntaxnet/parser_features.h"
  15. #include "syntaxnet/parser_state.h"
  16. #include "syntaxnet/parser_transitions.h"
  17. #include "syntaxnet/task_context.h"
  18. #include "syntaxnet/term_frequency_map.h"
  19. #include "syntaxnet/workspace.h"
  20. #include "tensorflow/core/platform/test.h"
  21. namespace syntaxnet {
  22. class CharShiftTransitionTest : public ::testing::Test {
  23. public:
  24. void SetUp() override {
  25. const char *str_sentence =
  26. "text: 'I saw a man with a กขค.' "
  27. "token { word: 'I' start: 0 end: 0 tag: 'PRP' category: 'PRON'"
  28. " head: 1 label: 'nsubj' break_level: NO_BREAK } "
  29. "token { word: 'saw' start: 2 end: 4 tag: 'VBD' category: 'VERB'"
  30. " label: 'ROOT' break_level: SPACE_BREAK } "
  31. "token { word: 'a' start: 6 end: 6 tag: 'DT' category: 'DET'"
  32. " head: 3 label: 'det' break_level: SPACE_BREAK } "
  33. "token { word: 'man' start: 8 end: 10 tag: 'NN' category: 'NOUN'"
  34. " head: 1 label: 'dobj' break_level: SPACE_BREAK } "
  35. "token { word: 'with' start: 12 end: 15 tag: 'IN' category: 'ADP'"
  36. " head: 1 label: 'prep' break_level: SPACE_BREAK } "
  37. "token { word: 'a' start: 17 end: 17 tag: 'DT' category: 'DET'"
  38. " head: 6 label: 'det' break_level: SPACE_BREAK } "
  39. "token { word: 'กขค' start: 19 end: 27 tag: 'NN' category: "
  40. "'NOUN'"
  41. " head: 4 label: 'pobj' break_level: SPACE_BREAK } "
  42. "token { word: '.' start: 28 end: 28 tag: '.' category: '.'"
  43. " head: 1 label: 'p' break_level: NO_BREAK } ";
  44. TextFormat::ParseFromString(str_sentence, &sentence_);
  45. // Populates char-map manually.
  46. char_map_.Increment(" ");
  47. char_map_.Increment("I");
  48. char_map_.Increment("s");
  49. char_map_.Increment("a");
  50. char_map_.Increment("w");
  51. char_map_.Increment("a");
  52. char_map_.Increment("m");
  53. char_map_.Increment("a");
  54. char_map_.Increment("n");
  55. char_map_.Increment("w");
  56. char_map_.Increment("i");
  57. char_map_.Increment("t");
  58. char_map_.Increment("h");
  59. char_map_.Increment("a");
  60. char_map_.Increment("ก");
  61. char_map_.Increment("ข");
  62. char_map_.Increment("ค");
  63. string char_map_filename =
  64. tensorflow::strings::StrCat(tensorflow::testing::TmpDir(), "char-map");
  65. char_map_.Save(char_map_filename);
  66. AddInputToContext("char-map", char_map_filename, "text", "");
  67. }
  68. void PrepareCharTransition(bool left_to_right) {
  69. context_.SetParameter("left-to-right", left_to_right ? "true" : "false");
  70. transition_system_.reset(ParserTransitionSystem::Create("char-shift-only"));
  71. transition_system_->Setup(&context_);
  72. // Parser state.
  73. state_.reset(new ParserState(
  74. &sentence_, transition_system_->NewTransitionState(true), &label_map_));
  75. char_state_ = reinterpret_cast<const CharShiftTransitionState *>(
  76. state_->transition_state());
  77. }
  78. void PrepareShiftTransition(bool left_to_right) {
  79. context_.SetParameter("left-to-right", left_to_right ? "true" : "false");
  80. transition_system_.reset(ParserTransitionSystem::Create("shift-only"));
  81. transition_system_->Setup(&context_);
  82. state_.reset(new ParserState(
  83. &sentence_, transition_system_->NewTransitionState(true), &label_map_));
  84. char_state_ = nullptr;
  85. }
  86. void PrepareExtractor(const string &feature_name) {
  87. extractor_.Parse(feature_name);
  88. extractor_.Setup(&context_);
  89. extractor_.Init(&context_);
  90. extractor_.RequestWorkspaces(&registry_);
  91. workspace_.Reset(registry_);
  92. extractor_.Preprocess(&workspace_, state_.get());
  93. }
  94. void AddInputToContext(const string &name, const string &file_pattern,
  95. const string &file_format,
  96. const string &record_format) {
  97. TaskInput *input = context_.GetInput(name);
  98. TaskInput::Part *part = input->add_part();
  99. part->set_file_pattern(file_pattern);
  100. part->set_file_format(file_format);
  101. part->set_record_format(record_format);
  102. }
  103. protected:
  104. string MultiFeatureString(const FeatureVector &result) {
  105. std::vector<string> values;
  106. for (int i = 0; i < result.size(); ++i) {
  107. values.push_back(result.type(i)->GetFeatureValueName(result.value(i)));
  108. }
  109. return utils::Join(values, ",");
  110. }
  111. Sentence sentence_;
  112. TaskContext context_;
  113. TaskInput *input_label_map_ = nullptr;
  114. TaskInput *input_tag_map_ = nullptr;
  115. TermFrequencyMap char_map_;
  116. TermFrequencyMap label_map_;
  117. std::unique_ptr<ParserTransitionSystem> transition_system_;
  118. std::unique_ptr<ParserState> state_;
  119. const CharShiftTransitionState *char_state_;
  120. ParserFeatureExtractor extractor_;
  121. WorkspaceRegistry registry_;
  122. WorkspaceSet workspace_;
  123. };
  124. TEST_F(CharShiftTransitionTest, LRShift) {
  125. PrepareCharTransition(true);
  126. int expected_next = 0;
  127. const std::vector<string> expected_actions = {
  128. "I:I", " :I", "s:saw", "a:saw", "w:saw", " :saw",
  129. "a:a", " :a", "m:man", "a:man", "n:man", " :man",
  130. "w:with", "i:with", "t:with", "h:with", " :with", "a:a",
  131. " :a", "ก:กขค", "ข:กขค", "ค:กขค", ".:."};
  132. EXPECT_EQ(char_state_->Next(), expected_next);
  133. while (!transition_system_->IsFinalState(*state_)) {
  134. ParserAction action = transition_system_->GetNextGoldAction(*state_);
  135. EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state_));
  136. EXPECT_EQ(transition_system_->ActionAsString(action, *state_),
  137. expected_actions[expected_next]);
  138. transition_system_->PerformActionWithoutHistory(action, state_.get());
  139. ++expected_next;
  140. EXPECT_EQ(char_state_->Next(), expected_next);
  141. }
  142. }
  143. TEST_F(CharShiftTransitionTest, RLShift) {
  144. PrepareCharTransition(false);
  145. int expected_next = 22;
  146. const std::vector<string> expected_actions = {
  147. "I:I", " :saw", "s:saw", "a:saw", "w:saw", " :a",
  148. "a:a", " :man", "m:man", "a:man", "n:man", " :with",
  149. "w:with", "i:with", "t:with", "h:with", " :a", "a:a",
  150. " :กขค", "ก:กขค", "ข:กขค", "ค:กขค", ".:."};
  151. EXPECT_EQ(char_state_->Next(), expected_next);
  152. while (!transition_system_->IsFinalState(*state_)) {
  153. ParserAction action = transition_system_->GetNextGoldAction(*state_);
  154. EXPECT_TRUE(transition_system_->IsAllowedAction(action, *state_));
  155. EXPECT_EQ(transition_system_->ActionAsString(action, *state_),
  156. expected_actions[expected_next]);
  157. transition_system_->PerformActionWithoutHistory(action, state_.get());
  158. --expected_next;
  159. EXPECT_EQ(char_state_->Next(), expected_next);
  160. }
  161. }
  162. TEST_F(CharShiftTransitionTest, TextChar) {
  163. PrepareCharTransition(true);
  164. PrepareExtractor(
  165. "char-input(-2).text-char "
  166. "char-input(-1).text-char "
  167. "char-input.text-char "
  168. "char-input(1).text-char "
  169. "char-input(2).text-char ");
  170. FeatureVector features;
  171. ParserAction action;
  172. // "I s"
  173. features.clear();
  174. extractor_.ExtractFeatures(workspace_, *state_, &features);
  175. EXPECT_EQ(MultiFeatureString(features), "I,<SPACE>,s");
  176. // "I sa"
  177. action = transition_system_->GetNextGoldAction(*state_);
  178. transition_system_->PerformActionWithoutHistory(action, state_.get());
  179. // " saw"
  180. action = transition_system_->GetNextGoldAction(*state_);
  181. transition_system_->PerformActionWithoutHistory(action, state_.get());
  182. // "."
  183. while (!transition_system_->IsFinalState(*state_)) {
  184. action = transition_system_->GetNextGoldAction(*state_);
  185. transition_system_->PerformActionWithoutHistory(action, state_.get());
  186. }
  187. features.clear();
  188. extractor_.ExtractFeatures(workspace_, *state_, &features);
  189. EXPECT_EQ(MultiFeatureString(features), "ค,<UNKNOWN>");
  190. }
  191. TEST_F(CharShiftTransitionTest, LastCharFocus) {
  192. PrepareShiftTransition(true);
  193. PrepareExtractor(
  194. "input(-1).last-char-focus "
  195. "input.last-char-focus "
  196. "input(1).last-char-focus "
  197. "input(2).last-char-focus ");
  198. FeatureVector features;
  199. ParserAction action;
  200. // "I saw a"
  201. features.clear();
  202. extractor_.ExtractFeatures(workspace_, *state_, &features);
  203. EXPECT_EQ(MultiFeatureString(features), ",0,4,6");
  204. // "I saw a man"
  205. action = transition_system_->GetNextGoldAction(*state_);
  206. transition_system_->PerformActionWithoutHistory(action, state_.get());
  207. features.clear();
  208. extractor_.ExtractFeatures(workspace_, *state_, &features);
  209. EXPECT_EQ(MultiFeatureString(features), "0,4,6,10");
  210. // "saw a man with"
  211. action = transition_system_->GetNextGoldAction(*state_);
  212. transition_system_->PerformActionWithoutHistory(action, state_.get());
  213. // "."
  214. while (!transition_system_->IsFinalState(*state_)) {
  215. action = transition_system_->GetNextGoldAction(*state_);
  216. transition_system_->PerformActionWithoutHistory(action, state_.get());
  217. }
  218. features.clear();
  219. extractor_.ExtractFeatures(workspace_, *state_, &features);
  220. EXPECT_EQ(MultiFeatureString(features), "22,,,");
  221. }
  222. } // namespace syntaxnet