shift_transitions.cc 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Shift transition system.
  13. //
  14. // This transition system has one type of actions:
  15. // - The SHIFT action pushes the next input token to the stack and
  16. // advances to the next input token.
  17. //
  18. // For this very simple transition system, we don't need a specific
  19. // TransitionState because we have no additional information to remember.
  20. // We use it to compute look-ahead in DRAGNN by using its representations in
  21. // downstream tasks.
  22. #include <string>
  23. #include "syntaxnet/parser_features.h"
  24. #include "syntaxnet/parser_state.h"
  25. #include "syntaxnet/parser_transitions.h"
  26. #include "syntaxnet/sentence_features.h"
  27. #include "syntaxnet/shared_store.h"
  28. #include "syntaxnet/task_context.h"
  29. #include "syntaxnet/term_frequency_map.h"
  30. #include "syntaxnet/utils.h"
  31. #include "tensorflow/core/lib/strings/strcat.h"
  32. namespace syntaxnet {
  33. class ShiftTransitionState : public ParserTransitionState {
  34. public:
  35. explicit ShiftTransitionState(bool left_to_right)
  36. : left_to_right_(left_to_right) {}
  37. explicit ShiftTransitionState(const ShiftTransitionState *state)
  38. : left_to_right_(state->left_to_right_) {}
  39. ParserTransitionState *Clone() const override {
  40. return new ShiftTransitionState(this);
  41. }
  42. // Set the initial value of next in ParserState.
  43. void Init(ParserState *state) override {
  44. if (!left_to_right_) {
  45. // Start from the last word of the sentence if we transit from right to
  46. // left.
  47. state->Advance(state->sentence().token_size() - 1);
  48. }
  49. }
  50. bool IsTokenCorrect(const ParserState &state, int index) const override {
  51. return true;
  52. }
  53. string ToString(const ParserState &state) const override { return ""; }
  54. private:
  55. bool left_to_right_;
  56. };
  57. class ShiftTransitionSystem : public ParserTransitionSystem {
  58. public:
  59. static const ParserAction kShiftAction = 0;
  60. // Determines the direction of the system.
  61. void Setup(TaskContext *context) override {
  62. // TODO(googleuser): Use FetchDeprecated.
  63. if (context->Get("left-to-right", "<NOT-SET>") == "<NOT-SET>") {
  64. left_to_right_ = context->Get("left_to_right", true);
  65. } else {
  66. left_to_right_ = context->Get("left-to-right", true);
  67. LOG(WARNING) << "'left-to-right' parameter set: this is DEPRECATED. "
  68. << "Use 'left_to_right' instead.";
  69. }
  70. }
  71. // The shift transition system doesn't actually look at the dependency tree,
  72. // so it does allow non-projective trees.
  73. bool AllowsNonProjective() const override { return true; }
  74. // Returns the number of action types.
  75. int NumActionTypes() const override { return 1; }
  76. // Returns the number of possible actions.
  77. int NumActions(int num_labels) const override { return 1; }
  78. ParserAction GetDefaultAction(const ParserState &state) const override {
  79. return kShiftAction;
  80. }
  81. // At anytime, the gold action is to shift.
  82. ParserAction GetNextGoldAction(const ParserState &state) const override {
  83. return kShiftAction;
  84. }
  85. // Checks if the action is allowed in a given parser state.
  86. bool IsAllowedAction(ParserAction action,
  87. const ParserState &state) const override {
  88. return left_to_right_ ? (!state.EndOfInput()) : (state.Next() >= 0);
  89. }
  90. // Makes a shift by pushing the next input token on the stack and moving to
  91. // the next position.
  92. void PerformActionWithoutHistory(ParserAction action,
  93. ParserState *state) const override {
  94. DCHECK(!IsFinalState(*state));
  95. if (!IsFinalState(*state)) {
  96. int next = state->Next();
  97. state->Push(next);
  98. next = left_to_right_ ? (next + 1) : (next - 1);
  99. state->Advance(next);
  100. }
  101. }
  102. bool IsFinalState(const ParserState &state) const override {
  103. return left_to_right_ ? state.EndOfInput() : (state.Next() < 0);
  104. }
  105. // Returns a string representation of a parser action.
  106. string ActionAsString(ParserAction action,
  107. const ParserState &state) const override {
  108. string current_word = state.GetToken(state.Next()).word();
  109. return current_word;
  110. }
  111. // All states are deterministic in this transition system.
  112. bool IsDeterministicState(const ParserState &state) const override {
  113. return true;
  114. }
  115. // Returns a new transition state.
  116. ParserTransitionState *NewTransitionState(bool training_mode) const override {
  117. return new ShiftTransitionState(get_direction());
  118. }
  119. bool get_direction() const { return left_to_right_; }
  120. private:
  121. bool left_to_right_ = true;
  122. };
  123. REGISTER_TRANSITION_SYSTEM("shift-only", ShiftTransitionSystem);
  124. } // namespace syntaxnet