syntaxnet_transition_state.cc 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
  16. #include "tensorflow/core/lib/strings/strcat.h"
  17. #include "tensorflow/core/platform/logging.h"
  18. namespace syntaxnet {
  19. namespace dragnn {
  20. SyntaxNetTransitionState::SyntaxNetTransitionState(
  21. std::unique_ptr<ParserState> parser_state, SyntaxNetSentence *sentence)
  22. : parser_state_(std::move(parser_state)), sentence_(sentence) {
  23. score_ = 0;
  24. current_beam_index_ = -1;
  25. parent_beam_index_ = 0;
  26. step_for_token_.resize(sentence->sentence()->token_size(), -1);
  27. parent_for_token_.resize(sentence->sentence()->token_size(), -1);
  28. parent_step_for_token_.resize(sentence->sentence()->token_size(), -1);
  29. }
  30. void SyntaxNetTransitionState::Init(const TransitionState &parent) {
  31. score_ = parent.GetScore();
  32. parent_beam_index_ = parent.GetBeamIndex();
  33. }
  34. std::unique_ptr<SyntaxNetTransitionState> SyntaxNetTransitionState::Clone()
  35. const {
  36. // Create a new state from a clone of the underlying parser state.
  37. std::unique_ptr<ParserState> cloned_state(parser_state_->Clone());
  38. std::unique_ptr<SyntaxNetTransitionState> new_state(
  39. new SyntaxNetTransitionState(std::move(cloned_state), sentence_));
  40. // Copy relevant data members and set non-copied ones to flag values.
  41. new_state->score_ = score_;
  42. new_state->current_beam_index_ = current_beam_index_;
  43. new_state->parent_beam_index_ = parent_beam_index_;
  44. new_state->step_for_token_ = step_for_token_;
  45. new_state->parent_step_for_token_ = parent_step_for_token_;
  46. new_state->parent_for_token_ = parent_for_token_;
  47. // Copy trace if it exists.
  48. if (trace_) {
  49. new_state->trace_.reset(new ComponentTrace(*trace_));
  50. }
  51. return new_state;
  52. }
  53. const int SyntaxNetTransitionState::ParentBeamIndex() const {
  54. return parent_beam_index_;
  55. }
  56. const int SyntaxNetTransitionState::GetBeamIndex() const {
  57. return current_beam_index_;
  58. }
  59. void SyntaxNetTransitionState::SetBeamIndex(const int index) {
  60. current_beam_index_ = index;
  61. }
  62. const float SyntaxNetTransitionState::GetScore() const { return score_; }
  63. void SyntaxNetTransitionState::SetScore(const float score) { score_ = score; }
  64. string SyntaxNetTransitionState::HTMLRepresentation() const {
  65. // Crude HTML string showing the stack and the word on the input.
  66. string html = "Stack: ";
  67. for (int i = parser_state_->StackSize() - 1; i >= 0; --i) {
  68. const int word_idx = parser_state_->Stack(i);
  69. if (word_idx >= 0) {
  70. tensorflow::strings::StrAppend(
  71. &html, parser_state_->GetToken(word_idx).word(), " ");
  72. }
  73. }
  74. tensorflow::strings::StrAppend(&html, "| Input: ");
  75. const int word_idx = parser_state_->Input(0);
  76. if (word_idx >= 0) {
  77. tensorflow::strings::StrAppend(
  78. &html, parser_state_->GetToken(word_idx).word(), " ");
  79. }
  80. return html;
  81. }
  82. } // namespace dragnn
  83. } // namespace syntaxnet