syntaxnet_component_test.cc 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/components/syntaxnet/syntaxnet_component.h"
  16. #include "dragnn/core/input_batch_cache.h"
  17. #include "dragnn/core/test/generic.h"
  18. #include "dragnn/core/test/mock_transition_state.h"
  19. #include "dragnn/io/sentence_input_batch.h"
  20. #include "syntaxnet/sentence.pb.h"
  21. #include "tensorflow/core/lib/core/errors.h"
  22. #include "tensorflow/core/lib/core/status.h"
  23. #include "tensorflow/core/lib/io/path.h"
  24. #include "tensorflow/core/platform/env.h"
  25. #include "tensorflow/core/platform/protobuf.h"
  26. #include "tensorflow/core/platform/test.h"
  27. // This test suite is intended to validate the contracts that the DRAGNN
  28. // system expects from all transition state subclasses. Developers creating
  29. // new TransitionStates should copy this test and modify it as necessary,
  30. // using it to ensure their state conforms to DRAGNN expectations.
  31. namespace syntaxnet {
  32. namespace dragnn {
  33. namespace {
  34. const char kSentence0[] = R"(
  35. token {
  36. word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  37. break_level: NO_BREAK
  38. }
  39. token {
  40. word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  41. break_level: SPACE_BREAK
  42. }
  43. token {
  44. word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
  45. break_level: NO_BREAK
  46. }
  47. )";
  48. const char kSentence1[] = R"(
  49. token {
  50. word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  51. break_level: NO_BREAK
  52. }
  53. token {
  54. word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  55. break_level: SPACE_BREAK
  56. }
  57. token {
  58. word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
  59. break_level: NO_BREAK
  60. }
  61. )";
  62. const char kLongSentence[] = R"(
  63. token {
  64. word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  65. break_level: NO_BREAK
  66. }
  67. token {
  68. word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  69. break_level: SPACE_BREAK
  70. }
  71. token {
  72. word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
  73. break_level: SPACE_BREAK
  74. }
  75. token {
  76. word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
  77. break_level: SPACE_BREAK
  78. }
  79. token {
  80. word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
  81. break_level: NO_BREAK
  82. }
  83. )";
  84. } // namespace
  85. using testing::Return;
  86. class SyntaxNetComponentTest : public ::testing::Test {
  87. public:
  88. std::unique_ptr<SyntaxNetComponent> CreateParser(
  89. const std::vector<std::vector<const TransitionState *>> &states,
  90. const std::vector<string> &data) {
  91. constexpr int kBeamSize = 2;
  92. return CreateParserWithBeamSize(kBeamSize, states, data);
  93. }
  94. std::unique_ptr<SyntaxNetComponent> CreateParserWithBeamSize(
  95. int beam_size,
  96. const std::vector<std::vector<const TransitionState *>> &states,
  97. const std::vector<string> &data) {
  98. // Get the master spec proto from the test data directory.
  99. MasterSpec master_spec;
  100. string file_name = tensorflow::io::JoinPath(
  101. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  102. "master_spec.textproto");
  103. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  104. &master_spec));
  105. // Get all the resource protos from the test data directory.
  106. for (Resource &resource :
  107. *(master_spec.mutable_component(0)->mutable_resource())) {
  108. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  109. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  110. resource.part(0).file_pattern()));
  111. }
  112. data_.reset(new InputBatchCache(data));
  113. // Create a parser component with the specified beam size.
  114. std::unique_ptr<SyntaxNetComponent> parser_component(
  115. new SyntaxNetComponent());
  116. parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
  117. parser_component->InitializeData(states, beam_size, data_.get());
  118. return parser_component;
  119. }
  120. const std::vector<Beam<SyntaxNetTransitionState> *> GetBeams(
  121. SyntaxNetComponent *component) const {
  122. std::vector<Beam<SyntaxNetTransitionState> *> return_vector;
  123. for (const auto &beam : component->batch_) {
  124. return_vector.push_back(beam.get());
  125. }
  126. return return_vector;
  127. }
  128. std::unique_ptr<InputBatchCache> data_;
  129. };
  130. TEST_F(SyntaxNetComponentTest, AdvancesFromOracleAndTerminates) {
  131. // Create an empty input batch and beam vector to initialize the parser.
  132. Sentence sentence_0;
  133. TextFormat::ParseFromString(kSentence0, &sentence_0);
  134. string sentence_0_str;
  135. sentence_0.SerializeToString(&sentence_0_str);
  136. auto test_parser = CreateParser({}, {sentence_0_str});
  137. constexpr int kNumTokensInSentence = 3;
  138. // The master spec will initialize a parser, so expect 2*N transitions.
  139. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  140. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  141. EXPECT_FALSE(test_parser->IsTerminal());
  142. test_parser->AdvanceFromOracle();
  143. }
  144. // At this point, the test parser should be terminal.
  145. EXPECT_TRUE(test_parser->IsTerminal());
  146. // Check that the component is reporting 2N steps taken.
  147. EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
  148. // Make sure the parser doesn't segfault.
  149. test_parser->FinalizeData();
  150. }
  151. TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
  152. // Create an empty input batch and beam vector to initialize the parser.
  153. Sentence sentence_0;
  154. TextFormat::ParseFromString(kSentence0, &sentence_0);
  155. string sentence_0_str;
  156. sentence_0.SerializeToString(&sentence_0_str);
  157. auto test_parser = CreateParser({}, {sentence_0_str});
  158. constexpr int kNumTokensInSentence = 3;
  159. // The master spec will initialize a parser, so expect 2*N transitions.
  160. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  161. // There are 93 possible transitions for any given state. Create a transition
  162. // array with a score of 10.0 for each transition.
  163. constexpr int kBeamSize = 2;
  164. constexpr int kNumPossibleTransitions = 93;
  165. constexpr float kTransitionValue = 10.0;
  166. float transition_matrix[kNumPossibleTransitions * kBeamSize];
  167. for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
  168. transition_matrix[i] = kTransitionValue;
  169. }
  170. // Transition the expected number of times.
  171. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  172. EXPECT_FALSE(test_parser->IsTerminal());
  173. test_parser->AdvanceFromPrediction(transition_matrix,
  174. kNumPossibleTransitions * kBeamSize);
  175. }
  176. // At this point, the test parser should be terminal.
  177. EXPECT_TRUE(test_parser->IsTerminal());
  178. // Check that the component is reporting 2N steps taken.
  179. EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
  180. // Prepare to validate the batched beams.
  181. auto beam = test_parser->GetBeam();
  182. // All beams should only have one element.
  183. for (const auto &per_beam : beam) {
  184. EXPECT_EQ(per_beam.size(), 1);
  185. }
  186. // The final states should have kExpectedNumTransitions * kTransitionValue.
  187. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  188. kTransitionValue * kExpectedNumTransitions);
  189. // Make sure the parser doesn't segfault.
  190. test_parser->FinalizeData();
  191. // TODO(googleuser): What should the finalized data look like?
  192. }
  193. TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
  194. // Create and initialize the state->
  195. MockTransitionState mock_state_one;
  196. constexpr int kParentBeamIndexOne = 1138;
  197. constexpr float kParentScoreOne = 7.2;
  198. EXPECT_CALL(mock_state_one, GetBeamIndex())
  199. .WillRepeatedly(Return(kParentBeamIndexOne));
  200. EXPECT_CALL(mock_state_one, GetScore())
  201. .WillRepeatedly(Return(kParentScoreOne));
  202. MockTransitionState mock_state_two;
  203. constexpr int kParentBeamIndexTwo = 1123;
  204. constexpr float kParentScoreTwo = 42.03;
  205. EXPECT_CALL(mock_state_two, GetBeamIndex())
  206. .WillRepeatedly(Return(kParentBeamIndexTwo));
  207. EXPECT_CALL(mock_state_two, GetScore())
  208. .WillRepeatedly(Return(kParentScoreTwo));
  209. // Create an empty input batch and beam vector to initialize the parser.
  210. Sentence sentence_0;
  211. TextFormat::ParseFromString(kSentence0, &sentence_0);
  212. string sentence_0_str;
  213. sentence_0.SerializeToString(&sentence_0_str);
  214. auto test_parser =
  215. CreateParser({{&mock_state_one, &mock_state_two}}, {sentence_0_str});
  216. constexpr int kNumTokensInSentence = 3;
  217. // The master spec will initialize a parser, so expect 2*N transitions.
  218. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  219. // There are 93 possible transitions for any given state. Create a transition
  220. // array with a score of 10.0 for each transition.
  221. constexpr int kBeamSize = 2;
  222. constexpr int kNumPossibleTransitions = 93;
  223. constexpr float kTransitionValue = 10.0;
  224. float transition_matrix[kNumPossibleTransitions * kBeamSize];
  225. for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
  226. transition_matrix[i] = kTransitionValue;
  227. }
  228. // Transition the expected number of times
  229. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  230. EXPECT_FALSE(test_parser->IsTerminal());
  231. test_parser->AdvanceFromPrediction(transition_matrix,
  232. kNumPossibleTransitions * kBeamSize);
  233. }
  234. // At this point, the test parser should be terminal.
  235. EXPECT_TRUE(test_parser->IsTerminal());
  236. // Check that the component is reporting 2N steps taken.
  237. EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
  238. // The final states should have kExpectedNumTransitions * kTransitionValue,
  239. // plus the higher parent state score (from state two).
  240. auto beam = test_parser->GetBeam();
  241. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  242. kTransitionValue * kExpectedNumTransitions + kParentScoreTwo);
  243. // Make sure that the parent state is reported correctly.
  244. EXPECT_EQ(test_parser->GetSourceBeamIndex(0, 0), kParentBeamIndexTwo);
  245. // Make sure the parser doesn't segfault.
  246. test_parser->FinalizeData();
  247. // TODO(googleuser): What should the finalized data look like?
  248. }
  249. TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
  250. // Create an empty input batch and beam vector to initialize the parser.
  251. Sentence sentence_0;
  252. TextFormat::ParseFromString(kSentence0, &sentence_0);
  253. string sentence_0_str;
  254. sentence_0.SerializeToString(&sentence_0_str);
  255. Sentence sentence_1;
  256. TextFormat::ParseFromString(kSentence1, &sentence_1);
  257. string sentence_1_str;
  258. sentence_1.SerializeToString(&sentence_1_str);
  259. auto test_parser = CreateParser({}, {sentence_0_str, sentence_1_str});
  260. constexpr int kNumTokensInSentence = 3;
  261. // The master spec will initialize a parser, so expect 2*N transitions.
  262. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  263. // There are 93 possible transitions for any given state. Create a transition
  264. // array with a score of 10.0 for each transition.
  265. constexpr int kBatchSize = 2;
  266. constexpr int kBeamSize = 2;
  267. constexpr int kNumPossibleTransitions = 93;
  268. constexpr float kTransitionValue = 10.0;
  269. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  270. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  271. transition_matrix[i] = kTransitionValue;
  272. }
  273. // Transition the expected number of times.
  274. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  275. EXPECT_FALSE(test_parser->IsTerminal());
  276. test_parser->AdvanceFromPrediction(
  277. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  278. }
  279. // At this point, the test parser should be terminal.
  280. EXPECT_TRUE(test_parser->IsTerminal());
  281. // Check that the component is reporting 2N steps taken.
  282. EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
  283. EXPECT_EQ(test_parser->StepsTaken(1), kExpectedNumTransitions);
  284. // The final states should have kExpectedNumTransitions * kTransitionValue.
  285. auto beam = test_parser->GetBeam();
  286. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  287. kTransitionValue * kExpectedNumTransitions);
  288. EXPECT_EQ(beam.at(1).at(0)->GetScore(),
  289. kTransitionValue * kExpectedNumTransitions);
  290. // Make sure the parser doesn't segfault.
  291. test_parser->FinalizeData();
  292. // TODO(googleuser): What should the finalized data look like?
  293. }
  294. TEST_F(SyntaxNetComponentTest,
  295. AdvancesFromPredictionForVaryingLengthSentences) {
  296. // Create an empty input batch and beam vector to initialize the parser.
  297. Sentence sentence_0;
  298. TextFormat::ParseFromString(kSentence0, &sentence_0);
  299. string sentence_0_str;
  300. sentence_0.SerializeToString(&sentence_0_str);
  301. Sentence long_sentence;
  302. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  303. string long_sentence_str;
  304. long_sentence.SerializeToString(&long_sentence_str);
  305. auto test_parser = CreateParser({}, {sentence_0_str, long_sentence_str});
  306. constexpr int kNumTokensInSentence = 3;
  307. constexpr int kNumTokensInLongSentence = 5;
  308. // There are 93 possible transitions for any given state. Create a transition
  309. // array with a score of 10.0 for each transition.
  310. constexpr int kBatchSize = 2;
  311. constexpr int kBeamSize = 2;
  312. constexpr int kNumPossibleTransitions = 93;
  313. constexpr float kTransitionValue = 10.0;
  314. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  315. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  316. transition_matrix[i] = kTransitionValue;
  317. }
  318. // Transition the expected number of times.
  319. constexpr int kExpectedNumTransitions = kNumTokensInLongSentence * 2;
  320. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  321. EXPECT_FALSE(test_parser->IsTerminal());
  322. test_parser->AdvanceFromPrediction(
  323. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  324. }
  325. // At this point, the test parser should be terminal.
  326. EXPECT_TRUE(test_parser->IsTerminal());
  327. // Check that the component is reporting 2N steps taken.
  328. EXPECT_EQ(test_parser->StepsTaken(0), kNumTokensInSentence * 2);
  329. EXPECT_EQ(test_parser->StepsTaken(1), kNumTokensInLongSentence * 2);
  330. // The final states should have kExpectedNumTransitions * kTransitionValue.
  331. auto beam = test_parser->GetBeam();
  332. // The first sentence is shorter, so it should have a lower final score.
  333. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  334. kTransitionValue * kNumTokensInSentence * 2);
  335. EXPECT_EQ(beam.at(1).at(0)->GetScore(),
  336. kTransitionValue * kNumTokensInLongSentence * 2);
  337. // Make sure the parser doesn't segfault.
  338. test_parser->FinalizeData();
  339. // TODO(googleuser): What should the finalized data look like?
  340. }
  341. TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
  342. // Create an empty input batch and beam vector to initialize the parser.
  343. Sentence sentence_0;
  344. TextFormat::ParseFromString(kSentence0, &sentence_0);
  345. string sentence_0_str;
  346. sentence_0.SerializeToString(&sentence_0_str);
  347. Sentence long_sentence;
  348. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  349. string long_sentence_str;
  350. long_sentence.SerializeToString(&long_sentence_str);
  351. // Get the master spec proto from the test data directory.
  352. MasterSpec master_spec;
  353. string file_name = tensorflow::io::JoinPath(
  354. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  355. "master_spec.textproto");
  356. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  357. &master_spec));
  358. // Get all the resource protos from the test data directory.
  359. for (Resource &resource :
  360. *(master_spec.mutable_component(0)->mutable_resource())) {
  361. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  362. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  363. resource.part(0).file_pattern()));
  364. }
  365. // Create an input batch cache with a large batch size.
  366. constexpr int kBeamSize = 2;
  367. std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
  368. {sentence_0_str, sentence_0_str, sentence_0_str, sentence_0_str}));
  369. std::unique_ptr<SyntaxNetComponent> parser_component(
  370. new SyntaxNetComponent());
  371. parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
  372. parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
  373. // Reset the component and pass in a new input batch that is smaller.
  374. parser_component->ResetComponent();
  375. std::unique_ptr<InputBatchCache> small_batch_data(new InputBatchCache(
  376. {long_sentence_str, long_sentence_str, long_sentence_str}));
  377. parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
  378. // There are 93 possible transitions for any given state. Create a transition
  379. // array with a score of 10.0 for each transition.
  380. constexpr int kBatchSize = 3;
  381. constexpr int kNumPossibleTransitions = 93;
  382. constexpr float kTransitionValue = 10.0;
  383. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  384. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  385. transition_matrix[i] = kTransitionValue;
  386. }
  387. // Transition the expected number of times.
  388. constexpr int kNumTokensInSentence = 5;
  389. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  390. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  391. EXPECT_FALSE(parser_component->IsTerminal());
  392. parser_component->AdvanceFromPrediction(
  393. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  394. }
  395. // At this point, the test parser should be terminal.
  396. EXPECT_TRUE(parser_component->IsTerminal());
  397. // Check that the component is reporting 2N steps taken.
  398. EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
  399. EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
  400. EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
  401. // The final states should have kExpectedNumTransitions * kTransitionValue.
  402. auto beam = parser_component->GetBeam();
  403. // The beam should be of batch size 3.
  404. EXPECT_EQ(beam.size(), 3);
  405. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  406. kTransitionValue * kExpectedNumTransitions);
  407. EXPECT_EQ(beam.at(1).at(0)->GetScore(),
  408. kTransitionValue * kExpectedNumTransitions);
  409. EXPECT_EQ(beam.at(2).at(0)->GetScore(),
  410. kTransitionValue * kExpectedNumTransitions);
  411. // Make sure the parser doesn't segfault.
  412. parser_component->FinalizeData();
  413. }
  414. TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
  415. // Create an empty input batch and beam vector to initialize the parser.
  416. Sentence sentence_0;
  417. TextFormat::ParseFromString(kSentence0, &sentence_0);
  418. string sentence_0_str;
  419. sentence_0.SerializeToString(&sentence_0_str);
  420. Sentence long_sentence;
  421. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  422. string long_sentence_str;
  423. long_sentence.SerializeToString(&long_sentence_str);
  424. // Get the master spec proto from the test data directory.
  425. MasterSpec master_spec;
  426. string file_name = tensorflow::io::JoinPath(
  427. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  428. "master_spec.textproto");
  429. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  430. &master_spec));
  431. // Get all the resource protos from the test data directory.
  432. for (Resource &resource :
  433. *(master_spec.mutable_component(0)->mutable_resource())) {
  434. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  435. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  436. resource.part(0).file_pattern()));
  437. }
  438. // Create an input batch cache with a small batch size.
  439. constexpr int kBeamSize = 2;
  440. std::unique_ptr<InputBatchCache> small_batch_data(
  441. new InputBatchCache(sentence_0_str));
  442. std::unique_ptr<SyntaxNetComponent> parser_component(
  443. new SyntaxNetComponent());
  444. parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
  445. parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
  446. // Reset the component and pass in a new input batch that is larger.
  447. parser_component->ResetComponent();
  448. std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
  449. {long_sentence_str, long_sentence_str, long_sentence_str}));
  450. parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
  451. // There are 93 possible transitions for any given state. Create a transition
  452. // array with a score of 10.0 for each transition.
  453. constexpr int kBatchSize = 3;
  454. constexpr int kNumPossibleTransitions = 93;
  455. constexpr float kTransitionValue = 10.0;
  456. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  457. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  458. transition_matrix[i] = kTransitionValue;
  459. }
  460. // Transition the expected number of times.
  461. constexpr int kNumTokensInSentence = 5;
  462. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  463. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  464. EXPECT_FALSE(parser_component->IsTerminal());
  465. parser_component->AdvanceFromPrediction(
  466. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  467. }
  468. // At this point, the test parser should be terminal.
  469. EXPECT_TRUE(parser_component->IsTerminal());
  470. // Check that the component is reporting 2N steps taken.
  471. EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
  472. EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
  473. EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
  474. // The final states should have kExpectedNumTransitions * kTransitionValue.
  475. auto beam = parser_component->GetBeam();
  476. // The beam should be of batch size 3.
  477. EXPECT_EQ(beam.size(), 3);
  478. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  479. kTransitionValue * kExpectedNumTransitions);
  480. EXPECT_EQ(beam.at(1).at(0)->GetScore(),
  481. kTransitionValue * kExpectedNumTransitions);
  482. EXPECT_EQ(beam.at(2).at(0)->GetScore(),
  483. kTransitionValue * kExpectedNumTransitions);
  484. // Make sure the parser doesn't segfault.
  485. parser_component->FinalizeData();
  486. }
  487. TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
  488. // Create an empty input batch and beam vector to initialize the parser.
  489. Sentence sentence_0;
  490. TextFormat::ParseFromString(kSentence0, &sentence_0);
  491. string sentence_0_str;
  492. sentence_0.SerializeToString(&sentence_0_str);
  493. Sentence long_sentence;
  494. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  495. string long_sentence_str;
  496. long_sentence.SerializeToString(&long_sentence_str);
  497. auto test_parser = CreateParser({}, {sentence_0_str});
  498. constexpr int kNumTokensInSentence = 3;
  499. // The master spec will initialize a parser, so expect 2*N transitions.
  500. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  501. // There are 93 possible transitions for any given state. Create a transition
  502. // array with a score of 10.0 for each transition.
  503. constexpr int kBeamSize = 2;
  504. constexpr int kNumPossibleTransitions = 93;
  505. constexpr float kTransitionValue = 10.0;
  506. float transition_matrix[kNumPossibleTransitions * kBeamSize];
  507. for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
  508. transition_matrix[i] = kTransitionValue;
  509. }
  510. // Transition the expected number of times.
  511. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  512. EXPECT_FALSE(test_parser->IsTerminal());
  513. test_parser->AdvanceFromPrediction(transition_matrix,
  514. kNumPossibleTransitions * kBeamSize);
  515. }
  516. // At this point, the test parser should be terminal.
  517. EXPECT_TRUE(test_parser->IsTerminal());
  518. // Check that the component is reporting 2N steps taken.
  519. EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
  520. // The final states should have kExpectedNumTransitions * kTransitionValue.
  521. auto beam = test_parser->GetBeam();
  522. EXPECT_EQ(beam.at(0).at(0)->GetScore(),
  523. kTransitionValue * kExpectedNumTransitions);
  524. // Reset the test parser and give it new data.
  525. test_parser->ResetComponent();
  526. std::unique_ptr<InputBatchCache> new_data(
  527. new InputBatchCache(long_sentence_str));
  528. test_parser->InitializeData({}, kBeamSize, new_data.get());
  529. // Check that the component is not terminal.
  530. EXPECT_FALSE(test_parser->IsTerminal());
  531. // Check that the component is reporting 0 steps taken.
  532. EXPECT_EQ(test_parser->StepsTaken(0), 0);
  533. // The states should have 0 as their score.
  534. auto new_beam = test_parser->GetBeam();
  535. EXPECT_EQ(new_beam.at(0).at(0)->GetScore(), 0);
  536. }
  537. TEST_F(SyntaxNetComponentTest, AdjustingMaxBeamSizeAdjustsSizeForAllBeams) {
  538. // Create an empty input batch and beam vector to initialize the parser.
  539. Sentence sentence_0;
  540. TextFormat::ParseFromString(kSentence0, &sentence_0);
  541. string sentence_0_str;
  542. sentence_0.SerializeToString(&sentence_0_str);
  543. Sentence long_sentence;
  544. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  545. string long_sentence_str;
  546. long_sentence.SerializeToString(&long_sentence_str);
  547. // Get the master spec proto from the test data directory.
  548. MasterSpec master_spec;
  549. string file_name = tensorflow::io::JoinPath(
  550. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  551. "master_spec.textproto");
  552. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  553. &master_spec));
  554. // Get all the resource protos from the test data directory.
  555. for (Resource &resource :
  556. *(master_spec.mutable_component(0)->mutable_resource())) {
  557. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  558. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  559. resource.part(0).file_pattern()));
  560. }
  561. // Create an input batch cache with a small batch size.
  562. constexpr int kBeamSize = 2;
  563. std::unique_ptr<InputBatchCache> small_batch_data(
  564. new InputBatchCache(sentence_0_str));
  565. std::unique_ptr<SyntaxNetComponent> parser_component(
  566. new SyntaxNetComponent());
  567. parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
  568. parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
  569. // Make sure all the beams in the batch have max size 2.
  570. for (const auto &beam : GetBeams(parser_component.get())) {
  571. EXPECT_EQ(beam->max_size(), kBeamSize);
  572. }
  573. // Reset the component and pass in a new input batch that is larger, with
  574. // a higher beam size.
  575. constexpr int kNewBeamSize = 5;
  576. parser_component->ResetComponent();
  577. std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
  578. {long_sentence_str, long_sentence_str, long_sentence_str}));
  579. parser_component->InitializeData({}, kNewBeamSize, large_batch_data.get());
  580. // Make sure all the beams in the batch now have max size 5.
  581. for (const auto &beam : GetBeams(parser_component.get())) {
  582. EXPECT_EQ(beam->max_size(), kNewBeamSize);
  583. }
  584. }
  585. TEST_F(SyntaxNetComponentTest, SettingBeamSizeZeroFails) {
  586. // Create an empty input batch and beam vector to initialize the parser.
  587. Sentence sentence_0;
  588. TextFormat::ParseFromString(kSentence0, &sentence_0);
  589. string sentence_0_str;
  590. sentence_0.SerializeToString(&sentence_0_str);
  591. Sentence long_sentence;
  592. TextFormat::ParseFromString(kLongSentence, &long_sentence);
  593. string long_sentence_str;
  594. long_sentence.SerializeToString(&long_sentence_str);
  595. // Get the master spec proto from the test data directory.
  596. MasterSpec master_spec;
  597. string file_name = tensorflow::io::JoinPath(
  598. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  599. "master_spec.textproto");
  600. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  601. &master_spec));
  602. // Get all the resource protos from the test data directory.
  603. for (Resource &resource :
  604. *(master_spec.mutable_component(0)->mutable_resource())) {
  605. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  606. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  607. resource.part(0).file_pattern()));
  608. }
  609. // Create an input batch cache with a small batch size.
  610. constexpr int kBeamSize = 0;
  611. std::unique_ptr<InputBatchCache> small_batch_data(
  612. new InputBatchCache(sentence_0_str));
  613. std::unique_ptr<SyntaxNetComponent> parser_component(
  614. new SyntaxNetComponent());
  615. parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
  616. EXPECT_DEATH(
  617. parser_component->InitializeData({}, kBeamSize, small_batch_data.get()),
  618. "must be greater than 0");
  619. }
  620. TEST_F(SyntaxNetComponentTest, ExportsFixedFeaturesWithPadding) {
  621. // Create an empty input batch and beam vector to initialize the parser.
  622. Sentence sentence_0;
  623. TextFormat::ParseFromString(kSentence0, &sentence_0);
  624. string sentence_0_str;
  625. sentence_0.SerializeToString(&sentence_0_str);
  626. Sentence sentence_1;
  627. TextFormat::ParseFromString(kSentence1, &sentence_1);
  628. string sentence_1_str;
  629. sentence_1.SerializeToString(&sentence_1_str);
  630. constexpr int kBeamSize = 3;
  631. auto test_parser =
  632. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  633. // Get and check the raw link features.
  634. vector<int32> indices;
  635. auto indices_fn = [&indices](int size) {
  636. indices.resize(size);
  637. return indices.data();
  638. };
  639. vector<int64> ids;
  640. auto ids_fn = [&ids](int size) {
  641. ids.resize(size);
  642. return ids.data();
  643. };
  644. vector<float> weights;
  645. auto weights_fn = [&weights](int size) {
  646. weights.resize(size);
  647. return weights.data();
  648. };
  649. constexpr int kChannelId = 0;
  650. const int num_features =
  651. test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
  652. // The raw features for each beam object should be [single, single].
  653. // There is also padding expected in this beam - there is only one
  654. // element in each beam (so two elements total; batch is two). Thus, we expect
  655. // 0,1 and 6,7 to be filled with one element each.
  656. constexpr int kExpectedOutputSize = 4;
  657. const vector<int32> expected_indices({0, 1, 6, 7});
  658. const vector<int64> expected_ids({0, 12, 0, 12});
  659. const vector<float> expected_weights({1.0, 1.0, 1.0, 1.0});
  660. EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
  661. EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
  662. EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
  663. EXPECT_EQ(num_features, kExpectedOutputSize);
  664. EXPECT_EQ(expected_indices, indices);
  665. EXPECT_EQ(expected_ids, ids);
  666. EXPECT_EQ(expected_weights, weights);
  667. }
  668. TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
  669. // Create an empty input batch and beam vector to initialize the parser.
  670. Sentence sentence_0;
  671. TextFormat::ParseFromString(kSentence0, &sentence_0);
  672. string sentence_0_str;
  673. sentence_0.SerializeToString(&sentence_0_str);
  674. Sentence sentence_1;
  675. TextFormat::ParseFromString(kSentence1, &sentence_1);
  676. string sentence_1_str;
  677. sentence_1.SerializeToString(&sentence_1_str);
  678. constexpr int kBeamSize = 3;
  679. auto test_parser =
  680. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  681. // There are 93 possible transitions for any given state. Create a transition
  682. // array with a score of 10.0 for each transition.
  683. constexpr int kBatchSize = 2;
  684. constexpr int kNumPossibleTransitions = 93;
  685. constexpr float kTransitionValue = 10.0;
  686. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  687. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  688. transition_matrix[i] = kTransitionValue;
  689. }
  690. // Advance twice, so that the underlying parser fills the beam.
  691. test_parser->AdvanceFromPrediction(
  692. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  693. test_parser->AdvanceFromPrediction(
  694. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  695. // Get and check the raw link features.
  696. vector<int32> indices;
  697. auto indices_fn = [&indices](int size) {
  698. indices.resize(size);
  699. return indices.data();
  700. };
  701. vector<int64> ids;
  702. auto ids_fn = [&ids](int size) {
  703. ids.resize(size);
  704. return ids.data();
  705. };
  706. vector<float> weights;
  707. auto weights_fn = [&weights](int size) {
  708. weights.resize(size);
  709. return weights.data();
  710. };
  711. constexpr int kChannelId = 0;
  712. const int num_features =
  713. test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
  714. constexpr int kExpectedOutputSize = 12;
  715. const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
  716. const vector<int64> expected_ids({7, 50, 12, 7, 12, 7, 7, 50, 12, 7, 12, 7});
  717. const vector<float> expected_weights(
  718. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
  719. EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
  720. EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
  721. EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
  722. EXPECT_EQ(num_features, kExpectedOutputSize);
  723. EXPECT_EQ(expected_indices, indices);
  724. EXPECT_EQ(expected_ids, ids);
  725. EXPECT_EQ(expected_weights, weights);
  726. }
  727. TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
  728. // Create an empty input batch and beam vector to initialize the parser.
  729. Sentence sentence_0;
  730. TextFormat::ParseFromString(kSentence0, &sentence_0);
  731. string sentence_0_str;
  732. sentence_0.SerializeToString(&sentence_0_str);
  733. Sentence sentence_1;
  734. TextFormat::ParseFromString(kSentence1, &sentence_1);
  735. string sentence_1_str;
  736. sentence_1.SerializeToString(&sentence_1_str);
  737. constexpr int kBeamSize = 3;
  738. auto test_parser =
  739. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  740. // There are 93 possible transitions for any given state. Create a transition
  741. // array with a score of 10.0 for each transition.
  742. constexpr int kBatchSize = 2;
  743. constexpr int kNumPossibleTransitions = 93;
  744. constexpr float kTransitionValue = 10.0;
  745. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  746. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  747. transition_matrix[i] = kTransitionValue;
  748. }
  749. // Replace the first several options with varying scores to test sorting.
  750. constexpr int kBatchOffset = kNumPossibleTransitions * kBeamSize;
  751. transition_matrix[0] = 3 * kTransitionValue;
  752. transition_matrix[1] = 3 * kTransitionValue;
  753. transition_matrix[2] = 4 * kTransitionValue;
  754. transition_matrix[3] = 4 * kTransitionValue;
  755. transition_matrix[4] = 2 * kTransitionValue;
  756. transition_matrix[5] = 2 * kTransitionValue;
  757. transition_matrix[kBatchOffset + 0] = 3 * kTransitionValue;
  758. transition_matrix[kBatchOffset + 1] = 3 * kTransitionValue;
  759. transition_matrix[kBatchOffset + 2] = 4 * kTransitionValue;
  760. transition_matrix[kBatchOffset + 3] = 4 * kTransitionValue;
  761. transition_matrix[kBatchOffset + 4] = 2 * kTransitionValue;
  762. transition_matrix[kBatchOffset + 5] = 2 * kTransitionValue;
  763. // Advance twice, so that the underlying parser fills the beam.
  764. test_parser->AdvanceFromPrediction(
  765. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  766. test_parser->AdvanceFromPrediction(
  767. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  768. // Get and check the raw link features.
  769. vector<int32> indices;
  770. auto indices_fn = [&indices](int size) {
  771. indices.resize(size);
  772. return indices.data();
  773. };
  774. vector<int64> ids;
  775. auto ids_fn = [&ids](int size) {
  776. ids.resize(size);
  777. return ids.data();
  778. };
  779. vector<float> weights;
  780. auto weights_fn = [&weights](int size) {
  781. weights.resize(size);
  782. return weights.data();
  783. };
  784. constexpr int kChannelId = 0;
  785. const int num_features =
  786. test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
  787. // In this case, all even features and all odd features are identical.
  788. constexpr int kExpectedOutputSize = 12;
  789. const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
  790. const vector<int64> expected_ids({12, 7, 7, 50, 12, 7, 12, 7, 7, 50, 12, 7});
  791. const vector<float> expected_weights(
  792. {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
  793. EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
  794. EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
  795. EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
  796. EXPECT_EQ(num_features, kExpectedOutputSize);
  797. EXPECT_EQ(expected_indices, indices);
  798. EXPECT_EQ(expected_ids, ids);
  799. EXPECT_EQ(expected_weights, weights);
  800. }
  801. TEST_F(SyntaxNetComponentTest, ExportsBulkFixedFeatures) {
  802. // Create an empty input batch and beam vector to initialize the parser.
  803. Sentence sentence_0;
  804. TextFormat::ParseFromString(kSentence0, &sentence_0);
  805. string sentence_0_str;
  806. sentence_0.SerializeToString(&sentence_0_str);
  807. Sentence sentence_1;
  808. TextFormat::ParseFromString(kSentence1, &sentence_1);
  809. string sentence_1_str;
  810. sentence_1.SerializeToString(&sentence_1_str);
  811. constexpr int kBeamSize = 3;
  812. auto test_parser =
  813. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  814. // Get and check the raw link features.
  815. vector<vector<int32>> indices;
  816. auto indices_fn = [&indices](int channel, int size) {
  817. indices.resize(channel + 1);
  818. indices[channel].resize(size);
  819. return indices[channel].data();
  820. };
  821. vector<vector<int64>> ids;
  822. auto ids_fn = [&ids](int channel, int size) {
  823. ids.resize(channel + 1);
  824. ids[channel].resize(size);
  825. return ids[channel].data();
  826. };
  827. vector<vector<float>> weights;
  828. auto weights_fn = [&weights](int channel, int size) {
  829. weights.resize(channel + 1);
  830. weights[channel].resize(size);
  831. return weights[channel].data();
  832. };
  833. BulkFeatureExtractor extractor(indices_fn, ids_fn, weights_fn);
  834. const int num_steps = test_parser->BulkGetFixedFeatures(extractor);
  835. // There should be 6 steps (2N, where N is the longest number of tokens).
  836. EXPECT_EQ(num_steps, 6);
  837. // These are empirically derived.
  838. const vector<int32> expected_ch0_indices({0, 36, 18, 54, 1, 37, 19, 55,
  839. 2, 38, 20, 56, 3, 39, 21, 57,
  840. 4, 40, 22, 58, 5, 41, 23, 59});
  841. const vector<int64> expected_ch0_ids({0, 12, 0, 12, 12, 7, 12, 7,
  842. 7, 50, 7, 50, 7, 50, 7, 50,
  843. 50, 50, 50, 50, 50, 50, 50, 50});
  844. const vector<float> expected_ch0_weights(
  845. {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
  846. const vector<int32> expected_ch1_indices(
  847. {0, 36, 72, 18, 54, 90, 1, 37, 73, 19, 55, 91, 2, 38, 74, 20, 56, 92,
  848. 3, 39, 75, 21, 57, 93, 4, 40, 76, 22, 58, 94, 5, 41, 77, 23, 59, 95});
  849. const vector<int64> expected_ch1_ids(
  850. {51, 0, 12, 51, 0, 12, 0, 12, 7, 0, 12, 7, 12, 7, 50, 12, 7, 50,
  851. 12, 7, 50, 12, 7, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50});
  852. const vector<float> expected_ch1_weights(
  853. {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
  854. 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
  855. EXPECT_EQ(indices[0], expected_ch0_indices);
  856. EXPECT_EQ(ids[0], expected_ch0_ids);
  857. EXPECT_EQ(weights[0], expected_ch0_weights);
  858. EXPECT_EQ(indices[1], expected_ch1_indices);
  859. EXPECT_EQ(ids[1], expected_ch1_ids);
  860. EXPECT_EQ(weights[1], expected_ch1_weights);
  861. }
  862. TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeaturesWithPadding) {
  863. // Create an empty input batch and beam vector to initialize the parser.
  864. Sentence sentence_0;
  865. TextFormat::ParseFromString(kSentence0, &sentence_0);
  866. string sentence_0_str;
  867. sentence_0.SerializeToString(&sentence_0_str);
  868. Sentence sentence_1;
  869. TextFormat::ParseFromString(kSentence1, &sentence_1);
  870. string sentence_1_str;
  871. sentence_1.SerializeToString(&sentence_1_str);
  872. constexpr int kBeamSize = 3;
  873. constexpr int kBatchSize = 2;
  874. auto test_parser =
  875. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  876. // Get and check the raw link features.
  877. constexpr int kNumLinkFeatures = 2;
  878. auto link_features = test_parser->GetRawLinkFeatures(0);
  879. EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
  880. EXPECT_EQ(link_features.at(0).feature_value(), -1);
  881. EXPECT_EQ(link_features.at(0).batch_idx(), 0);
  882. EXPECT_EQ(link_features.at(0).beam_idx(), 0);
  883. EXPECT_EQ(link_features.at(1).feature_value(), -2);
  884. EXPECT_EQ(link_features.at(1).batch_idx(), 0);
  885. EXPECT_EQ(link_features.at(1).beam_idx(), 0);
  886. // These are padding, so we do not expect them to have a feature value.
  887. EXPECT_FALSE(link_features.at(2).has_feature_value());
  888. EXPECT_FALSE(link_features.at(2).has_batch_idx());
  889. EXPECT_FALSE(link_features.at(2).has_beam_idx());
  890. EXPECT_FALSE(link_features.at(3).has_feature_value());
  891. EXPECT_FALSE(link_features.at(3).has_batch_idx());
  892. EXPECT_FALSE(link_features.at(3).has_beam_idx());
  893. EXPECT_FALSE(link_features.at(4).has_feature_value());
  894. EXPECT_FALSE(link_features.at(4).has_batch_idx());
  895. EXPECT_FALSE(link_features.at(4).has_beam_idx());
  896. EXPECT_FALSE(link_features.at(5).has_feature_value());
  897. EXPECT_FALSE(link_features.at(5).has_batch_idx());
  898. EXPECT_FALSE(link_features.at(5).has_beam_idx());
  899. EXPECT_EQ(link_features.at(6).feature_value(), -1);
  900. EXPECT_EQ(link_features.at(6).batch_idx(), 1);
  901. EXPECT_EQ(link_features.at(6).beam_idx(), 0);
  902. EXPECT_EQ(link_features.at(7).feature_value(), -2);
  903. EXPECT_EQ(link_features.at(7).batch_idx(), 1);
  904. EXPECT_EQ(link_features.at(7).beam_idx(), 0);
  905. // These are padding, so we do not expect them to have a feature value.
  906. EXPECT_FALSE(link_features.at(8).has_feature_value());
  907. EXPECT_FALSE(link_features.at(8).has_batch_idx());
  908. EXPECT_FALSE(link_features.at(8).has_beam_idx());
  909. EXPECT_FALSE(link_features.at(9).has_feature_value());
  910. EXPECT_FALSE(link_features.at(9).has_batch_idx());
  911. EXPECT_FALSE(link_features.at(9).has_beam_idx());
  912. EXPECT_FALSE(link_features.at(10).has_feature_value());
  913. EXPECT_FALSE(link_features.at(10).has_batch_idx());
  914. EXPECT_FALSE(link_features.at(10).has_beam_idx());
  915. EXPECT_FALSE(link_features.at(11).has_feature_value());
  916. EXPECT_FALSE(link_features.at(11).has_batch_idx());
  917. EXPECT_FALSE(link_features.at(11).has_beam_idx());
  918. }
  919. TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
  920. // Create an empty input batch and beam vector to initialize the parser.
  921. Sentence sentence_0;
  922. TextFormat::ParseFromString(kSentence0, &sentence_0);
  923. string sentence_0_str;
  924. sentence_0.SerializeToString(&sentence_0_str);
  925. Sentence sentence_1;
  926. TextFormat::ParseFromString(kSentence1, &sentence_1);
  927. string sentence_1_str;
  928. sentence_1.SerializeToString(&sentence_1_str);
  929. constexpr int kBeamSize = 3;
  930. auto test_parser =
  931. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
  932. // There are 93 possible transitions for any given state. Create a transition
  933. // array with a score of 10.0 for each transition.
  934. constexpr int kBatchSize = 2;
  935. constexpr int kNumPossibleTransitions = 93;
  936. constexpr float kTransitionValue = 10.0;
  937. float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
  938. for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
  939. transition_matrix[i] = kTransitionValue;
  940. }
  941. // Advance twice, so that the underlying parser fills the beam.
  942. test_parser->AdvanceFromPrediction(
  943. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  944. test_parser->AdvanceFromPrediction(
  945. transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
  946. // Get and check the raw link features.
  947. constexpr int kNumLinkFeatures = 2;
  948. auto link_features = test_parser->GetRawLinkFeatures(0);
  949. EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
  950. // These should index into batch 0.
  951. EXPECT_EQ(link_features.at(0).feature_value(), 1);
  952. EXPECT_EQ(link_features.at(0).batch_idx(), 0);
  953. EXPECT_EQ(link_features.at(0).beam_idx(), 0);
  954. EXPECT_EQ(link_features.at(1).feature_value(), 0);
  955. EXPECT_EQ(link_features.at(1).batch_idx(), 0);
  956. EXPECT_EQ(link_features.at(1).beam_idx(), 0);
  957. EXPECT_EQ(link_features.at(2).feature_value(), -1);
  958. EXPECT_EQ(link_features.at(2).batch_idx(), 0);
  959. EXPECT_EQ(link_features.at(2).beam_idx(), 1);
  960. EXPECT_EQ(link_features.at(3).feature_value(), -2);
  961. EXPECT_EQ(link_features.at(3).batch_idx(), 0);
  962. EXPECT_EQ(link_features.at(3).beam_idx(), 1);
  963. EXPECT_EQ(link_features.at(4).feature_value(), -1);
  964. EXPECT_EQ(link_features.at(4).batch_idx(), 0);
  965. EXPECT_EQ(link_features.at(4).beam_idx(), 2);
  966. EXPECT_EQ(link_features.at(5).feature_value(), -2);
  967. EXPECT_EQ(link_features.at(5).batch_idx(), 0);
  968. EXPECT_EQ(link_features.at(5).beam_idx(), 2);
  969. // These should index into batch 1.
  970. EXPECT_EQ(link_features.at(6).feature_value(), 1);
  971. EXPECT_EQ(link_features.at(6).batch_idx(), 1);
  972. EXPECT_EQ(link_features.at(6).beam_idx(), 0);
  973. EXPECT_EQ(link_features.at(7).feature_value(), 0);
  974. EXPECT_EQ(link_features.at(7).batch_idx(), 1);
  975. EXPECT_EQ(link_features.at(7).beam_idx(), 0);
  976. EXPECT_EQ(link_features.at(8).feature_value(), -1);
  977. EXPECT_EQ(link_features.at(8).batch_idx(), 1);
  978. EXPECT_EQ(link_features.at(8).beam_idx(), 1);
  979. EXPECT_EQ(link_features.at(9).feature_value(), -2);
  980. EXPECT_EQ(link_features.at(9).batch_idx(), 1);
  981. EXPECT_EQ(link_features.at(9).beam_idx(), 1);
  982. EXPECT_EQ(link_features.at(10).feature_value(), -1);
  983. EXPECT_EQ(link_features.at(10).batch_idx(), 1);
  984. EXPECT_EQ(link_features.at(10).beam_idx(), 2);
  985. EXPECT_EQ(link_features.at(11).feature_value(), -2);
  986. EXPECT_EQ(link_features.at(11).batch_idx(), 1);
  987. EXPECT_EQ(link_features.at(11).beam_idx(), 2);
  988. }
  989. TEST_F(SyntaxNetComponentTest, AdvancesFromOracleWithTracing) {
  990. // Create an empty input batch and beam vector to initialize the parser.
  991. Sentence sentence_0;
  992. TextFormat::ParseFromString(kSentence0, &sentence_0);
  993. string sentence_0_str;
  994. sentence_0.SerializeToString(&sentence_0_str);
  995. constexpr int kBeamSize = 1;
  996. auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
  997. test_parser->InitializeTracing();
  998. constexpr int kNumTokensInSentence = 3;
  999. // The master spec will initialize a parser, so expect 2*N transitions.
  1000. constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
  1001. constexpr int kFixedFeatureChannels = 1;
  1002. for (int i = 0; i < kExpectedNumTransitions; ++i) {
  1003. EXPECT_FALSE(test_parser->IsTerminal());
  1004. vector<int32> indices;
  1005. auto indices_fn = [&indices](int size) {
  1006. indices.resize(size);
  1007. return indices.data();
  1008. };
  1009. vector<int64> ids;
  1010. auto ids_fn = [&ids](int size) {
  1011. ids.resize(size);
  1012. return ids.data();
  1013. };
  1014. vector<float> weights;
  1015. auto weights_fn = [&weights](int size) {
  1016. weights.resize(size);
  1017. return weights.data();
  1018. };
  1019. for (int j = 0; j < kFixedFeatureChannels; ++j) {
  1020. test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, j);
  1021. }
  1022. auto features = test_parser->GetRawLinkFeatures(0);
  1023. // Make some fake translations to test visualization.
  1024. for (int j = 0; j < features.size(); ++j) {
  1025. features[j].set_step_idx(j < i ? j : -1);
  1026. }
  1027. test_parser->AddTranslatedLinkFeaturesToTrace(features, 0);
  1028. test_parser->AdvanceFromOracle();
  1029. }
  1030. // At this point, the test parser should be terminal.
  1031. EXPECT_TRUE(test_parser->IsTerminal());
  1032. // TODO(googleuser): Add EXPECT_EQ here instead of printing.
  1033. std::vector<std::vector<ComponentTrace>> traces =
  1034. test_parser->GetTraceProtos();
  1035. for (auto &batch_trace : traces) {
  1036. for (auto &trace : batch_trace) {
  1037. LOG(INFO) << "trace:" << std::endl << trace.DebugString();
  1038. }
  1039. }
  1040. }
  1041. TEST_F(SyntaxNetComponentTest, NoTracingDropsFeatureNames) {
  1042. // Create an empty input batch and beam vector to initialize the parser.
  1043. Sentence sentence_0;
  1044. TextFormat::ParseFromString(kSentence0, &sentence_0);
  1045. string sentence_0_str;
  1046. sentence_0.SerializeToString(&sentence_0_str);
  1047. constexpr int kBeamSize = 1;
  1048. const auto test_parser =
  1049. CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
  1050. const auto link_features = test_parser->GetRawLinkFeatures(0);
  1051. // The fml associated with the channel is "stack.focus stack(1).focus".
  1052. // Both features should lack the feature_name field.
  1053. EXPECT_EQ(link_features.size(), 2);
  1054. EXPECT_FALSE(link_features.at(0).has_feature_name());
  1055. EXPECT_FALSE(link_features.at(1).has_feature_name());
  1056. }
  1057. TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
  1058. // Create an empty input batch and beam vector to initialize the parser.
  1059. Sentence sentence_0;
  1060. TextFormat::ParseFromString(kSentence0, &sentence_0);
  1061. string sentence_0_str;
  1062. sentence_0.SerializeToString(&sentence_0_str);
  1063. constexpr int kBeamSize = 1;
  1064. auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
  1065. test_parser->InitializeTracing();
  1066. const auto link_features = test_parser->GetRawLinkFeatures(0);
  1067. // The fml associated with the channel is "stack.focus stack(1).focus".
  1068. EXPECT_EQ(link_features.size(), 2);
  1069. EXPECT_EQ(link_features.at(0).feature_name(), "stack.focus");
  1070. EXPECT_EQ(link_features.at(1).feature_name(), "stack(1).focus");
  1071. }
  1072. } // namespace dragnn
  1073. } // namespace syntaxnet