morpher_transitions.cc 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Morpher transition system.
  13. //
  14. // This transition system has one type of actions:
  15. // - The SHIFT action pushes the next input token to the stack and
  16. // advances to the next input token, assigning a part-of-speech tag to the
  17. // token that was shifted.
  18. //
  19. // The transition system operates with parser actions encoded as integers:
  20. // - A SHIFT action is encoded as number starting from 0.
  21. #include <string>
  22. #include "syntaxnet/morphology_label_set.h"
  23. #include "syntaxnet/parser_features.h"
  24. #include "syntaxnet/parser_state.h"
  25. #include "syntaxnet/parser_transitions.h"
  26. #include "syntaxnet/sentence_features.h"
  27. #include "syntaxnet/shared_store.h"
  28. #include "syntaxnet/task_context.h"
  29. #include "syntaxnet/term_frequency_map.h"
  30. #include "syntaxnet/utils.h"
  31. #include "tensorflow/core/lib/strings/strcat.h"
  32. namespace syntaxnet {
  33. class MorphologyTransitionState : public ParserTransitionState {
  34. public:
  35. explicit MorphologyTransitionState(const MorphologyLabelSet *label_set)
  36. : label_set_(label_set) {}
  37. explicit MorphologyTransitionState(const MorphologyTransitionState *state)
  38. : MorphologyTransitionState(state->label_set_) {
  39. tag_ = state->tag_;
  40. gold_tag_ = state->gold_tag_;
  41. }
  42. // Clones the transition state by returning a new object.
  43. ParserTransitionState *Clone() const override {
  44. return new MorphologyTransitionState(this);
  45. }
  46. // Reads gold tags for each token.
  47. void Init(ParserState *state) override {
  48. tag_.resize(state->sentence().token_size(), -1);
  49. gold_tag_.resize(state->sentence().token_size(), -1);
  50. for (int pos = 0; pos < state->sentence().token_size(); ++pos) {
  51. const Token &token = state->GetToken(pos);
  52. // NOTE: we allow token to not have a TokenMorphology extension or for the
  53. // TokenMorphology to be absent from the label_set_ because this can
  54. // happen at test time.
  55. gold_tag_[pos] = label_set_->LookupExisting(
  56. token.GetExtension(TokenMorphology::morphology));
  57. }
  58. }
  59. // Returns the tag assigned to a given token.
  60. int Tag(int index) const {
  61. DCHECK_GE(index, 0);
  62. DCHECK_LT(index, tag_.size());
  63. return index == -1 ? -1 : tag_[index];
  64. }
  65. // Sets this tag on the token at index.
  66. void SetTag(int index, int tag) {
  67. DCHECK_GE(index, 0);
  68. DCHECK_LT(index, tag_.size());
  69. tag_[index] = tag;
  70. }
  71. // Returns the gold tag for a given token.
  72. int GoldTag(int index) const {
  73. DCHECK_GE(index, -1);
  74. DCHECK_LT(index, gold_tag_.size());
  75. return index == -1 ? -1 : gold_tag_[index];
  76. }
  77. // Returns the proto corresponding to the tag, or an empty proto if the tag is
  78. // not found.
  79. const TokenMorphology &TagAsProto(int tag) const {
  80. if (tag >= 0 && tag < label_set_->Size()) {
  81. return label_set_->Lookup(tag);
  82. }
  83. return TokenMorphology::default_instance();
  84. }
  85. // Adds transition state specific annotations to the document.
  86. void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
  87. Sentence *sentence) const override {
  88. for (int i = 0; i < tag_.size(); ++i) {
  89. Token *token = sentence->mutable_token(i);
  90. *token->MutableExtension(TokenMorphology::morphology) =
  91. TagAsProto(Tag(i));
  92. }
  93. }
  94. // Whether a parsed token should be considered correct for evaluation.
  95. bool IsTokenCorrect(const ParserState &state, int index) const override {
  96. return GoldTag(index) == Tag(index);
  97. }
  98. // Returns a human readable string representation of this state.
  99. string ToString(const ParserState &state) const override {
  100. string str;
  101. for (int i = state.StackSize(); i > 0; --i) {
  102. const string &word = state.GetToken(state.Stack(i - 1)).word();
  103. if (i != state.StackSize() - 1) str.append(" ");
  104. tensorflow::strings::StrAppend(
  105. &str, word, "[",
  106. TagAsProto(Tag(state.StackSize() - i)).ShortDebugString(), "]");
  107. }
  108. for (int i = state.Next(); i < state.NumTokens(); ++i) {
  109. tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
  110. }
  111. return str;
  112. }
  113. private:
  114. // Currently assigned morphological analysis for each token in this sentence.
  115. std::vector<int> tag_;
  116. // Gold morphological analysis from the input document.
  117. std::vector<int> gold_tag_;
  118. // Tag map used for conversions between integer and string representations
  119. // part of speech tags. Not owned.
  120. const MorphologyLabelSet *label_set_ = nullptr;
  121. TF_DISALLOW_COPY_AND_ASSIGN(MorphologyTransitionState);
  122. };
  123. class MorphologyTransitionSystem : public ParserTransitionSystem {
  124. public:
  125. ~MorphologyTransitionSystem() override { SharedStore::Release(label_set_); }
  126. // Determines tag map location.
  127. void Setup(TaskContext *context) override {
  128. context->GetInput("morph-label-set");
  129. }
  130. // Reads tag map and tag to category map.
  131. void Init(TaskContext *context) override {
  132. const string fname =
  133. TaskContext::InputFile(*context->GetInput("morph-label-set"));
  134. label_set_ =
  135. SharedStoreUtils::GetWithDefaultName<MorphologyLabelSet>(fname);
  136. }
  137. // The SHIFT action uses the same value as the corresponding action type.
  138. static ParserAction ShiftAction(int tag) { return tag; }
  139. // The morpher transition system doesn't look at the dependency tree, so it
  140. // allows non-projective trees.
  141. bool AllowsNonProjective() const override { return true; }
  142. // Returns the number of action types.
  143. int NumActionTypes() const override { return 1; }
  144. // Returns the number of possible actions.
  145. int NumActions(int num_labels) const override { return label_set_->Size(); }
  146. // The default action for a given state is assigning the most frequent tag.
  147. ParserAction GetDefaultAction(const ParserState &state) const override {
  148. return ShiftAction(0);
  149. }
  150. // Returns the next gold action for a given state according to the
  151. // underlying annotated sentence.
  152. ParserAction GetNextGoldAction(const ParserState &state) const override {
  153. if (!state.EndOfInput()) {
  154. return ShiftAction(TransitionState(state).GoldTag(state.Next()));
  155. }
  156. return ShiftAction(0);
  157. }
  158. // Checks if the action is allowed in a given parser state.
  159. bool IsAllowedAction(ParserAction action,
  160. const ParserState &state) const override {
  161. return !state.EndOfInput();
  162. }
  163. // Makes a shift by pushing the next input token on the stack and moving to
  164. // the next position.
  165. void PerformActionWithoutHistory(ParserAction action,
  166. ParserState *state) const override {
  167. DCHECK(!state->EndOfInput());
  168. if (!state->EndOfInput()) {
  169. MutableTransitionState(state)->SetTag(state->Next(), action);
  170. state->Push(state->Next());
  171. state->Advance();
  172. }
  173. }
  174. // We are in a final state when we reached the end of the input and the stack
  175. // is empty.
  176. bool IsFinalState(const ParserState &state) const override {
  177. return state.EndOfInput();
  178. }
  179. // Returns a string representation of a parser action.
  180. string ActionAsString(ParserAction action,
  181. const ParserState &state) const override {
  182. return tensorflow::strings::StrCat(
  183. "SHIFT(", label_set_->Lookup(action).ShortDebugString(), ")");
  184. }
  185. // No state is deterministic in this transition system.
  186. bool IsDeterministicState(const ParserState &state) const override {
  187. return false;
  188. }
  189. // Returns a new transition state to be used to enhance the parser state.
  190. ParserTransitionState *NewTransitionState(bool training_mode) const override {
  191. return new MorphologyTransitionState(label_set_);
  192. }
  193. // Downcasts the const ParserTransitionState in ParserState to a const
  194. // MorphologyTransitionState.
  195. static const MorphologyTransitionState &TransitionState(
  196. const ParserState &state) {
  197. return *static_cast<const MorphologyTransitionState *>(
  198. state.transition_state());
  199. }
  200. // Downcasts the ParserTransitionState in ParserState to an
  201. // MorphologyTransitionState.
  202. static MorphologyTransitionState *MutableTransitionState(ParserState *state) {
  203. return static_cast<MorphologyTransitionState *>(
  204. state->mutable_transition_state());
  205. }
  206. // Input for the tag map. Not owned.
  207. TaskInput *input_label_set_ = nullptr;
  208. // Tag map used for conversions between integer and string representations
  209. // morphology labels. Owned through SharedStore.
  210. const MorphologyLabelSet *label_set_;
  211. };
  212. REGISTER_TRANSITION_SYSTEM("morpher", MorphologyTransitionSystem);
  213. // Feature function for retrieving the tag assigned to a token by the tagger
  214. // transition system.
  215. class PredictedMorphTagFeatureFunction : public ParserIndexFeatureFunction {
  216. public:
  217. PredictedMorphTagFeatureFunction() {}
  218. // Determines tag map location.
  219. void Setup(TaskContext *context) override {
  220. context->GetInput("morph-label-set", "recordio", "token-morphology");
  221. }
  222. // Reads tag map.
  223. void Init(TaskContext *context) override {
  224. const string fname =
  225. TaskContext::InputFile(*context->GetInput("morph-label-set"));
  226. label_set_ = SharedStore::Get<MorphologyLabelSet>(fname, fname);
  227. set_feature_type(new FullLabelFeatureType(name(), label_set_));
  228. }
  229. // Gets the MorphologyTransitionState from the parser state and reads the
  230. // assigned
  231. // tag at the focus index. Returns -1 if the focus is not within the sentence.
  232. FeatureValue Compute(const WorkspaceSet &workspaces, const ParserState &state,
  233. int focus, const FeatureVector *result) const override {
  234. if (focus < 0 || focus >= state.sentence().token_size()) return -1;
  235. return static_cast<const MorphologyTransitionState *>(
  236. state.transition_state())
  237. ->Tag(focus);
  238. }
  239. private:
  240. // Tag map used for conversions between integer and string representations
  241. // part of speech tags. Owned through SharedStore.
  242. const MorphologyLabelSet *label_set_;
  243. TF_DISALLOW_COPY_AND_ASSIGN(PredictedMorphTagFeatureFunction);
  244. };
  245. REGISTER_PARSER_IDX_FEATURE_FUNCTION("pred-morph-tag",
  246. PredictedMorphTagFeatureFunction);
  247. } // namespace syntaxnet