// Copyright 2017 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // ============================================================================= #include "dragnn/core/beam.h" #include "dragnn/core/interfaces/cloneable_transition_state.h" #include "dragnn/core/interfaces/transition_state.h" #include "dragnn/core/test/mock_transition_state.h" #include #include "tensorflow/core/platform/test.h" namespace syntaxnet { namespace dragnn { using testing::MockFunction; using testing::Return; using testing::Ne; using testing::_; namespace { // ***************************************************************************** // Test-internal class definitions. // ***************************************************************************** // Create a very basic transition state to test the beam. All it does is keep // track of its current beam index and score, as well as providing a field // for the transition function to write in what transition occurred. // Note that this class does not fulfill the entire TransitionState contract, // since it is only used in this particular test. class TestTransitionState : public CloneableTransitionState { public: TestTransitionState() {} void Init(const TransitionState &parent) override {} std::unique_ptr Clone() const override { std::unique_ptr ptr(new TestTransitionState()); return ptr; } const int ParentBeamIndex() const override { return parent_beam_index_; } // Get the current beam index for this state. const int GetBeamIndex() const override { return beam_index_; } // Set the current beam index for this state. void SetBeamIndex(const int index) override { beam_index_ = index; } // Get the score associated with this transition state. const float GetScore() const override { return score_; } // Set the score associated with this transition state. void SetScore(const float score) override { score_ = score; } // Depicts this state as an HTML-language string. string HTMLRepresentation() const override { return ""; } int parent_beam_index_; int beam_index_; float score_; int transition_action_; }; // This transition function annotates a TestTransitionState with the action that // was chosen for the transition. auto transition_function = [](TestTransitionState *state, int action) { TestTransitionState *cast_state = dynamic_cast(state); cast_state->transition_action_ = action; }; // Create oracle and permission functions that do nothing. auto null_oracle = [](TestTransitionState *) { return 0; }; auto null_permissions = [](TestTransitionState *, int) { return true; }; auto null_finality = [](TestTransitionState *) { return false; }; // Create a unique_ptr with a test transition state in it and set its initial // score. std::unique_ptr CreateState(float score) { std::unique_ptr state; state.reset(new TestTransitionState()); state->SetScore(score); return state; } // Convenience accessor for the action field in TestTransitionState. int GetTransition(const TransitionState *state) { return (dynamic_cast(state))->transition_action_; } // Convenience accessor for the parent_beam_index_ field in TestTransitionState. void SetParentBeamIndex(TransitionState *state, int index) { (dynamic_cast(state))->parent_beam_index_ = index; } } // namespace // ***************************************************************************** // Tests begin here. // ***************************************************************************** TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) { // Create a matrix of transitions. constexpr int kNumTransitions = 4; constexpr int kMatrixSize = kNumTransitions; constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0}; constexpr int kBestTransition = 2; constexpr float kOldScore = 3.0; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(kOldScore)); constexpr int kBeamSize = 1; Beam beam(kBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), kBeamSize); // Make sure the state has performed the expected transition. EXPECT_EQ(GetTransition(beam.beam().at(0)), kBestTransition); // Make sure the state has had its score updated properly. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[kBestTransition]); // Make sure that the beam index field is consistent with the actual beam idx. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0); // Make sure that the beam_state accessor actually accesses the beam. EXPECT_EQ(beam.beam().at(0), beam.beam_state(0)); // Validate the beam history field. auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 0); } TEST(BeamTest, AdvancingCreatesNewTransitions) { // Create a matrix of transitions. constexpr int kMaxBeamSize = 8; constexpr int kNumTransitions = 4; constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize; constexpr float matrix[kMatrixSize] = { 30.0, 20.0, 40.0, 10.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0}; constexpr float kOldScore = 4.0; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(kOldScore)); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 4); // Make sure the state has performed the expected transition. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2); EXPECT_EQ(GetTransition(beam.beam().at(1)), 0); EXPECT_EQ(GetTransition(beam.beam().at(2)), 1); EXPECT_EQ(GetTransition(beam.beam().at(3)), 3); // Make sure the state has had its score updated properly. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]); EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[0]); EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[1]); EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScore + matrix[3]); // Make sure that the beam index field is consistent with the actual beam idx. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } // In this case, we expect the top 4 results to have come from state 0 and // the remaining 4 slots to be empty (-1). auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 0); EXPECT_EQ(history.at(1).at(1), 0); EXPECT_EQ(history.at(1).at(2), 0); EXPECT_EQ(history.at(1).at(3), 0); EXPECT_EQ(history.at(1).at(4), -1); EXPECT_EQ(history.at(1).at(5), -1); EXPECT_EQ(history.at(1).at(6), -1); EXPECT_EQ(history.at(1).at(7), -1); } TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) { // Create a matrix of transitions. constexpr int kMaxBeamSize = 8; constexpr int kNumTransitions = 4; constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize; constexpr float matrix[kMatrixSize] = { 30.0, 20.0, 40.0, 10.0, // State 0 31.0, 21.0, 41.0, 11.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0}; constexpr float kOldScores[] = {5.0, 7.0}; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(kOldScores[0])); states.push_back(CreateState(kOldScores[1])); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 8); // Make sure the state has performed the expected transition. // Note that the transition index is not the index into the matrix, but rather // the index into the matrix 'row' for that state. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2); EXPECT_EQ(GetTransition(beam.beam().at(1)), 2); EXPECT_EQ(GetTransition(beam.beam().at(2)), 0); EXPECT_EQ(GetTransition(beam.beam().at(3)), 0); EXPECT_EQ(GetTransition(beam.beam().at(4)), 1); EXPECT_EQ(GetTransition(beam.beam().at(5)), 1); EXPECT_EQ(GetTransition(beam.beam().at(6)), 3); EXPECT_EQ(GetTransition(beam.beam().at(7)), 3); // Make sure the state has had its score updated properly. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]); EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]); EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]); EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]); EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]); EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]); EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]); EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]); // Make sure that the beam index field is consistent with the actual beam idx. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } // Validate the history at this step. auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 1); EXPECT_EQ(history.at(1).at(1), 0); EXPECT_EQ(history.at(1).at(2), 1); EXPECT_EQ(history.at(1).at(3), 0); EXPECT_EQ(history.at(1).at(4), 1); EXPECT_EQ(history.at(1).at(5), 0); EXPECT_EQ(history.at(1).at(6), 1); EXPECT_EQ(history.at(1).at(7), 0); } TEST(BeamTest, AdvancingDropsLowValuePredictions) { // Create a matrix of transitions. constexpr int kNumTransitions = 4; constexpr int kMaxBeamSize = 8; constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize; constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0, // State 0 31.0, 21.0, 41.0, 11.0, // State 1 32.0, 22.0, 42.0, 12.0, // State 2 33.0, 23.0, 43.0, 13.0, // State 3 34.0, 24.0, 44.0, 14.0, // State 4 35.0, 25.0, 45.0, 15.0, // State 5 36.0, 26.0, 46.0, 16.0, // State 6 37.0, 27.0, 47.0, 17.0}; // State 7 constexpr float kOldScores[] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8}; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(kOldScores[0])); states.push_back(CreateState(kOldScores[1])); states.push_back(CreateState(kOldScores[2])); states.push_back(CreateState(kOldScores[3])); states.push_back(CreateState(kOldScores[4])); states.push_back(CreateState(kOldScores[5])); states.push_back(CreateState(kOldScores[6])); states.push_back(CreateState(kOldScores[7])); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 8); // Make sure the state has performed the expected transition. // In this case, every state will perform transition 2. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2); EXPECT_EQ(GetTransition(beam.beam().at(1)), 2); EXPECT_EQ(GetTransition(beam.beam().at(2)), 2); EXPECT_EQ(GetTransition(beam.beam().at(3)), 2); EXPECT_EQ(GetTransition(beam.beam().at(4)), 2); EXPECT_EQ(GetTransition(beam.beam().at(5)), 2); EXPECT_EQ(GetTransition(beam.beam().at(6)), 2); EXPECT_EQ(GetTransition(beam.beam().at(7)), 2); // Make sure the state has had its score updated properly. (Note that row // 0 had the smallest transition score, so it ends up on the bottom of the // beam, and so forth.) For the matrix index, N*kNumTransitions gets into the // correct state row and we add 2 since that was the transition index. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[7] + matrix[7 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[6] + matrix[6 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[5] + matrix[5 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[4] + matrix[4 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[3] + matrix[3 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[2] + matrix[2 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[1 * kNumTransitions + 2]); EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[0 * kNumTransitions + 2]); // Make sure that the beam index field is consistent with the actual beam idx. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 7); EXPECT_EQ(history.at(1).at(1), 6); EXPECT_EQ(history.at(1).at(2), 5); EXPECT_EQ(history.at(1).at(3), 4); EXPECT_EQ(history.at(1).at(4), 3); EXPECT_EQ(history.at(1).at(5), 2); EXPECT_EQ(history.at(1).at(6), 1); EXPECT_EQ(history.at(1).at(7), 0); } TEST(BeamTest, AdvancesFromOracleWithSingleBeam) { // Create an oracle function for this state. constexpr int kOracleLabel = 3; auto oracle_function = [](TransitionState *) { return kOracleLabel; }; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(0.0)); constexpr int kBeamSize = 1; Beam beam(kBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, oracle_function); beam.Init(std::move(states)); beam.AdvanceFromOracle(); // Validate the new beam. EXPECT_EQ(beam.beam().size(), kBeamSize); // Make sure the state has performed the expected transition. EXPECT_EQ(GetTransition(beam.beam().at(0)), kOracleLabel); // Make sure the state has had its score held to 0. EXPECT_EQ(beam.beam().at(0)->GetScore(), 0.0); // Make sure that the beam index field is consistent with the actual beam idx. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0); // Validate the beam history field. auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 0); } TEST(BeamTest, AdvancesFromOracleWithMultipleStates) { constexpr int kMaxBeamSize = 8; // Create a beam with 8 transition states. std::vector> states; for (int i = 0; i < kMaxBeamSize; ++i) { // This is nonzero to test the oracle holding scores to 0. states.push_back(CreateState(10.0)); } std::vector expected_actions; // Create an oracle function for this state. Use mocks for finer control. testing::MockFunction mock_oracle_function; for (int i = 0; i < kMaxBeamSize; ++i) { // We expect each state to be queried for its oracle label, // and then to be transitioned in place with its oracle label. int oracle_label = i % 3; // 3 is arbitrary. EXPECT_CALL(mock_oracle_function, Call(states.at(i).get())) .WillOnce(Return(oracle_label)); expected_actions.push_back(oracle_label); } Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, mock_oracle_function.AsStdFunction()); beam.Init(std::move(states)); beam.AdvanceFromOracle(); // Make sure the state has performed the expected transition, has had its // score held to 0, and is self consistent. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(GetTransition(beam.beam().at(i)), expected_actions.at(i)); EXPECT_EQ(beam.beam().at(i)->GetScore(), 0.0); EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } auto history = beam.history(); for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(history.at(1).at(i), i); } } TEST(BeamTest, ReportsNonFinality) { constexpr int kMaxBeamSize = 8; // Create a beam with 8 transition states. std::vector> states; for (int i = 0; i < kMaxBeamSize; ++i) { // This is nonzero to test the oracle holding scores to 0. states.push_back(CreateState(10.0)); } std::vector expected_actions; // Create a finality function for this state. Use mocks for finer control. testing::MockFunction mock_finality_function; // Make precisely one call return false, which should cause IsFinal // to report false. constexpr int incomplete_state = 3; EXPECT_CALL(mock_finality_function, Call(states.at(incomplete_state).get())) .WillOnce(Return(false)); EXPECT_CALL(mock_finality_function, Call(Ne(states.at(incomplete_state).get()))) .WillRepeatedly(Return(true)); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(), transition_function, null_oracle); beam.Init(std::move(states)); EXPECT_FALSE(beam.IsTerminal()); } TEST(BeamTest, ReportsFinality) { constexpr int kMaxBeamSize = 8; // Create a beam with 8 transition states. std::vector> states; for (int i = 0; i < kMaxBeamSize; ++i) { // This is nonzero to test the oracle holding scores to 0. states.push_back(CreateState(10.0)); } std::vector expected_actions; // Create a finality function for this state. Use mocks for finer control. testing::MockFunction mock_finality_function; // All calls will return true, so IsFinal should return true. EXPECT_CALL(mock_finality_function, Call(_)).WillRepeatedly(Return(true)); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(), transition_function, null_oracle); beam.Init(std::move(states)); EXPECT_TRUE(beam.IsTerminal()); } TEST(BeamTest, IgnoresForbiddenTransitionActions) { // Create a matrix of transitions. constexpr int kMaxBeamSize = 4; constexpr int kNumTransitions = 4; constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize; constexpr float matrix[kMatrixSize] = { 10.0, 1000.0, 40.0, 30.0, 00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0}; constexpr float kOldScore = 4.0; // Create the beam. std::vector> states; states.push_back(CreateState(kOldScore)); // Forbid the second transition (index 1). testing::MockFunction mock_permission_function; EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 0)) .WillOnce(Return(true)); EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 1)) .WillOnce(Return(false)); EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 2)) .WillOnce(Return(true)); EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 3)) .WillOnce(Return(true)); Beam beam(kMaxBeamSize); beam.SetFunctions(mock_permission_function.AsStdFunction(), null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 3); // Make sure the state has performed the expected transition. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2); EXPECT_EQ(GetTransition(beam.beam().at(1)), 3); EXPECT_EQ(GetTransition(beam.beam().at(2)), 0); // Make sure the state has had its score updated properly. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]); EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[3]); EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[0]); // Make sure that the beam index field is consistent with the actual beam idx. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } // In this case, we expect the top 3 results to have come from state 0 and // the remaining 3 slots to be empty (-1). auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 0); EXPECT_EQ(history.at(1).at(1), 0); EXPECT_EQ(history.at(1).at(2), 0); EXPECT_EQ(history.at(1).at(3), -1); } TEST(BeamTest, BadlySizedMatrixDies) { // Create a matrix of transitions. constexpr int kNumTransitions = 4; constexpr int kMatrixSize = 4; // We have a max beam size of 4; should be 16. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0}; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(0.0)); states.push_back(CreateState(0.0)); constexpr int kMaxBeamSize = 8; Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // This matrix should have 8 elements, not 4, so this should die. EXPECT_DEATH(beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions), "Transition matrix size does not match max beam size \\* number " "of state transitions"); } TEST(BeamTest, BadlySizedBeamInitializationDies) { // Create an initialization beam too large for the max beam size. constexpr int kMaxBeamSize = 4; std::vector> states; for (int i = 0; i < kMaxBeamSize + 1; ++i) { states.push_back(CreateState(0.0)); } Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); // Try to initialize the beam; this should die. EXPECT_DEATH(beam.Init(std::move(states)), "Attempted to initialize a beam with more states"); } TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) { // Create a standard beam. constexpr int kMaxBeamSize = 4; std::vector> states; for (int i = 0; i < kMaxBeamSize; ++i) { states.push_back(CreateState(0.0)); } Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // Verify that all beam indices have been initialized. for (int i = 0; i < kMaxBeamSize; ++i) { EXPECT_EQ(i, beam.beam_state(i)->GetBeamIndex()); } } TEST(BeamTest, FindPreviousIndexTracesHistory) { // Create a matrix of transitions. constexpr int kNumTransitions = 4; constexpr int kMaxBeamSize = 8; constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize; constexpr float matrix[kMatrixSize] = { 30.0, 20.0, 40.0, 10.0, // State 0 31.0, 21.0, 41.0, 11.0, // State 1 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0}; constexpr float kOldScores[] = {5.0, 7.0}; constexpr int kParentBeamIndices[] = {1138, 42}; // Create the beam and transition it. std::vector> states; states.push_back(CreateState(kOldScores[0])); states.push_back(CreateState(kOldScores[1])); // Set parent beam indices. SetParentBeamIndex(states.at(0).get(), kParentBeamIndices[0]); SetParentBeamIndex(states.at(1).get(), kParentBeamIndices[1]); Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 8); // Make sure the state has performed the expected transition. // Note that the transition index is not the index into the matrix, but rather // the index into the matrix 'row' for that state. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2); EXPECT_EQ(GetTransition(beam.beam().at(1)), 2); EXPECT_EQ(GetTransition(beam.beam().at(2)), 0); EXPECT_EQ(GetTransition(beam.beam().at(3)), 0); EXPECT_EQ(GetTransition(beam.beam().at(4)), 1); EXPECT_EQ(GetTransition(beam.beam().at(5)), 1); EXPECT_EQ(GetTransition(beam.beam().at(6)), 3); EXPECT_EQ(GetTransition(beam.beam().at(7)), 3); // Make sure the state has had its score updated properly. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]); EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]); EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]); EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]); EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]); EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]); EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]); EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]); // Make sure that the beam index field is consistent with the actual beam idx. for (int i = 0; i < beam.beam().size(); ++i) { EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i); } // Validate the history at this step. auto history = beam.history(); EXPECT_EQ(history.at(1).at(0), 1); EXPECT_EQ(history.at(1).at(1), 0); EXPECT_EQ(history.at(1).at(2), 1); EXPECT_EQ(history.at(1).at(3), 0); EXPECT_EQ(history.at(1).at(4), 1); EXPECT_EQ(history.at(1).at(5), 0); EXPECT_EQ(history.at(1).at(6), 1); EXPECT_EQ(history.at(1).at(7), 0); EXPECT_EQ(history.at(0).at(0), kParentBeamIndices[0]); EXPECT_EQ(history.at(0).at(1), kParentBeamIndices[1]); EXPECT_EQ(history.at(0).at(2), -1); EXPECT_EQ(history.at(0).at(3), -1); EXPECT_EQ(history.at(0).at(4), -1); EXPECT_EQ(history.at(0).at(5), -1); EXPECT_EQ(history.at(0).at(6), -1); EXPECT_EQ(history.at(0).at(7), -1); // Make sure that FindPreviousIndex can read through the history from step 1 // to step 0. constexpr int kDesiredIndex = 0; constexpr int kCurrentIndexOne = 4; EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexOne, kDesiredIndex), kParentBeamIndices[1]); constexpr int kCurrentIndexTwo = 7; EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexTwo, kDesiredIndex), kParentBeamIndices[0]); } TEST(BeamTest, FindPreviousIndexReturnsInError) { // Create the beam. This now has only one history state, 0. std::vector> states; states.push_back(CreateState(0.0)); constexpr int kMaxBeamSize = 8; Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // If the requested step is greater than the number of steps taken, expect -1. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1); // If the requested step is less than 0, expect -1. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1); // If the requested index does not have a state, expect -1. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1); // If the requested index is less than 0, expect -1. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1); // If the requested index is larger than the maximum beam size -1, expect -1. EXPECT_EQ(beam.FindPreviousIndex(0, kMaxBeamSize), -1); } TEST(BeamTest, ResetClearsBeamState) { // Create the beam std::vector> states; states.push_back(CreateState(1.0)); constexpr int kMaxBeamSize = 8; Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // Validate the new beam. EXPECT_EQ(beam.beam().size(), 1); // Reset the beam. beam.Reset(); // Validate the now-reset beam, which should be empty. EXPECT_EQ(beam.beam().size(), 0); } TEST(BeamTest, ResetClearsBeamHistory) { // Create the beam std::vector> states; states.push_back(CreateState(1.0)); constexpr int kMaxBeamSize = 8; Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // Validate the new beam history. EXPECT_EQ(beam.history().size(), 1); // Reset the beam. beam.Reset(); // Validate the now-reset beam history, which should be empty. EXPECT_EQ(beam.history().size(), 0); } TEST(BeamTest, SettingMaxSizeResetsBeam) { // Create the beam std::vector> states; states.push_back(CreateState(1.0)); constexpr int kMaxBeamSize = 8; Beam beam(kMaxBeamSize); beam.SetFunctions(null_permissions, null_finality, transition_function, null_oracle); beam.Init(std::move(states)); // Validate the new beam history. EXPECT_EQ(beam.history().size(), 1); // Reset the beam. constexpr int kNewMaxBeamSize = 4; beam.SetMaxSize(kNewMaxBeamSize); EXPECT_EQ(beam.max_size(), kNewMaxBeamSize); // Validate the now-reset beam history, which should be empty. EXPECT_EQ(beam.history().size(), 0); } } // namespace dragnn } // namespace syntaxnet