arc_standard_transitions.cc 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Arc-standard transition system.
  13. //
  14. // This transition system has three types of actions:
  15. // - The SHIFT action pushes the next input token to the stack and
  16. // advances to the next input token.
  17. // - The LEFT_ARC action adds a dependency relation from first to second token
  18. // on the stack and removes second one.
  19. // - The RIGHT_ARC action adds a dependency relation from second to first token
  20. // on the stack and removes the first one.
  21. //
  22. // The transition system operates with parser actions encoded as integers:
  23. // - A SHIFT action is encoded as 0.
  24. // - A LEFT_ARC action is encoded as an odd number starting from 1.
  25. // - A RIGHT_ARC action is encoded as an even number starting from 2.
  26. #include <string>
  27. #include "syntaxnet/parser_state.h"
  28. #include "syntaxnet/parser_transitions.h"
  29. #include "syntaxnet/utils.h"
  30. #include "tensorflow/core/lib/strings/strcat.h"
  31. namespace syntaxnet {
  32. class ArcStandardTransitionState : public ParserTransitionState {
  33. public:
  34. // Clones the transition state by returning a new object.
  35. ParserTransitionState *Clone() const override {
  36. return new ArcStandardTransitionState();
  37. }
  38. // Pushes the root on the stack before using the parser state in parsing.
  39. void Init(ParserState *state) override { state->Push(-1); }
  40. // Adds transition state specific annotations to the document.
  41. void AddParseToDocument(const ParserState &state, bool rewrite_root_labels,
  42. Sentence *sentence) const override {
  43. for (int i = 0; i < state.NumTokens(); ++i) {
  44. Token *token = sentence->mutable_token(i);
  45. token->set_label(state.LabelAsString(state.Label(i)));
  46. if (state.Head(i) != -1) {
  47. token->set_head(state.Head(i));
  48. } else {
  49. token->clear_head();
  50. if (rewrite_root_labels) {
  51. token->set_label(state.LabelAsString(state.RootLabel()));
  52. }
  53. }
  54. }
  55. }
  56. // Whether a parsed token should be considered correct for evaluation.
  57. bool IsTokenCorrect(const ParserState &state, int index) const override {
  58. return state.GoldHead(index) == state.Head(index);
  59. }
  60. // Returns a human readable string representation of this state.
  61. string ToString(const ParserState &state) const override {
  62. string str;
  63. str.append("[");
  64. for (int i = state.StackSize() - 1; i >= 0; --i) {
  65. const string &word = state.GetToken(state.Stack(i)).word();
  66. if (i != state.StackSize() - 1) str.append(" ");
  67. if (word == "") {
  68. str.append(ParserState::kRootLabel);
  69. } else {
  70. str.append(word);
  71. }
  72. }
  73. str.append("]");
  74. for (int i = state.Next(); i < state.NumTokens(); ++i) {
  75. tensorflow::strings::StrAppend(&str, " ", state.GetToken(i).word());
  76. }
  77. return str;
  78. }
  79. };
  80. class ArcStandardTransitionSystem : public ParserTransitionSystem {
  81. public:
  82. // Action types for the arc-standard transition system.
  83. enum ParserActionType {
  84. SHIFT = 0,
  85. LEFT_ARC = 1,
  86. RIGHT_ARC = 2,
  87. };
  88. // The SHIFT action uses the same value as the corresponding action type.
  89. static ParserAction ShiftAction() { return SHIFT; }
  90. // The LEFT_ARC action converts the label to an odd number greater or equal
  91. // to 1.
  92. static ParserAction LeftArcAction(int label) { return 1 + (label << 1); }
  93. // The RIGHT_ARC action converts the label to an even number greater or equal
  94. // to 2.
  95. static ParserAction RightArcAction(int label) {
  96. return 1 + ((label << 1) | 1);
  97. }
  98. // Extracts the action type from a given parser action.
  99. static ParserActionType ActionType(ParserAction action) {
  100. return static_cast<ParserActionType>(action < 1 ? action
  101. : 1 + (~action & 1));
  102. }
  103. // Extracts the label from a given parser action. If the action is SHIFT,
  104. // returns -1.
  105. static int Label(ParserAction action) {
  106. return action < 1 ? -1 : (action - 1) >> 1;
  107. }
  108. // Returns the number of action types.
  109. int NumActionTypes() const override { return 3; }
  110. // Returns the number of possible actions.
  111. int NumActions(int num_labels) const override { return 1 + 2 * num_labels; }
  112. // The method returns the default action for a given state.
  113. ParserAction GetDefaultAction(const ParserState &state) const override {
  114. // If there are further tokens available in the input then Shift.
  115. if (!state.EndOfInput()) return ShiftAction();
  116. // Do a "reduce".
  117. return RightArcAction(2);
  118. }
  119. // Returns the next gold action for a given state according to the
  120. // underlying annotated sentence.
  121. ParserAction GetNextGoldAction(const ParserState &state) const override {
  122. // If the stack contains less than 2 tokens, the only valid parser action is
  123. // shift.
  124. if (state.StackSize() < 2) {
  125. // It is illegal to request the gold action if the transition system is
  126. // in a terminal state.
  127. CHECK(!state.EndOfInput());
  128. VLOG(2) << "Gold action: SHIFT (stack < 2 tokens)";
  129. return ShiftAction();
  130. }
  131. // If the second token on the stack is the head of the first one,
  132. // return a right arc action.
  133. if (state.GoldHead(state.Stack(0)) == state.Stack(1) &&
  134. DoneChildrenRightOf(state, state.Stack(0))) {
  135. const int gold_label = state.GoldLabel(state.Stack(0));
  136. VLOG(2) << "Gold action: RIGHT_ARC, label:" << gold_label;
  137. return RightArcAction(gold_label);
  138. }
  139. // If the first token on the stack is the head of the second one,
  140. // return a left arc action.
  141. if (state.GoldHead(state.Stack(1)) == state.Top()) {
  142. const int gold_label = state.GoldLabel(state.Stack(1));
  143. VLOG(2) << "Gold action: LEFT_ARC, label:" << gold_label;
  144. return LeftArcAction(gold_label);
  145. }
  146. // Otherwise, shift.
  147. VLOG(2) << "Gold action: SHIFT (default)";
  148. return ShiftAction();
  149. }
  150. // Determines if a token has any children to the right in the sentence.
  151. // Arc standard is a bottom-up parsing method and has to finish all sub-trees
  152. // first.
  153. static bool DoneChildrenRightOf(const ParserState &state, int head) {
  154. int index = state.Next();
  155. int num_tokens = state.sentence().token_size();
  156. while (index < num_tokens) {
  157. // Check if the token at index is the child of head.
  158. int actual_head = state.GoldHead(index);
  159. if (actual_head == head) return false;
  160. // If the head of the token at index is to the right of it there cannot be
  161. // any children in-between, so we can skip forward to the head. Note this
  162. // is only true for projective trees.
  163. if (actual_head > index) {
  164. index = actual_head;
  165. } else {
  166. ++index;
  167. }
  168. }
  169. return true;
  170. }
  171. // Checks if the action is allowed in a given parser state.
  172. bool IsAllowedAction(ParserAction action,
  173. const ParserState &state) const override {
  174. switch (ActionType(action)) {
  175. case SHIFT:
  176. return IsAllowedShift(state);
  177. case LEFT_ARC:
  178. return IsAllowedLeftArc(state);
  179. case RIGHT_ARC:
  180. return IsAllowedRightArc(state);
  181. }
  182. return false;
  183. }
  184. // Returns true if a shift is allowed in the given parser state.
  185. bool IsAllowedShift(const ParserState &state) const {
  186. // We can shift if there are more input tokens.
  187. return !state.EndOfInput();
  188. }
  189. // Returns true if a left-arc is allowed in the given parser state.
  190. bool IsAllowedLeftArc(const ParserState &state) const {
  191. // Left-arc requires two or more tokens on the stack but the first token
  192. // is the root an we do not want and left arc to the root.
  193. return state.StackSize() > 2;
  194. }
  195. // Returns true if a right-arc is allowed in the given parser state.
  196. bool IsAllowedRightArc(const ParserState &state) const {
  197. // Right arc requires three or more tokens on the stack.
  198. return state.StackSize() > 1;
  199. }
  200. // Performs the specified action on a given parser state, without adding the
  201. // action to the state's history.
  202. void PerformActionWithoutHistory(ParserAction action,
  203. ParserState *state) const override {
  204. switch (ActionType(action)) {
  205. case SHIFT:
  206. PerformShift(state);
  207. break;
  208. case LEFT_ARC:
  209. PerformLeftArc(state, Label(action));
  210. break;
  211. case RIGHT_ARC:
  212. PerformRightArc(state, Label(action));
  213. break;
  214. }
  215. }
  216. // Makes a shift by pushing the next input token on the stack and moving to
  217. // the next position.
  218. void PerformShift(ParserState *state) const {
  219. DCHECK(IsAllowedShift(*state));
  220. state->Push(state->Next());
  221. state->Advance();
  222. }
  223. // Makes a left-arc between the two top tokens on stack and pops the second
  224. // token on stack.
  225. void PerformLeftArc(ParserState *state, int label) const {
  226. DCHECK(IsAllowedLeftArc(*state));
  227. int s0 = state->Pop();
  228. state->AddArc(state->Pop(), s0, label);
  229. state->Push(s0);
  230. }
  231. // Makes a right-arc between the two top tokens on stack and pops the stack.
  232. void PerformRightArc(ParserState *state, int label) const {
  233. DCHECK(IsAllowedRightArc(*state));
  234. int s0 = state->Pop();
  235. int s1 = state->Pop();
  236. state->AddArc(s0, s1, label);
  237. state->Push(s1);
  238. }
  239. // We are in a deterministic state when we either reached the end of the input
  240. // or reduced everything from the stack.
  241. bool IsDeterministicState(const ParserState &state) const override {
  242. return state.StackSize() < 2 && !state.EndOfInput();
  243. }
  244. // We are in a final state when we reached the end of the input and the stack
  245. // is empty.
  246. bool IsFinalState(const ParserState &state) const override {
  247. VLOG(2) << "Final state check: EOI: " << state.EndOfInput()
  248. << " Stack size: " << state.StackSize();
  249. return state.EndOfInput() && state.StackSize() < 2;
  250. }
  251. // Returns a string representation of a parser action.
  252. string ActionAsString(ParserAction action,
  253. const ParserState &state) const override {
  254. switch (ActionType(action)) {
  255. case SHIFT:
  256. return "SHIFT";
  257. case LEFT_ARC:
  258. return "LEFT_ARC(" + state.LabelAsString(Label(action)) + ")";
  259. case RIGHT_ARC:
  260. return "RIGHT_ARC(" + state.LabelAsString(Label(action)) + ")";
  261. }
  262. return "UNKNOWN";
  263. }
  264. // Returns a new transition state to be used to enhance the parser state.
  265. ParserTransitionState *NewTransitionState(bool training_mode) const override {
  266. return new ArcStandardTransitionState();
  267. }
  268. // Meta information API. Returns token indices to link parser actions back
  269. // to positions in the input sentence.
  270. bool SupportsActionMetaData() const override { return true; }
  271. // Returns the child of a new arc for reduce actions.
  272. int ChildIndex(const ParserState &state,
  273. const ParserAction &action) const override {
  274. switch (ActionType(action)) {
  275. case SHIFT:
  276. return -1;
  277. case LEFT_ARC: // left arc pops stack(1)
  278. return state.Stack(1);
  279. case RIGHT_ARC:
  280. return state.Stack(0);
  281. default:
  282. LOG(FATAL) << "Invalid parser action: " << action;
  283. }
  284. }
  285. // Returns the parent of a new arc for reduce actions.
  286. int ParentIndex(const ParserState &state,
  287. const ParserAction &action) const override {
  288. switch (ActionType(action)) {
  289. case SHIFT:
  290. return -1;
  291. case LEFT_ARC: // left arc pops stack(1)
  292. return state.Stack(0);
  293. case RIGHT_ARC:
  294. return state.Stack(1);
  295. default:
  296. LOG(FATAL) << "Invalid parser action: " << action;
  297. }
  298. }
  299. };
  300. REGISTER_TRANSITION_SYSTEM("arc-standard", ArcStandardTransitionSystem);
  301. } // namespace syntaxnet