beam_test.cc 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774
  1. #include "dragnn/core/beam.h"
  2. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  3. #include "dragnn/core/interfaces/transition_state.h"
  4. #include "dragnn/core/test/mock_transition_state.h"
  5. #include <gmock/gmock.h>
  6. #include "tensorflow/core/platform/test.h"
  7. namespace syntaxnet {
  8. namespace dragnn {
  9. using testing::MockFunction;
  10. using testing::Return;
  11. using testing::Ne;
  12. using testing::_;
  13. namespace {
  14. // *****************************************************************************
  15. // Test-internal class definitions.
  16. // *****************************************************************************
  17. // Create a very basic transition state to test the beam. All it does is keep
  18. // track of its current beam index and score, as well as providing a field
  19. // for the transition function to write in what transition occurred.
  20. // Note that this class does not fulfill the entire TransitionState contract,
  21. // since it is only used in this particular test.
  22. class TestTransitionState
  23. : public CloneableTransitionState<TestTransitionState> {
  24. public:
  25. TestTransitionState() {}
  26. void Init(const TransitionState &parent) override {}
  27. std::unique_ptr<TestTransitionState> Clone() const override {
  28. std::unique_ptr<TestTransitionState> ptr(new TestTransitionState());
  29. return ptr;
  30. }
  31. const int ParentBeamIndex() const override { return parent_beam_index_; }
  32. // Get the current beam index for this state.
  33. const int GetBeamIndex() const override { return beam_index_; }
  34. // Set the current beam index for this state.
  35. void SetBeamIndex(const int index) override { beam_index_ = index; }
  36. // Get the score associated with this transition state.
  37. const float GetScore() const override { return score_; }
  38. // Set the score associated with this transition state.
  39. void SetScore(const float score) override { score_ = score; }
  40. // Depicts this state as an HTML-language string.
  41. string HTMLRepresentation() const override { return ""; }
  42. int parent_beam_index_;
  43. int beam_index_;
  44. float score_;
  45. int transition_action_;
  46. };
  47. // This transition function annotates a TestTransitionState with the action that
  48. // was chosen for the transition.
  49. auto transition_function = [](TestTransitionState *state, int action) {
  50. TestTransitionState *cast_state = dynamic_cast<TestTransitionState *>(state);
  51. cast_state->transition_action_ = action;
  52. };
  53. // Create oracle and permission functions that do nothing.
  54. auto null_oracle = [](TestTransitionState *) { return 0; };
  55. auto null_permissions = [](TestTransitionState *, int) { return true; };
  56. auto null_finality = [](TestTransitionState *) { return false; };
  57. // Create a unique_ptr with a test transition state in it and set its initial
  58. // score.
  59. std::unique_ptr<TestTransitionState> CreateState(float score) {
  60. std::unique_ptr<TestTransitionState> state;
  61. state.reset(new TestTransitionState());
  62. state->SetScore(score);
  63. return state;
  64. }
  65. // Convenience accessor for the action field in TestTransitionState.
  66. int GetTransition(const TransitionState *state) {
  67. return (dynamic_cast<const TestTransitionState *>(state))->transition_action_;
  68. }
  69. // Convenience accessor for the parent_beam_index_ field in TestTransitionState.
  70. void SetParentBeamIndex(TransitionState *state, int index) {
  71. (dynamic_cast<TestTransitionState *>(state))->parent_beam_index_ = index;
  72. }
  73. } // namespace
  74. // *****************************************************************************
  75. // Tests begin here.
  76. // *****************************************************************************
  77. TEST(BeamTest, AdvancesFromPredictionWithSingleBeam) {
  78. // Create a matrix of transitions.
  79. constexpr int kNumTransitions = 4;
  80. constexpr int kMatrixSize = kNumTransitions;
  81. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
  82. constexpr int kBestTransition = 2;
  83. constexpr float kOldScore = 3.0;
  84. // Create the beam and transition it.
  85. std::vector<std::unique_ptr<TestTransitionState>> states;
  86. states.push_back(CreateState(kOldScore));
  87. constexpr int kBeamSize = 1;
  88. Beam<TestTransitionState> beam(kBeamSize);
  89. beam.SetFunctions(null_permissions, null_finality, transition_function,
  90. null_oracle);
  91. beam.Init(std::move(states));
  92. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  93. // Validate the new beam.
  94. EXPECT_EQ(beam.beam().size(), kBeamSize);
  95. // Make sure the state has performed the expected transition.
  96. EXPECT_EQ(GetTransition(beam.beam().at(0)), kBestTransition);
  97. // Make sure the state has had its score updated properly.
  98. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[kBestTransition]);
  99. // Make sure that the beam index field is consistent with the actual beam idx.
  100. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
  101. // Make sure that the beam_state accessor actually accesses the beam.
  102. EXPECT_EQ(beam.beam().at(0), beam.beam_state(0));
  103. // Validate the beam history field.
  104. auto history = beam.history();
  105. EXPECT_EQ(history.at(1).at(0), 0);
  106. }
  107. TEST(BeamTest, AdvancingCreatesNewTransitions) {
  108. // Create a matrix of transitions.
  109. constexpr int kMaxBeamSize = 8;
  110. constexpr int kNumTransitions = 4;
  111. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  112. constexpr float matrix[kMatrixSize] = {
  113. 30.0, 20.0, 40.0, 10.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  114. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  115. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  116. constexpr float kOldScore = 4.0;
  117. // Create the beam and transition it.
  118. std::vector<std::unique_ptr<TestTransitionState>> states;
  119. states.push_back(CreateState(kOldScore));
  120. Beam<TestTransitionState> beam(kMaxBeamSize);
  121. beam.SetFunctions(null_permissions, null_finality, transition_function,
  122. null_oracle);
  123. beam.Init(std::move(states));
  124. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  125. // Validate the new beam.
  126. EXPECT_EQ(beam.beam().size(), 4);
  127. // Make sure the state has performed the expected transition.
  128. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  129. EXPECT_EQ(GetTransition(beam.beam().at(1)), 0);
  130. EXPECT_EQ(GetTransition(beam.beam().at(2)), 1);
  131. EXPECT_EQ(GetTransition(beam.beam().at(3)), 3);
  132. // Make sure the state has had its score updated properly.
  133. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
  134. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[0]);
  135. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[1]);
  136. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScore + matrix[3]);
  137. // Make sure that the beam index field is consistent with the actual beam idx.
  138. for (int i = 0; i < beam.beam().size(); ++i) {
  139. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  140. }
  141. // In this case, we expect the top 4 results to have come from state 0 and
  142. // the remaining 4 slots to be empty (-1).
  143. auto history = beam.history();
  144. EXPECT_EQ(history.at(1).at(0), 0);
  145. EXPECT_EQ(history.at(1).at(1), 0);
  146. EXPECT_EQ(history.at(1).at(2), 0);
  147. EXPECT_EQ(history.at(1).at(3), 0);
  148. EXPECT_EQ(history.at(1).at(4), -1);
  149. EXPECT_EQ(history.at(1).at(5), -1);
  150. EXPECT_EQ(history.at(1).at(6), -1);
  151. EXPECT_EQ(history.at(1).at(7), -1);
  152. }
  153. TEST(BeamTest, MultipleElementBeamsAdvanceAllElements) {
  154. // Create a matrix of transitions.
  155. constexpr int kMaxBeamSize = 8;
  156. constexpr int kNumTransitions = 4;
  157. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  158. constexpr float matrix[kMatrixSize] = {
  159. 30.0, 20.0, 40.0, 10.0, // State 0
  160. 31.0, 21.0, 41.0, 11.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  161. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  162. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  163. constexpr float kOldScores[] = {5.0, 7.0};
  164. // Create the beam and transition it.
  165. std::vector<std::unique_ptr<TestTransitionState>> states;
  166. states.push_back(CreateState(kOldScores[0]));
  167. states.push_back(CreateState(kOldScores[1]));
  168. Beam<TestTransitionState> beam(kMaxBeamSize);
  169. beam.SetFunctions(null_permissions, null_finality, transition_function,
  170. null_oracle);
  171. beam.Init(std::move(states));
  172. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  173. // Validate the new beam.
  174. EXPECT_EQ(beam.beam().size(), 8);
  175. // Make sure the state has performed the expected transition.
  176. // Note that the transition index is not the index into the matrix, but rather
  177. // the index into the matrix 'row' for that state.
  178. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  179. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  180. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  181. EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
  182. EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
  183. EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
  184. EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
  185. EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
  186. // Make sure the state has had its score updated properly.
  187. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
  188. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
  189. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
  190. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
  191. EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
  192. EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
  193. EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
  194. EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
  195. // Make sure that the beam index field is consistent with the actual beam idx.
  196. for (int i = 0; i < beam.beam().size(); ++i) {
  197. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  198. }
  199. // Validate the history at this step.
  200. auto history = beam.history();
  201. EXPECT_EQ(history.at(1).at(0), 1);
  202. EXPECT_EQ(history.at(1).at(1), 0);
  203. EXPECT_EQ(history.at(1).at(2), 1);
  204. EXPECT_EQ(history.at(1).at(3), 0);
  205. EXPECT_EQ(history.at(1).at(4), 1);
  206. EXPECT_EQ(history.at(1).at(5), 0);
  207. EXPECT_EQ(history.at(1).at(6), 1);
  208. EXPECT_EQ(history.at(1).at(7), 0);
  209. }
  210. TEST(BeamTest, AdvancingDropsLowValuePredictions) {
  211. // Create a matrix of transitions.
  212. constexpr int kNumTransitions = 4;
  213. constexpr int kMaxBeamSize = 8;
  214. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  215. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0, // State 0
  216. 31.0, 21.0, 41.0, 11.0, // State 1
  217. 32.0, 22.0, 42.0, 12.0, // State 2
  218. 33.0, 23.0, 43.0, 13.0, // State 3
  219. 34.0, 24.0, 44.0, 14.0, // State 4
  220. 35.0, 25.0, 45.0, 15.0, // State 5
  221. 36.0, 26.0, 46.0, 16.0, // State 6
  222. 37.0, 27.0, 47.0, 17.0}; // State 7
  223. constexpr float kOldScores[] = {0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8};
  224. // Create the beam and transition it.
  225. std::vector<std::unique_ptr<TestTransitionState>> states;
  226. states.push_back(CreateState(kOldScores[0]));
  227. states.push_back(CreateState(kOldScores[1]));
  228. states.push_back(CreateState(kOldScores[2]));
  229. states.push_back(CreateState(kOldScores[3]));
  230. states.push_back(CreateState(kOldScores[4]));
  231. states.push_back(CreateState(kOldScores[5]));
  232. states.push_back(CreateState(kOldScores[6]));
  233. states.push_back(CreateState(kOldScores[7]));
  234. Beam<TestTransitionState> beam(kMaxBeamSize);
  235. beam.SetFunctions(null_permissions, null_finality, transition_function,
  236. null_oracle);
  237. beam.Init(std::move(states));
  238. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  239. // Validate the new beam.
  240. EXPECT_EQ(beam.beam().size(), 8);
  241. // Make sure the state has performed the expected transition.
  242. // In this case, every state will perform transition 2.
  243. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  244. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  245. EXPECT_EQ(GetTransition(beam.beam().at(2)), 2);
  246. EXPECT_EQ(GetTransition(beam.beam().at(3)), 2);
  247. EXPECT_EQ(GetTransition(beam.beam().at(4)), 2);
  248. EXPECT_EQ(GetTransition(beam.beam().at(5)), 2);
  249. EXPECT_EQ(GetTransition(beam.beam().at(6)), 2);
  250. EXPECT_EQ(GetTransition(beam.beam().at(7)), 2);
  251. // Make sure the state has had its score updated properly. (Note that row
  252. // 0 had the smallest transition score, so it ends up on the bottom of the
  253. // beam, and so forth.) For the matrix index, N*kNumTransitions gets into the
  254. // correct state row and we add 2 since that was the transition index.
  255. EXPECT_EQ(beam.beam().at(0)->GetScore(),
  256. kOldScores[7] + matrix[7 * kNumTransitions + 2]);
  257. EXPECT_EQ(beam.beam().at(1)->GetScore(),
  258. kOldScores[6] + matrix[6 * kNumTransitions + 2]);
  259. EXPECT_EQ(beam.beam().at(2)->GetScore(),
  260. kOldScores[5] + matrix[5 * kNumTransitions + 2]);
  261. EXPECT_EQ(beam.beam().at(3)->GetScore(),
  262. kOldScores[4] + matrix[4 * kNumTransitions + 2]);
  263. EXPECT_EQ(beam.beam().at(4)->GetScore(),
  264. kOldScores[3] + matrix[3 * kNumTransitions + 2]);
  265. EXPECT_EQ(beam.beam().at(5)->GetScore(),
  266. kOldScores[2] + matrix[2 * kNumTransitions + 2]);
  267. EXPECT_EQ(beam.beam().at(6)->GetScore(),
  268. kOldScores[1] + matrix[1 * kNumTransitions + 2]);
  269. EXPECT_EQ(beam.beam().at(7)->GetScore(),
  270. kOldScores[0] + matrix[0 * kNumTransitions + 2]);
  271. // Make sure that the beam index field is consistent with the actual beam idx.
  272. for (int i = 0; i < beam.beam().size(); ++i) {
  273. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  274. }
  275. auto history = beam.history();
  276. EXPECT_EQ(history.at(1).at(0), 7);
  277. EXPECT_EQ(history.at(1).at(1), 6);
  278. EXPECT_EQ(history.at(1).at(2), 5);
  279. EXPECT_EQ(history.at(1).at(3), 4);
  280. EXPECT_EQ(history.at(1).at(4), 3);
  281. EXPECT_EQ(history.at(1).at(5), 2);
  282. EXPECT_EQ(history.at(1).at(6), 1);
  283. EXPECT_EQ(history.at(1).at(7), 0);
  284. }
  285. TEST(BeamTest, AdvancesFromOracleWithSingleBeam) {
  286. // Create an oracle function for this state.
  287. constexpr int kOracleLabel = 3;
  288. auto oracle_function = [](TransitionState *) { return kOracleLabel; };
  289. // Create the beam and transition it.
  290. std::vector<std::unique_ptr<TestTransitionState>> states;
  291. states.push_back(CreateState(0.0));
  292. constexpr int kBeamSize = 1;
  293. Beam<TestTransitionState> beam(kBeamSize);
  294. beam.SetFunctions(null_permissions, null_finality, transition_function,
  295. oracle_function);
  296. beam.Init(std::move(states));
  297. beam.AdvanceFromOracle();
  298. // Validate the new beam.
  299. EXPECT_EQ(beam.beam().size(), kBeamSize);
  300. // Make sure the state has performed the expected transition.
  301. EXPECT_EQ(GetTransition(beam.beam().at(0)), kOracleLabel);
  302. // Make sure the state has had its score held to 0.
  303. EXPECT_EQ(beam.beam().at(0)->GetScore(), 0.0);
  304. // Make sure that the beam index field is consistent with the actual beam idx.
  305. EXPECT_EQ(beam.beam().at(0)->GetBeamIndex(), 0);
  306. // Validate the beam history field.
  307. auto history = beam.history();
  308. EXPECT_EQ(history.at(1).at(0), 0);
  309. }
  310. TEST(BeamTest, AdvancesFromOracleWithMultipleStates) {
  311. constexpr int kMaxBeamSize = 8;
  312. // Create a beam with 8 transition states.
  313. std::vector<std::unique_ptr<TestTransitionState>> states;
  314. for (int i = 0; i < kMaxBeamSize; ++i) {
  315. // This is nonzero to test the oracle holding scores to 0.
  316. states.push_back(CreateState(10.0));
  317. }
  318. std::vector<int> expected_actions;
  319. // Create an oracle function for this state. Use mocks for finer control.
  320. testing::MockFunction<int(TestTransitionState *)> mock_oracle_function;
  321. for (int i = 0; i < kMaxBeamSize; ++i) {
  322. // We expect each state to be queried for its oracle label,
  323. // and then to be transitioned in place with its oracle label.
  324. int oracle_label = i % 3; // 3 is arbitrary.
  325. EXPECT_CALL(mock_oracle_function, Call(states.at(i).get()))
  326. .WillOnce(Return(oracle_label));
  327. expected_actions.push_back(oracle_label);
  328. }
  329. Beam<TestTransitionState> beam(kMaxBeamSize);
  330. beam.SetFunctions(null_permissions, null_finality, transition_function,
  331. mock_oracle_function.AsStdFunction());
  332. beam.Init(std::move(states));
  333. beam.AdvanceFromOracle();
  334. // Make sure the state has performed the expected transition, has had its
  335. // score held to 0, and is self consistent.
  336. for (int i = 0; i < beam.beam().size(); ++i) {
  337. EXPECT_EQ(GetTransition(beam.beam().at(i)), expected_actions.at(i));
  338. EXPECT_EQ(beam.beam().at(i)->GetScore(), 0.0);
  339. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  340. }
  341. auto history = beam.history();
  342. for (int i = 0; i < beam.beam().size(); ++i) {
  343. EXPECT_EQ(history.at(1).at(i), i);
  344. }
  345. }
  346. TEST(BeamTest, ReportsNonFinality) {
  347. constexpr int kMaxBeamSize = 8;
  348. // Create a beam with 8 transition states.
  349. std::vector<std::unique_ptr<TestTransitionState>> states;
  350. for (int i = 0; i < kMaxBeamSize; ++i) {
  351. // This is nonzero to test the oracle holding scores to 0.
  352. states.push_back(CreateState(10.0));
  353. }
  354. std::vector<int> expected_actions;
  355. // Create a finality function for this state. Use mocks for finer control.
  356. testing::MockFunction<int(TestTransitionState *)> mock_finality_function;
  357. // Make precisely one call return false, which should cause IsFinal
  358. // to report false.
  359. constexpr int incomplete_state = 3;
  360. EXPECT_CALL(mock_finality_function, Call(states.at(incomplete_state).get()))
  361. .WillOnce(Return(false));
  362. EXPECT_CALL(mock_finality_function,
  363. Call(Ne(states.at(incomplete_state).get())))
  364. .WillRepeatedly(Return(true));
  365. Beam<TestTransitionState> beam(kMaxBeamSize);
  366. beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
  367. transition_function, null_oracle);
  368. beam.Init(std::move(states));
  369. EXPECT_FALSE(beam.IsTerminal());
  370. }
  371. TEST(BeamTest, ReportsFinality) {
  372. constexpr int kMaxBeamSize = 8;
  373. // Create a beam with 8 transition states.
  374. std::vector<std::unique_ptr<TestTransitionState>> states;
  375. for (int i = 0; i < kMaxBeamSize; ++i) {
  376. // This is nonzero to test the oracle holding scores to 0.
  377. states.push_back(CreateState(10.0));
  378. }
  379. std::vector<int> expected_actions;
  380. // Create a finality function for this state. Use mocks for finer control.
  381. testing::MockFunction<int(TransitionState *)> mock_finality_function;
  382. // All calls will return true, so IsFinal should return true.
  383. EXPECT_CALL(mock_finality_function, Call(_)).WillRepeatedly(Return(true));
  384. Beam<TestTransitionState> beam(kMaxBeamSize);
  385. beam.SetFunctions(null_permissions, mock_finality_function.AsStdFunction(),
  386. transition_function, null_oracle);
  387. beam.Init(std::move(states));
  388. EXPECT_TRUE(beam.IsTerminal());
  389. }
  390. TEST(BeamTest, IgnoresForbiddenTransitionActions) {
  391. // Create a matrix of transitions.
  392. constexpr int kMaxBeamSize = 4;
  393. constexpr int kNumTransitions = 4;
  394. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  395. constexpr float matrix[kMatrixSize] = {
  396. 10.0, 1000.0, 40.0, 30.0, 00.0, 0000.0, 00.0, 00.0,
  397. 00.0, 0000.0, 00.0, 00.0, 00.0, 0000.0, 00.0, 00.0};
  398. constexpr float kOldScore = 4.0;
  399. // Create the beam.
  400. std::vector<std::unique_ptr<TestTransitionState>> states;
  401. states.push_back(CreateState(kOldScore));
  402. // Forbid the second transition (index 1).
  403. testing::MockFunction<int(TestTransitionState *, int)>
  404. mock_permission_function;
  405. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 0))
  406. .WillOnce(Return(true));
  407. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 1))
  408. .WillOnce(Return(false));
  409. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 2))
  410. .WillOnce(Return(true));
  411. EXPECT_CALL(mock_permission_function, Call(states.at(0).get(), 3))
  412. .WillOnce(Return(true));
  413. Beam<TestTransitionState> beam(kMaxBeamSize);
  414. beam.SetFunctions(mock_permission_function.AsStdFunction(), null_finality,
  415. transition_function, null_oracle);
  416. beam.Init(std::move(states));
  417. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  418. // Validate the new beam.
  419. EXPECT_EQ(beam.beam().size(), 3);
  420. // Make sure the state has performed the expected transition.
  421. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  422. EXPECT_EQ(GetTransition(beam.beam().at(1)), 3);
  423. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  424. // Make sure the state has had its score updated properly.
  425. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScore + matrix[2]);
  426. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScore + matrix[3]);
  427. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScore + matrix[0]);
  428. // Make sure that the beam index field is consistent with the actual beam idx.
  429. for (int i = 0; i < beam.beam().size(); ++i) {
  430. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  431. }
  432. // In this case, we expect the top 3 results to have come from state 0 and
  433. // the remaining 3 slots to be empty (-1).
  434. auto history = beam.history();
  435. EXPECT_EQ(history.at(1).at(0), 0);
  436. EXPECT_EQ(history.at(1).at(1), 0);
  437. EXPECT_EQ(history.at(1).at(2), 0);
  438. EXPECT_EQ(history.at(1).at(3), -1);
  439. }
  440. TEST(BeamTest, BadlySizedMatrixDies) {
  441. // Create a matrix of transitions.
  442. constexpr int kNumTransitions = 4;
  443. constexpr int kMatrixSize = 4; // We have a max beam size of 4; should be 16.
  444. constexpr float matrix[kMatrixSize] = {30.0, 20.0, 40.0, 10.0};
  445. // Create the beam and transition it.
  446. std::vector<std::unique_ptr<TestTransitionState>> states;
  447. states.push_back(CreateState(0.0));
  448. states.push_back(CreateState(0.0));
  449. constexpr int kMaxBeamSize = 8;
  450. Beam<TestTransitionState> beam(kMaxBeamSize);
  451. beam.SetFunctions(null_permissions, null_finality, transition_function,
  452. null_oracle);
  453. beam.Init(std::move(states));
  454. // This matrix should have 8 elements, not 4, so this should die.
  455. EXPECT_DEATH(beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions),
  456. "Transition matrix size does not match max beam size \\* number "
  457. "of state transitions");
  458. }
  459. TEST(BeamTest, BadlySizedBeamInitializationDies) {
  460. // Create an initialization beam too large for the max beam size.
  461. constexpr int kMaxBeamSize = 4;
  462. std::vector<std::unique_ptr<TestTransitionState>> states;
  463. for (int i = 0; i < kMaxBeamSize + 1; ++i) {
  464. states.push_back(CreateState(0.0));
  465. }
  466. Beam<TestTransitionState> beam(kMaxBeamSize);
  467. beam.SetFunctions(null_permissions, null_finality, transition_function,
  468. null_oracle);
  469. // Try to initialize the beam; this should die.
  470. EXPECT_DEATH(beam.Init(std::move(states)),
  471. "Attempted to initialize a beam with more states");
  472. }
  473. TEST(BeamTest, ValidBeamIndicesAfterBeamInitialization) {
  474. // Create a standard beam.
  475. constexpr int kMaxBeamSize = 4;
  476. std::vector<std::unique_ptr<TestTransitionState>> states;
  477. for (int i = 0; i < kMaxBeamSize; ++i) {
  478. states.push_back(CreateState(0.0));
  479. }
  480. Beam<TestTransitionState> beam(kMaxBeamSize);
  481. beam.SetFunctions(null_permissions, null_finality, transition_function,
  482. null_oracle);
  483. beam.Init(std::move(states));
  484. // Verify that all beam indices have been initialized.
  485. for (int i = 0; i < kMaxBeamSize; ++i) {
  486. EXPECT_EQ(i, beam.beam_state(i)->GetBeamIndex());
  487. }
  488. }
  489. TEST(BeamTest, FindPreviousIndexTracesHistory) {
  490. // Create a matrix of transitions.
  491. constexpr int kNumTransitions = 4;
  492. constexpr int kMaxBeamSize = 8;
  493. constexpr int kMatrixSize = kNumTransitions * kMaxBeamSize;
  494. constexpr float matrix[kMatrixSize] = {
  495. 30.0, 20.0, 40.0, 10.0, // State 0
  496. 31.0, 21.0, 41.0, 11.0, // State 1
  497. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0,
  498. 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0, 00.0};
  499. constexpr float kOldScores[] = {5.0, 7.0};
  500. constexpr int kParentBeamIndices[] = {1138, 42};
  501. // Create the beam and transition it.
  502. std::vector<std::unique_ptr<TestTransitionState>> states;
  503. states.push_back(CreateState(kOldScores[0]));
  504. states.push_back(CreateState(kOldScores[1]));
  505. // Set parent beam indices.
  506. SetParentBeamIndex(states.at(0).get(), kParentBeamIndices[0]);
  507. SetParentBeamIndex(states.at(1).get(), kParentBeamIndices[1]);
  508. Beam<TestTransitionState> beam(kMaxBeamSize);
  509. beam.SetFunctions(null_permissions, null_finality, transition_function,
  510. null_oracle);
  511. beam.Init(std::move(states));
  512. beam.AdvanceFromPrediction(matrix, kMatrixSize, kNumTransitions);
  513. // Validate the new beam.
  514. EXPECT_EQ(beam.beam().size(), 8);
  515. // Make sure the state has performed the expected transition.
  516. // Note that the transition index is not the index into the matrix, but rather
  517. // the index into the matrix 'row' for that state.
  518. EXPECT_EQ(GetTransition(beam.beam().at(0)), 2);
  519. EXPECT_EQ(GetTransition(beam.beam().at(1)), 2);
  520. EXPECT_EQ(GetTransition(beam.beam().at(2)), 0);
  521. EXPECT_EQ(GetTransition(beam.beam().at(3)), 0);
  522. EXPECT_EQ(GetTransition(beam.beam().at(4)), 1);
  523. EXPECT_EQ(GetTransition(beam.beam().at(5)), 1);
  524. EXPECT_EQ(GetTransition(beam.beam().at(6)), 3);
  525. EXPECT_EQ(GetTransition(beam.beam().at(7)), 3);
  526. // Make sure the state has had its score updated properly.
  527. EXPECT_EQ(beam.beam().at(0)->GetScore(), kOldScores[1] + matrix[6]);
  528. EXPECT_EQ(beam.beam().at(1)->GetScore(), kOldScores[0] + matrix[2]);
  529. EXPECT_EQ(beam.beam().at(2)->GetScore(), kOldScores[1] + matrix[4]);
  530. EXPECT_EQ(beam.beam().at(3)->GetScore(), kOldScores[0] + matrix[0]);
  531. EXPECT_EQ(beam.beam().at(4)->GetScore(), kOldScores[1] + matrix[5]);
  532. EXPECT_EQ(beam.beam().at(5)->GetScore(), kOldScores[0] + matrix[1]);
  533. EXPECT_EQ(beam.beam().at(6)->GetScore(), kOldScores[1] + matrix[7]);
  534. EXPECT_EQ(beam.beam().at(7)->GetScore(), kOldScores[0] + matrix[3]);
  535. // Make sure that the beam index field is consistent with the actual beam idx.
  536. for (int i = 0; i < beam.beam().size(); ++i) {
  537. EXPECT_EQ(beam.beam().at(i)->GetBeamIndex(), i);
  538. }
  539. // Validate the history at this step.
  540. auto history = beam.history();
  541. EXPECT_EQ(history.at(1).at(0), 1);
  542. EXPECT_EQ(history.at(1).at(1), 0);
  543. EXPECT_EQ(history.at(1).at(2), 1);
  544. EXPECT_EQ(history.at(1).at(3), 0);
  545. EXPECT_EQ(history.at(1).at(4), 1);
  546. EXPECT_EQ(history.at(1).at(5), 0);
  547. EXPECT_EQ(history.at(1).at(6), 1);
  548. EXPECT_EQ(history.at(1).at(7), 0);
  549. EXPECT_EQ(history.at(0).at(0), kParentBeamIndices[0]);
  550. EXPECT_EQ(history.at(0).at(1), kParentBeamIndices[1]);
  551. EXPECT_EQ(history.at(0).at(2), -1);
  552. EXPECT_EQ(history.at(0).at(3), -1);
  553. EXPECT_EQ(history.at(0).at(4), -1);
  554. EXPECT_EQ(history.at(0).at(5), -1);
  555. EXPECT_EQ(history.at(0).at(6), -1);
  556. EXPECT_EQ(history.at(0).at(7), -1);
  557. // Make sure that FindPreviousIndex can read through the history from step 1
  558. // to step 0.
  559. constexpr int kDesiredIndex = 0;
  560. constexpr int kCurrentIndexOne = 4;
  561. EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexOne, kDesiredIndex),
  562. kParentBeamIndices[1]);
  563. constexpr int kCurrentIndexTwo = 7;
  564. EXPECT_EQ(beam.FindPreviousIndex(kCurrentIndexTwo, kDesiredIndex),
  565. kParentBeamIndices[0]);
  566. }
  567. TEST(BeamTest, FindPreviousIndexReturnsInError) {
  568. // Create the beam. This now has only one history state, 0.
  569. std::vector<std::unique_ptr<TestTransitionState>> states;
  570. states.push_back(CreateState(0.0));
  571. constexpr int kMaxBeamSize = 8;
  572. Beam<TestTransitionState> beam(kMaxBeamSize);
  573. beam.SetFunctions(null_permissions, null_finality, transition_function,
  574. null_oracle);
  575. beam.Init(std::move(states));
  576. // If the requested step is greater than the number of steps taken, expect -1.
  577. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
  578. // If the requested step is less than 0, expect -1.
  579. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
  580. // If the requested index does not have a state, expect -1.
  581. EXPECT_EQ(beam.FindPreviousIndex(0, 1), -1);
  582. // If the requested index is less than 0, expect -1.
  583. EXPECT_EQ(beam.FindPreviousIndex(0, -1), -1);
  584. // If the requested index is larger than the maximum beam size -1, expect -1.
  585. EXPECT_EQ(beam.FindPreviousIndex(0, kMaxBeamSize), -1);
  586. }
  587. TEST(BeamTest, ResetClearsBeamState) {
  588. // Create the beam
  589. std::vector<std::unique_ptr<TestTransitionState>> states;
  590. states.push_back(CreateState(1.0));
  591. constexpr int kMaxBeamSize = 8;
  592. Beam<TestTransitionState> beam(kMaxBeamSize);
  593. beam.SetFunctions(null_permissions, null_finality, transition_function,
  594. null_oracle);
  595. beam.Init(std::move(states));
  596. // Validate the new beam.
  597. EXPECT_EQ(beam.beam().size(), 1);
  598. // Reset the beam.
  599. beam.Reset();
  600. // Validate the now-reset beam, which should be empty.
  601. EXPECT_EQ(beam.beam().size(), 0);
  602. }
  603. TEST(BeamTest, ResetClearsBeamHistory) {
  604. // Create the beam
  605. std::vector<std::unique_ptr<TestTransitionState>> states;
  606. states.push_back(CreateState(1.0));
  607. constexpr int kMaxBeamSize = 8;
  608. Beam<TestTransitionState> beam(kMaxBeamSize);
  609. beam.SetFunctions(null_permissions, null_finality, transition_function,
  610. null_oracle);
  611. beam.Init(std::move(states));
  612. // Validate the new beam history.
  613. EXPECT_EQ(beam.history().size(), 1);
  614. // Reset the beam.
  615. beam.Reset();
  616. // Validate the now-reset beam history, which should be empty.
  617. EXPECT_EQ(beam.history().size(), 0);
  618. }
  619. TEST(BeamTest, SettingMaxSizeResetsBeam) {
  620. // Create the beam
  621. std::vector<std::unique_ptr<TestTransitionState>> states;
  622. states.push_back(CreateState(1.0));
  623. constexpr int kMaxBeamSize = 8;
  624. Beam<TestTransitionState> beam(kMaxBeamSize);
  625. beam.SetFunctions(null_permissions, null_finality, transition_function,
  626. null_oracle);
  627. beam.Init(std::move(states));
  628. // Validate the new beam history.
  629. EXPECT_EQ(beam.history().size(), 1);
  630. // Reset the beam.
  631. constexpr int kNewMaxBeamSize = 4;
  632. beam.SetMaxSize(kNewMaxBeamSize);
  633. EXPECT_EQ(beam.max_size(), kNewMaxBeamSize);
  634. // Validate the now-reset beam history, which should be empty.
  635. EXPECT_EQ(beam.history().size(), 0);
  636. }
  637. } // namespace dragnn
  638. } // namespace syntaxnet