sentence_features.cc 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/sentence_features.h"
  13. #include "syntaxnet/registry.h"
  14. #include "util/utf8/unicodetext.h"
  15. namespace syntaxnet {
  16. TermFrequencyMapFeature::~TermFrequencyMapFeature() {
  17. if (term_map_ != nullptr) {
  18. SharedStore::Release(term_map_);
  19. term_map_ = nullptr;
  20. }
  21. }
  22. void TermFrequencyMapFeature::Setup(TaskContext *context) {
  23. TokenLookupFeature::Setup(context);
  24. context->GetInput(input_name_, "text", "");
  25. }
  26. void TermFrequencyMapFeature::Init(TaskContext *context) {
  27. min_freq_ = GetIntParameter("min-freq", 0);
  28. max_num_terms_ = GetIntParameter("max-num-terms", 0);
  29. file_name_ = context->InputFile(*context->GetInput(input_name_));
  30. term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
  31. file_name_, min_freq_, max_num_terms_);
  32. TokenLookupFeature::Init(context);
  33. }
  34. string TermFrequencyMapFeature::GetFeatureValueName(FeatureValue value) const {
  35. if (value == UnknownValue()) return "<UNKNOWN>";
  36. if (value >= 0 && value < (NumValues() - 1)) {
  37. return term_map_->GetTerm(value);
  38. }
  39. LOG(ERROR) << "Invalid feature value: " << value;
  40. return "<INVALID>";
  41. }
  42. string TermFrequencyMapFeature::WorkspaceName() const {
  43. return SharedStoreUtils::CreateDefaultName("term-frequency-map", input_name_,
  44. min_freq_, max_num_terms_);
  45. }
  46. string Hyphen::GetFeatureValueName(FeatureValue value) const {
  47. switch (value) {
  48. case NO_HYPHEN:
  49. return "NO_HYPHEN";
  50. case HAS_HYPHEN:
  51. return "HAS_HYPHEN";
  52. }
  53. return "<INVALID>";
  54. }
  55. FeatureValue Hyphen::ComputeValue(const Token &token) const {
  56. const string &word = token.word();
  57. return (word.find('-') < word.length() ? HAS_HYPHEN : NO_HYPHEN);
  58. }
  59. string Digit::GetFeatureValueName(FeatureValue value) const {
  60. switch (value) {
  61. case NO_DIGIT:
  62. return "NO_DIGIT";
  63. case SOME_DIGIT:
  64. return "SOME_DIGIT";
  65. case ALL_DIGIT:
  66. return "ALL_DIGIT";
  67. }
  68. return "<INVALID>";
  69. }
  70. FeatureValue Digit::ComputeValue(const Token &token) const {
  71. const string &word = token.word();
  72. bool has_digit = isdigit(word[0]);
  73. bool all_digit = has_digit;
  74. for (size_t i = 1; i < word.length(); ++i) {
  75. bool char_is_digit = isdigit(word[i]);
  76. all_digit = all_digit && char_is_digit;
  77. has_digit = has_digit || char_is_digit;
  78. if (!all_digit && has_digit) return SOME_DIGIT;
  79. }
  80. if (!all_digit) return NO_DIGIT;
  81. return ALL_DIGIT;
  82. }
  83. AffixTableFeature::AffixTableFeature(AffixTable::Type type)
  84. : type_(type) {
  85. if (type == AffixTable::PREFIX) {
  86. input_name_ = "prefix-table";
  87. } else {
  88. input_name_ = "suffix-table";
  89. }
  90. }
  91. AffixTableFeature::~AffixTableFeature() {
  92. SharedStore::Release(affix_table_);
  93. affix_table_ = nullptr;
  94. }
  95. string AffixTableFeature::WorkspaceName() const {
  96. return SharedStoreUtils::CreateDefaultName(
  97. "affix-table", input_name_, type_, affix_length_);
  98. }
  99. // Utility function to create a new affix table without changing constructors,
  100. // to be called by the SharedStore.
  101. static AffixTable *CreateAffixTable(const string &filename,
  102. AffixTable::Type type) {
  103. AffixTable *affix_table = new AffixTable(type, 1);
  104. std::unique_ptr<tensorflow::RandomAccessFile> file;
  105. TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
  106. ProtoRecordReader reader(file.release());
  107. affix_table->Read(&reader);
  108. return affix_table;
  109. }
  110. void AffixTableFeature::Setup(TaskContext *context) {
  111. context->GetInput(input_name_, "recordio", "affix-table");
  112. affix_length_ = GetIntParameter("length", 0);
  113. CHECK_GE(affix_length_, 0)
  114. << "Length must be specified for affix preprocessor.";
  115. TokenLookupFeature::Setup(context);
  116. }
  117. void AffixTableFeature::Init(TaskContext *context) {
  118. string filename = context->InputFile(*context->GetInput(input_name_));
  119. // Get the shared AffixTable object.
  120. std::function<AffixTable *()> closure =
  121. std::bind(CreateAffixTable, filename, type_);
  122. affix_table_ = SharedStore::ClosureGetOrDie(filename, &closure);
  123. CHECK_GE(affix_table_->max_length(), affix_length_)
  124. << "Affixes of length " << affix_length_ << " needed, but the affix "
  125. <<"table only provides affixes of length <= "
  126. << affix_table_->max_length() << ".";
  127. TokenLookupFeature::Init(context);
  128. }
  129. FeatureValue AffixTableFeature::ComputeValue(const Token &token) const {
  130. const string &word = token.word();
  131. UnicodeText text;
  132. text.PointToUTF8(word.c_str(), word.size());
  133. if (affix_length_ > text.size()) return UnknownValue();
  134. UnicodeText::const_iterator start, end;
  135. if (type_ == AffixTable::PREFIX) {
  136. start = end = text.begin();
  137. for (int i = 0; i < affix_length_; ++i) ++end;
  138. } else {
  139. start = end = text.end();
  140. for (int i = 0; i < affix_length_; ++i) --start;
  141. }
  142. string affix(start.utf8_data(), end.utf8_data() - start.utf8_data());
  143. int affix_id = affix_table_->AffixId(affix);
  144. return affix_id == -1 ? UnknownValue() : affix_id;
  145. }
  146. string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
  147. if (value == UnknownValue()) return "<UNKNOWN>";
  148. if (value >= 0 && value < UnknownValue()) {
  149. return affix_table_->AffixForm(value);
  150. }
  151. LOG(ERROR) << "Invalid feature value: " << value;
  152. return "<INVALID>";
  153. }
  154. // Registry for the Sentence + token index feature functions.
  155. REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
  156. // Register the features defined in the header.
  157. REGISTER_SENTENCE_IDX_FEATURE("word", Word);
  158. REGISTER_SENTENCE_IDX_FEATURE("lcword", LowercaseWord);
  159. REGISTER_SENTENCE_IDX_FEATURE("tag", Tag);
  160. REGISTER_SENTENCE_IDX_FEATURE("offset", Offset);
  161. REGISTER_SENTENCE_IDX_FEATURE("hyphen", Hyphen);
  162. REGISTER_SENTENCE_IDX_FEATURE("digit", Digit);
  163. REGISTER_SENTENCE_IDX_FEATURE("prefix", PrefixFeature);
  164. REGISTER_SENTENCE_IDX_FEATURE("suffix", SuffixFeature);
  165. } // namespace syntaxnet