feature_extractor.h 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Generic feature extractor for extracting features from objects. The feature
  13. // extractor can be used for extracting features from any object. The feature
  14. // extractor and feature function classes are template classes that have to
  15. // be instantiated for extracting feature from a specific object type.
  16. //
  17. // A feature extractor consists of a hierarchy of feature functions. Each
  18. // feature function extracts one or more feature type and value pairs from the
  19. // object.
  20. //
  21. // The feature extractor has a modular design where new feature functions can be
  22. // registered as components. The feature extractor is initialized from a
  23. // descriptor represented by a protocol buffer. The feature extractor can also
  24. // be initialized from a text-based source specification of the feature
  25. // extractor. Feature specification parsers can be added as components. By
  26. // default the feature extractor can be read from an ASCII protocol buffer or in
  27. // a simple feature modeling language (fml).
  28. // A feature function is invoked with a focus. Nested feature function can be
  29. // invoked with another focus determined by the parent feature function.
  30. #ifndef SYNTAXNET_FEATURE_EXTRACTOR_H_
  31. #define SYNTAXNET_FEATURE_EXTRACTOR_H_
  32. #include <memory>
  33. #include <string>
  34. #include <vector>
  35. #include "syntaxnet/feature_extractor.pb.h"
  36. #include "syntaxnet/feature_types.h"
  37. #include "syntaxnet/proto_io.h"
  38. #include "syntaxnet/registry.h"
  39. #include "syntaxnet/task_context.h"
  40. #include "syntaxnet/utils.h"
  41. #include "syntaxnet/workspace.h"
  42. #include "tensorflow/core/lib/core/status.h"
  43. #include "tensorflow/core/lib/core/stringpiece.h"
  44. #include "tensorflow/core/lib/io/record_reader.h"
  45. #include "tensorflow/core/lib/io/record_writer.h"
  46. #include "tensorflow/core/lib/strings/strcat.h"
  47. #include "tensorflow/core/platform/env.h"
  48. namespace syntaxnet {
  49. // Use the same type for feature values as is used for predicated.
  50. typedef int64 Predicate;
  51. typedef Predicate FeatureValue;
  52. // Output feature model in FML format.
  53. void ToFMLFunction(const FeatureFunctionDescriptor &function, string *output);
  54. void ToFML(const FeatureFunctionDescriptor &function, string *output);
  55. // A feature vector contains feature type and value pairs.
  56. class FeatureVector {
  57. public:
  58. FeatureVector() {}
  59. // Adds feature type and value pair to feature vector.
  60. void add(FeatureType *type, FeatureValue value) {
  61. features_.emplace_back(type, value);
  62. }
  63. // Removes all elements from the feature vector.
  64. void clear() { features_.clear(); }
  65. // Returns the number of elements in the feature vector.
  66. int size() const { return features_.size(); }
  67. // Reserves space in the underlying feature vector.
  68. void reserve(int n) { features_.reserve(n); }
  69. // Returns feature type for an element in the feature vector.
  70. FeatureType *type(int index) const { return features_[index].type; }
  71. // Returns feature value for an element in the feature vector.
  72. FeatureValue value(int index) const { return features_[index].value; }
  73. private:
  74. // Structure for holding feature type and value pairs.
  75. struct Element {
  76. Element() : type(nullptr), value(-1) {}
  77. Element(FeatureType *t, FeatureValue v) : type(t), value(v) {}
  78. FeatureType *type;
  79. FeatureValue value;
  80. };
  81. // Array for storing feature vector elements.
  82. std::vector<Element> features_;
  83. TF_DISALLOW_COPY_AND_ASSIGN(FeatureVector);
  84. };
  85. // The generic feature extractor is the type-independent part of a feature
  86. // extractor. This holds the descriptor for the feature extractor and the
  87. // collection of feature types used in the feature extractor. The feature
  88. // types are not available until FeatureExtractor<>::Init() has been called.
  89. class GenericFeatureExtractor {
  90. public:
  91. GenericFeatureExtractor();
  92. virtual ~GenericFeatureExtractor();
  93. // Initializes the feature extractor from a source representation of the
  94. // feature extractor. The first line is used for determining the feature
  95. // specification language. If the first line starts with #! followed by a name
  96. // then this name is used for instantiating a feature specification parser
  97. // with that name. If the language cannot be detected this way it falls back
  98. // to using the default language supplied.
  99. void Parse(const string &source);
  100. // Returns the feature extractor descriptor.
  101. const FeatureExtractorDescriptor &descriptor() const { return descriptor_; }
  102. FeatureExtractorDescriptor *mutable_descriptor() { return &descriptor_; }
  103. // Returns the number of feature types in the feature extractor. Invalid
  104. // before Init() has been called.
  105. int feature_types() const { return feature_types_.size(); }
  106. // Returns all feature types names used by the extractor. The names are
  107. // added to the types_names array. Invalid before Init() has been called.
  108. void GetFeatureTypeNames(std::vector<string> *type_names) const;
  109. // Returns a feature type used in the extractor. Invalid before Init() has
  110. // been called.
  111. const FeatureType *feature_type(int index) const {
  112. return feature_types_[index];
  113. }
  114. // Returns the feature domain size of this feature extractor.
  115. // NOTE: The way that domain size is calculated is, for some, unintuitive. It
  116. // is the largest domain size of any feature type.
  117. FeatureValue GetDomainSize() const;
  118. protected:
  119. // Initializes the feature types used by the extractor. Called from
  120. // FeatureExtractor<>::Init().
  121. void InitializeFeatureTypes();
  122. private:
  123. // Initializes the top-level feature functions.
  124. virtual void InitializeFeatureFunctions() = 0;
  125. // Returns all feature types used by the extractor. The feature types are
  126. // added to the result array.
  127. virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const = 0;
  128. // Descriptor for the feature extractor. This is a protocol buffer that
  129. // contains all the information about the feature extractor. The feature
  130. // functions are initialized from the information in the descriptor.
  131. FeatureExtractorDescriptor descriptor_;
  132. // All feature types used by the feature extractor. The collection of all the
  133. // feature types describes the feature space of the feature set produced by
  134. // the feature extractor. Not owned.
  135. std::vector<FeatureType *> feature_types_;
  136. };
  137. // The generic feature function is the type-independent part of a feature
  138. // function. Each feature function is associated with the descriptor that it is
  139. // instantiated from. The feature types associated with this feature function
  140. // will be established by the time FeatureExtractor<>::Init() completes.
  141. class GenericFeatureFunction {
  142. public:
  143. // A feature value that represents the absence of a value.
  144. static constexpr FeatureValue kNone = -1;
  145. GenericFeatureFunction();
  146. virtual ~GenericFeatureFunction();
  147. // Sets up the feature function. NB: FeatureTypes of nested functions are not
  148. // guaranteed to be available until Init().
  149. virtual void Setup(TaskContext *context) {}
  150. // Initializes the feature function. NB: The FeatureType of this function must
  151. // be established when this method completes.
  152. virtual void Init(TaskContext *context) {}
  153. // Requests workspaces from a registry to obtain indices into a WorkspaceSet
  154. // for any Workspace objects used by this feature function. NB: This will be
  155. // called after Init(), so it can depend on resources and arguments.
  156. virtual void RequestWorkspaces(WorkspaceRegistry *registry) {}
  157. // Appends the feature types produced by the feature function to types. The
  158. // default implementation appends feature_type(), if non-null. Invalid
  159. // before Init() has been called.
  160. virtual void GetFeatureTypes(std::vector<FeatureType *> *types) const;
  161. // Returns the feature type for feature produced by this feature function. If
  162. // the feature function produces features of different types this returns
  163. // null. Invalid before Init() has been called.
  164. virtual FeatureType *GetFeatureType() const;
  165. // Returns the name of the registry used for creating the feature function.
  166. // This can be used for checking if two feature functions are of the same
  167. // kind.
  168. virtual const char *RegistryName() const = 0;
  169. // Returns the value of a named parameter in the feature functions descriptor.
  170. // If the named parameter is not found the global parameters are searched.
  171. string GetParameter(const string &name) const;
  172. int GetIntParameter(const string &name, int default_value) const;
  173. bool GetBoolParameter(const string &name, bool default_value) const;
  174. // Returns the FML function description for the feature function, i.e. the
  175. // name and parameters without the nested features.
  176. string FunctionName() const {
  177. string output;
  178. ToFMLFunction(*descriptor_, &output);
  179. return output;
  180. }
  181. // Returns the prefix for nested feature functions. This is the prefix of this
  182. // feature function concatenated with the feature function name.
  183. string SubPrefix() const {
  184. return prefix_.empty() ? FunctionName() : prefix_ + "." + FunctionName();
  185. }
  186. // Returns/sets the feature extractor this function belongs to.
  187. GenericFeatureExtractor *extractor() const { return extractor_; }
  188. void set_extractor(GenericFeatureExtractor *extractor) {
  189. extractor_ = extractor;
  190. }
  191. // Returns/sets the feature function descriptor.
  192. FeatureFunctionDescriptor *descriptor() const { return descriptor_; }
  193. void set_descriptor(FeatureFunctionDescriptor *descriptor) {
  194. descriptor_ = descriptor;
  195. }
  196. // Returns a descriptive name for the feature function. The name is taken from
  197. // the descriptor for the feature function. If the name is empty or the
  198. // feature function is a variable the name is the FML representation of the
  199. // feature, including the prefix.
  200. string name() const {
  201. string output;
  202. if (descriptor_->name().empty()) {
  203. if (!prefix_.empty()) {
  204. output.append(prefix_);
  205. output.append(".");
  206. }
  207. ToFML(*descriptor_, &output);
  208. } else {
  209. output = descriptor_->name();
  210. }
  211. tensorflow::StringPiece stripped(output);
  212. utils::RemoveWhitespaceContext(&stripped);
  213. return stripped.ToString();
  214. }
  215. // Returns the argument from the feature function descriptor. It defaults to
  216. // 0 if the argument has not been specified.
  217. int argument() const {
  218. return descriptor_->has_argument() ? descriptor_->argument() : 0;
  219. }
  220. // Returns/sets/clears function name prefix.
  221. const string &prefix() const { return prefix_; }
  222. void set_prefix(const string &prefix) { prefix_ = prefix; }
  223. protected:
  224. // Returns the feature type for single-type feature functions.
  225. FeatureType *feature_type() const { return feature_type_; }
  226. // Sets the feature type for single-type feature functions. This takes
  227. // ownership of feature_type. Can only be called once.
  228. void set_feature_type(FeatureType *feature_type) {
  229. CHECK(feature_type_ == nullptr);
  230. feature_type_ = feature_type;
  231. }
  232. private:
  233. // Feature extractor this feature function belongs to. Not owned.
  234. GenericFeatureExtractor *extractor_ = nullptr;
  235. // Descriptor for feature function. Not owned.
  236. FeatureFunctionDescriptor *descriptor_ = nullptr;
  237. // Feature type for features produced by this feature function. If the
  238. // feature function produces features of multiple feature types this is null
  239. // and the feature function must return it's feature types in
  240. // GetFeatureTypes(). Owned.
  241. FeatureType *feature_type_ = nullptr;
  242. // Prefix used for sub-feature types of this function.
  243. string prefix_;
  244. };
  245. // Feature function that can extract features from an object. Templated on
  246. // two type arguments:
  247. //
  248. // OBJ: The "object" from which features are extracted; e.g., a sentence. This
  249. // should be a plain type, rather than a reference or pointer.
  250. //
  251. // ARGS: A set of 0 or more types that are used to "index" into some part of the
  252. // object that should be extracted, e.g. an int token index for a sentence
  253. // object. This should not be a reference type.
  254. template<class OBJ, class ...ARGS>
  255. class FeatureFunction
  256. : public GenericFeatureFunction,
  257. public RegisterableClass< FeatureFunction<OBJ, ARGS...> > {
  258. public:
  259. using Self = FeatureFunction<OBJ, ARGS...>;
  260. // Preprocesses the object. This will be called prior to calling Evaluate()
  261. // or Compute() on that object.
  262. virtual void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {}
  263. // Appends features computed from the object and focus to the result. The
  264. // default implementation delegates to Compute(), adding a single value if
  265. // available. Multi-valued feature functions must override this method.
  266. virtual void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
  267. ARGS... args, FeatureVector *result) const {
  268. FeatureValue value = Compute(workspaces, object, args..., result);
  269. if (value != kNone) result->add(feature_type(), value);
  270. }
  271. // Returns a feature value computed from the object and focus, or kNone if no
  272. // value is computed. Single-valued feature functions only need to override
  273. // this method.
  274. virtual FeatureValue Compute(const WorkspaceSet &workspaces,
  275. const OBJ &object,
  276. ARGS... args,
  277. const FeatureVector *fv) const {
  278. return kNone;
  279. }
  280. // Instantiates a new feature function in a feature extractor from a feature
  281. // descriptor.
  282. static Self *Instantiate(GenericFeatureExtractor *extractor,
  283. FeatureFunctionDescriptor *fd,
  284. const string &prefix) {
  285. Self *f = Self::Create(fd->type());
  286. f->set_extractor(extractor);
  287. f->set_descriptor(fd);
  288. f->set_prefix(prefix);
  289. return f;
  290. }
  291. // Returns the name of the registry for the feature function.
  292. const char *RegistryName() const override {
  293. return Self::registry()->name;
  294. }
  295. private:
  296. // Special feature function class for resolving variable references. The type
  297. // of the feature function is used for resolving the variable reference. When
  298. // evaluated it will either get the feature value(s) from the variable portion
  299. // of the feature vector, if present, or otherwise it will call the referenced
  300. // feature extractor function directly to extract the feature(s).
  301. class Reference;
  302. };
  303. // Base class for features with nested feature functions. The nested functions
  304. // are of type NES, which may be different from the type of the parent function.
  305. // NB: NestedFeatureFunction will ensure that all initialization of nested
  306. // functions takes place during Setup() and Init() -- after the nested features
  307. // are initialized, the parent feature is initialized via SetupNested() and
  308. // InitNested(). Alternatively, a derived classes that overrides Setup() and
  309. // Init() directly should call Parent::Setup(), Parent::Init(), etc. first.
  310. //
  311. // Note: NestedFeatureFunction cannot know how to call Preprocess, Evaluate, or
  312. // Compute, since the nested functions may be of a different type.
  313. template<class NES, class OBJ, class ...ARGS>
  314. class NestedFeatureFunction : public FeatureFunction<OBJ, ARGS...> {
  315. public:
  316. using Parent = NestedFeatureFunction<NES, OBJ, ARGS...>;
  317. // Clean up nested functions.
  318. ~NestedFeatureFunction() override { utils::STLDeleteElements(&nested_); }
  319. // By default, just appends the nested feature types.
  320. void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
  321. CHECK(!this->nested().empty())
  322. << "Nested features require nested features to be defined.";
  323. for (auto *function : nested_) function->GetFeatureTypes(types);
  324. }
  325. // Sets up the nested features.
  326. void Setup(TaskContext *context) override {
  327. CreateNested(this->extractor(), this->descriptor(), &nested_,
  328. this->SubPrefix());
  329. for (auto *function : nested_) function->Setup(context);
  330. SetupNested(context);
  331. }
  332. // Sets up this NestedFeatureFunction specifically.
  333. virtual void SetupNested(TaskContext *context) {}
  334. // Initializes the nested features.
  335. void Init(TaskContext *context) override {
  336. for (auto *function : nested_) function->Init(context);
  337. InitNested(context);
  338. }
  339. // Initializes this NestedFeatureFunction specifically.
  340. virtual void InitNested(TaskContext *context) {}
  341. // Gets all the workspaces needed for the nested functions.
  342. void RequestWorkspaces(WorkspaceRegistry *registry) override {
  343. for (auto *function : nested_) function->RequestWorkspaces(registry);
  344. }
  345. // Returns the list of nested feature functions.
  346. const std::vector<NES *> &nested() const { return nested_; }
  347. // Instantiates nested feature functions for a feature function. Creates and
  348. // initializes one feature function for each sub-descriptor in the feature
  349. // descriptor.
  350. static void CreateNested(GenericFeatureExtractor *extractor,
  351. FeatureFunctionDescriptor *fd,
  352. std::vector<NES *> *functions,
  353. const string &prefix) {
  354. for (int i = 0; i < fd->feature_size(); ++i) {
  355. FeatureFunctionDescriptor *sub = fd->mutable_feature(i);
  356. NES *f = NES::Instantiate(extractor, sub, prefix);
  357. functions->push_back(f);
  358. }
  359. }
  360. protected:
  361. // The nested feature functions, if any, in order of declaration in the
  362. // feature descriptor. Owned.
  363. std::vector<NES *> nested_;
  364. };
  365. // Base class for a nested feature function that takes nested features with the
  366. // same signature as these features, i.e. a meta feature. For this class, we can
  367. // provide preprocessing of the nested features.
  368. template<class OBJ, class ...ARGS>
  369. class MetaFeatureFunction : public NestedFeatureFunction<
  370. FeatureFunction<OBJ, ARGS...>, OBJ, ARGS...> {
  371. public:
  372. // Preprocesses using the nested features.
  373. void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
  374. for (auto *function : this->nested_) {
  375. function->Preprocess(workspaces, object);
  376. }
  377. }
  378. };
  379. // Template for a special type of locator: The locator of type
  380. // FeatureFunction<OBJ, ARGS...> calls nested functions of type
  381. // FeatureFunction<OBJ, IDX, ARGS...>, where the derived class DER is
  382. // responsible for translating by providing the following:
  383. //
  384. // // Gets the new additional focus.
  385. // IDX GetFocus(const WorkspaceSet &workspaces, const OBJ &object);
  386. //
  387. // This is useful to e.g. add a token focus to a parser state based on some
  388. // desired property of that state.
  389. template<class DER, class OBJ, class IDX, class ...ARGS>
  390. class FeatureAddFocusLocator : public NestedFeatureFunction<
  391. FeatureFunction<OBJ, IDX, ARGS...>, OBJ, ARGS...> {
  392. public:
  393. void Preprocess(WorkspaceSet *workspaces, OBJ *object) const override {
  394. for (auto *function : this->nested_) {
  395. function->Preprocess(workspaces, object);
  396. }
  397. }
  398. void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
  399. ARGS... args, FeatureVector *result) const override {
  400. IDX focus = static_cast<const DER *>(this)->GetFocus(
  401. workspaces, object, args...);
  402. for (auto *function : this->nested()) {
  403. function->Evaluate(workspaces, object, focus, args..., result);
  404. }
  405. }
  406. // Returns the first nested feature's computed value.
  407. FeatureValue Compute(const WorkspaceSet &workspaces,
  408. const OBJ &object,
  409. ARGS... args,
  410. const FeatureVector *result) const override {
  411. IDX focus = static_cast<const DER *>(this)->GetFocus(
  412. workspaces, object, args...);
  413. return this->nested()[0]->Compute(
  414. workspaces, object, focus, args..., result);
  415. }
  416. };
  417. // CRTP feature locator class. This is a meta feature that modifies ARGS and
  418. // then calls the nested feature functions with the modified ARGS. Note that in
  419. // order for this template to work correctly, all of ARGS must be types for
  420. // which the reference operator & can be interpreted as a pointer to the
  421. // argument. The derived class DER must implement the UpdateFocus method which
  422. // takes pointers to the ARGS arguments:
  423. //
  424. // // Updates the current arguments.
  425. // void UpdateArgs(const OBJ &object, ARGS *...args) const;
  426. template<class DER, class OBJ, class ...ARGS>
  427. class FeatureLocator : public MetaFeatureFunction<OBJ, ARGS...> {
  428. public:
  429. // Feature locators have an additional check that there is no intrinsic type.
  430. void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
  431. CHECK(this->feature_type() == nullptr)
  432. << "FeatureLocators should not have an intrinsic type.";
  433. MetaFeatureFunction<OBJ, ARGS...>::GetFeatureTypes(types);
  434. }
  435. // Evaluates the locator.
  436. void Evaluate(const WorkspaceSet &workspaces, const OBJ &object,
  437. ARGS... args, FeatureVector *result) const override {
  438. static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
  439. for (auto *function : this->nested()) {
  440. function->Evaluate(workspaces, object, args..., result);
  441. }
  442. }
  443. // Returns the first nested feature's computed value.
  444. FeatureValue Compute(const WorkspaceSet &workspaces, const OBJ &object,
  445. ARGS... args,
  446. const FeatureVector *result) const override {
  447. static_cast<const DER *>(this)->UpdateArgs(workspaces, object, &args...);
  448. return this->nested()[0]->Compute(workspaces, object, args..., result);
  449. }
  450. };
  451. // Feature extractor for extracting features from objects of a certain class.
  452. // Template type parameters are as defined for FeatureFunction.
  453. template<class OBJ, class ...ARGS>
  454. class FeatureExtractor : public GenericFeatureExtractor {
  455. public:
  456. // Feature function type for top-level functions in the feature extractor.
  457. typedef FeatureFunction<OBJ, ARGS...> Function;
  458. typedef FeatureExtractor<OBJ, ARGS...> Self;
  459. // Feature locator type for the feature extractor.
  460. template<class DER>
  461. using Locator = FeatureLocator<DER, OBJ, ARGS...>;
  462. // Initializes feature extractor.
  463. FeatureExtractor() {}
  464. ~FeatureExtractor() override { utils::STLDeleteElements(&functions_); }
  465. // Sets up the feature extractor. Note that only top-level functions exist
  466. // until Setup() is called. This does not take ownership over the context,
  467. // which must outlive this.
  468. void Setup(TaskContext *context) {
  469. for (Function *function : functions_) function->Setup(context);
  470. }
  471. // Initializes the feature extractor. Must be called after Setup(). This
  472. // does not take ownership over the context, which must outlive this.
  473. void Init(TaskContext *context) {
  474. for (Function *function : functions_) function->Init(context);
  475. this->InitializeFeatureTypes();
  476. }
  477. // Requests workspaces from the registry. Must be called after Init(), and
  478. // before Preprocess(). Does not take ownership over registry. This should be
  479. // the same registry used to initialize the WorkspaceSet used in Preprocess()
  480. // and ExtractFeatures(). NB: This is a different ordering from that used in
  481. // SentenceFeatureRepresentation style feature computation.
  482. void RequestWorkspaces(WorkspaceRegistry *registry) {
  483. for (auto *function : functions_) function->RequestWorkspaces(registry);
  484. }
  485. // Preprocesses the object using feature functions for the phase. Must be
  486. // called before any calls to ExtractFeatures() on that object and phase.
  487. void Preprocess(WorkspaceSet *workspaces, OBJ *object) const {
  488. for (Function *function : functions_) {
  489. function->Preprocess(workspaces, object);
  490. }
  491. }
  492. // Extracts features from an object with a focus. This invokes all the
  493. // top-level feature functions in the feature extractor. Only feature
  494. // functions belonging to the specified phase are invoked.
  495. void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &object,
  496. ARGS... args, FeatureVector *result) const {
  497. result->reserve(this->feature_types());
  498. // Extract features.
  499. for (int i = 0; i < functions_.size(); ++i) {
  500. functions_[i]->Evaluate(workspaces, object, args..., result);
  501. }
  502. }
  503. private:
  504. // Creates and initializes all feature functions in the feature extractor.
  505. void InitializeFeatureFunctions() override {
  506. // Create all top-level feature functions.
  507. for (int i = 0; i < descriptor().feature_size(); ++i) {
  508. FeatureFunctionDescriptor *fd = mutable_descriptor()->mutable_feature(i);
  509. Function *function = Function::Instantiate(this, fd, "");
  510. functions_.push_back(function);
  511. }
  512. }
  513. // Collect all feature types used in the feature extractor.
  514. void GetFeatureTypes(std::vector<FeatureType *> *types) const override {
  515. for (int i = 0; i < functions_.size(); ++i) {
  516. functions_[i]->GetFeatureTypes(types);
  517. }
  518. }
  519. // Top-level feature functions (and variables) in the feature extractor.
  520. // Owned.
  521. std::vector<Function *> functions_;
  522. };
  523. #define REGISTER_SYNTAXNET_FEATURE_FUNCTION(base, name, component) \
  524. REGISTER_SYNTAXNET_CLASS_COMPONENT(base, name, component)
  525. } // namespace syntaxnet
  526. #endif // SYNTAXNET_FEATURE_EXTRACTOR_H_