123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789 |
- // Copyright 2017 Google Inc. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // =============================================================================
- #include "dragnn/core/beam.h"
- #include "dragnn/core/interfaces/cloneable_transition_state.h"
- #include "dragnn/core/interfaces/transition_state.h"
- #include "dragnn/core/test/mock_transition_state.h"
- #include <gmock/gmock.h>
- #include "tensorflow/core/platform/test.h"
- namespace syntaxnet {
- namespace dragnn {
- using testing::MockFunction;
- using testing::Return;
- using testing::Ne;
- using testing::_;
- namespace {
- // *****************************************************************************
- // Test-internal class definitions.
- // *****************************************************************************
- // Create a very basic transition state to test the beam. All it does is keep
- // track of its current beam index and score, as well as providing a field
- // for the transition function to write in what transition occurred.
- // Note that this class does not fulfill the entire TransitionState contract,
- // since it is only used in this particular test.
- class TestTransitionState
- : public CloneableTransitionState<TestTransitionState> {
- public:
- TestTransitionState() {}
- void Init(const TransitionState &parent) override {}
- std::unique_ptr<TestTransitionState> Clone() const override {
- std::unique_ptr<TestTransitionState> ptr(new TestTransitionState());
- return ptr;
- }
- const int ParentBeamIndex() const override { return parent_beam_index_; }
- // Get the current beam index for this state.
- const int GetBeamIndex() const override { return beam_index_; }
- // Set the current beam index for this state.
- void SetBeamIndex(const int index) override { beam_index_ = index; }
- // Get the score associated with this transition state.
- const float GetScore() const override { return score_; }
- // Set the score associated with this transition state.
- void SetScore(const float score) override { score_ = score; }
- // Depicts this state as an HTML-language string.
- string HTMLRepresentation() const override { return ""; }
- int parent_beam_index_;
- int beam_index_;
- float score_;
- int transition_action_;
- };
- // This transition function annotates a TestTransitionState with the action that
- // was chosen for the transition.
- auto transition_function = [](TestTransitionState *state, int action) {
- TestTransitionState *cast_state = dynamic_cast<TestTransitionState *>(state);
- cast_state->transition_action_ = action;
- };
- // Create oracle and permission functions that do nothing.
- auto null_oracle = [](TestTransitionState *) { return 0; };
- auto null_permissions = [](TestTransitionState *, int) { return true; };
- auto null_finality = [](TestTransitionState *) { return false; };
- // Create a unique_ptr with a test transition state in it and set its initial
- // score.
- std::unique_ptr<TestTransitionState> CreateState(float score) {
- std::unique_ptr<TestTransitionState> state;
- state.reset(new TestTransitionState());
- state->SetScore(score);
- return state;
- }
- // Convenience accessor for the action field in TestTransitionState.
- int GetTransition(const TransitionState *state) {
- return (dynamic_cast<const TestTransitionState *>(state))->transition_action_;
- }
- // Convenience accessor for the parent_beam_index_ field in TestTransitionState.
- void SetParentBeamIndex(TransitionState *state, int index) {
- (dynamic_cast<TestTransitionState *>(state))->parent_beam_index_ = index;
- }
- } // namespace
- // *****************************************************************************
- // Tests begin here.
- // *****************************************************************************
- TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
- // Create a matrix of transitions.
- constexpr int kNumTransitions = 4;
- constexpr int kMatrixSize = kNumTransitions;
- constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
- constexpr int kBestTransition = 2;
- constexpr float kOldScore = 3.0;
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScore));
- constexpr int kBeamSize = 1;
- Beam<TestTransitionState> beam(kBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), kBeamSize);
- // Make sure the state has performed the expected transition.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), kBestTransition);
- // Make sure the state has had its score updated properly.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[kBestTransition]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
- // Make sure that the beam_state accessor actually accesses the beam.
- EXPECT_EQ(beam.beam().at(0), beam.beam_state(0));
- // Validate the beam history field.
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 0);
- }
- TEST(BeamTest, AdvancingCreatesNewTransitions) {
- // Create a matrix of transitions.
- constexpr int kMaxBeamSize = 8;
- constexpr int kNumTransitions = 4;
- constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
- constexpr float matrix[kMatrixSize] = {
- 30.0, 20.0, 40.0, 10.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
- constexpr float kOldScore = 4.0;
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScore));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 4);
- // Make sure the state has performed the expected transition.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(1)), 0);
- EXPECT_EQ(GetTransition(beam.beam().at(2)), 1);
- EXPECT_EQ(GetTransition(beam.beam().at(3)), 3);
- // Make sure the state has had its score updated properly.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
- EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[0]);
- EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[1]);
- EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScore + matrix[3]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- // In this case, we expect the top 4 results to have come from state 0 and
- // the remaining 4 slots to be empty (-1).
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 0);
- EXPECT_EQ(history.at(1).at(1), 0);
- EXPECT_EQ(history.at(1).at(2), 0);
- EXPECT_EQ(history.at(1).at(3), 0);
- EXPECT_EQ(history.at(1).at(4), -1);
- EXPECT_EQ(history.at(1).at(5), -1);
- EXPECT_EQ(history.at(1).at(6), -1);
- EXPECT_EQ(history.at(1).at(7), -1);
- }
- TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
- // Create a matrix of transitions.
- constexpr int kMaxBeamSize = 8;
- constexpr int kNumTransitions = 4;
- constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
- constexpr float matrix[kMatrixSize] = {
- 30.0, 20.0, 40.0, 10.0, // State 0
- 31.0, 21.0, 41.0, 11.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
- constexpr float kOldScores[] = {5.0, 7.0};
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScores[0]));
- states.push_back(CreateState(kOldScores[1]));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 8);
- // Make sure the state has performed the expected transition.
- // Note that the transition index is not the index into the matrix, but rather
- // the index into the matrix 'row' for that state.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
- EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
- EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
- EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
- EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
- EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
- // Make sure the state has had its score updated properly.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
- EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
- EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
- EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
- EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
- EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
- EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
- EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- // Validate the history at this step.
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 1);
- EXPECT_EQ(history.at(1).at(1), 0);
- EXPECT_EQ(history.at(1).at(2), 1);
- EXPECT_EQ(history.at(1).at(3), 0);
- EXPECT_EQ(history.at(1).at(4), 1);
- EXPECT_EQ(history.at(1).at(5), 0);
- EXPECT_EQ(history.at(1).at(6), 1);
- EXPECT_EQ(history.at(1).at(7), 0);
- }
- TEST(BeamTest, AdvancingDropsLowValuePredictions) {
- // Create a matrix of transitions.
- constexpr int kNumTransitions = 4;
- constexpr int kMaxBeamSize = 8;
- constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
- constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0, // State 0
- 31.0, 21.0, 41.0, 11.0, // State 1
- 32.0, 22.0, 42.0, 12.0, // State 2
- 33.0, 23.0, 43.0, 13.0, // State 3
- 34.0, 24.0, 44.0, 14.0, // State 4
- 35.0, 25.0, 45.0, 15.0, // State 5
- 36.0, 26.0, 46.0, 16.0, // State 6
- 37.0, 27.0, 47.0, 17.0}; // State 7
- constexpr float kOldScores[] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScores[0]));
- states.push_back(CreateState(kOldScores[1]));
- states.push_back(CreateState(kOldScores[2]));
- states.push_back(CreateState(kOldScores[3]));
- states.push_back(CreateState(kOldScores[4]));
- states.push_back(CreateState(kOldScores[5]));
- states.push_back(CreateState(kOldScores[6]));
- states.push_back(CreateState(kOldScores[7]));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 8);
- // Make sure the state has performed the expected transition.
- // In this case, every state will perform transition 2.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(2)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(3)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(4)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(5)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(6)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(7)), 2);
- // Make sure the state has had its score updated properly. (Note that row
- // 0 had the smallest transition score, so it ends up on the bottom of the
- // beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
- // correct state row and we add 2 since that was the transition index.
- EXPECT_EQ(beam.beam().at(0)->GetScore(),
- kOldScores[7] + matrix[7 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(1)->GetScore(),
- kOldScores[6] + matrix[6 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(2)->GetScore(),
- kOldScores[5] + matrix[5 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(3)->GetScore(),
- kOldScores[4] + matrix[4 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(4)->GetScore(),
- kOldScores[3] + matrix[3 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(5)->GetScore(),
- kOldScores[2] + matrix[2 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(6)->GetScore(),
- kOldScores[1] + matrix[1 * kNumTransitions + 2]);
- EXPECT_EQ(beam.beam().at(7)->GetScore(),
- kOldScores[0] + matrix[0 * kNumTransitions + 2]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 7);
- EXPECT_EQ(history.at(1).at(1), 6);
- EXPECT_EQ(history.at(1).at(2), 5);
- EXPECT_EQ(history.at(1).at(3), 4);
- EXPECT_EQ(history.at(1).at(4), 3);
- EXPECT_EQ(history.at(1).at(5), 2);
- EXPECT_EQ(history.at(1).at(6), 1);
- EXPECT_EQ(history.at(1).at(7), 0);
- }
- TEST(BeamTest, AdvancesFromOracleWithSingleBeam) {
- // Create an oracle function for this state.
- constexpr int kOracleLabel = 3;
- auto oracle_function = [](TransitionState *) { return kOracleLabel; };
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(0.0));
- constexpr int kBeamSize = 1;
- Beam<TestTransitionState> beam(kBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- oracle_function);
- beam.Init(std::move(states));
- beam.AdvanceFromOracle();
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), kBeamSize);
- // Make sure the state has performed the expected transition.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), kOracleLabel);
- // Make sure the state has had its score held to 0.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), 0.0);
- // Make sure that the beam index field is consistent with the actual beam idx.
- EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
- // Validate the beam history field.
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 0);
- }
- TEST(BeamTest, AdvancesFromOracleWithMultipleStates) {
- constexpr int kMaxBeamSize = 8;
- // Create a beam with 8 transition states.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- for (int i = 0; i < kMaxBeamSize; ++i) {
- // This is nonzero to test the oracle holding scores to 0.
- states.push_back(CreateState(10.0));
- }
- std::vector<int> expected_actions;
- // Create an oracle function for this state. Use mocks for finer control.
- testing::MockFunction<int(TestTransitionState *)> mock_oracle_function;
- for (int i = 0; i < kMaxBeamSize; ++i) {
- // We expect each state to be queried for its oracle label,
- // and then to be transitioned in place with its oracle label.
- int oracle_label = i % 3; // 3 is arbitrary.
- EXPECT_CALL(mock_oracle_function, Call(states.at(i).get()))
- .WillOnce(Return(oracle_label));
- expected_actions.push_back(oracle_label);
- }
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- mock_oracle_function.AsStdFunction());
- beam.Init(std::move(states));
- beam.AdvanceFromOracle();
- // Make sure the state has performed the expected transition, has had its
- // score held to 0, and is self consistent.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(GetTransition(beam.beam().at(i)), expected_actions.at(i));
- EXPECT_EQ(beam.beam().at(i)->GetScore(), 0.0);
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- auto history = beam.history();
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(history.at(1).at(i), i);
- }
- }
- TEST(BeamTest, ReportsNonFinality) {
- constexpr int kMaxBeamSize = 8;
- // Create a beam with 8 transition states.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- for (int i = 0; i < kMaxBeamSize; ++i) {
- // This is nonzero to test the oracle holding scores to 0.
- states.push_back(CreateState(10.0));
- }
- std::vector<int> expected_actions;
- // Create a finality function for this state. Use mocks for finer control.
- testing::MockFunction<int(TestTransitionState *)> mock_finality_function;
- // Make precisely one call return false, which should cause IsFinal
- // to report false.
- constexpr int incomplete_state = 3;
- EXPECT_CALL(mock_finality_function, Call(states.at(incomplete_state).get()))
- .WillOnce(Return(false));
- EXPECT_CALL(mock_finality_function,
- Call(Ne(states.at(incomplete_state).get())))
- .WillRepeatedly(Return(true));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
- transition_function, null_oracle);
- beam.Init(std::move(states));
- EXPECT_FALSE(beam.IsTerminal());
- }
- TEST(BeamTest, ReportsFinality) {
- constexpr int kMaxBeamSize = 8;
- // Create a beam with 8 transition states.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- for (int i = 0; i < kMaxBeamSize; ++i) {
- // This is nonzero to test the oracle holding scores to 0.
- states.push_back(CreateState(10.0));
- }
- std::vector<int> expected_actions;
- // Create a finality function for this state. Use mocks for finer control.
- testing::MockFunction<int(TransitionState *)> mock_finality_function;
- // All calls will return true, so IsFinal should return true.
- EXPECT_CALL(mock_finality_function, Call(_)).WillRepeatedly(Return(true));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
- transition_function, null_oracle);
- beam.Init(std::move(states));
- EXPECT_TRUE(beam.IsTerminal());
- }
- TEST(BeamTest, IgnoresForbiddenTransitionActions) {
- // Create a matrix of transitions.
- constexpr int kMaxBeamSize = 4;
- constexpr int kNumTransitions = 4;
- constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
- constexpr float matrix[kMatrixSize] = {
- 10.0, 1000.0, 40.0, 30.0, 00.0, 0000.0, 00.0, 00.0,
- 00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0};
- constexpr float kOldScore = 4.0;
- // Create the beam.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScore));
- // Forbid the second transition (index 1).
- testing::MockFunction<int(TestTransitionState *, int)>
- mock_permission_function;
- EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 0))
- .WillOnce(Return(true));
- EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 1))
- .WillOnce(Return(false));
- EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 2))
- .WillOnce(Return(true));
- EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 3))
- .WillOnce(Return(true));
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(mock_permission_function.AsStdFunction(), null_finality,
- transition_function, null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 3);
- // Make sure the state has performed the expected transition.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(1)), 3);
- EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
- // Make sure the state has had its score updated properly.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
- EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[3]);
- EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[0]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- // In this case, we expect the top 3 results to have come from state 0 and
- // the remaining 3 slots to be empty (-1).
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 0);
- EXPECT_EQ(history.at(1).at(1), 0);
- EXPECT_EQ(history.at(1).at(2), 0);
- EXPECT_EQ(history.at(1).at(3), -1);
- }
- TEST(BeamTest, BadlySizedMatrixDies) {
- // Create a matrix of transitions.
- constexpr int kNumTransitions = 4;
- constexpr int kMatrixSize = 4; // We have a max beam size of 4; should be 16.
- constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(0.0));
- states.push_back(CreateState(0.0));
- constexpr int kMaxBeamSize = 8;
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // This matrix should have 8 elements, not 4, so this should die.
- EXPECT_DEATH(beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions),
- "Transition matrix size does not match max beam size \\* number "
- "of state transitions");
- }
- TEST(BeamTest, BadlySizedBeamInitializationDies) {
- // Create an initialization beam too large for the max beam size.
- constexpr int kMaxBeamSize = 4;
- std::vector<std::unique_ptr<TestTransitionState>> states;
- for (int i = 0; i < kMaxBeamSize + 1; ++i) {
- states.push_back(CreateState(0.0));
- }
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- // Try to initialize the beam; this should die.
- EXPECT_DEATH(beam.Init(std::move(states)),
- "Attempted to initialize a beam with more states");
- }
- TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) {
- // Create a standard beam.
- constexpr int kMaxBeamSize = 4;
- std::vector<std::unique_ptr<TestTransitionState>> states;
- for (int i = 0; i < kMaxBeamSize; ++i) {
- states.push_back(CreateState(0.0));
- }
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // Verify that all beam indices have been initialized.
- for (int i = 0; i < kMaxBeamSize; ++i) {
- EXPECT_EQ(i, beam.beam_state(i)->GetBeamIndex());
- }
- }
- TEST(BeamTest, FindPreviousIndexTracesHistory) {
- // Create a matrix of transitions.
- constexpr int kNumTransitions = 4;
- constexpr int kMaxBeamSize = 8;
- constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
- constexpr float matrix[kMatrixSize] = {
- 30.0, 20.0, 40.0, 10.0, // State 0
- 31.0, 21.0, 41.0, 11.0, // State 1
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
- 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
- constexpr float kOldScores[] = {5.0, 7.0};
- constexpr int kParentBeamIndices[] = {1138, 42};
- // Create the beam and transition it.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(kOldScores[0]));
- states.push_back(CreateState(kOldScores[1]));
- // Set parent beam indices.
- SetParentBeamIndex(states.at(0).get(), kParentBeamIndices[0]);
- SetParentBeamIndex(states.at(1).get(), kParentBeamIndices[1]);
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 8);
- // Make sure the state has performed the expected transition.
- // Note that the transition index is not the index into the matrix, but rather
- // the index into the matrix 'row' for that state.
- EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
- EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
- EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
- EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
- EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
- EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
- EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
- // Make sure the state has had its score updated properly.
- EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
- EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
- EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
- EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
- EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
- EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
- EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
- EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
- // Make sure that the beam index field is consistent with the actual beam idx.
- for (int i = 0; i < beam.beam().size(); ++i) {
- EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
- }
- // Validate the history at this step.
- auto history = beam.history();
- EXPECT_EQ(history.at(1).at(0), 1);
- EXPECT_EQ(history.at(1).at(1), 0);
- EXPECT_EQ(history.at(1).at(2), 1);
- EXPECT_EQ(history.at(1).at(3), 0);
- EXPECT_EQ(history.at(1).at(4), 1);
- EXPECT_EQ(history.at(1).at(5), 0);
- EXPECT_EQ(history.at(1).at(6), 1);
- EXPECT_EQ(history.at(1).at(7), 0);
- EXPECT_EQ(history.at(0).at(0), kParentBeamIndices[0]);
- EXPECT_EQ(history.at(0).at(1), kParentBeamIndices[1]);
- EXPECT_EQ(history.at(0).at(2), -1);
- EXPECT_EQ(history.at(0).at(3), -1);
- EXPECT_EQ(history.at(0).at(4), -1);
- EXPECT_EQ(history.at(0).at(5), -1);
- EXPECT_EQ(history.at(0).at(6), -1);
- EXPECT_EQ(history.at(0).at(7), -1);
- // Make sure that FindPreviousIndex can read through the history from step 1
- // to step 0.
- constexpr int kDesiredIndex = 0;
- constexpr int kCurrentIndexOne = 4;
- EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexOne, kDesiredIndex),
- kParentBeamIndices[1]);
- constexpr int kCurrentIndexTwo = 7;
- EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexTwo, kDesiredIndex),
- kParentBeamIndices[0]);
- }
- TEST(BeamTest, FindPreviousIndexReturnsInError) {
- // Create the beam. This now has only one history state, 0.
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(0.0));
- constexpr int kMaxBeamSize = 8;
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // If the requested step is greater than the number of steps taken, expect -1.
- EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
- // If the requested step is less than 0, expect -1.
- EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
- // If the requested index does not have a state, expect -1.
- EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
- // If the requested index is less than 0, expect -1.
- EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
- // If the requested index is larger than the maximum beam size -1, expect -1.
- EXPECT_EQ(beam.FindPreviousIndex(0, kMaxBeamSize), -1);
- }
- TEST(BeamTest, ResetClearsBeamState) {
- // Create the beam
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(1.0));
- constexpr int kMaxBeamSize = 8;
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // Validate the new beam.
- EXPECT_EQ(beam.beam().size(), 1);
- // Reset the beam.
- beam.Reset();
- // Validate the now-reset beam, which should be empty.
- EXPECT_EQ(beam.beam().size(), 0);
- }
- TEST(BeamTest, ResetClearsBeamHistory) {
- // Create the beam
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(1.0));
- constexpr int kMaxBeamSize = 8;
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // Validate the new beam history.
- EXPECT_EQ(beam.history().size(), 1);
- // Reset the beam.
- beam.Reset();
- // Validate the now-reset beam history, which should be empty.
- EXPECT_EQ(beam.history().size(), 0);
- }
- TEST(BeamTest, SettingMaxSizeResetsBeam) {
- // Create the beam
- std::vector<std::unique_ptr<TestTransitionState>> states;
- states.push_back(CreateState(1.0));
- constexpr int kMaxBeamSize = 8;
- Beam<TestTransitionState> beam(kMaxBeamSize);
- beam.SetFunctions(null_permissions, null_finality, transition_function,
- null_oracle);
- beam.Init(std::move(states));
- // Validate the new beam history.
- EXPECT_EQ(beam.history().size(), 1);
- // Reset the beam.
- constexpr int kNewMaxBeamSize = 4;
- beam.SetMaxSize(kNewMaxBeamSize);
- EXPECT_EQ(beam.max_size(), kNewMaxBeamSize);
- // Validate the now-reset beam history, which should be empty.
- EXPECT_EQ(beam.history().size(), 0);
- }
- } // namespace dragnn
- } // namespace syntaxnet
|