lexicon_builder.cc 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include <stddef.h>
  13. #include <string>
  14. #include "syntaxnet/affix.h"
  15. #include "syntaxnet/char_ngram_string_extractor.h"
  16. #include "syntaxnet/feature_extractor.h"
  17. #include "syntaxnet/segmenter_utils.h"
  18. #include "syntaxnet/sentence.pb.h"
  19. #include "syntaxnet/sentence_batch.h"
  20. #include "syntaxnet/term_frequency_map.h"
  21. #include "syntaxnet/utils.h"
  22. #include "tensorflow/core/framework/op_kernel.h"
  23. #include "tensorflow/core/lib/core/status.h"
  24. #include "tensorflow/core/platform/env.h"
  25. // A task that collects term statistics over a corpus and saves a set of
  26. // term maps; these saved mappings are used to map strings to ints in both the
  27. // chunker trainer and the chunker processors.
  28. using tensorflow::DEVICE_CPU;
  29. using tensorflow::DT_INT32;
  30. using tensorflow::DT_STRING;
  31. using tensorflow::OpKernel;
  32. using tensorflow::OpKernelConstruction;
  33. using tensorflow::OpKernelContext;
  34. using tensorflow::Tensor;
  35. using tensorflow::TensorShape;
  36. using tensorflow::errors::InvalidArgument;
  37. namespace syntaxnet {
  38. namespace {
  39. // Helper function to load the TaskSpec either from the `task_context`
  40. // or the `task_context_str` arguments of the op.
  41. void LoadSpec(OpKernelConstruction *context, TaskSpec *task_spec) {
  42. string file_path, data;
  43. OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
  44. if (!file_path.empty()) {
  45. OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
  46. file_path, &data));
  47. } else {
  48. OP_REQUIRES_OK(context, context->GetAttr("task_context_str", &data));
  49. }
  50. OP_REQUIRES(context, TextFormat::ParseFromString(data, task_spec),
  51. InvalidArgument("Could not parse task context from ", data));
  52. }
  53. class LexiconBuilder : public OpKernel {
  54. public:
  55. explicit LexiconBuilder(OpKernelConstruction *context) : OpKernel(context) {
  56. OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name_));
  57. OP_REQUIRES_OK(context, context->GetAttr("lexicon_max_prefix_length",
  58. &max_prefix_length_));
  59. OP_REQUIRES_OK(context, context->GetAttr("lexicon_max_suffix_length",
  60. &max_suffix_length_));
  61. LoadSpec(context, task_context_.mutable_spec());
  62. int min_length, max_length;
  63. OP_REQUIRES_OK(context, context->GetAttr("lexicon_min_char_ngram_length",
  64. &min_length));
  65. OP_REQUIRES_OK(context, context->GetAttr("lexicon_max_char_ngram_length",
  66. &max_length));
  67. bool add_terminators, mark_boundaries;
  68. OP_REQUIRES_OK(context,
  69. context->GetAttr("lexicon_char_ngram_include_terminators",
  70. &add_terminators));
  71. OP_REQUIRES_OK(context,
  72. context->GetAttr("lexicon_char_ngram_mark_boundaries",
  73. &mark_boundaries));
  74. char_ngram_string_extractor_.set_min_length(min_length);
  75. char_ngram_string_extractor_.set_max_length(max_length);
  76. char_ngram_string_extractor_.set_add_terminators(add_terminators);
  77. char_ngram_string_extractor_.set_mark_boundaries(mark_boundaries);
  78. char_ngram_string_extractor_.Setup(task_context_);
  79. }
  80. // Counts term frequencies.
  81. void Compute(OpKernelContext *context) override {
  82. // Term frequency maps to be populated by the corpus.
  83. TermFrequencyMap words;
  84. TermFrequencyMap lcwords;
  85. TermFrequencyMap tags;
  86. TermFrequencyMap categories;
  87. TermFrequencyMap labels;
  88. TermFrequencyMap chars;
  89. TermFrequencyMap char_ngrams;
  90. // Affix tables to be populated by the corpus.
  91. AffixTable prefixes(AffixTable::PREFIX, max_prefix_length_);
  92. AffixTable suffixes(AffixTable::SUFFIX, max_suffix_length_);
  93. // Tag-to-category mapping.
  94. TagToCategoryMap tag_to_category;
  95. // Make a pass over the corpus.
  96. int64 num_tokens = 0;
  97. int64 num_documents = 0;
  98. Sentence *document;
  99. TextReader corpus(*task_context_.GetInput(corpus_name_), &task_context_);
  100. while ((document = corpus.Read()) != nullptr) {
  101. // Gather token information.
  102. for (int t = 0; t < document->token_size(); ++t) {
  103. // Get token and lowercased word.
  104. const Token &token = document->token(t);
  105. string word = token.word();
  106. utils::NormalizeDigits(&word);
  107. string lcword = tensorflow::str_util::Lowercase(word);
  108. // Make sure the token does not contain a newline.
  109. CHECK(lcword.find('\n') == string::npos);
  110. // Increment frequencies (only for terms that exist).
  111. if (!word.empty() && !HasSpaces(word)) words.Increment(word);
  112. if (!lcword.empty() && !HasSpaces(lcword)) lcwords.Increment(lcword);
  113. if (!token.tag().empty()) tags.Increment(token.tag());
  114. if (!token.category().empty()) categories.Increment(token.category());
  115. if (!token.label().empty()) labels.Increment(token.label());
  116. // Add prefixes/suffixes for the current word.
  117. prefixes.AddAffixesForWord(word.c_str(), word.size());
  118. suffixes.AddAffixesForWord(word.c_str(), word.size());
  119. // Add mapping from tag to category.
  120. tag_to_category.SetCategory(token.tag(), token.category());
  121. // Add characters.
  122. std::vector<tensorflow::StringPiece> char_sp;
  123. SegmenterUtils::GetUTF8Chars(word, &char_sp);
  124. for (const auto &c : char_sp) {
  125. const string c_str = c.ToString();
  126. if (!c_str.empty() && !HasSpaces(c_str)) chars.Increment(c_str);
  127. }
  128. // Add character ngrams.
  129. char_ngram_string_extractor_.Extract(
  130. word, [&](const string &char_ngram) {
  131. char_ngrams.Increment(char_ngram);
  132. });
  133. // Update the number of processed tokens.
  134. ++num_tokens;
  135. }
  136. delete document;
  137. ++num_documents;
  138. }
  139. LOG(INFO) << "Term maps collected over " << num_tokens << " tokens from "
  140. << num_documents << " documents";
  141. // Write mappings to disk.
  142. words.Save(TaskContext::InputFile(*task_context_.GetInput("word-map")));
  143. lcwords.Save(TaskContext::InputFile(*task_context_.GetInput("lcword-map")));
  144. tags.Save(TaskContext::InputFile(*task_context_.GetInput("tag-map")));
  145. categories.Save(
  146. TaskContext::InputFile(*task_context_.GetInput("category-map")));
  147. labels.Save(TaskContext::InputFile(*task_context_.GetInput("label-map")));
  148. chars.Save(TaskContext::InputFile(*task_context_.GetInput("char-map")));
  149. // Optional, for backwards-compatibility with existing specs.
  150. TaskInput *char_ngrams_input = task_context_.GetInput("char-ngram-map");
  151. if (char_ngrams_input->part_size() > 0) {
  152. char_ngrams.Save(TaskContext::InputFile(*char_ngrams_input));
  153. }
  154. // Write affixes to disk.
  155. WriteAffixTable(prefixes, TaskContext::InputFile(
  156. *task_context_.GetInput("prefix-table")));
  157. WriteAffixTable(suffixes, TaskContext::InputFile(
  158. *task_context_.GetInput("suffix-table")));
  159. // Write tag-to-category mapping to disk.
  160. tag_to_category.Save(
  161. TaskContext::InputFile(*task_context_.GetInput("tag-to-category")));
  162. }
  163. private:
  164. // Returns true if the word contains spaces.
  165. static bool HasSpaces(const string &word) {
  166. for (char c : word) {
  167. if (c == ' ') return true;
  168. }
  169. return false;
  170. }
  171. // Writes an affix table to a task output.
  172. static void WriteAffixTable(const AffixTable &affixes,
  173. const string &output_file) {
  174. ProtoRecordWriter writer(output_file);
  175. affixes.Write(&writer);
  176. }
  177. // Name of the context input to compute lexicons.
  178. string corpus_name_;
  179. // Max length for prefix table.
  180. int max_prefix_length_;
  181. // Max length for suffix table.
  182. int max_suffix_length_;
  183. // Extractor for character n-gram strings.
  184. CharNgramStringExtractor char_ngram_string_extractor_;
  185. // Task context used to configure this op.
  186. TaskContext task_context_;
  187. };
  188. REGISTER_KERNEL_BUILDER(Name("LexiconBuilder").Device(DEVICE_CPU),
  189. LexiconBuilder);
  190. class FeatureSize : public OpKernel {
  191. public:
  192. explicit FeatureSize(OpKernelConstruction *context) : OpKernel(context) {
  193. OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
  194. OP_REQUIRES_OK(context, context->MatchSignature(
  195. {}, {DT_INT32, DT_INT32, DT_INT32, DT_INT32}));
  196. LoadSpec(context, task_context_.mutable_spec());
  197. // See comment at the bottom of Compute() below.
  198. const string label_map_path =
  199. TaskContext::InputFile(*task_context_.GetInput("label-map"));
  200. label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
  201. label_map_path, 0, 0);
  202. }
  203. ~FeatureSize() override { SharedStore::Release(label_map_); }
  204. void Compute(OpKernelContext *context) override {
  205. // Computes feature sizes.
  206. ParserEmbeddingFeatureExtractor features(arg_prefix_);
  207. features.Setup(&task_context_);
  208. features.Init(&task_context_);
  209. const int num_embeddings = features.NumEmbeddings();
  210. Tensor *feature_sizes = nullptr;
  211. Tensor *domain_sizes = nullptr;
  212. Tensor *embedding_dims = nullptr;
  213. Tensor *num_actions = nullptr;
  214. TF_CHECK_OK(context->allocate_output(0, TensorShape({num_embeddings}),
  215. &feature_sizes));
  216. TF_CHECK_OK(context->allocate_output(1, TensorShape({num_embeddings}),
  217. &domain_sizes));
  218. TF_CHECK_OK(context->allocate_output(2, TensorShape({num_embeddings}),
  219. &embedding_dims));
  220. TF_CHECK_OK(context->allocate_output(3, TensorShape({}), &num_actions));
  221. for (int i = 0; i < num_embeddings; ++i) {
  222. feature_sizes->vec<int32>()(i) = features.FeatureSize(i);
  223. domain_sizes->vec<int32>()(i) = features.EmbeddingSize(i);
  224. embedding_dims->vec<int32>()(i) = features.EmbeddingDims(i);
  225. }
  226. // Computes number of actions in the transition system.
  227. std::unique_ptr<ParserTransitionSystem> transition_system(
  228. ParserTransitionSystem::Create(task_context_.Get(
  229. features.GetParamName("transition_system"), "arc-standard")));
  230. transition_system->Setup(&task_context_);
  231. transition_system->Init(&task_context_);
  232. // Note: label_map_->Size() is ignored by non-parser transition systems.
  233. // So even though we read the parser's label-map (output value tags and
  234. // their frequency), this function works for other transition systems.
  235. num_actions->scalar<int32>()() =
  236. transition_system->NumActions(label_map_->Size());
  237. }
  238. private:
  239. // Task context used to configure this op.
  240. TaskContext task_context_;
  241. // Dependency label map used in transition system.
  242. const TermFrequencyMap *label_map_;
  243. // Prefix for context parameters.
  244. string arg_prefix_;
  245. };
  246. REGISTER_KERNEL_BUILDER(Name("FeatureSize").Device(DEVICE_CPU), FeatureSize);
  247. class FeatureVocab : public OpKernel {
  248. public:
  249. explicit FeatureVocab(OpKernelConstruction *context) : OpKernel(context) {
  250. OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
  251. OP_REQUIRES_OK(context,
  252. context->GetAttr("embedding_name", &embedding_name_));
  253. OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_STRING}));
  254. LoadSpec(context, task_context_.mutable_spec());
  255. }
  256. void Compute(OpKernelContext *context) override {
  257. // Computes feature sizes.
  258. ParserEmbeddingFeatureExtractor features(arg_prefix_);
  259. features.Setup(&task_context_);
  260. features.Init(&task_context_);
  261. const std::vector<string> mapped_words =
  262. features.GetMappingsForEmbedding(embedding_name_);
  263. Tensor *vocab = nullptr;
  264. const int64 size = mapped_words.size();
  265. TF_CHECK_OK(context->allocate_output(0, TensorShape({size}), &vocab));
  266. for (int i = 0; i < size; ++i) {
  267. vocab->vec<string>()(i) = mapped_words[i];
  268. }
  269. }
  270. private:
  271. // Task context used to configure this op.
  272. TaskContext task_context_;
  273. // Prefix for context parameters.
  274. string arg_prefix_;
  275. // Name of embedding for which the vocabulary is to be extracted.
  276. string embedding_name_;
  277. };
  278. REGISTER_KERNEL_BUILDER(Name("FeatureVocab").Device(DEVICE_CPU), FeatureVocab);
  279. } // namespace
  280. } // namespace syntaxnet