parser_features.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/parser_features.h"
  13. #include <string>
  14. #include "syntaxnet/registry.h"
  15. #include "syntaxnet/sentence_features.h"
  16. #include "syntaxnet/workspace.h"
  17. namespace syntaxnet {
  18. // Registry for the parser feature functions.
  19. REGISTER_CLASS_REGISTRY("parser feature function", ParserFeatureFunction);
  20. // Registry for the parser state + token index feature functions.
  21. REGISTER_CLASS_REGISTRY("parser+index feature function",
  22. ParserIndexFeatureFunction);
  23. RootFeatureType::RootFeatureType(const string &name,
  24. const FeatureType &wrapped_type,
  25. int root_value)
  26. : FeatureType(name), wrapped_type_(wrapped_type), root_value_(root_value) {}
  27. string RootFeatureType::GetFeatureValueName(FeatureValue value) const {
  28. if (value == root_value_) return "<ROOT>";
  29. return wrapped_type_.GetFeatureValueName(value);
  30. }
  31. FeatureValue RootFeatureType::GetDomainSize() const {
  32. return wrapped_type_.GetDomainSize() + 1;
  33. }
  34. // Parser feature locator for accessing the remaining input tokens in the parser
  35. // state. It takes the offset relative to the current input token as argument.
  36. // Negative values represent tokens to the left, positive values to the right
  37. // and 0 (the default argument value) represents the current input token.
  38. class InputParserLocator : public ParserLocator<InputParserLocator> {
  39. public:
  40. // Gets the new focus.
  41. int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
  42. const int offset = argument();
  43. return state.Input(offset);
  44. }
  45. };
  46. REGISTER_PARSER_FEATURE_FUNCTION("input", InputParserLocator);
  47. // Parser feature locator for accessing the stack in the parser state. The
  48. // argument represents the position on the stack, 0 being the top of the stack.
  49. class StackParserLocator : public ParserLocator<StackParserLocator> {
  50. public:
  51. // Gets the new focus.
  52. int GetFocus(const WorkspaceSet &workspaces, const ParserState &state) const {
  53. const int position = argument();
  54. return state.Stack(position);
  55. }
  56. };
  57. REGISTER_PARSER_FEATURE_FUNCTION("stack", StackParserLocator);
  58. // Parser feature locator for locating the head of the focus token. The argument
  59. // specifies the number of times the head function is applied. Please note that
  60. // this operates on partially built dependency trees.
  61. class HeadFeatureLocator : public ParserIndexLocator<HeadFeatureLocator> {
  62. public:
  63. // Updates the current focus to a new location. If the initial focus is
  64. // outside the range of the sentence, returns -2.
  65. void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
  66. int *focus) const {
  67. if (*focus < -1 || *focus >= state.sentence().token_size()) {
  68. *focus = -2;
  69. return;
  70. }
  71. const int levels = argument();
  72. *focus = state.Parent(*focus, levels);
  73. }
  74. };
  75. REGISTER_PARSER_IDX_FEATURE_FUNCTION("head", HeadFeatureLocator);
  76. // Parser feature locator for locating children of the focus token. The argument
  77. // specifies the number of times the leftmost (when the argument is < 0) or
  78. // rightmost (when the argument > 0) child function is applied. Please note that
  79. // this operates on partially built dependency trees.
  80. class ChildFeatureLocator : public ParserIndexLocator<ChildFeatureLocator> {
  81. public:
  82. // Updates the current focus to a new location. If the initial focus is
  83. // outside the range of the sentence, returns -2.
  84. void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
  85. int *focus) const {
  86. if (*focus < -1 || *focus >= state.sentence().token_size()) {
  87. *focus = -2;
  88. return;
  89. }
  90. const int levels = argument();
  91. if (levels < 0) {
  92. *focus = state.LeftmostChild(*focus, -levels);
  93. } else {
  94. *focus = state.RightmostChild(*focus, levels);
  95. }
  96. }
  97. };
  98. REGISTER_PARSER_IDX_FEATURE_FUNCTION("child", ChildFeatureLocator);
  99. // Parser feature locator for locating siblings of the focus token. The argument
  100. // specifies the sibling position relative to the focus token: a negative value
  101. // triggers a search to the left, while a positive value one to the right.
  102. // Please note that this operates on partially built dependency trees.
  103. class SiblingFeatureLocator
  104. : public ParserIndexLocator<SiblingFeatureLocator> {
  105. public:
  106. // Updates the current focus to a new location. If the initial focus is
  107. // outside the range of the sentence, returns -2.
  108. void UpdateArgs(const WorkspaceSet &workspaces, const ParserState &state,
  109. int *focus) const {
  110. if (*focus < -1 || *focus >= state.sentence().token_size()) {
  111. *focus = -2;
  112. return;
  113. }
  114. const int position = argument();
  115. if (position < 0) {
  116. *focus = state.LeftSibling(*focus, -position);
  117. } else {
  118. *focus = state.RightSibling(*focus, position);
  119. }
  120. }
  121. };
  122. REGISTER_PARSER_IDX_FEATURE_FUNCTION("sibling", SiblingFeatureLocator);
  123. // Feature function for computing the label from focus token. Note that this
  124. // does not use the precomputed values, since we get the labels from the stack;
  125. // the reason it utilizes sentence_features::Label is to obtain the label map.
  126. class LabelFeatureFunction : public BasicParserSentenceFeatureFunction<Label> {
  127. public:
  128. // Computes the label of the relation between the focus token and its parent.
  129. // Valid focus values range from -1 to sentence->size() - 1, inclusively.
  130. FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
  131. int focus, const FeatureVector *result) const override {
  132. if (focus == -1) return RootValue();
  133. if (focus < -1 || focus >= state.sentence().token_size()) {
  134. return feature_.NumValues();
  135. }
  136. const int label = state.Label(focus);
  137. return label == -1 ? RootValue() : label;
  138. }
  139. };
  140. REGISTER_PARSER_IDX_FEATURE_FUNCTION("label", LabelFeatureFunction);
  141. typedef BasicParserSentenceFeatureFunction<Word> WordFeatureFunction;
  142. REGISTER_PARSER_IDX_FEATURE_FUNCTION("word", WordFeatureFunction);
  143. typedef BasicParserSentenceFeatureFunction<Char> CharFeatureFunction;
  144. REGISTER_PARSER_IDX_FEATURE_FUNCTION("char", CharFeatureFunction);
  145. typedef BasicParserSentenceFeatureFunction<Tag> TagFeatureFunction;
  146. REGISTER_PARSER_IDX_FEATURE_FUNCTION("tag", TagFeatureFunction);
  147. typedef BasicParserSentenceFeatureFunction<Digit> DigitFeatureFunction;
  148. REGISTER_PARSER_IDX_FEATURE_FUNCTION("digit", DigitFeatureFunction);
  149. typedef BasicParserSentenceFeatureFunction<Hyphen> HyphenFeatureFunction;
  150. REGISTER_PARSER_IDX_FEATURE_FUNCTION("hyphen", HyphenFeatureFunction);
  151. typedef BasicParserSentenceFeatureFunction<Capitalization>
  152. CapitalizationFeatureFunction;
  153. REGISTER_PARSER_IDX_FEATURE_FUNCTION("capitalization",
  154. CapitalizationFeatureFunction);
  155. typedef BasicParserSentenceFeatureFunction<PunctuationAmount>
  156. PunctuationAmountFeatureFunction;
  157. REGISTER_PARSER_IDX_FEATURE_FUNCTION("punctuation-amount",
  158. PunctuationAmountFeatureFunction);
  159. typedef BasicParserSentenceFeatureFunction<Quote>
  160. QuoteFeatureFunction;
  161. REGISTER_PARSER_IDX_FEATURE_FUNCTION("quote",
  162. QuoteFeatureFunction);
  163. typedef BasicParserSentenceFeatureFunction<PrefixFeature> PrefixFeatureFunction;
  164. REGISTER_PARSER_IDX_FEATURE_FUNCTION("prefix", PrefixFeatureFunction);
  165. typedef BasicParserSentenceFeatureFunction<SuffixFeature> SuffixFeatureFunction;
  166. REGISTER_PARSER_IDX_FEATURE_FUNCTION("suffix", SuffixFeatureFunction);
  167. // Parser feature function that can use nested sentence feature functions for
  168. // feature extraction.
  169. class ParserTokenFeatureFunction : public NestedFeatureFunction<
  170. FeatureFunction<Sentence, int>, ParserState, int> {
  171. public:
  172. void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
  173. for (auto *function : nested_) {
  174. function->Preprocess(workspaces, state->mutable_sentence());
  175. }
  176. }
  177. void Evaluate(const WorkspaceSet &workspaces, const ParserState &state,
  178. int focus, FeatureVector *result) const override {
  179. for (auto *function : nested_) {
  180. function->Evaluate(workspaces, state.sentence(), focus, result);
  181. }
  182. }
  183. // Returns the first nested feature's computed value.
  184. FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
  185. int focus, const FeatureVector *result) const override {
  186. if (nested_.empty()) return -1;
  187. return nested_[0]->Compute(workspaces, state.sentence(), focus, result);
  188. }
  189. };
  190. REGISTER_PARSER_IDX_FEATURE_FUNCTION("token",
  191. ParserTokenFeatureFunction);
  192. } // namespace syntaxnet