beam_test.cc 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/core/beam.h"
  16. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  17. #include "dragnn/core/interfaces/transition_state.h"
  18. #include "dragnn/core/test/mock_transition_state.h"
  19. #include <gmock/gmock.h>
  20. #include "tensorflow/core/platform/test.h"
  21. namespace syntaxnet {
  22. namespace dragnn {
  23. using testing::MockFunction;
  24. using testing::Return;
  25. using testing::Ne;
  26. using testing::_;
  27. namespace {
  28. // *****************************************************************************
  29. // Test-internal class definitions.
  30. // *****************************************************************************
  31. // Create a very basic transition state to test the beam. All it does is keep
  32. // track of its current beam index and score, as well as providing a field
  33. // for the transition function to write in what transition occurred.
  34. // Note that this class does not fulfill the entire TransitionState contract,
  35. // since it is only used in this particular test.
  36. class TestTransitionState
  37. : public CloneableTransitionState<TestTransitionState> {
  38. public:
  39. TestTransitionState() {}
  40. void Init(const TransitionState &parent) override {}
  41. std::unique_ptr<TestTransitionState> Clone() const override {
  42. std::unique_ptr<TestTransitionState> ptr(new TestTransitionState());
  43. return ptr;
  44. }
  45. const int ParentBeamIndex() const override { return parent_beam_index_; }
  46. // Get the current beam index for this state.
  47. const int GetBeamIndex() const override { return beam_index_; }
  48. // Set the current beam index for this state.
  49. void SetBeamIndex(const int index) override { beam_index_ = index; }
  50. // Get the score associated with this transition state.
  51. const float GetScore() const override { return score_; }
  52. // Set the score associated with this transition state.
  53. void SetScore(const float score) override { score_ = score; }
  54. // Depicts this state as an HTML-language string.
  55. string HTMLRepresentation() const override { return ""; }
  56. int parent_beam_index_;
  57. int beam_index_;
  58. float score_;
  59. int transition_action_;
  60. };
  61. // This transition function annotates a TestTransitionState with the action that
  62. // was chosen for the transition.
  63. auto transition_function = [](TestTransitionState *state, int action) {
  64. TestTransitionState *cast_state = dynamic_cast<TestTransitionState *>(state);
  65. cast_state->transition_action_ = action;
  66. };
  67. // Create oracle and permission functions that do nothing.
  68. auto null_oracle = [](TestTransitionState *) { return 0; };
  69. auto null_permissions = [](TestTransitionState *, int) { return true; };
  70. auto null_finality = [](TestTransitionState *) { return false; };
  71. // Create a unique_ptr with a test transition state in it and set its initial
  72. // score.
  73. std::unique_ptr<TestTransitionState> CreateState(float score) {
  74. std::unique_ptr<TestTransitionState> state;
  75. state.reset(new TestTransitionState());
  76. state->SetScore(score);
  77. return state;
  78. }
  79. // Convenience accessor for the action field in TestTransitionState.
  80. int GetTransition(const TransitionState *state) {
  81. return (dynamic_cast<const TestTransitionState *>(state))->transition_action_;
  82. }
  83. // Convenience accessor for the parent_beam_index_ field in TestTransitionState.
  84. void SetParentBeamIndex(TransitionState *state, int index) {
  85. (dynamic_cast<TestTransitionState *>(state))->parent_beam_index_ = index;
  86. }
  87. } // namespace
  88. // *****************************************************************************
  89. // Tests begin here.
  90. // *****************************************************************************
  91. TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
  92. // Create a matrix of transitions.
  93. constexpr int kNumTransitions = 4;
  94. constexpr int kMatrixSize = kNumTransitions;
  95. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
  96. constexpr int kBestTransition = 2;
  97. constexpr float kOldScore = 3.0;
  98. // Create the beam and transition it.
  99. std::vector<std::unique_ptr<TestTransitionState>> states;
  100. states.push_back(CreateState(kOldScore));
  101. constexpr int kBeamSize = 1;
  102. Beam<TestTransitionState> beam(kBeamSize);
  103. beam.SetFunctions(null_permissions, null_finality, transition_function,
  104. null_oracle);
  105. beam.Init(std::move(states));
  106. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  107. // Validate the new beam.
  108. EXPECT_EQ(beam.beam().size(), kBeamSize);
  109. // Make sure the state has performed the expected transition.
  110. EXPECT_EQ(GetTransition(beam.beam().at(0)), kBestTransition);
  111. // Make sure the state has had its score updated properly.
  112. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[kBestTransition]);
  113. // Make sure that the beam index field is consistent with the actual beam idx.
  114. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
  115. // Make sure that the beam_state accessor actually accesses the beam.
  116. EXPECT_EQ(beam.beam().at(0), beam.beam_state(0));
  117. // Validate the beam history field.
  118. auto history = beam.history();
  119. EXPECT_EQ(history.at(1).at(0), 0);
  120. }
  121. TEST(BeamTest, AdvancingCreatesNewTransitions) {
  122. // Create a matrix of transitions.
  123. constexpr int kMaxBeamSize = 8;
  124. constexpr int kNumTransitions = 4;
  125. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  126. constexpr float matrix[kMatrixSize] = {
  127. 30.0, 20.0, 40.0, 10.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  128. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  129. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  130. constexpr float kOldScore = 4.0;
  131. // Create the beam and transition it.
  132. std::vector<std::unique_ptr<TestTransitionState>> states;
  133. states.push_back(CreateState(kOldScore));
  134. Beam<TestTransitionState> beam(kMaxBeamSize);
  135. beam.SetFunctions(null_permissions, null_finality, transition_function,
  136. null_oracle);
  137. beam.Init(std::move(states));
  138. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  139. // Validate the new beam.
  140. EXPECT_EQ(beam.beam().size(), 4);
  141. // Make sure the state has performed the expected transition.
  142. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  143. EXPECT_EQ(GetTransition(beam.beam().at(1)), 0);
  144. EXPECT_EQ(GetTransition(beam.beam().at(2)), 1);
  145. EXPECT_EQ(GetTransition(beam.beam().at(3)), 3);
  146. // Make sure the state has had its score updated properly.
  147. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
  148. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[0]);
  149. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[1]);
  150. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScore + matrix[3]);
  151. // Make sure that the beam index field is consistent with the actual beam idx.
  152. for (int i = 0; i < beam.beam().size(); ++i) {
  153. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  154. }
  155. // In this case, we expect the top 4 results to have come from state 0 and
  156. // the remaining 4 slots to be empty (-1).
  157. auto history = beam.history();
  158. EXPECT_EQ(history.at(1).at(0), 0);
  159. EXPECT_EQ(history.at(1).at(1), 0);
  160. EXPECT_EQ(history.at(1).at(2), 0);
  161. EXPECT_EQ(history.at(1).at(3), 0);
  162. EXPECT_EQ(history.at(1).at(4), -1);
  163. EXPECT_EQ(history.at(1).at(5), -1);
  164. EXPECT_EQ(history.at(1).at(6), -1);
  165. EXPECT_EQ(history.at(1).at(7), -1);
  166. }
  167. TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
  168. // Create a matrix of transitions.
  169. constexpr int kMaxBeamSize = 8;
  170. constexpr int kNumTransitions = 4;
  171. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  172. constexpr float matrix[kMatrixSize] = {
  173. 30.0, 20.0, 40.0, 10.0, // State 0
  174. 31.0, 21.0, 41.0, 11.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  175. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  176. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  177. constexpr float kOldScores[] = {5.0, 7.0};
  178. // Create the beam and transition it.
  179. std::vector<std::unique_ptr<TestTransitionState>> states;
  180. states.push_back(CreateState(kOldScores[0]));
  181. states.push_back(CreateState(kOldScores[1]));
  182. Beam<TestTransitionState> beam(kMaxBeamSize);
  183. beam.SetFunctions(null_permissions, null_finality, transition_function,
  184. null_oracle);
  185. beam.Init(std::move(states));
  186. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  187. // Validate the new beam.
  188. EXPECT_EQ(beam.beam().size(), 8);
  189. // Make sure the state has performed the expected transition.
  190. // Note that the transition index is not the index into the matrix, but rather
  191. // the index into the matrix 'row' for that state.
  192. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  193. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  194. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  195. EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
  196. EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
  197. EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
  198. EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
  199. EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
  200. // Make sure the state has had its score updated properly.
  201. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
  202. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
  203. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
  204. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
  205. EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
  206. EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
  207. EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
  208. EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
  209. // Make sure that the beam index field is consistent with the actual beam idx.
  210. for (int i = 0; i < beam.beam().size(); ++i) {
  211. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  212. }
  213. // Validate the history at this step.
  214. auto history = beam.history();
  215. EXPECT_EQ(history.at(1).at(0), 1);
  216. EXPECT_EQ(history.at(1).at(1), 0);
  217. EXPECT_EQ(history.at(1).at(2), 1);
  218. EXPECT_EQ(history.at(1).at(3), 0);
  219. EXPECT_EQ(history.at(1).at(4), 1);
  220. EXPECT_EQ(history.at(1).at(5), 0);
  221. EXPECT_EQ(history.at(1).at(6), 1);
  222. EXPECT_EQ(history.at(1).at(7), 0);
  223. }
  224. TEST(BeamTest, AdvancingDropsLowValuePredictions) {
  225. // Create a matrix of transitions.
  226. constexpr int kNumTransitions = 4;
  227. constexpr int kMaxBeamSize = 8;
  228. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  229. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0, // State 0
  230. 31.0, 21.0, 41.0, 11.0, // State 1
  231. 32.0, 22.0, 42.0, 12.0, // State 2
  232. 33.0, 23.0, 43.0, 13.0, // State 3
  233. 34.0, 24.0, 44.0, 14.0, // State 4
  234. 35.0, 25.0, 45.0, 15.0, // State 5
  235. 36.0, 26.0, 46.0, 16.0, // State 6
  236. 37.0, 27.0, 47.0, 17.0}; // State 7
  237. constexpr float kOldScores[] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
  238. // Create the beam and transition it.
  239. std::vector<std::unique_ptr<TestTransitionState>> states;
  240. states.push_back(CreateState(kOldScores[0]));
  241. states.push_back(CreateState(kOldScores[1]));
  242. states.push_back(CreateState(kOldScores[2]));
  243. states.push_back(CreateState(kOldScores[3]));
  244. states.push_back(CreateState(kOldScores[4]));
  245. states.push_back(CreateState(kOldScores[5]));
  246. states.push_back(CreateState(kOldScores[6]));
  247. states.push_back(CreateState(kOldScores[7]));
  248. Beam<TestTransitionState> beam(kMaxBeamSize);
  249. beam.SetFunctions(null_permissions, null_finality, transition_function,
  250. null_oracle);
  251. beam.Init(std::move(states));
  252. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  253. // Validate the new beam.
  254. EXPECT_EQ(beam.beam().size(), 8);
  255. // Make sure the state has performed the expected transition.
  256. // In this case, every state will perform transition 2.
  257. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  258. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  259. EXPECT_EQ(GetTransition(beam.beam().at(2)), 2);
  260. EXPECT_EQ(GetTransition(beam.beam().at(3)), 2);
  261. EXPECT_EQ(GetTransition(beam.beam().at(4)), 2);
  262. EXPECT_EQ(GetTransition(beam.beam().at(5)), 2);
  263. EXPECT_EQ(GetTransition(beam.beam().at(6)), 2);
  264. EXPECT_EQ(GetTransition(beam.beam().at(7)), 2);
  265. // Make sure the state has had its score updated properly. (Note that row
  266. // 0 had the smallest transition score, so it ends up on the bottom of the
  267. // beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
  268. // correct state row and we add 2 since that was the transition index.
  269. EXPECT_EQ(beam.beam().at(0)->GetScore(),
  270. kOldScores[7] + matrix[7 * kNumTransitions + 2]);
  271. EXPECT_EQ(beam.beam().at(1)->GetScore(),
  272. kOldScores[6] + matrix[6 * kNumTransitions + 2]);
  273. EXPECT_EQ(beam.beam().at(2)->GetScore(),
  274. kOldScores[5] + matrix[5 * kNumTransitions + 2]);
  275. EXPECT_EQ(beam.beam().at(3)->GetScore(),
  276. kOldScores[4] + matrix[4 * kNumTransitions + 2]);
  277. EXPECT_EQ(beam.beam().at(4)->GetScore(),
  278. kOldScores[3] + matrix[3 * kNumTransitions + 2]);
  279. EXPECT_EQ(beam.beam().at(5)->GetScore(),
  280. kOldScores[2] + matrix[2 * kNumTransitions + 2]);
  281. EXPECT_EQ(beam.beam().at(6)->GetScore(),
  282. kOldScores[1] + matrix[1 * kNumTransitions + 2]);
  283. EXPECT_EQ(beam.beam().at(7)->GetScore(),
  284. kOldScores[0] + matrix[0 * kNumTransitions + 2]);
  285. // Make sure that the beam index field is consistent with the actual beam idx.
  286. for (int i = 0; i < beam.beam().size(); ++i) {
  287. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  288. }
  289. auto history = beam.history();
  290. EXPECT_EQ(history.at(1).at(0), 7);
  291. EXPECT_EQ(history.at(1).at(1), 6);
  292. EXPECT_EQ(history.at(1).at(2), 5);
  293. EXPECT_EQ(history.at(1).at(3), 4);
  294. EXPECT_EQ(history.at(1).at(4), 3);
  295. EXPECT_EQ(history.at(1).at(5), 2);
  296. EXPECT_EQ(history.at(1).at(6), 1);
  297. EXPECT_EQ(history.at(1).at(7), 0);
  298. }
  299. TEST(BeamTest, AdvancesFromOracleWithSingleBeam) {
  300. // Create an oracle function for this state.
  301. constexpr int kOracleLabel = 3;
  302. auto oracle_function = [](TransitionState *) { return kOracleLabel; };
  303. // Create the beam and transition it.
  304. std::vector<std::unique_ptr<TestTransitionState>> states;
  305. states.push_back(CreateState(0.0));
  306. constexpr int kBeamSize = 1;
  307. Beam<TestTransitionState> beam(kBeamSize);
  308. beam.SetFunctions(null_permissions, null_finality, transition_function,
  309. oracle_function);
  310. beam.Init(std::move(states));
  311. beam.AdvanceFromOracle();
  312. // Validate the new beam.
  313. EXPECT_EQ(beam.beam().size(), kBeamSize);
  314. // Make sure the state has performed the expected transition.
  315. EXPECT_EQ(GetTransition(beam.beam().at(0)), kOracleLabel);
  316. // Make sure the state has had its score held to 0.
  317. EXPECT_EQ(beam.beam().at(0)->GetScore(), 0.0);
  318. // Make sure that the beam index field is consistent with the actual beam idx.
  319. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
  320. // Validate the beam history field.
  321. auto history = beam.history();
  322. EXPECT_EQ(history.at(1).at(0), 0);
  323. }
  324. TEST(BeamTest, AdvancesFromOracleWithMultipleStates) {
  325. constexpr int kMaxBeamSize = 8;
  326. // Create a beam with 8 transition states.
  327. std::vector<std::unique_ptr<TestTransitionState>> states;
  328. for (int i = 0; i < kMaxBeamSize; ++i) {
  329. // This is nonzero to test the oracle holding scores to 0.
  330. states.push_back(CreateState(10.0));
  331. }
  332. std::vector<int> expected_actions;
  333. // Create an oracle function for this state. Use mocks for finer control.
  334. testing::MockFunction<int(TestTransitionState *)> mock_oracle_function;
  335. for (int i = 0; i < kMaxBeamSize; ++i) {
  336. // We expect each state to be queried for its oracle label,
  337. // and then to be transitioned in place with its oracle label.
  338. int oracle_label = i % 3; // 3 is arbitrary.
  339. EXPECT_CALL(mock_oracle_function, Call(states.at(i).get()))
  340. .WillOnce(Return(oracle_label));
  341. expected_actions.push_back(oracle_label);
  342. }
  343. Beam<TestTransitionState> beam(kMaxBeamSize);
  344. beam.SetFunctions(null_permissions, null_finality, transition_function,
  345. mock_oracle_function.AsStdFunction());
  346. beam.Init(std::move(states));
  347. beam.AdvanceFromOracle();
  348. // Make sure the state has performed the expected transition, has had its
  349. // score held to 0, and is self consistent.
  350. for (int i = 0; i < beam.beam().size(); ++i) {
  351. EXPECT_EQ(GetTransition(beam.beam().at(i)), expected_actions.at(i));
  352. EXPECT_EQ(beam.beam().at(i)->GetScore(), 0.0);
  353. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  354. }
  355. auto history = beam.history();
  356. for (int i = 0; i < beam.beam().size(); ++i) {
  357. EXPECT_EQ(history.at(1).at(i), i);
  358. }
  359. }
  360. TEST(BeamTest, ReportsNonFinality) {
  361. constexpr int kMaxBeamSize = 8;
  362. // Create a beam with 8 transition states.
  363. std::vector<std::unique_ptr<TestTransitionState>> states;
  364. for (int i = 0; i < kMaxBeamSize; ++i) {
  365. // This is nonzero to test the oracle holding scores to 0.
  366. states.push_back(CreateState(10.0));
  367. }
  368. std::vector<int> expected_actions;
  369. // Create a finality function for this state. Use mocks for finer control.
  370. testing::MockFunction<int(TestTransitionState *)> mock_finality_function;
  371. // Make precisely one call return false, which should cause IsFinal
  372. // to report false.
  373. constexpr int incomplete_state = 3;
  374. EXPECT_CALL(mock_finality_function, Call(states.at(incomplete_state).get()))
  375. .WillOnce(Return(false));
  376. EXPECT_CALL(mock_finality_function,
  377. Call(Ne(states.at(incomplete_state).get())))
  378. .WillRepeatedly(Return(true));
  379. Beam<TestTransitionState> beam(kMaxBeamSize);
  380. beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
  381. transition_function, null_oracle);
  382. beam.Init(std::move(states));
  383. EXPECT_FALSE(beam.IsTerminal());
  384. }
  385. TEST(BeamTest, ReportsFinality) {
  386. constexpr int kMaxBeamSize = 8;
  387. // Create a beam with 8 transition states.
  388. std::vector<std::unique_ptr<TestTransitionState>> states;
  389. for (int i = 0; i < kMaxBeamSize; ++i) {
  390. // This is nonzero to test the oracle holding scores to 0.
  391. states.push_back(CreateState(10.0));
  392. }
  393. std::vector<int> expected_actions;
  394. // Create a finality function for this state. Use mocks for finer control.
  395. testing::MockFunction<int(TransitionState *)> mock_finality_function;
  396. // All calls will return true, so IsFinal should return true.
  397. EXPECT_CALL(mock_finality_function, Call(_)).WillRepeatedly(Return(true));
  398. Beam<TestTransitionState> beam(kMaxBeamSize);
  399. beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
  400. transition_function, null_oracle);
  401. beam.Init(std::move(states));
  402. EXPECT_TRUE(beam.IsTerminal());
  403. }
  404. TEST(BeamTest, IgnoresForbiddenTransitionActions) {
  405. // Create a matrix of transitions.
  406. constexpr int kMaxBeamSize = 4;
  407. constexpr int kNumTransitions = 4;
  408. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  409. constexpr float matrix[kMatrixSize] = {
  410. 10.0, 1000.0, 40.0, 30.0, 00.0, 0000.0, 00.0, 00.0,
  411. 00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0};
  412. constexpr float kOldScore = 4.0;
  413. // Create the beam.
  414. std::vector<std::unique_ptr<TestTransitionState>> states;
  415. states.push_back(CreateState(kOldScore));
  416. // Forbid the second transition (index 1).
  417. testing::MockFunction<int(TestTransitionState *, int)>
  418. mock_permission_function;
  419. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 0))
  420. .WillOnce(Return(true));
  421. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 1))
  422. .WillOnce(Return(false));
  423. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 2))
  424. .WillOnce(Return(true));
  425. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 3))
  426. .WillOnce(Return(true));
  427. Beam<TestTransitionState> beam(kMaxBeamSize);
  428. beam.SetFunctions(mock_permission_function.AsStdFunction(), null_finality,
  429. transition_function, null_oracle);
  430. beam.Init(std::move(states));
  431. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  432. // Validate the new beam.
  433. EXPECT_EQ(beam.beam().size(), 3);
  434. // Make sure the state has performed the expected transition.
  435. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  436. EXPECT_EQ(GetTransition(beam.beam().at(1)), 3);
  437. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  438. // Make sure the state has had its score updated properly.
  439. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
  440. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[3]);
  441. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[0]);
  442. // Make sure that the beam index field is consistent with the actual beam idx.
  443. for (int i = 0; i < beam.beam().size(); ++i) {
  444. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  445. }
  446. // In this case, we expect the top 3 results to have come from state 0 and
  447. // the remaining 3 slots to be empty (-1).
  448. auto history = beam.history();
  449. EXPECT_EQ(history.at(1).at(0), 0);
  450. EXPECT_EQ(history.at(1).at(1), 0);
  451. EXPECT_EQ(history.at(1).at(2), 0);
  452. EXPECT_EQ(history.at(1).at(3), -1);
  453. }
  454. TEST(BeamTest, BadlySizedMatrixDies) {
  455. // Create a matrix of transitions.
  456. constexpr int kNumTransitions = 4;
  457. constexpr int kMatrixSize = 4; // We have a max beam size of 4; should be 16.
  458. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
  459. // Create the beam and transition it.
  460. std::vector<std::unique_ptr<TestTransitionState>> states;
  461. states.push_back(CreateState(0.0));
  462. states.push_back(CreateState(0.0));
  463. constexpr int kMaxBeamSize = 8;
  464. Beam<TestTransitionState> beam(kMaxBeamSize);
  465. beam.SetFunctions(null_permissions, null_finality, transition_function,
  466. null_oracle);
  467. beam.Init(std::move(states));
  468. // This matrix should have 8 elements, not 4, so this should die.
  469. EXPECT_DEATH(beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions),
  470. "Transition matrix size does not match max beam size \\* number "
  471. "of state transitions");
  472. }
  473. TEST(BeamTest, BadlySizedBeamInitializationDies) {
  474. // Create an initialization beam too large for the max beam size.
  475. constexpr int kMaxBeamSize = 4;
  476. std::vector<std::unique_ptr<TestTransitionState>> states;
  477. for (int i = 0; i < kMaxBeamSize + 1; ++i) {
  478. states.push_back(CreateState(0.0));
  479. }
  480. Beam<TestTransitionState> beam(kMaxBeamSize);
  481. beam.SetFunctions(null_permissions, null_finality, transition_function,
  482. null_oracle);
  483. // Try to initialize the beam; this should die.
  484. EXPECT_DEATH(beam.Init(std::move(states)),
  485. "Attempted to initialize a beam with more states");
  486. }
  487. TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) {
  488. // Create a standard beam.
  489. constexpr int kMaxBeamSize = 4;
  490. std::vector<std::unique_ptr<TestTransitionState>> states;
  491. for (int i = 0; i < kMaxBeamSize; ++i) {
  492. states.push_back(CreateState(0.0));
  493. }
  494. Beam<TestTransitionState> beam(kMaxBeamSize);
  495. beam.SetFunctions(null_permissions, null_finality, transition_function,
  496. null_oracle);
  497. beam.Init(std::move(states));
  498. // Verify that all beam indices have been initialized.
  499. for (int i = 0; i < kMaxBeamSize; ++i) {
  500. EXPECT_EQ(i, beam.beam_state(i)->GetBeamIndex());
  501. }
  502. }
  503. TEST(BeamTest, FindPreviousIndexTracesHistory) {
  504. // Create a matrix of transitions.
  505. constexpr int kNumTransitions = 4;
  506. constexpr int kMaxBeamSize = 8;
  507. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  508. constexpr float matrix[kMatrixSize] = {
  509. 30.0, 20.0, 40.0, 10.0, // State 0
  510. 31.0, 21.0, 41.0, 11.0, // State 1
  511. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  512. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  513. constexpr float kOldScores[] = {5.0, 7.0};
  514. constexpr int kParentBeamIndices[] = {1138, 42};
  515. // Create the beam and transition it.
  516. std::vector<std::unique_ptr<TestTransitionState>> states;
  517. states.push_back(CreateState(kOldScores[0]));
  518. states.push_back(CreateState(kOldScores[1]));
  519. // Set parent beam indices.
  520. SetParentBeamIndex(states.at(0).get(), kParentBeamIndices[0]);
  521. SetParentBeamIndex(states.at(1).get(), kParentBeamIndices[1]);
  522. Beam<TestTransitionState> beam(kMaxBeamSize);
  523. beam.SetFunctions(null_permissions, null_finality, transition_function,
  524. null_oracle);
  525. beam.Init(std::move(states));
  526. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  527. // Validate the new beam.
  528. EXPECT_EQ(beam.beam().size(), 8);
  529. // Make sure the state has performed the expected transition.
  530. // Note that the transition index is not the index into the matrix, but rather
  531. // the index into the matrix 'row' for that state.
  532. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  533. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  534. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  535. EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
  536. EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
  537. EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
  538. EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
  539. EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
  540. // Make sure the state has had its score updated properly.
  541. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
  542. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
  543. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
  544. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
  545. EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
  546. EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
  547. EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
  548. EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
  549. // Make sure that the beam index field is consistent with the actual beam idx.
  550. for (int i = 0; i < beam.beam().size(); ++i) {
  551. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  552. }
  553. // Validate the history at this step.
  554. auto history = beam.history();
  555. EXPECT_EQ(history.at(1).at(0), 1);
  556. EXPECT_EQ(history.at(1).at(1), 0);
  557. EXPECT_EQ(history.at(1).at(2), 1);
  558. EXPECT_EQ(history.at(1).at(3), 0);
  559. EXPECT_EQ(history.at(1).at(4), 1);
  560. EXPECT_EQ(history.at(1).at(5), 0);
  561. EXPECT_EQ(history.at(1).at(6), 1);
  562. EXPECT_EQ(history.at(1).at(7), 0);
  563. EXPECT_EQ(history.at(0).at(0), kParentBeamIndices[0]);
  564. EXPECT_EQ(history.at(0).at(1), kParentBeamIndices[1]);
  565. EXPECT_EQ(history.at(0).at(2), -1);
  566. EXPECT_EQ(history.at(0).at(3), -1);
  567. EXPECT_EQ(history.at(0).at(4), -1);
  568. EXPECT_EQ(history.at(0).at(5), -1);
  569. EXPECT_EQ(history.at(0).at(6), -1);
  570. EXPECT_EQ(history.at(0).at(7), -1);
  571. // Make sure that FindPreviousIndex can read through the history from step 1
  572. // to step 0.
  573. constexpr int kDesiredIndex = 0;
  574. constexpr int kCurrentIndexOne = 4;
  575. EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexOne, kDesiredIndex),
  576. kParentBeamIndices[1]);
  577. constexpr int kCurrentIndexTwo = 7;
  578. EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexTwo, kDesiredIndex),
  579. kParentBeamIndices[0]);
  580. }
  581. TEST(BeamTest, FindPreviousIndexReturnsInError) {
  582. // Create the beam. This now has only one history state, 0.
  583. std::vector<std::unique_ptr<TestTransitionState>> states;
  584. states.push_back(CreateState(0.0));
  585. constexpr int kMaxBeamSize = 8;
  586. Beam<TestTransitionState> beam(kMaxBeamSize);
  587. beam.SetFunctions(null_permissions, null_finality, transition_function,
  588. null_oracle);
  589. beam.Init(std::move(states));
  590. // If the requested step is greater than the number of steps taken, expect -1.
  591. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
  592. // If the requested step is less than 0, expect -1.
  593. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
  594. // If the requested index does not have a state, expect -1.
  595. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
  596. // If the requested index is less than 0, expect -1.
  597. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
  598. // If the requested index is larger than the maximum beam size -1, expect -1.
  599. EXPECT_EQ(beam.FindPreviousIndex(0, kMaxBeamSize), -1);
  600. }
  601. TEST(BeamTest, ResetClearsBeamState) {
  602. // Create the beam
  603. std::vector<std::unique_ptr<TestTransitionState>> states;
  604. states.push_back(CreateState(1.0));
  605. constexpr int kMaxBeamSize = 8;
  606. Beam<TestTransitionState> beam(kMaxBeamSize);
  607. beam.SetFunctions(null_permissions, null_finality, transition_function,
  608. null_oracle);
  609. beam.Init(std::move(states));
  610. // Validate the new beam.
  611. EXPECT_EQ(beam.beam().size(), 1);
  612. // Reset the beam.
  613. beam.Reset();
  614. // Validate the now-reset beam, which should be empty.
  615. EXPECT_EQ(beam.beam().size(), 0);
  616. }
  617. TEST(BeamTest, ResetClearsBeamHistory) {
  618. // Create the beam
  619. std::vector<std::unique_ptr<TestTransitionState>> states;
  620. states.push_back(CreateState(1.0));
  621. constexpr int kMaxBeamSize = 8;
  622. Beam<TestTransitionState> beam(kMaxBeamSize);
  623. beam.SetFunctions(null_permissions, null_finality, transition_function,
  624. null_oracle);
  625. beam.Init(std::move(states));
  626. // Validate the new beam history.
  627. EXPECT_EQ(beam.history().size(), 1);
  628. // Reset the beam.
  629. beam.Reset();
  630. // Validate the now-reset beam history, which should be empty.
  631. EXPECT_EQ(beam.history().size(), 0);
  632. }
  633. TEST(BeamTest, SettingMaxSizeResetsBeam) {
  634. // Create the beam
  635. std::vector<std::unique_ptr<TestTransitionState>> states;
  636. states.push_back(CreateState(1.0));
  637. constexpr int kMaxBeamSize = 8;
  638. Beam<TestTransitionState> beam(kMaxBeamSize);
  639. beam.SetFunctions(null_permissions, null_finality, transition_function,
  640. null_oracle);
  641. beam.Init(std::move(states));
  642. // Validate the new beam history.
  643. EXPECT_EQ(beam.history().size(), 1);
  644. // Reset the beam.
  645. constexpr int kNewMaxBeamSize = 4;
  646. beam.SetMaxSize(kNewMaxBeamSize);
  647. EXPECT_EQ(beam.max_size(), kNewMaxBeamSize);
  648. // Validate the now-reset beam history, which should be empty.
  649. EXPECT_EQ(beam.history().size(), 0);
  650. }
  651. } // namespace dragnn
  652. } // namespace syntaxnet