parser_transitions.h 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Transition system for the transition-based dependency parser.
  13. #ifndef SYNTAXNET_PARSER_TRANSITIONS_H_
  14. #define SYNTAXNET_PARSER_TRANSITIONS_H_
  15. #include <string>
  16. #include <vector>
  17. #include "syntaxnet/registry.h"
  18. #include "syntaxnet/utils.h"
  19. namespace tensorflow {
  20. namespace io {
  21. class RecordReader;
  22. class RecordWriter;
  23. }
  24. }
  25. namespace syntaxnet {
  26. class Sentence;
  27. class ParserState;
  28. class TaskContext;
  29. // Parser actions for the transition system are encoded as integers.
  30. typedef int ParserAction;
  31. // Transition system-specific state. Transition systems can subclass this to
  32. // preprocess the parser state and/or to keep additional information during
  33. // parsing.
  34. class ParserTransitionState {
  35. public:
  36. virtual ~ParserTransitionState() {}
  37. // Clones the transition state.
  38. virtual ParserTransitionState *Clone() const = 0;
  39. // Initializes a parser state for the transition system.
  40. virtual void Init(ParserState *state) = 0;
  41. virtual void AddParseToDocument(const ParserState &state,
  42. bool rewrite_root_labels,
  43. Sentence *sentence) const {}
  44. // Whether a parsed token should be considered correct for evaluation.
  45. virtual bool IsTokenCorrect(const ParserState &state, int index) const = 0;
  46. // Returns a human readable string representation of this state.
  47. virtual string ToString(const ParserState &state) const = 0;
  48. };
  49. // A transition system is used for handling the parser state transitions. During
  50. // training the transition system is used for extracting a canonical sequence of
  51. // transitions for an annotated sentence. During parsing the transition system
  52. // is used for applying the predicted transitions to the parse state and
  53. // therefore build the parse tree for the sentence. Transition systems can be
  54. // implemented by subclassing this abstract class and registered using the
  55. // REGISTER_TRANSITION_SYSTEM macro.
  56. class ParserTransitionSystem
  57. : public RegisterableClass<ParserTransitionSystem> {
  58. public:
  59. // Construction and cleanup.
  60. ParserTransitionSystem() {}
  61. virtual ~ParserTransitionSystem() {}
  62. // Sets up the transition system. If inputs are needed, this is the place to
  63. // specify them.
  64. virtual void Setup(TaskContext *context) {}
  65. // Initializes the transition system.
  66. virtual void Init(TaskContext *context) {}
  67. // Reads the transition system from disk.
  68. virtual void Read(tensorflow::io::RecordReader *reader) {}
  69. // Writes the transition system to disk.
  70. virtual void Write(tensorflow::io::RecordWriter *writer) const {}
  71. // Returns the number of action types.
  72. virtual int NumActionTypes() const = 0;
  73. // Returns the number of actions.
  74. virtual int NumActions(int num_labels) const = 0;
  75. // Internally creates the set of outcomes (when transition systems support a
  76. // variable number of actions).
  77. virtual void CreateOutcomeSet(int num_labels) {}
  78. // Returns the default action for a given state.
  79. virtual ParserAction GetDefaultAction(const ParserState &state) const = 0;
  80. // Returns the next gold action for the parser during training using the
  81. // dependency relations found in the underlying annotated sentence.
  82. virtual ParserAction GetNextGoldAction(const ParserState &state) const = 0;
  83. // Returns all next gold actions for the parser during training using the
  84. // dependency relations found in the underlying annotated sentence.
  85. virtual void GetAllNextGoldActions(const ParserState &state,
  86. std::vector<ParserAction> *actions) const {
  87. ParserAction action = GetNextGoldAction(state);
  88. *actions = {action};
  89. }
  90. // Internally counts all next gold actions from the current parser state.
  91. virtual void CountAllNextGoldActions(const ParserState &state) {}
  92. // Returns the number of atomic actions within the specified ParserAction.
  93. virtual int ActionLength(ParserAction action) const { return 1; }
  94. // Returns true if the action is allowed in the given parser state.
  95. virtual bool IsAllowedAction(ParserAction action,
  96. const ParserState &state) const = 0;
  97. // Performs the specified action on a given parser state. The action is not
  98. // saved in the state's history.
  99. virtual void PerformActionWithoutHistory(ParserAction action,
  100. ParserState *state) const = 0;
  101. // Performs the specified action on a given parser state. The action is saved
  102. // in the state's history.
  103. void PerformAction(ParserAction action, ParserState *state) const;
  104. // Returns true if a given state is deterministic.
  105. virtual bool IsDeterministicState(const ParserState &state) const = 0;
  106. // Returns true if no more actions can be applied to a given parser state.
  107. virtual bool IsFinalState(const ParserState &state) const = 0;
  108. // Returns a string representation of a parser action.
  109. virtual string ActionAsString(ParserAction action,
  110. const ParserState &state) const = 0;
  111. // Returns a new transition state that can be used to put additional
  112. // information in a parser state. By specifying if we are in training_mode
  113. // (true) or not (false), we can construct a different transition state
  114. // depending on whether we are training a model or parsing new documents. A
  115. // null return value means we don't need to add anything to the parser state.
  116. virtual ParserTransitionState *NewTransitionState(bool training_mode) const {
  117. return nullptr;
  118. }
  119. // Whether to back off to the best allowable transition rather than the
  120. // default action when the highest scoring action is not allowed. Some
  121. // transition systems do not degrade gracefully to the default action and so
  122. // should return true for this function.
  123. virtual bool BackOffToBestAllowableTransition() const { return false; }
  124. // Whether the system returns multiple gold transitions from a single
  125. // configuration.
  126. virtual bool ReturnsMultipleGoldTransitions() const { return false; }
  127. // Whether the system allows non-projective trees.
  128. virtual bool AllowsNonProjective() const { return false; }
  129. // Action meta data: get pointers to token indices based on meta-info about
  130. // (state, action) pairs. NOTE: the following interface is somewhat
  131. // experimental and may be subject to change. Use with caution and ask
  132. // googleuser@ for details.
  133. // Whether or not the system supports computing meta-data about actions.
  134. virtual bool SupportsActionMetaData() const { return false; }
  135. // Get the index of the child that would be created by this action. -1 for
  136. // no child created.
  137. virtual int ChildIndex(const ParserState &state,
  138. const ParserAction &action) const {
  139. return -1;
  140. }
  141. // Get the index of the parent that would gain a new child by this action. -1
  142. // for no parent modified.
  143. virtual int ParentIndex(const ParserState &state,
  144. const ParserAction &action) const {
  145. return -1;
  146. }
  147. private:
  148. TF_DISALLOW_COPY_AND_ASSIGN(ParserTransitionSystem);
  149. };
  150. #define REGISTER_TRANSITION_SYSTEM(type, component) \
  151. REGISTER_SYNTAXNET_CLASS_COMPONENT(ParserTransitionSystem, type, component)
  152. } // namespace syntaxnet
  153. #endif // SYNTAXNET_PARSER_TRANSITIONS_H_