reader_ops.cc 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include <math.h>
  13. #include <deque>
  14. #include <memory>
  15. #include <string>
  16. #include <unordered_map>
  17. #include <vector>
  18. #include "syntaxnet/base.h"
  19. #include "syntaxnet/feature_extractor.h"
  20. #include "syntaxnet/parser_state.h"
  21. #include "syntaxnet/parser_transitions.h"
  22. #include "syntaxnet/sentence.pb.h"
  23. #include "syntaxnet/sentence_batch.h"
  24. #include "syntaxnet/shared_store.h"
  25. #include "syntaxnet/sparse.pb.h"
  26. #include "syntaxnet/task_context.h"
  27. #include "syntaxnet/task_spec.pb.h"
  28. #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
  29. #include "tensorflow/core/framework/op_kernel.h"
  30. #include "tensorflow/core/framework/tensor.h"
  31. #include "tensorflow/core/framework/tensor_shape.h"
  32. #include "tensorflow/core/lib/core/status.h"
  33. #include "tensorflow/core/lib/io/table.h"
  34. #include "tensorflow/core/lib/io/table_options.h"
  35. #include "tensorflow/core/lib/strings/stringprintf.h"
  36. #include "tensorflow/core/platform/env.h"
  37. using tensorflow::DEVICE_CPU;
  38. using tensorflow::DT_FLOAT;
  39. using tensorflow::DT_INT32;
  40. using tensorflow::DT_INT64;
  41. using tensorflow::DT_STRING;
  42. using tensorflow::DataType;
  43. using tensorflow::OpKernel;
  44. using tensorflow::OpKernelConstruction;
  45. using tensorflow::OpKernelContext;
  46. using tensorflow::Tensor;
  47. using tensorflow::TensorShape;
  48. using tensorflow::error::OUT_OF_RANGE;
  49. using tensorflow::errors::InvalidArgument;
  50. namespace syntaxnet {
  51. class ParsingReader : public OpKernel {
  52. public:
  53. explicit ParsingReader(OpKernelConstruction *context) : OpKernel(context) {
  54. string file_path, corpus_name;
  55. OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
  56. OP_REQUIRES_OK(context, context->GetAttr("feature_size", &feature_size_));
  57. OP_REQUIRES_OK(context, context->GetAttr("batch_size", &max_batch_size_));
  58. OP_REQUIRES_OK(context, context->GetAttr("corpus_name", &corpus_name));
  59. OP_REQUIRES_OK(context, context->GetAttr("arg_prefix", &arg_prefix_));
  60. // Reads task context from file.
  61. string data;
  62. OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
  63. file_path, &data));
  64. OP_REQUIRES(context,
  65. TextFormat::ParseFromString(data, task_context_.mutable_spec()),
  66. InvalidArgument("Could not parse task context at ", file_path));
  67. // Set up the batch reader.
  68. sentence_batch_.reset(
  69. new SentenceBatch(max_batch_size_, corpus_name));
  70. sentence_batch_->Init(&task_context_);
  71. // Set up the parsing features and transition system.
  72. states_.resize(max_batch_size_);
  73. workspaces_.resize(max_batch_size_);
  74. features_.reset(new ParserEmbeddingFeatureExtractor(arg_prefix_));
  75. features_->Setup(&task_context_);
  76. transition_system_.reset(ParserTransitionSystem::Create(task_context_.Get(
  77. features_->GetParamName("transition_system"), "arc-standard")));
  78. transition_system_->Setup(&task_context_);
  79. features_->Init(&task_context_);
  80. features_->RequestWorkspaces(&workspace_registry_);
  81. transition_system_->Init(&task_context_);
  82. string label_map_path =
  83. TaskContext::InputFile(*task_context_.GetInput("label-map"));
  84. label_map_ = SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(
  85. label_map_path, 0, 0);
  86. // Checks number of feature groups matches the task context.
  87. const int required_size = features_->embedding_dims().size();
  88. OP_REQUIRES(
  89. context, feature_size_ == required_size,
  90. InvalidArgument("Task context requires feature_size=", required_size));
  91. }
  92. ~ParsingReader() override { SharedStore::Release(label_map_); }
  93. // Creates a new ParserState if there's another sentence to be read.
  94. virtual void AdvanceSentence(int index) {
  95. states_[index].reset();
  96. if (sentence_batch_->AdvanceSentence(index)) {
  97. states_[index].reset(new ParserState(
  98. sentence_batch_->sentence(index),
  99. transition_system_->NewTransitionState(true), label_map_));
  100. workspaces_[index].Reset(workspace_registry_);
  101. features_->Preprocess(&workspaces_[index], states_[index].get());
  102. }
  103. }
  104. void Compute(OpKernelContext *context) override {
  105. mutex_lock lock(mu_);
  106. // Advances states to the next positions.
  107. PerformActions(context);
  108. // Advances any final states to the next sentences.
  109. for (int i = 0; i < max_batch_size_; ++i) {
  110. if (state(i) == nullptr) continue;
  111. // Switches to the next sentence if we're at a final state.
  112. while (transition_system_->IsFinalState(*state(i))) {
  113. VLOG(2) << "Advancing sentence " << i;
  114. AdvanceSentence(i);
  115. if (state(i) == nullptr) break; // EOF has been reached
  116. }
  117. }
  118. // Rewinds if no states remain in the batch (we need to re-wind the corpus).
  119. if (sentence_batch_->size() == 0) {
  120. ++num_epochs_;
  121. LOG(INFO) << "Starting epoch " << num_epochs_;
  122. sentence_batch_->Rewind();
  123. for (int i = 0; i < max_batch_size_; ++i) AdvanceSentence(i);
  124. }
  125. // Create the outputs for each feature space.
  126. std::vector<Tensor *> feature_outputs(features_->NumEmbeddings());
  127. for (size_t i = 0; i < feature_outputs.size(); ++i) {
  128. OP_REQUIRES_OK(context, context->allocate_output(
  129. i, TensorShape({sentence_batch_->size(),
  130. features_->FeatureSize(i)}),
  131. &feature_outputs[i]));
  132. }
  133. // Populate feature outputs.
  134. for (int i = 0, index = 0; i < max_batch_size_; ++i) {
  135. if (states_[i] == nullptr) continue;
  136. // Extract features from the current parser state, and fill up the
  137. // available batch slots.
  138. std::vector<std::vector<SparseFeatures>> features =
  139. features_->ExtractSparseFeatures(workspaces_[i], *states_[i]);
  140. for (size_t feature_space = 0; feature_space < features.size();
  141. ++feature_space) {
  142. int feature_size = features[feature_space].size();
  143. CHECK(feature_size == features_->FeatureSize(feature_space));
  144. auto features_output = feature_outputs[feature_space]->matrix<string>();
  145. for (int k = 0; k < feature_size; ++k) {
  146. features_output(index, k) =
  147. features[feature_space][k].SerializeAsString();
  148. }
  149. }
  150. ++index;
  151. }
  152. // Return the number of epochs.
  153. Tensor *epoch_output;
  154. OP_REQUIRES_OK(context, context->allocate_output(
  155. feature_size_, TensorShape({}), &epoch_output));
  156. auto num_epochs = epoch_output->scalar<int32>();
  157. num_epochs() = num_epochs_;
  158. // Create outputs specific to this reader.
  159. AddAdditionalOutputs(context);
  160. }
  161. protected:
  162. // Peforms any relevant actions on the parser states, typically either
  163. // the gold action or a predicted action from decoding.
  164. virtual void PerformActions(OpKernelContext *context) = 0;
  165. // Adds outputs specific to this reader starting at additional_output_index().
  166. virtual void AddAdditionalOutputs(OpKernelContext *context) const = 0;
  167. // Returns the output type specification of the this base class.
  168. std::vector<DataType> default_outputs() const {
  169. std::vector<DataType> output_types(feature_size_, DT_STRING);
  170. output_types.push_back(DT_INT32);
  171. return output_types;
  172. }
  173. // Accessors.
  174. int max_batch_size() const { return max_batch_size_; }
  175. int batch_size() const { return sentence_batch_->size(); }
  176. int additional_output_index() const { return feature_size_ + 1; }
  177. ParserState *state(int i) const { return states_[i].get(); }
  178. const ParserTransitionSystem &transition_system() const {
  179. return *transition_system_;
  180. }
  181. // Parser task context.
  182. const TaskContext &task_context() const { return task_context_; }
  183. const string &arg_prefix() const { return arg_prefix_; }
  184. private:
  185. // Task context used to configure this op.
  186. TaskContext task_context_;
  187. // Prefix for context parameters.
  188. string arg_prefix_;
  189. // mutex to synchronize access to Compute.
  190. mutex mu_;
  191. // How many times the document source has been rewinded.
  192. int num_epochs_ = 0;
  193. // How many sentences this op can be processing at any given time.
  194. int max_batch_size_ = 1;
  195. // Number of feature groups in the brain parser features.
  196. int feature_size_ = -1;
  197. // Batch of sentences, and the corresponding parser states.
  198. std::unique_ptr<SentenceBatch> sentence_batch_;
  199. // Batch: ParserState objects.
  200. std::vector<std::unique_ptr<ParserState>> states_;
  201. // Batch: WorkspaceSet objects.
  202. std::vector<WorkspaceSet> workspaces_;
  203. // Dependency label map used in transition system.
  204. const TermFrequencyMap *label_map_;
  205. // Transition system.
  206. std::unique_ptr<ParserTransitionSystem> transition_system_;
  207. // Typed feature extractor for embeddings.
  208. std::unique_ptr<ParserEmbeddingFeatureExtractor> features_;
  209. // Internal workspace registry for use in feature extraction.
  210. WorkspaceRegistry workspace_registry_;
  211. TF_DISALLOW_COPY_AND_ASSIGN(ParsingReader);
  212. };
  213. class GoldParseReader : public ParsingReader {
  214. public:
  215. explicit GoldParseReader(OpKernelConstruction *context)
  216. : ParsingReader(context) {
  217. // Sets up number and type of inputs and outputs.
  218. std::vector<DataType> output_types = default_outputs();
  219. output_types.push_back(DT_INT32);
  220. OP_REQUIRES_OK(context, context->MatchSignature({}, output_types));
  221. }
  222. private:
  223. // Always performs the next gold action for each state.
  224. void PerformActions(OpKernelContext *context) override {
  225. for (int i = 0; i < max_batch_size(); ++i) {
  226. if (state(i) != nullptr) {
  227. transition_system().PerformAction(
  228. transition_system().GetNextGoldAction(*state(i)), state(i));
  229. }
  230. }
  231. }
  232. // Adds the list of gold actions for each state as an additional output.
  233. void AddAdditionalOutputs(OpKernelContext *context) const override {
  234. Tensor *actions_output;
  235. OP_REQUIRES_OK(context, context->allocate_output(
  236. additional_output_index(),
  237. TensorShape({batch_size()}), &actions_output));
  238. // Add all gold actions for non-null states as an additional output.
  239. auto gold_actions = actions_output->vec<int32>();
  240. for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
  241. if (state(i) != nullptr) {
  242. const int gold_action =
  243. transition_system().GetNextGoldAction(*state(i));
  244. gold_actions(batch_index++) = gold_action;
  245. }
  246. }
  247. }
  248. TF_DISALLOW_COPY_AND_ASSIGN(GoldParseReader);
  249. };
  250. REGISTER_KERNEL_BUILDER(Name("GoldParseReader").Device(DEVICE_CPU),
  251. GoldParseReader);
  252. // DecodedParseReader parses sentences using transition scores computed
  253. // by a TensorFlow network. This op additionally computes a token correctness
  254. // evaluation metric which can be used to select hyperparameter settings and
  255. // training stopping point.
  256. //
  257. // The notion of correct token is determined by the transition system, e.g.
  258. // a tagger will return POS tag accuracy, while an arc-standard parser will
  259. // return UAS.
  260. //
  261. // Which tokens should be scored is controlled by the '<arg_prefix>_scoring'
  262. // task parameter. Possible values are
  263. // - 'default': skips tokens with only punctuation in the tag name.
  264. // - 'conllx': skips tokens with only punctuation in the surface form.
  265. // - 'ignore_parens': same as conllx, but skipping parentheses as well.
  266. // - '': scores all tokens.
  267. class DecodedParseReader : public ParsingReader {
  268. public:
  269. explicit DecodedParseReader(OpKernelConstruction *context)
  270. : ParsingReader(context) {
  271. // Sets up number and type of inputs and outputs.
  272. std::vector<DataType> output_types = default_outputs();
  273. output_types.push_back(DT_INT32);
  274. output_types.push_back(DT_STRING);
  275. OP_REQUIRES_OK(context, context->MatchSignature({DT_FLOAT}, output_types));
  276. // Gets scoring parameters.
  277. scoring_type_ = task_context().Get(
  278. tensorflow::strings::StrCat(arg_prefix(), "_scoring"), "");
  279. }
  280. private:
  281. void AdvanceSentence(int index) override {
  282. ParsingReader::AdvanceSentence(index);
  283. if (state(index)) {
  284. docids_.push_front(state(index)->sentence().docid());
  285. }
  286. }
  287. // Tallies the # of correct and incorrect tokens for a given ParserState.
  288. void ComputeTokenAccuracy(const ParserState &state) {
  289. for (int i = 0; i < state.sentence().token_size(); ++i) {
  290. const Token &token = state.GetToken(i);
  291. if (utils::PunctuationUtil::ScoreToken(token.word(), token.tag(),
  292. scoring_type_)) {
  293. ++num_tokens_;
  294. if (state.IsTokenCorrect(i)) ++num_correct_;
  295. }
  296. }
  297. }
  298. // Performs the allowed action with the highest score on the given state.
  299. // Also records the accuracy whenver a terminal action is taken.
  300. void PerformActions(OpKernelContext *context) override {
  301. auto scores_matrix = context->input(0).matrix<float>();
  302. num_tokens_ = 0;
  303. num_correct_ = 0;
  304. for (int i = 0, batch_index = 0; i < max_batch_size(); ++i) {
  305. ParserState *state = this->state(i);
  306. if (state != nullptr) {
  307. int best_action = 0;
  308. float best_score = -INFINITY;
  309. for (int action = 0; action < scores_matrix.dimension(1); ++action) {
  310. float score = scores_matrix(batch_index, action);
  311. if (score > best_score &&
  312. transition_system().IsAllowedAction(action, *state)) {
  313. best_action = action;
  314. best_score = score;
  315. }
  316. }
  317. transition_system().PerformAction(best_action, state);
  318. // Update the # of scored correct tokens if this is the last state
  319. // in the sentence and save the annotated document.
  320. if (transition_system().IsFinalState(*state)) {
  321. ComputeTokenAccuracy(*state);
  322. sentence_map_[state->sentence().docid()] = state->sentence();
  323. state->AddParseToDocument(&sentence_map_[state->sentence().docid()]);
  324. }
  325. ++batch_index;
  326. }
  327. }
  328. }
  329. // Adds the evaluation metrics and annotated documents as additional outputs,
  330. // if there were any terminal states.
  331. void AddAdditionalOutputs(OpKernelContext *context) const override {
  332. Tensor *counts_output;
  333. OP_REQUIRES_OK(context,
  334. context->allocate_output(additional_output_index(),
  335. TensorShape({2}), &counts_output));
  336. auto eval_metrics = counts_output->vec<int32>();
  337. eval_metrics(0) = num_tokens_;
  338. eval_metrics(1) = num_correct_;
  339. // Output annotated documents for each state. To preserve order, repeatedly
  340. // pull from the back of the docids queue as long as the sentences have been
  341. // completely processed. If the next document has not been completely
  342. // processed yet, then the docid will not be found in 'sentence_map_'.
  343. std::vector<Sentence> sentences;
  344. while (!docids_.empty() &&
  345. sentence_map_.find(docids_.back()) != sentence_map_.end()) {
  346. sentences.emplace_back(sentence_map_[docids_.back()]);
  347. sentence_map_.erase(docids_.back());
  348. docids_.pop_back();
  349. }
  350. Tensor *annotated_output;
  351. OP_REQUIRES_OK(context,
  352. context->allocate_output(
  353. additional_output_index() + 1,
  354. TensorShape({static_cast<int64>(sentences.size())}),
  355. &annotated_output));
  356. auto document_output = annotated_output->vec<string>();
  357. for (size_t i = 0; i < sentences.size(); ++i) {
  358. document_output(i) = sentences[i].SerializeAsString();
  359. }
  360. }
  361. // State for eval metric computation.
  362. int num_tokens_ = 0;
  363. int num_correct_ = 0;
  364. // Parameter for deciding which tokens to score.
  365. string scoring_type_;
  366. mutable std::deque<string> docids_;
  367. mutable std::map<string, Sentence> sentence_map_;
  368. TF_DISALLOW_COPY_AND_ASSIGN(DecodedParseReader);
  369. };
  370. REGISTER_KERNEL_BUILDER(Name("DecodedParseReader").Device(DEVICE_CPU),
  371. DecodedParseReader);
  372. class WordEmbeddingInitializer : public OpKernel {
  373. public:
  374. explicit WordEmbeddingInitializer(OpKernelConstruction *context)
  375. : OpKernel(context) {
  376. string file_path, data;
  377. OP_REQUIRES_OK(context, context->GetAttr("task_context", &file_path));
  378. OP_REQUIRES_OK(context, ReadFileToString(tensorflow::Env::Default(),
  379. file_path, &data));
  380. OP_REQUIRES(context,
  381. TextFormat::ParseFromString(data, task_context_.mutable_spec()),
  382. InvalidArgument("Could not parse task context at ", file_path));
  383. OP_REQUIRES_OK(context, context->GetAttr("vectors", &vectors_path_));
  384. OP_REQUIRES_OK(context,
  385. context->GetAttr("embedding_init", &embedding_init_));
  386. // Sets up number and type of inputs and outputs.
  387. OP_REQUIRES_OK(context, context->MatchSignature({}, {DT_FLOAT}));
  388. }
  389. void Compute(OpKernelContext *context) override {
  390. // Loads words from vocabulary with mapping to ids.
  391. string path = TaskContext::InputFile(*task_context_.GetInput("word-map"));
  392. const TermFrequencyMap *word_map =
  393. SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
  394. unordered_map<string, int64> vocab;
  395. for (int i = 0; i < word_map->Size(); ++i) {
  396. vocab[word_map->GetTerm(i)] = i;
  397. }
  398. // Creates a reader pointing to a local copy of the vectors recordio.
  399. string tmp_vectors_path;
  400. OP_REQUIRES_OK(context, CopyToTmpPath(vectors_path_, &tmp_vectors_path));
  401. ProtoRecordReader reader(tmp_vectors_path);
  402. // Loads the embedding vectors into a matrix.
  403. Tensor *embedding_matrix = nullptr;
  404. TokenEmbedding embedding;
  405. while (reader.Read(&embedding) == tensorflow::Status::OK()) {
  406. if (embedding_matrix == nullptr) {
  407. const int embedding_size = embedding.vector().values_size();
  408. OP_REQUIRES_OK(
  409. context, context->allocate_output(
  410. 0, TensorShape({word_map->Size() + 3, embedding_size}),
  411. &embedding_matrix));
  412. embedding_matrix->matrix<float>()
  413. .setRandom<Eigen::internal::NormalRandomGenerator<float>>();
  414. embedding_matrix->matrix<float>() =
  415. embedding_matrix->matrix<float>() * static_cast<float>(
  416. embedding_init_ / sqrt(embedding_size));
  417. }
  418. if (vocab.find(embedding.token()) != vocab.end()) {
  419. SetNormalizedRow(embedding.vector(), vocab[embedding.token()],
  420. embedding_matrix);
  421. }
  422. }
  423. }
  424. private:
  425. // Sets embedding_matrix[row] to a normalized version of the given vector.
  426. void SetNormalizedRow(const TokenEmbedding::Vector &vector, const int row,
  427. Tensor *embedding_matrix) {
  428. float norm = 0.0f;
  429. for (int col = 0; col < vector.values_size(); ++col) {
  430. float val = vector.values(col);
  431. norm += val * val;
  432. }
  433. norm = sqrt(norm);
  434. for (int col = 0; col < vector.values_size(); ++col) {
  435. embedding_matrix->matrix<float>()(row, col) = vector.values(col) / norm;
  436. }
  437. }
  438. // Copies the file at source_path to a temporary file and sets tmp_path to the
  439. // temporary file's location. This is helpful since reading from non local
  440. // files with a record reader can be very slow.
  441. static tensorflow::Status CopyToTmpPath(const string &source_path,
  442. string *tmp_path) {
  443. // Opens source file.
  444. std::unique_ptr<tensorflow::RandomAccessFile> source_file;
  445. TF_RETURN_IF_ERROR(tensorflow::Env::Default()->NewRandomAccessFile(
  446. source_path, &source_file));
  447. // Creates destination file.
  448. std::unique_ptr<tensorflow::WritableFile> target_file;
  449. *tmp_path = tensorflow::strings::Printf(
  450. "/tmp/%d.%lld", getpid(), tensorflow::Env::Default()->NowMicros());
  451. TF_RETURN_IF_ERROR(
  452. tensorflow::Env::Default()->NewWritableFile(*tmp_path, &target_file));
  453. // Performs copy.
  454. tensorflow::Status s;
  455. const size_t kBytesToRead = 10 << 20; // 10MB at a time.
  456. string scratch;
  457. scratch.resize(kBytesToRead);
  458. for (uint64 offset = 0; s.ok(); offset += kBytesToRead) {
  459. tensorflow::StringPiece data;
  460. s.Update(source_file->Read(offset, kBytesToRead, &data, &scratch[0]));
  461. target_file->Append(data);
  462. }
  463. if (s.code() == OUT_OF_RANGE) {
  464. return tensorflow::Status::OK();
  465. } else {
  466. return s;
  467. }
  468. }
  469. // Task context used to configure this op.
  470. TaskContext task_context_;
  471. // Embedding vectors that are not found in the input sstable are initialized
  472. // randomly from a normal distribution with zero mean and
  473. // std dev = embedding_init_ / sqrt(embedding_size).
  474. float embedding_init_ = 1.f;
  475. // Path to recordio with word embedding vectors.
  476. string vectors_path_;
  477. };
  478. REGISTER_KERNEL_BUILDER(Name("WordEmbeddingInitializer").Device(DEVICE_CPU),
  479. WordEmbeddingInitializer);
  480. } // namespace syntaxnet