syntaxnet_transition_state.h 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
  3. #include <vector>
  4. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  5. #include "dragnn/core/interfaces/transition_state.h"
  6. #include "dragnn/io/syntaxnet_sentence.h"
  7. #include "dragnn/protos/trace.pb.h"
  8. #include "syntaxnet/base.h"
  9. #include "syntaxnet/parser_state.h"
  10. namespace syntaxnet {
  11. namespace dragnn {
  12. class SyntaxNetTransitionState
  13. : public CloneableTransitionState<SyntaxNetTransitionState> {
  14. public:
  15. // Create a SyntaxNetTransitionState to wrap this nlp_saft::ParserState.
  16. SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state,
  17. SyntaxNetSentence *sentence);
  18. // Initialize this TransitionState from a previous TransitionState. The
  19. // ParentBeamIndex is the location of that previous TransitionState in the
  20. // provided beam.
  21. void Init(const TransitionState &parent) override;
  22. // Produces a new state with the same backing data as this state.
  23. std::unique_ptr<SyntaxNetTransitionState> Clone() const override;
  24. // Return the beam index of the state passed into the initializer of this
  25. // TransitionState.
  26. const int ParentBeamIndex() const override;
  27. // Get the current beam index for this state.
  28. const int GetBeamIndex() const override;
  29. // Set the current beam index for this state.
  30. void SetBeamIndex(const int index) override;
  31. // Get the score associated with this transition state.
  32. const float GetScore() const override;
  33. // Set the score associated with this transition state.
  34. void SetScore(const float score) override;
  35. // Depicts this state as an HTML-language string.
  36. string HTMLRepresentation() const override;
  37. // **** END INHERITED INTERFACE ****
  38. // TODO(googleuser): Make these comments actually mean something.
  39. // Data accessor.
  40. int step_for_token(int token) {
  41. if (token < 0 || token >= step_for_token_.size()) {
  42. return -1;
  43. } else {
  44. return step_for_token_.at(token);
  45. }
  46. }
  47. // Data setter.
  48. void set_step_for_token(int token, int step) {
  49. step_for_token_.insert(step_for_token_.begin() + token, step);
  50. }
  51. // Data accessor.
  52. int parent_step_for_token(int token) {
  53. if (token < 0 || token >= step_for_token_.size()) {
  54. return -1;
  55. } else {
  56. return parent_step_for_token_.at(token);
  57. }
  58. }
  59. // Data setter.
  60. void set_parent_step_for_token(int token, int parent_step) {
  61. parent_step_for_token_.insert(parent_step_for_token_.begin() + token,
  62. parent_step);
  63. }
  64. // Data accessor.
  65. int parent_for_token(int token) {
  66. if (token < 0 || token >= step_for_token_.size()) {
  67. return -1;
  68. } else {
  69. return parent_for_token_.at(token);
  70. }
  71. }
  72. // Data setter.
  73. void set_parent_for_token(int token, int parent) {
  74. parent_for_token_.insert(parent_for_token_.begin() + token, parent);
  75. }
  76. // Accessor for the underlying nlp_saft::ParserState.
  77. ParserState *parser_state() { return parser_state_.get(); }
  78. // Accessor for the underlying sentence object.
  79. SyntaxNetSentence *sentence() { return sentence_; }
  80. ComponentTrace *mutable_trace() {
  81. CHECK(trace_) << "Trace is not initialized";
  82. return trace_.get();
  83. }
  84. void set_trace(std::unique_ptr<ComponentTrace> trace) {
  85. trace_ = std::move(trace);
  86. }
  87. private:
  88. // Underlying ParserState object that is being wrapped.
  89. std::unique_ptr<ParserState> parser_state_;
  90. // Sentence object that is being examined with this state.
  91. SyntaxNetSentence *sentence_;
  92. // The current score of this state.
  93. float score_;
  94. // The current beam index of this state.
  95. int current_beam_index_;
  96. // The parent beam index for this state.
  97. int parent_beam_index_;
  98. // Maintains a list of which steps in the history correspond to
  99. // representations for each of the tokens on the stack.
  100. std::vector<int> step_for_token_;
  101. // Maintains a list of which steps in the history correspond to the actions
  102. // that assigned a parent for tokens when reduced.
  103. std::vector<int> parent_step_for_token_;
  104. // Maintain the parent index of a token in the system.
  105. std::vector<int> parent_for_token_;
  106. // Trace of the history to produce this state.
  107. std::unique_ptr<ComponentTrace> trace_;
  108. };
  109. } // namespace dragnn
  110. } // namespace syntaxnet
  111. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_