sentence_features.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/sentence_features.h"
  13. #include "syntaxnet/char_properties.h"
  14. #include "syntaxnet/registry.h"
  15. #include "util/utf8/unicodetext.h"
  16. #include "util/utf8/unilib.h"
  17. #include "util/utf8/unilib_utf8_utils.h"
  18. namespace syntaxnet {
  19. TermFrequencyMapFeature::~TermFrequencyMapFeature() {
  20. if (term_map_ != nullptr) {
  21. SharedStore::Release(term_map_);
  22. term_map_ = nullptr;
  23. }
  24. }
  25. void TermFrequencyMapFeature::Setup(TaskContext *context) {
  26. TokenLookupFeature::Setup(context);
  27. context->GetInput(input_name_, "text", "");
  28. }
  29. void TermFrequencyMapFeature::Init(TaskContext *context) {
  30. min_freq_ = GetIntParameter("min-freq", 0);
  31. max_num_terms_ = GetIntParameter("max-num-terms", 0);
  32. file_name_ = context->InputFile(*context->GetInput(input_name_));
  33. term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
  34. file_name_, min_freq_, max_num_terms_);
  35. TokenLookupFeature::Init(context);
  36. }
  37. string TermFrequencyMapFeature::GetFeatureValueName(FeatureValue value) const {
  38. if (value == UnknownValue()) return "<UNKNOWN>";
  39. if (value >= 0 && value < (NumValues() - 1)) {
  40. return term_map_->GetTerm(value);
  41. }
  42. LOG(ERROR) << "Invalid feature value: " << value;
  43. return "<INVALID>";
  44. }
  45. string TermFrequencyMapFeature::WorkspaceName() const {
  46. return SharedStoreUtils::CreateDefaultName("term-frequency-map", input_name_,
  47. min_freq_, max_num_terms_);
  48. }
  49. TermFrequencyMapSetFeature::~TermFrequencyMapSetFeature() {
  50. if (term_map_ != nullptr) {
  51. SharedStore::Release(term_map_);
  52. term_map_ = nullptr;
  53. }
  54. }
  55. void TermFrequencyMapSetFeature::Setup(TaskContext *context) {
  56. context->GetInput(input_name_, "text", "");
  57. }
  58. void TermFrequencyMapSetFeature::Init(TaskContext *context) {
  59. min_freq_ = GetIntParameter("min-freq", 0);
  60. max_num_terms_ = GetIntParameter("max-num-terms", 0);
  61. file_name_ = context->InputFile(*context->GetInput(input_name_));
  62. term_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
  63. file_name_, min_freq_, max_num_terms_);
  64. TokenLookupSetFeature::Init(context);
  65. }
  66. string TermFrequencyMapSetFeature::WorkspaceName() const {
  67. return SharedStoreUtils::CreateDefaultName(
  68. "term-frequency-map-set", input_name_, min_freq_, max_num_terms_);
  69. }
  70. namespace {
  71. void GetUTF8Chars(const string &word, vector<tensorflow::StringPiece> *chars) {
  72. UnicodeText text;
  73. text.PointToUTF8(word.c_str(), word.size());
  74. for (UnicodeText::const_iterator it = text.begin(); it != text.end(); ++it) {
  75. chars->push_back(tensorflow::StringPiece(it.utf8_data(), it.utf8_length()));
  76. }
  77. }
  78. int UTF8FirstLetterNumBytes(const char *utf8_str) {
  79. if (*utf8_str == '\0') return 0;
  80. return UniLib::OneCharLen(utf8_str);
  81. }
  82. } // namespace
  83. void CharNgram::GetTokenIndices(const Token &token, vector<int> *values) const {
  84. values->clear();
  85. vector<tensorflow::StringPiece> char_sp;
  86. if (use_terminators_) char_sp.push_back("^");
  87. GetUTF8Chars(token.word(), &char_sp);
  88. if (use_terminators_) char_sp.push_back("$");
  89. for (int start = 0; start < char_sp.size(); ++start) {
  90. string char_ngram;
  91. for (int index = 0;
  92. index < max_char_ngram_length_ && start + index < char_sp.size();
  93. ++index) {
  94. tensorflow::StringPiece c = char_sp[start + index];
  95. if (c == " ") break; // Never add char ngrams containing spaces.
  96. tensorflow::strings::StrAppend(&char_ngram, c);
  97. int value = LookupIndex(char_ngram);
  98. if (value != -1) { // Skip unknown values.
  99. values->push_back(value);
  100. }
  101. }
  102. }
  103. }
  104. void MorphologySet::GetTokenIndices(const Token &token,
  105. vector<int> *values) const {
  106. values->clear();
  107. const TokenMorphology &token_morphology =
  108. token.GetExtension(TokenMorphology::morphology);
  109. for (const TokenMorphology::Attribute &att : token_morphology.attribute()) {
  110. int value =
  111. LookupIndex(tensorflow::strings::StrCat(att.name(), "=", att.value()));
  112. if (value != -1) { // Skip unknown values.
  113. values->push_back(value);
  114. }
  115. }
  116. }
  117. string Hyphen::GetFeatureValueName(FeatureValue value) const {
  118. switch (value) {
  119. case NO_HYPHEN:
  120. return "NO_HYPHEN";
  121. case HAS_HYPHEN:
  122. return "HAS_HYPHEN";
  123. }
  124. return "<INVALID>";
  125. }
  126. FeatureValue Hyphen::ComputeValue(const Token &token) const {
  127. const string &word = token.word();
  128. return (word.find('-') < word.length() ? HAS_HYPHEN : NO_HYPHEN);
  129. }
  130. void Capitalization::Setup(TaskContext *context) {
  131. utf8_ = (GetParameter("utf8") == "true");
  132. }
  133. // Runs ComputeValue for each token in the sentence.
  134. void Capitalization::Preprocess(WorkspaceSet *workspaces,
  135. Sentence *sentence) const {
  136. if (workspaces->Has<VectorIntWorkspace>(Workspace())) return;
  137. VectorIntWorkspace *workspace =
  138. new VectorIntWorkspace(sentence->token_size());
  139. for (int i = 0; i < sentence->token_size(); ++i) {
  140. const int value = ComputeValueWithFocus(sentence->token(i), i);
  141. workspace->set_element(i, value);
  142. }
  143. workspaces->Set<VectorIntWorkspace>(Workspace(), workspace);
  144. }
  145. string Capitalization::GetFeatureValueName(FeatureValue value) const {
  146. switch (value) {
  147. case LOWERCASE:
  148. return "LOWERCASE";
  149. case UPPERCASE:
  150. return "UPPERCASE";
  151. case CAPITALIZED:
  152. return "CAPITALIZED";
  153. case CAPITALIZED_SENTENCE_INITIAL:
  154. return "CAPITALIZED_SENTENCE_INITIAL";
  155. case NON_ALPHABETIC:
  156. return "NON_ALPHABETIC";
  157. }
  158. return "<INVALID>";
  159. }
  160. FeatureValue Capitalization::ComputeValueWithFocus(const Token &token,
  161. int focus) const {
  162. const string &word = token.word();
  163. // Check whether there is an uppercase or lowercase character.
  164. bool has_upper = false;
  165. bool has_lower = false;
  166. if (utf8_) {
  167. LOG(FATAL) << "Not implemented.";
  168. } else {
  169. const char *str = word.c_str();
  170. for (int i = 0; i < word.length(); ++i) {
  171. const char c = str[i];
  172. has_upper = (has_upper || (c >= 'A' && c <= 'Z'));
  173. has_lower = (has_lower || (c >= 'a' && c <= 'z'));
  174. }
  175. }
  176. // Compute simple values.
  177. if (!has_upper && has_lower) return LOWERCASE;
  178. if (has_upper && !has_lower) return UPPERCASE;
  179. if (!has_upper && !has_lower) return NON_ALPHABETIC;
  180. // Else has_upper && has_lower; a normal capitalized word. Check the break
  181. // level to determine whether the capitalized word is sentence-initial.
  182. const bool sentence_initial = (focus == 0);
  183. return sentence_initial ? CAPITALIZED_SENTENCE_INITIAL : CAPITALIZED;
  184. }
  185. string PunctuationAmount::GetFeatureValueName(FeatureValue value) const {
  186. switch (value) {
  187. case NO_PUNCTUATION:
  188. return "NO_PUNCTUATION";
  189. case SOME_PUNCTUATION:
  190. return "SOME_PUNCTUATION";
  191. case ALL_PUNCTUATION:
  192. return "ALL_PUNCTUATION";
  193. }
  194. return "<INVALID>";
  195. }
  196. FeatureValue PunctuationAmount::ComputeValue(const Token &token) const {
  197. const string &word = token.word();
  198. bool has_punctuation = false;
  199. bool all_punctuation = true;
  200. const char *start = word.c_str();
  201. const char *end = word.c_str() + word.size();
  202. while (start < end) {
  203. int char_length = UTF8FirstLetterNumBytes(start);
  204. bool char_is_punct = is_punctuation_or_symbol(start, char_length);
  205. all_punctuation &= char_is_punct;
  206. has_punctuation |= char_is_punct;
  207. if (!all_punctuation && has_punctuation) return SOME_PUNCTUATION;
  208. start += char_length;
  209. }
  210. if (!all_punctuation) return NO_PUNCTUATION;
  211. return ALL_PUNCTUATION;
  212. }
  213. string Quote::GetFeatureValueName(FeatureValue value) const {
  214. switch (value) {
  215. case NO_QUOTE:
  216. return "NO_QUOTE";
  217. case OPEN_QUOTE:
  218. return "OPEN_QUOTE";
  219. case CLOSE_QUOTE:
  220. return "CLOSE_QUOTE";
  221. case UNKNOWN_QUOTE:
  222. return "UNKNOWN_QUOTE";
  223. }
  224. return "<INVALID>";
  225. }
  226. FeatureValue Quote::ComputeValue(const Token &token) const {
  227. const string &word = token.word();
  228. // Penn Treebank open and close quotes are multi-character.
  229. if (word == "``") return OPEN_QUOTE;
  230. if (word == "''") return CLOSE_QUOTE;
  231. if (word.length() == 1) {
  232. int char_len = UTF8FirstLetterNumBytes(word.c_str());
  233. bool is_open = is_open_quote(word.c_str(), char_len);
  234. bool is_close = is_close_quote(word.c_str(), char_len);
  235. if (is_open && !is_close) return OPEN_QUOTE;
  236. if (is_close && !is_open) return CLOSE_QUOTE;
  237. if (is_open && is_close) return UNKNOWN_QUOTE;
  238. }
  239. return NO_QUOTE;
  240. }
  241. void Quote::Preprocess(WorkspaceSet *workspaces, Sentence *sentence) const {
  242. if (workspaces->Has<VectorIntWorkspace>(Workspace())) return;
  243. VectorIntWorkspace *workspace =
  244. new VectorIntWorkspace(sentence->token_size());
  245. // For double quote ", it is unknown whether they are open or closed without
  246. // looking at the prior tokens in the sentence. in_quote is true iff an odd
  247. // number of " marks have been seen so far in the sentence (similar to the
  248. // behavior of some tokenizers).
  249. bool in_quote = false;
  250. for (int i = 0; i < sentence->token_size(); ++i) {
  251. int quote_type = ComputeValue(sentence->token(i));
  252. if (quote_type == UNKNOWN_QUOTE) {
  253. // Update based on in_quote and flip in_quote.
  254. quote_type = in_quote ? CLOSE_QUOTE : OPEN_QUOTE;
  255. in_quote = !in_quote;
  256. }
  257. workspace->set_element(i, quote_type);
  258. }
  259. workspaces->Set<VectorIntWorkspace>(Workspace(), workspace);
  260. }
  261. string Digit::GetFeatureValueName(FeatureValue value) const {
  262. switch (value) {
  263. case NO_DIGIT:
  264. return "NO_DIGIT";
  265. case SOME_DIGIT:
  266. return "SOME_DIGIT";
  267. case ALL_DIGIT:
  268. return "ALL_DIGIT";
  269. }
  270. return "<INVALID>";
  271. }
  272. FeatureValue Digit::ComputeValue(const Token &token) const {
  273. const string &word = token.word();
  274. bool has_digit = isdigit(word[0]);
  275. bool all_digit = has_digit;
  276. for (size_t i = 1; i < word.length(); ++i) {
  277. bool char_is_digit = isdigit(word[i]);
  278. all_digit = all_digit && char_is_digit;
  279. has_digit = has_digit || char_is_digit;
  280. if (!all_digit && has_digit) return SOME_DIGIT;
  281. }
  282. if (!all_digit) return NO_DIGIT;
  283. return ALL_DIGIT;
  284. }
  285. AffixTableFeature::AffixTableFeature(AffixTable::Type type)
  286. : type_(type) {
  287. if (type == AffixTable::PREFIX) {
  288. input_name_ = "prefix-table";
  289. } else {
  290. input_name_ = "suffix-table";
  291. }
  292. }
  293. AffixTableFeature::~AffixTableFeature() {
  294. SharedStore::Release(affix_table_);
  295. affix_table_ = nullptr;
  296. }
  297. string AffixTableFeature::WorkspaceName() const {
  298. return SharedStoreUtils::CreateDefaultName(
  299. "affix-table", input_name_, type_, affix_length_);
  300. }
  301. // Utility function to create a new affix table without changing constructors,
  302. // to be called by the SharedStore.
  303. static AffixTable *CreateAffixTable(const string &filename,
  304. AffixTable::Type type) {
  305. AffixTable *affix_table = new AffixTable(type, 1);
  306. std::unique_ptr<tensorflow::RandomAccessFile> file;
  307. TF_CHECK_OK(tensorflow::Env::Default()->NewRandomAccessFile(filename, &file));
  308. ProtoRecordReader reader(file.release());
  309. affix_table->Read(&reader);
  310. return affix_table;
  311. }
  312. void AffixTableFeature::Setup(TaskContext *context) {
  313. context->GetInput(input_name_, "recordio", "affix-table");
  314. affix_length_ = GetIntParameter("length", 0);
  315. CHECK_GE(affix_length_, 0) << "Length must be specified for affix feature.";
  316. TokenLookupFeature::Setup(context);
  317. }
  318. void AffixTableFeature::Init(TaskContext *context) {
  319. string filename = context->InputFile(*context->GetInput(input_name_));
  320. // Get the shared AffixTable object.
  321. std::function<AffixTable *()> closure =
  322. std::bind(CreateAffixTable, filename, type_);
  323. affix_table_ = SharedStore::ClosureGetOrDie(filename, &closure);
  324. CHECK_GE(affix_table_->max_length(), affix_length_)
  325. << "Affixes of length " << affix_length_ << " needed, but the affix "
  326. <<"table only provides affixes of length <= "
  327. << affix_table_->max_length() << ".";
  328. TokenLookupFeature::Init(context);
  329. }
  330. FeatureValue AffixTableFeature::ComputeValue(const Token &token) const {
  331. const string &word = token.word();
  332. UnicodeText text;
  333. text.PointToUTF8(word.c_str(), word.size());
  334. if (affix_length_ > text.size()) return UnknownValue();
  335. UnicodeText::const_iterator start, end;
  336. if (type_ == AffixTable::PREFIX) {
  337. start = end = text.begin();
  338. for (int i = 0; i < affix_length_; ++i) ++end;
  339. } else {
  340. start = end = text.end();
  341. for (int i = 0; i < affix_length_; ++i) --start;
  342. }
  343. string affix(start.utf8_data(), end.utf8_data() - start.utf8_data());
  344. int affix_id = affix_table_->AffixId(affix);
  345. return affix_id == -1 ? UnknownValue() : affix_id;
  346. }
  347. string AffixTableFeature::GetFeatureValueName(FeatureValue value) const {
  348. if (value == UnknownValue()) return "<UNKNOWN>";
  349. if (value >= 0 && value < UnknownValue()) {
  350. return affix_table_->AffixForm(value);
  351. }
  352. LOG(ERROR) << "Invalid feature value: " << value;
  353. return "<INVALID>";
  354. }
  355. // Registry for the Sentence + token index feature functions.
  356. REGISTER_CLASS_REGISTRY("sentence+index feature function", SentenceFeature);
  357. // Register the features defined in the header.
  358. REGISTER_SENTENCE_IDX_FEATURE("word", Word);
  359. REGISTER_SENTENCE_IDX_FEATURE("char", Char);
  360. REGISTER_SENTENCE_IDX_FEATURE("lcword", LowercaseWord);
  361. REGISTER_SENTENCE_IDX_FEATURE("tag", Tag);
  362. REGISTER_SENTENCE_IDX_FEATURE("offset", Offset);
  363. REGISTER_SENTENCE_IDX_FEATURE("hyphen", Hyphen);
  364. REGISTER_SENTENCE_IDX_FEATURE("digit", Digit);
  365. REGISTER_SENTENCE_IDX_FEATURE("prefix", PrefixFeature);
  366. REGISTER_SENTENCE_IDX_FEATURE("suffix", SuffixFeature);
  367. REGISTER_SENTENCE_IDX_FEATURE("char-ngram", CharNgram);
  368. REGISTER_SENTENCE_IDX_FEATURE("morphology-set", MorphologySet);
  369. REGISTER_SENTENCE_IDX_FEATURE("capitalization", Capitalization);
  370. REGISTER_SENTENCE_IDX_FEATURE("punctuation-amount", PunctuationAmount);
  371. REGISTER_SENTENCE_IDX_FEATURE("quote", Quote);
  372. } // namespace syntaxnet