syntaxnet_transition_state.h 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
  17. #include <vector>
  18. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  19. #include "dragnn/core/interfaces/transition_state.h"
  20. #include "dragnn/io/syntaxnet_sentence.h"
  21. #include "dragnn/protos/trace.pb.h"
  22. #include "syntaxnet/base.h"
  23. #include "syntaxnet/parser_state.h"
  24. namespace syntaxnet {
  25. namespace dragnn {
  26. class SyntaxNetTransitionState
  27. : public CloneableTransitionState<SyntaxNetTransitionState> {
  28. public:
  29. // Create a SyntaxNetTransitionState to wrap this nlp_saft::ParserState.
  30. SyntaxNetTransitionState(std::unique_ptr<ParserState> parser_state,
  31. SyntaxNetSentence *sentence);
  32. // Initialize this TransitionState from a previous TransitionState. The
  33. // ParentBeamIndex is the location of that previous TransitionState in the
  34. // provided beam.
  35. void Init(const TransitionState &parent) override;
  36. // Produces a new state with the same backing data as this state.
  37. std::unique_ptr<SyntaxNetTransitionState> Clone() const override;
  38. // Return the beam index of the state passed into the initializer of this
  39. // TransitionState.
  40. const int ParentBeamIndex() const override;
  41. // Get the current beam index for this state.
  42. const int GetBeamIndex() const override;
  43. // Set the current beam index for this state.
  44. void SetBeamIndex(const int index) override;
  45. // Get the score associated with this transition state.
  46. const float GetScore() const override;
  47. // Set the score associated with this transition state.
  48. void SetScore(const float score) override;
  49. // Depicts this state as an HTML-language string.
  50. string HTMLRepresentation() const override;
  51. // **** END INHERITED INTERFACE ****
  52. // TODO(googleuser): Make these comments actually mean something.
  53. // Data accessor.
  54. int step_for_token(int token) {
  55. if (token < 0 || token >= step_for_token_.size()) {
  56. return -1;
  57. } else {
  58. return step_for_token_.at(token);
  59. }
  60. }
  61. // Data setter.
  62. void set_step_for_token(int token, int step) {
  63. step_for_token_.insert(step_for_token_.begin() + token, step);
  64. }
  65. // Data accessor.
  66. int parent_step_for_token(int token) {
  67. if (token < 0 || token >= step_for_token_.size()) {
  68. return -1;
  69. } else {
  70. return parent_step_for_token_.at(token);
  71. }
  72. }
  73. // Data setter.
  74. void set_parent_step_for_token(int token, int parent_step) {
  75. parent_step_for_token_.insert(parent_step_for_token_.begin() + token,
  76. parent_step);
  77. }
  78. // Data accessor.
  79. int parent_for_token(int token) {
  80. if (token < 0 || token >= step_for_token_.size()) {
  81. return -1;
  82. } else {
  83. return parent_for_token_.at(token);
  84. }
  85. }
  86. // Data setter.
  87. void set_parent_for_token(int token, int parent) {
  88. parent_for_token_.insert(parent_for_token_.begin() + token, parent);
  89. }
  90. // Accessor for the underlying nlp_saft::ParserState.
  91. ParserState *parser_state() { return parser_state_.get(); }
  92. // Accessor for the underlying sentence object.
  93. SyntaxNetSentence *sentence() { return sentence_; }
  94. ComponentTrace *mutable_trace() {
  95. CHECK(trace_) << "Trace is not initialized";
  96. return trace_.get();
  97. }
  98. void set_trace(std::unique_ptr<ComponentTrace> trace) {
  99. trace_ = std::move(trace);
  100. }
  101. private:
  102. // Underlying ParserState object that is being wrapped.
  103. std::unique_ptr<ParserState> parser_state_;
  104. // Sentence object that is being examined with this state.
  105. SyntaxNetSentence *sentence_;
  106. // The current score of this state.
  107. float score_;
  108. // The current beam index of this state.
  109. int current_beam_index_;
  110. // The parent beam index for this state.
  111. int parent_beam_index_;
  112. // Maintains a list of which steps in the history correspond to
  113. // representations for each of the tokens on the stack.
  114. std::vector<int> step_for_token_;
  115. // Maintains a list of which steps in the history correspond to the actions
  116. // that assigned a parent for tokens when reduced.
  117. std::vector<int> parent_step_for_token_;
  118. // Maintain the parent index of a token in the system.
  119. std::vector<int> parent_for_token_;
  120. // Trace of the history to produce this state.
  121. std::unique_ptr<ComponentTrace> trace_;
  122. };
  123. } // namespace dragnn
  124. } // namespace syntaxnet
  125. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_