parser_features.h 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Sentence-based features for the transition parser.
  13. #ifndef SYNTAXNET_PARSER_FEATURES_H_
  14. #define SYNTAXNET_PARSER_FEATURES_H_
  15. #include <string>
  16. #include "syntaxnet/feature_extractor.h"
  17. #include "syntaxnet/feature_types.h"
  18. #include "syntaxnet/parser_state.h"
  19. #include "syntaxnet/task_context.h"
  20. #include "syntaxnet/workspace.h"
  21. namespace syntaxnet {
  22. // A union used to represent discrete and continuous feature values.
  23. union FloatFeatureValue {
  24. public:
  25. explicit FloatFeatureValue(FeatureValue v) : discrete_value(v) {}
  26. FloatFeatureValue(uint32 i, float w) : id(i), weight(w) {}
  27. FeatureValue discrete_value;
  28. struct {
  29. uint32 id;
  30. float weight;
  31. };
  32. };
  33. typedef FeatureFunction<ParserState> ParserFeatureFunction;
  34. // Feature function for the transition parser based on a parser state object and
  35. // a token index. This typically extracts information from a given token.
  36. typedef FeatureFunction<ParserState, int> ParserIndexFeatureFunction;
  37. // Utilities to register the two types of parser features.
  38. #define REGISTER_PARSER_FEATURE_FUNCTION(name, component) \
  39. REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserFeatureFunction, name, component)
  40. #define REGISTER_PARSER_IDX_FEATURE_FUNCTION(name, component) \
  41. REGISTER_SYNTAXNET_FEATURE_FUNCTION(ParserIndexFeatureFunction, name, \
  42. component)
  43. // Alias for locator type that takes a parser state, and produces a focus
  44. // integer that can be used on nested ParserIndexFeature objects.
  45. template<class DER>
  46. using ParserLocator = FeatureAddFocusLocator<DER, ParserState, int>;
  47. // Alias for Locator type features that take (ParserState, int) signatures and
  48. // call other ParserIndexFeatures.
  49. template<class DER>
  50. using ParserIndexLocator = FeatureLocator<DER, ParserState, int>;
  51. // Feature extractor for the transition parser based on a parser state object.
  52. typedef FeatureExtractor<ParserState> ParserFeatureExtractor;
  53. // A simple wrapper FeatureType that adds a special "<ROOT>" type.
  54. class RootFeatureType : public FeatureType {
  55. public:
  56. // Creates a RootFeatureType that wraps a given type and adds the special
  57. // "<ROOT>" value in root_value.
  58. RootFeatureType(const string &name, const FeatureType &wrapped_type,
  59. int root_value);
  60. // Returns the feature value name, but with the special "<ROOT>" value.
  61. string GetFeatureValueName(FeatureValue value) const override;
  62. // Returns the original number of features plus one for the "<ROOT>" value.
  63. FeatureValue GetDomainSize() const override;
  64. private:
  65. // A wrapped type that handles everything else besides "<ROOT>".
  66. const FeatureType &wrapped_type_;
  67. // The reserved root value.
  68. int root_value_;
  69. };
  70. // Simple feature function that wraps a Sentence based feature
  71. // function. It adds a "<ROOT>" feature value that is triggered whenever the
  72. // focus is the special root token. This class is sub-classed based on the
  73. // extracted arguments of the nested function.
  74. template<class F>
  75. class ParserSentenceFeatureFunction : public ParserIndexFeatureFunction {
  76. public:
  77. // Instantiates and sets up the nested feature.
  78. void Setup(TaskContext *context) override {
  79. this->feature_.set_descriptor(this->descriptor());
  80. this->feature_.set_prefix(this->prefix());
  81. this->feature_.set_extractor(this->extractor());
  82. feature_.Setup(context);
  83. }
  84. // Initializes the nested feature and sets feature type.
  85. void Init(TaskContext *context) override {
  86. feature_.Init(context);
  87. num_base_values_ = feature_.GetFeatureType()->GetDomainSize();
  88. set_feature_type(new RootFeatureType(
  89. name(), *feature_.GetFeatureType(), RootValue()));
  90. }
  91. // Passes workspace requests and preprocessing to the nested feature.
  92. void RequestWorkspaces(WorkspaceRegistry *registry) override {
  93. feature_.RequestWorkspaces(registry);
  94. }
  95. void Preprocess(WorkspaceSet *workspaces, ParserState *state) const override {
  96. feature_.Preprocess(workspaces, state->mutable_sentence());
  97. }
  98. protected:
  99. // Returns the special value to represent a root token.
  100. FeatureValue RootValue() const { return num_base_values_; }
  101. // Store the number of base values from the wrapped function so compute the
  102. // root value.
  103. int num_base_values_;
  104. // The wrapped feature.
  105. F feature_;
  106. };
  107. // Specialization of ParserSentenceFeatureFunction that calls the nested feature
  108. // with (Sentence, int) arguments based on the current integer focus.
  109. template<class F>
  110. class BasicParserSentenceFeatureFunction :
  111. public ParserSentenceFeatureFunction<F> {
  112. public:
  113. FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
  114. int focus, const FeatureVector *result) const override {
  115. if (focus == -1) return this->RootValue();
  116. return this->feature_.Compute(workspaces, state.sentence(), focus, result);
  117. }
  118. };
  119. } // namespace syntaxnet
  120. #endif // SYNTAXNET_PARSER_FEATURES_H_