parser_state.cc 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/parser_state.h"
  13. #include "syntaxnet/kbest_syntax.pb.h"
  14. #include "syntaxnet/sentence.pb.h"
  15. #include "syntaxnet/term_frequency_map.h"
  16. #include "syntaxnet/utils.h"
  17. namespace syntaxnet {
  18. const char ParserState::kRootLabel[] = "ROOT";
  19. ParserState::ParserState(Sentence *sentence,
  20. ParserTransitionState *transition_state,
  21. const TermFrequencyMap *label_map)
  22. : sentence_(sentence),
  23. num_tokens_(sentence->token_size()),
  24. transition_state_(transition_state),
  25. label_map_(label_map),
  26. root_label_(kDefaultRootLabel),
  27. next_(0) {
  28. // Initialize the stack. Some transition systems could also push the
  29. // artificial root on the stack, so we make room for that as well.
  30. stack_.reserve(num_tokens_ + 1);
  31. // Allocate space for head indices and labels. Initialize the head for all
  32. // tokens to be the artificial root node, i.e. token -1.
  33. head_.resize(num_tokens_, -1);
  34. label_.resize(num_tokens_, RootLabel());
  35. // Transition system-specific preprocessing.
  36. if (transition_state_ != nullptr) transition_state_->Init(this);
  37. }
  38. ParserState::~ParserState() { delete transition_state_; }
  39. ParserState *ParserState::Clone() const {
  40. ParserState *new_state = new ParserState();
  41. new_state->sentence_ = sentence_;
  42. new_state->num_tokens_ = num_tokens_;
  43. new_state->alternative_ = alternative_;
  44. new_state->transition_state_ =
  45. (transition_state_ == nullptr ? nullptr : transition_state_->Clone());
  46. new_state->label_map_ = label_map_;
  47. new_state->root_label_ = root_label_;
  48. new_state->next_ = next_;
  49. new_state->stack_.assign(stack_.begin(), stack_.end());
  50. new_state->head_.assign(head_.begin(), head_.end());
  51. new_state->label_.assign(label_.begin(), label_.end());
  52. new_state->score_ = score_;
  53. new_state->is_gold_ = is_gold_;
  54. return new_state;
  55. }
  56. int ParserState::RootLabel() const { return root_label_; }
  57. int ParserState::Next() const {
  58. DCHECK_GE(next_, -1);
  59. DCHECK_LE(next_, num_tokens_);
  60. return next_;
  61. }
  62. int ParserState::Input(int offset) const {
  63. int index = next_ + offset;
  64. return index >= -1 && index < num_tokens_ ? index : -2;
  65. }
  66. void ParserState::Advance() {
  67. DCHECK_LT(next_, num_tokens_);
  68. ++next_;
  69. }
  70. bool ParserState::EndOfInput() const { return next_ == num_tokens_; }
  71. void ParserState::Push(int index) {
  72. DCHECK_LE(stack_.size(), num_tokens_);
  73. stack_.push_back(index);
  74. }
  75. int ParserState::Pop() {
  76. DCHECK(!StackEmpty());
  77. const int result = stack_.back();
  78. stack_.pop_back();
  79. return result;
  80. }
  81. int ParserState::Top() const {
  82. DCHECK(!StackEmpty());
  83. return stack_.back();
  84. }
  85. int ParserState::Stack(int position) const {
  86. if (position < 0) return -2;
  87. const int index = stack_.size() - 1 - position;
  88. return (index < 0) ? -2 : stack_[index];
  89. }
  90. int ParserState::StackSize() const { return stack_.size(); }
  91. bool ParserState::StackEmpty() const { return stack_.empty(); }
  92. int ParserState::Head(int index) const {
  93. DCHECK_GE(index, -1);
  94. DCHECK_LT(index, num_tokens_);
  95. return index == -1 ? -1 : head_[index];
  96. }
  97. int ParserState::Label(int index) const {
  98. DCHECK_GE(index, -1);
  99. DCHECK_LT(index, num_tokens_);
  100. return index == -1 ? RootLabel() : label_[index];
  101. }
  102. int ParserState::Parent(int index, int n) const {
  103. // Find the n-th parent by applying the head function n times.
  104. DCHECK_GE(index, -1);
  105. DCHECK_LT(index, num_tokens_);
  106. while (n-- > 0) index = Head(index);
  107. return index;
  108. }
  109. int ParserState::LeftmostChild(int index, int n) const {
  110. DCHECK_GE(index, -1);
  111. DCHECK_LT(index, num_tokens_);
  112. while (n-- > 0) {
  113. // Find the leftmost child by scanning from start until a child is
  114. // encountered.
  115. int i;
  116. for (i = -1; i < index; ++i) {
  117. if (Head(i) == index) break;
  118. }
  119. if (i == index) return -2;
  120. index = i;
  121. }
  122. return index;
  123. }
  124. int ParserState::RightmostChild(int index, int n) const {
  125. DCHECK_GE(index, -1);
  126. DCHECK_LT(index, num_tokens_);
  127. while (n-- > 0) {
  128. // Find the rightmost child by scanning backward from end until a child
  129. // is encountered.
  130. int i;
  131. for (i = num_tokens_ - 1; i > index; --i) {
  132. if (Head(i) == index) break;
  133. }
  134. if (i == index) return -2;
  135. index = i;
  136. }
  137. return index;
  138. }
  139. int ParserState::LeftSibling(int index, int n) const {
  140. // Find the n-th left sibling by scanning left until the n-th child of the
  141. // parent is encountered.
  142. DCHECK_GE(index, -1);
  143. DCHECK_LT(index, num_tokens_);
  144. if (index == -1 && n > 0) return -2;
  145. int i = index;
  146. while (n > 0) {
  147. --i;
  148. if (i == -1) return -2;
  149. if (Head(i) == Head(index)) --n;
  150. }
  151. return i;
  152. }
  153. int ParserState::RightSibling(int index, int n) const {
  154. // Find the n-th right sibling by scanning right until the n-th child of the
  155. // parent is encountered.
  156. DCHECK_GE(index, -1);
  157. DCHECK_LT(index, num_tokens_);
  158. if (index == -1 && n > 0) return -2;
  159. int i = index;
  160. while (n > 0) {
  161. ++i;
  162. if (i == num_tokens_) return -2;
  163. if (Head(i) == Head(index)) --n;
  164. }
  165. return i;
  166. }
  167. void ParserState::AddArc(int index, int head, int label) {
  168. DCHECK_GE(index, 0);
  169. DCHECK_LT(index, num_tokens_);
  170. head_[index] = head;
  171. label_[index] = label;
  172. }
  173. int ParserState::GoldHead(int index) const {
  174. // A valid ParserState index is transformed to a valid Sentence index,
  175. // then the gold head is extracted.
  176. DCHECK_GE(index, -1);
  177. DCHECK_LT(index, num_tokens_);
  178. if (index == -1) return -1;
  179. const int offset = 0;
  180. const int gold_head = GetToken(index).head();
  181. return gold_head == -1 ? -1 : gold_head - offset;
  182. }
  183. int ParserState::GoldLabel(int index) const {
  184. // A valid ParserState index is transformed to a valid Sentence index,
  185. // then the gold label is extracted.
  186. DCHECK_GE(index, -1);
  187. DCHECK_LT(index, num_tokens_);
  188. if (index == -1) return RootLabel();
  189. string gold_label;
  190. gold_label = GetToken(index).label();
  191. return label_map_->LookupIndex(gold_label, RootLabel() /* unknown */);
  192. }
  193. void ParserState::AddParseToDocument(Sentence *sentence,
  194. bool rewrite_root_labels) const {
  195. transition_state_->AddParseToDocument(*this, rewrite_root_labels, sentence);
  196. }
  197. bool ParserState::IsTokenCorrect(int index) const {
  198. return transition_state_->IsTokenCorrect(*this, index);
  199. }
  200. string ParserState::LabelAsString(int label) const {
  201. if (label == root_label_) return "ROOT";
  202. if (label >= 0 && label < label_map_->Size()) {
  203. return label_map_->GetTerm(label);
  204. }
  205. return "";
  206. }
  207. string ParserState::ToString() const {
  208. return transition_state_->ToString(*this);
  209. }
  210. } // namespace syntaxnet