| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274 |
- // Copyright 2017 Google Inc. All Rights Reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- // =============================================================================
- #include "dragnn/components/syntaxnet/syntaxnet_component.h"
- #include "dragnn/core/input_batch_cache.h"
- #include "dragnn/core/test/generic.h"
- #include "dragnn/core/test/mock_transition_state.h"
- #include "dragnn/io/sentence_input_batch.h"
- #include "syntaxnet/sentence.pb.h"
- #include "tensorflow/core/lib/core/errors.h"
- #include "tensorflow/core/lib/core/status.h"
- #include "tensorflow/core/lib/io/path.h"
- #include "tensorflow/core/platform/env.h"
- #include "tensorflow/core/platform/protobuf.h"
- #include "tensorflow/core/platform/test.h"
- // This test suite is intended to validate the contracts that the DRAGNN
- // system expects from all transition state subclasses. Developers creating
- // new TransitionStates should copy this test and modify it as necessary,
- // using it to ensure their state conforms to DRAGNN expectations.
- namespace syntaxnet {
- namespace dragnn {
- namespace {
- const char kSentence0[] = R"(
- token {
- word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
- break_level: NO_BREAK
- }
- token {
- word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
- break_level: SPACE_BREAK
- }
- token {
- word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
- break_level: NO_BREAK
- }
- )";
- const char kSentence1[] = R"(
- token {
- word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
- break_level: NO_BREAK
- }
- token {
- word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
- break_level: SPACE_BREAK
- }
- token {
- word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
- break_level: NO_BREAK
- }
- )";
- const char kLongSentence[] = R"(
- token {
- word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
- break_level: NO_BREAK
- }
- token {
- word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
- break_level: SPACE_BREAK
- }
- token {
- word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
- break_level: SPACE_BREAK
- }
- token {
- word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
- break_level: SPACE_BREAK
- }
- token {
- word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
- break_level: NO_BREAK
- }
- )";
- } // namespace
- using testing::Return;
- class SyntaxNetComponentTest : public ::testing::Test {
- public:
- std::unique_ptr<SyntaxNetComponent> CreateParser(
- const std::vector<std::vector<const TransitionState *>> &states,
- const std::vector<string> &data) {
- constexpr int kBeamSize = 2;
- return CreateParserWithBeamSize(kBeamSize, states, data);
- }
- std::unique_ptr<SyntaxNetComponent> CreateParserWithBeamSize(
- int beam_size,
- const std::vector<std::vector<const TransitionState *>> &states,
- const std::vector<string> &data) {
- // Get the master spec proto from the test data directory.
- MasterSpec master_spec;
- string file_name = tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- "master_spec.textproto");
- TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
- &master_spec));
- // Get all the resource protos from the test data directory.
- for (Resource &resource :
- *(master_spec.mutable_component(0)->mutable_resource())) {
- resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- resource.part(0).file_pattern()));
- }
- data_.reset(new InputBatchCache(data));
- // Create a parser component with the specified beam size.
- std::unique_ptr<SyntaxNetComponent> parser_component(
- new SyntaxNetComponent());
- parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
- parser_component->InitializeData(states, beam_size, data_.get());
- return parser_component;
- }
- const std::vector<Beam<SyntaxNetTransitionState> *> GetBeams(
- SyntaxNetComponent *component) const {
- std::vector<Beam<SyntaxNetTransitionState> *> return_vector;
- for (const auto &beam : component->batch_) {
- return_vector.push_back(beam.get());
- }
- return return_vector;
- }
- std::unique_ptr<InputBatchCache> data_;
- };
- TEST_F(SyntaxNetComponentTest, AdvancesFromOracleAndTerminates) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- auto test_parser = CreateParser({}, {sentence_0_str});
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromOracle();
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
- // Make sure the parser doesn't segfault.
- test_parser->FinalizeData();
- }
- TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionAndTerminates) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- auto test_parser = CreateParser({}, {sentence_0_str});
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBeamSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromPrediction(transition_matrix,
- kNumPossibleTransitions * kBeamSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
- // Prepare to validate the batched beams.
- auto beam = test_parser->GetBeam();
- // All beams should only have one element.
- for (const auto &per_beam : beam) {
- EXPECT_EQ(per_beam.size(), 1);
- }
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- // Make sure the parser doesn't segfault.
- test_parser->FinalizeData();
- // TODO(googleuser): What should the finalized data look like?
- }
- TEST_F(SyntaxNetComponentTest, RetainsPassedTransitionStateData) {
- // Create and initialize the state->
- MockTransitionState mock_state_one;
- constexpr int kParentBeamIndexOne = 1138;
- constexpr float kParentScoreOne = 7.2;
- EXPECT_CALL(mock_state_one, GetBeamIndex())
- .WillRepeatedly(Return(kParentBeamIndexOne));
- EXPECT_CALL(mock_state_one, GetScore())
- .WillRepeatedly(Return(kParentScoreOne));
- MockTransitionState mock_state_two;
- constexpr int kParentBeamIndexTwo = 1123;
- constexpr float kParentScoreTwo = 42.03;
- EXPECT_CALL(mock_state_two, GetBeamIndex())
- .WillRepeatedly(Return(kParentBeamIndexTwo));
- EXPECT_CALL(mock_state_two, GetScore())
- .WillRepeatedly(Return(kParentScoreTwo));
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- auto test_parser =
- CreateParser({{&mock_state_one, &mock_state_two}}, {sentence_0_str});
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBeamSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromPrediction(transition_matrix,
- kNumPossibleTransitions * kBeamSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
- // The final states should have kExpectedNumTransitions * kTransitionValue,
- // plus the higher parent state score (from state two).
- auto beam = test_parser->GetBeam();
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions + kParentScoreTwo);
- // Make sure that the parent state is reported correctly.
- EXPECT_EQ(test_parser->GetSourceBeamIndex(0, 0), kParentBeamIndexTwo);
- // Make sure the parser doesn't segfault.
- test_parser->FinalizeData();
- // TODO(googleuser): What should the finalized data look like?
- }
- TEST_F(SyntaxNetComponentTest, AdvancesFromPredictionForMultiSentenceBatches) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- auto test_parser = CreateParser({}, {sentence_0_str, sentence_1_str});
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 2;
- constexpr int kBeamSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
- EXPECT_EQ(test_parser->StepsTaken(1), kExpectedNumTransitions);
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- auto beam = test_parser->GetBeam();
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- EXPECT_EQ(beam.at(1).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- // Make sure the parser doesn't segfault.
- test_parser->FinalizeData();
- // TODO(googleuser): What should the finalized data look like?
- }
- TEST_F(SyntaxNetComponentTest,
- AdvancesFromPredictionForVaryingLengthSentences) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- auto test_parser = CreateParser({}, {sentence_0_str, long_sentence_str});
- constexpr int kNumTokensInSentence = 3;
- constexpr int kNumTokensInLongSentence = 5;
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 2;
- constexpr int kBeamSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- constexpr int kExpectedNumTransitions = kNumTokensInLongSentence * 2;
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kNumTokensInSentence * 2);
- EXPECT_EQ(test_parser->StepsTaken(1), kNumTokensInLongSentence * 2);
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- auto beam = test_parser->GetBeam();
- // The first sentence is shorter, so it should have a lower final score.
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kNumTokensInSentence * 2);
- EXPECT_EQ(beam.at(1).at(0)->GetScore(),
- kTransitionValue * kNumTokensInLongSentence * 2);
- // Make sure the parser doesn't segfault.
- test_parser->FinalizeData();
- // TODO(googleuser): What should the finalized data look like?
- }
- TEST_F(SyntaxNetComponentTest, ResetAllowsReductionInBatchSize) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- // Get the master spec proto from the test data directory.
- MasterSpec master_spec;
- string file_name = tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- "master_spec.textproto");
- TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
- &master_spec));
- // Get all the resource protos from the test data directory.
- for (Resource &resource :
- *(master_spec.mutable_component(0)->mutable_resource())) {
- resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- resource.part(0).file_pattern()));
- }
- // Create an input batch cache with a large batch size.
- constexpr int kBeamSize = 2;
- std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
- {sentence_0_str, sentence_0_str, sentence_0_str, sentence_0_str}));
- std::unique_ptr<SyntaxNetComponent> parser_component(
- new SyntaxNetComponent());
- parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
- parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
- // Reset the component and pass in a new input batch that is smaller.
- parser_component->ResetComponent();
- std::unique_ptr<InputBatchCache> small_batch_data(new InputBatchCache(
- {long_sentence_str, long_sentence_str, long_sentence_str}));
- parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 3;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- constexpr int kNumTokensInSentence = 5;
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(parser_component->IsTerminal());
- parser_component->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(parser_component->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
- EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
- EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- auto beam = parser_component->GetBeam();
- // The beam should be of batch size 3.
- EXPECT_EQ(beam.size(), 3);
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- EXPECT_EQ(beam.at(1).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- EXPECT_EQ(beam.at(2).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- // Make sure the parser doesn't segfault.
- parser_component->FinalizeData();
- }
- TEST_F(SyntaxNetComponentTest, ResetAllowsIncreaseInBatchSize) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- // Get the master spec proto from the test data directory.
- MasterSpec master_spec;
- string file_name = tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- "master_spec.textproto");
- TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
- &master_spec));
- // Get all the resource protos from the test data directory.
- for (Resource &resource :
- *(master_spec.mutable_component(0)->mutable_resource())) {
- resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- resource.part(0).file_pattern()));
- }
- // Create an input batch cache with a small batch size.
- constexpr int kBeamSize = 2;
- std::unique_ptr<InputBatchCache> small_batch_data(
- new InputBatchCache(sentence_0_str));
- std::unique_ptr<SyntaxNetComponent> parser_component(
- new SyntaxNetComponent());
- parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
- parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
- // Reset the component and pass in a new input batch that is larger.
- parser_component->ResetComponent();
- std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
- {long_sentence_str, long_sentence_str, long_sentence_str}));
- parser_component->InitializeData({}, kBeamSize, large_batch_data.get());
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 3;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- constexpr int kNumTokensInSentence = 5;
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(parser_component->IsTerminal());
- parser_component->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(parser_component->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(parser_component->StepsTaken(0), kExpectedNumTransitions);
- EXPECT_EQ(parser_component->StepsTaken(1), kExpectedNumTransitions);
- EXPECT_EQ(parser_component->StepsTaken(2), kExpectedNumTransitions);
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- auto beam = parser_component->GetBeam();
- // The beam should be of batch size 3.
- EXPECT_EQ(beam.size(), 3);
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- EXPECT_EQ(beam.at(1).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- EXPECT_EQ(beam.at(2).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- // Make sure the parser doesn't segfault.
- parser_component->FinalizeData();
- }
- TEST_F(SyntaxNetComponentTest, ResetCausesBeamToReset) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- auto test_parser = CreateParser({}, {sentence_0_str});
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBeamSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Transition the expected number of times.
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- test_parser->AdvanceFromPrediction(transition_matrix,
- kNumPossibleTransitions * kBeamSize);
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // Check that the component is reporting 2N steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), kExpectedNumTransitions);
- // The final states should have kExpectedNumTransitions * kTransitionValue.
- auto beam = test_parser->GetBeam();
- EXPECT_EQ(beam.at(0).at(0)->GetScore(),
- kTransitionValue * kExpectedNumTransitions);
- // Reset the test parser and give it new data.
- test_parser->ResetComponent();
- std::unique_ptr<InputBatchCache> new_data(
- new InputBatchCache(long_sentence_str));
- test_parser->InitializeData({}, kBeamSize, new_data.get());
- // Check that the component is not terminal.
- EXPECT_FALSE(test_parser->IsTerminal());
- // Check that the component is reporting 0 steps taken.
- EXPECT_EQ(test_parser->StepsTaken(0), 0);
- // The states should have 0 as their score.
- auto new_beam = test_parser->GetBeam();
- EXPECT_EQ(new_beam.at(0).at(0)->GetScore(), 0);
- }
- TEST_F(SyntaxNetComponentTest, AdjustingMaxBeamSizeAdjustsSizeForAllBeams) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- // Get the master spec proto from the test data directory.
- MasterSpec master_spec;
- string file_name = tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- "master_spec.textproto");
- TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
- &master_spec));
- // Get all the resource protos from the test data directory.
- for (Resource &resource :
- *(master_spec.mutable_component(0)->mutable_resource())) {
- resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- resource.part(0).file_pattern()));
- }
- // Create an input batch cache with a small batch size.
- constexpr int kBeamSize = 2;
- std::unique_ptr<InputBatchCache> small_batch_data(
- new InputBatchCache(sentence_0_str));
- std::unique_ptr<SyntaxNetComponent> parser_component(
- new SyntaxNetComponent());
- parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
- parser_component->InitializeData({}, kBeamSize, small_batch_data.get());
- // Make sure all the beams in the batch have max size 2.
- for (const auto &beam : GetBeams(parser_component.get())) {
- EXPECT_EQ(beam->max_size(), kBeamSize);
- }
- // Reset the component and pass in a new input batch that is larger, with
- // a higher beam size.
- constexpr int kNewBeamSize = 5;
- parser_component->ResetComponent();
- std::unique_ptr<InputBatchCache> large_batch_data(new InputBatchCache(
- {long_sentence_str, long_sentence_str, long_sentence_str}));
- parser_component->InitializeData({}, kNewBeamSize, large_batch_data.get());
- // Make sure all the beams in the batch now have max size 5.
- for (const auto &beam : GetBeams(parser_component.get())) {
- EXPECT_EQ(beam->max_size(), kNewBeamSize);
- }
- }
- TEST_F(SyntaxNetComponentTest, SettingBeamSizeZeroFails) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence long_sentence;
- TextFormat::ParseFromString(kLongSentence, &long_sentence);
- string long_sentence_str;
- long_sentence.SerializeToString(&long_sentence_str);
- // Get the master spec proto from the test data directory.
- MasterSpec master_spec;
- string file_name = tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- "master_spec.textproto");
- TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
- &master_spec));
- // Get all the resource protos from the test data directory.
- for (Resource &resource :
- *(master_spec.mutable_component(0)->mutable_resource())) {
- resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
- test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
- resource.part(0).file_pattern()));
- }
- // Create an input batch cache with a small batch size.
- constexpr int kBeamSize = 0;
- std::unique_ptr<InputBatchCache> small_batch_data(
- new InputBatchCache(sentence_0_str));
- std::unique_ptr<SyntaxNetComponent> parser_component(
- new SyntaxNetComponent());
- parser_component->InitializeComponent(*(master_spec.mutable_component(0)));
- EXPECT_DEATH(
- parser_component->InitializeData({}, kBeamSize, small_batch_data.get()),
- "must be greater than 0");
- }
- TEST_F(SyntaxNetComponentTest, ExportsFixedFeaturesWithPadding) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // Get and check the raw link features.
- vector<int32> indices;
- auto indices_fn = [&indices](int size) {
- indices.resize(size);
- return indices.data();
- };
- vector<int64> ids;
- auto ids_fn = [&ids](int size) {
- ids.resize(size);
- return ids.data();
- };
- vector<float> weights;
- auto weights_fn = [&weights](int size) {
- weights.resize(size);
- return weights.data();
- };
- constexpr int kChannelId = 0;
- const int num_features =
- test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
- // The raw features for each beam object should be [single, single].
- // There is also padding expected in this beam - there is only one
- // element in each beam (so two elements total; batch is two). Thus, we expect
- // 0,1 and 6,7 to be filled with one element each.
- constexpr int kExpectedOutputSize = 4;
- const vector<int32> expected_indices({0, 1, 6, 7});
- const vector<int64> expected_ids({0, 12, 0, 12});
- const vector<float> expected_weights({1.0, 1.0, 1.0, 1.0});
- EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
- EXPECT_EQ(num_features, kExpectedOutputSize);
- EXPECT_EQ(expected_indices, indices);
- EXPECT_EQ(expected_ids, ids);
- EXPECT_EQ(expected_weights, weights);
- }
- TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Advance twice, so that the underlying parser fills the beam.
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- // Get and check the raw link features.
- vector<int32> indices;
- auto indices_fn = [&indices](int size) {
- indices.resize(size);
- return indices.data();
- };
- vector<int64> ids;
- auto ids_fn = [&ids](int size) {
- ids.resize(size);
- return ids.data();
- };
- vector<float> weights;
- auto weights_fn = [&weights](int size) {
- weights.resize(size);
- return weights.data();
- };
- constexpr int kChannelId = 0;
- const int num_features =
- test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
- constexpr int kExpectedOutputSize = 12;
- const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
- const vector<int64> expected_ids({7, 50, 12, 7, 12, 7, 7, 50, 12, 7, 12, 7});
- const vector<float> expected_weights(
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
- EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
- EXPECT_EQ(num_features, kExpectedOutputSize);
- EXPECT_EQ(expected_indices, indices);
- EXPECT_EQ(expected_ids, ids);
- EXPECT_EQ(expected_weights, weights);
- }
- TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Replace the first several options with varying scores to test sorting.
- constexpr int kBatchOffset = kNumPossibleTransitions * kBeamSize;
- transition_matrix[0] = 3 * kTransitionValue;
- transition_matrix[1] = 3 * kTransitionValue;
- transition_matrix[2] = 4 * kTransitionValue;
- transition_matrix[3] = 4 * kTransitionValue;
- transition_matrix[4] = 2 * kTransitionValue;
- transition_matrix[5] = 2 * kTransitionValue;
- transition_matrix[kBatchOffset + 0] = 3 * kTransitionValue;
- transition_matrix[kBatchOffset + 1] = 3 * kTransitionValue;
- transition_matrix[kBatchOffset + 2] = 4 * kTransitionValue;
- transition_matrix[kBatchOffset + 3] = 4 * kTransitionValue;
- transition_matrix[kBatchOffset + 4] = 2 * kTransitionValue;
- transition_matrix[kBatchOffset + 5] = 2 * kTransitionValue;
- // Advance twice, so that the underlying parser fills the beam.
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- // Get and check the raw link features.
- vector<int32> indices;
- auto indices_fn = [&indices](int size) {
- indices.resize(size);
- return indices.data();
- };
- vector<int64> ids;
- auto ids_fn = [&ids](int size) {
- ids.resize(size);
- return ids.data();
- };
- vector<float> weights;
- auto weights_fn = [&weights](int size) {
- weights.resize(size);
- return weights.data();
- };
- constexpr int kChannelId = 0;
- const int num_features =
- test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
- // In this case, all even features and all odd features are identical.
- constexpr int kExpectedOutputSize = 12;
- const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
- const vector<int64> expected_ids({12, 7, 7, 50, 12, 7, 12, 7, 7, 50, 12, 7});
- const vector<float> expected_weights(
- {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
- EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
- EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
- EXPECT_EQ(num_features, kExpectedOutputSize);
- EXPECT_EQ(expected_indices, indices);
- EXPECT_EQ(expected_ids, ids);
- EXPECT_EQ(expected_weights, weights);
- }
- TEST_F(SyntaxNetComponentTest, ExportsBulkFixedFeatures) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // Get and check the raw link features.
- vector<vector<int32>> indices;
- auto indices_fn = [&indices](int channel, int size) {
- indices.resize(channel + 1);
- indices[channel].resize(size);
- return indices[channel].data();
- };
- vector<vector<int64>> ids;
- auto ids_fn = [&ids](int channel, int size) {
- ids.resize(channel + 1);
- ids[channel].resize(size);
- return ids[channel].data();
- };
- vector<vector<float>> weights;
- auto weights_fn = [&weights](int channel, int size) {
- weights.resize(channel + 1);
- weights[channel].resize(size);
- return weights[channel].data();
- };
- BulkFeatureExtractor extractor(indices_fn, ids_fn, weights_fn);
- const int num_steps = test_parser->BulkGetFixedFeatures(extractor);
- // There should be 6 steps (2N, where N is the longest number of tokens).
- EXPECT_EQ(num_steps, 6);
- // These are empirically derived.
- const vector<int32> expected_ch0_indices({0, 36, 18, 54, 1, 37, 19, 55,
- 2, 38, 20, 56, 3, 39, 21, 57,
- 4, 40, 22, 58, 5, 41, 23, 59});
- const vector<int64> expected_ch0_ids({0, 12, 0, 12, 12, 7, 12, 7,
- 7, 50, 7, 50, 7, 50, 7, 50,
- 50, 50, 50, 50, 50, 50, 50, 50});
- const vector<float> expected_ch0_weights(
- {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
- const vector<int32> expected_ch1_indices(
- {0, 36, 72, 18, 54, 90, 1, 37, 73, 19, 55, 91, 2, 38, 74, 20, 56, 92,
- 3, 39, 75, 21, 57, 93, 4, 40, 76, 22, 58, 94, 5, 41, 77, 23, 59, 95});
- const vector<int64> expected_ch1_ids(
- {51, 0, 12, 51, 0, 12, 0, 12, 7, 0, 12, 7, 12, 7, 50, 12, 7, 50,
- 12, 7, 50, 12, 7, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50, 7, 50, 50});
- const vector<float> expected_ch1_weights(
- {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
- 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1});
- EXPECT_EQ(indices[0], expected_ch0_indices);
- EXPECT_EQ(ids[0], expected_ch0_ids);
- EXPECT_EQ(weights[0], expected_ch0_weights);
- EXPECT_EQ(indices[1], expected_ch1_indices);
- EXPECT_EQ(ids[1], expected_ch1_ids);
- EXPECT_EQ(weights[1], expected_ch1_weights);
- }
- TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeaturesWithPadding) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- constexpr int kBatchSize = 2;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // Get and check the raw link features.
- constexpr int kNumLinkFeatures = 2;
- auto link_features = test_parser->GetRawLinkFeatures(0);
- EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
- EXPECT_EQ(link_features.at(0).feature_value(), -1);
- EXPECT_EQ(link_features.at(0).batch_idx(), 0);
- EXPECT_EQ(link_features.at(0).beam_idx(), 0);
- EXPECT_EQ(link_features.at(1).feature_value(), -2);
- EXPECT_EQ(link_features.at(1).batch_idx(), 0);
- EXPECT_EQ(link_features.at(1).beam_idx(), 0);
- // These are padding, so we do not expect them to have a feature value.
- EXPECT_FALSE(link_features.at(2).has_feature_value());
- EXPECT_FALSE(link_features.at(2).has_batch_idx());
- EXPECT_FALSE(link_features.at(2).has_beam_idx());
- EXPECT_FALSE(link_features.at(3).has_feature_value());
- EXPECT_FALSE(link_features.at(3).has_batch_idx());
- EXPECT_FALSE(link_features.at(3).has_beam_idx());
- EXPECT_FALSE(link_features.at(4).has_feature_value());
- EXPECT_FALSE(link_features.at(4).has_batch_idx());
- EXPECT_FALSE(link_features.at(4).has_beam_idx());
- EXPECT_FALSE(link_features.at(5).has_feature_value());
- EXPECT_FALSE(link_features.at(5).has_batch_idx());
- EXPECT_FALSE(link_features.at(5).has_beam_idx());
- EXPECT_EQ(link_features.at(6).feature_value(), -1);
- EXPECT_EQ(link_features.at(6).batch_idx(), 1);
- EXPECT_EQ(link_features.at(6).beam_idx(), 0);
- EXPECT_EQ(link_features.at(7).feature_value(), -2);
- EXPECT_EQ(link_features.at(7).batch_idx(), 1);
- EXPECT_EQ(link_features.at(7).beam_idx(), 0);
- // These are padding, so we do not expect them to have a feature value.
- EXPECT_FALSE(link_features.at(8).has_feature_value());
- EXPECT_FALSE(link_features.at(8).has_batch_idx());
- EXPECT_FALSE(link_features.at(8).has_beam_idx());
- EXPECT_FALSE(link_features.at(9).has_feature_value());
- EXPECT_FALSE(link_features.at(9).has_batch_idx());
- EXPECT_FALSE(link_features.at(9).has_beam_idx());
- EXPECT_FALSE(link_features.at(10).has_feature_value());
- EXPECT_FALSE(link_features.at(10).has_batch_idx());
- EXPECT_FALSE(link_features.at(10).has_beam_idx());
- EXPECT_FALSE(link_features.at(11).has_feature_value());
- EXPECT_FALSE(link_features.at(11).has_batch_idx());
- EXPECT_FALSE(link_features.at(11).has_beam_idx());
- }
- TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- Sentence sentence_1;
- TextFormat::ParseFromString(kSentence1, &sentence_1);
- string sentence_1_str;
- sentence_1.SerializeToString(&sentence_1_str);
- constexpr int kBeamSize = 3;
- auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
- // There are 93 possible transitions for any given state. Create a transition
- // array with a score of 10.0 for each transition.
- constexpr int kBatchSize = 2;
- constexpr int kNumPossibleTransitions = 93;
- constexpr float kTransitionValue = 10.0;
- float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
- for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
- transition_matrix[i] = kTransitionValue;
- }
- // Advance twice, so that the underlying parser fills the beam.
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- test_parser->AdvanceFromPrediction(
- transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
- // Get and check the raw link features.
- constexpr int kNumLinkFeatures = 2;
- auto link_features = test_parser->GetRawLinkFeatures(0);
- EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
- // These should index into batch 0.
- EXPECT_EQ(link_features.at(0).feature_value(), 1);
- EXPECT_EQ(link_features.at(0).batch_idx(), 0);
- EXPECT_EQ(link_features.at(0).beam_idx(), 0);
- EXPECT_EQ(link_features.at(1).feature_value(), 0);
- EXPECT_EQ(link_features.at(1).batch_idx(), 0);
- EXPECT_EQ(link_features.at(1).beam_idx(), 0);
- EXPECT_EQ(link_features.at(2).feature_value(), -1);
- EXPECT_EQ(link_features.at(2).batch_idx(), 0);
- EXPECT_EQ(link_features.at(2).beam_idx(), 1);
- EXPECT_EQ(link_features.at(3).feature_value(), -2);
- EXPECT_EQ(link_features.at(3).batch_idx(), 0);
- EXPECT_EQ(link_features.at(3).beam_idx(), 1);
- EXPECT_EQ(link_features.at(4).feature_value(), -1);
- EXPECT_EQ(link_features.at(4).batch_idx(), 0);
- EXPECT_EQ(link_features.at(4).beam_idx(), 2);
- EXPECT_EQ(link_features.at(5).feature_value(), -2);
- EXPECT_EQ(link_features.at(5).batch_idx(), 0);
- EXPECT_EQ(link_features.at(5).beam_idx(), 2);
- // These should index into batch 1.
- EXPECT_EQ(link_features.at(6).feature_value(), 1);
- EXPECT_EQ(link_features.at(6).batch_idx(), 1);
- EXPECT_EQ(link_features.at(6).beam_idx(), 0);
- EXPECT_EQ(link_features.at(7).feature_value(), 0);
- EXPECT_EQ(link_features.at(7).batch_idx(), 1);
- EXPECT_EQ(link_features.at(7).beam_idx(), 0);
- EXPECT_EQ(link_features.at(8).feature_value(), -1);
- EXPECT_EQ(link_features.at(8).batch_idx(), 1);
- EXPECT_EQ(link_features.at(8).beam_idx(), 1);
- EXPECT_EQ(link_features.at(9).feature_value(), -2);
- EXPECT_EQ(link_features.at(9).batch_idx(), 1);
- EXPECT_EQ(link_features.at(9).beam_idx(), 1);
- EXPECT_EQ(link_features.at(10).feature_value(), -1);
- EXPECT_EQ(link_features.at(10).batch_idx(), 1);
- EXPECT_EQ(link_features.at(10).beam_idx(), 2);
- EXPECT_EQ(link_features.at(11).feature_value(), -2);
- EXPECT_EQ(link_features.at(11).batch_idx(), 1);
- EXPECT_EQ(link_features.at(11).beam_idx(), 2);
- }
- TEST_F(SyntaxNetComponentTest, AdvancesFromOracleWithTracing) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- constexpr int kBeamSize = 1;
- auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
- test_parser->InitializeTracing();
- constexpr int kNumTokensInSentence = 3;
- // The master spec will initialize a parser, so expect 2*N transitions.
- constexpr int kExpectedNumTransitions = kNumTokensInSentence * 2;
- constexpr int kFixedFeatureChannels = 1;
- for (int i = 0; i < kExpectedNumTransitions; ++i) {
- EXPECT_FALSE(test_parser->IsTerminal());
- vector<int32> indices;
- auto indices_fn = [&indices](int size) {
- indices.resize(size);
- return indices.data();
- };
- vector<int64> ids;
- auto ids_fn = [&ids](int size) {
- ids.resize(size);
- return ids.data();
- };
- vector<float> weights;
- auto weights_fn = [&weights](int size) {
- weights.resize(size);
- return weights.data();
- };
- for (int j = 0; j < kFixedFeatureChannels; ++j) {
- test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, j);
- }
- auto features = test_parser->GetRawLinkFeatures(0);
- // Make some fake translations to test visualization.
- for (int j = 0; j < features.size(); ++j) {
- features[j].set_step_idx(j < i ? j : -1);
- }
- test_parser->AddTranslatedLinkFeaturesToTrace(features, 0);
- test_parser->AdvanceFromOracle();
- }
- // At this point, the test parser should be terminal.
- EXPECT_TRUE(test_parser->IsTerminal());
- // TODO(googleuser): Add EXPECT_EQ here instead of printing.
- std::vector<std::vector<ComponentTrace>> traces =
- test_parser->GetTraceProtos();
- for (auto &batch_trace : traces) {
- for (auto &trace : batch_trace) {
- LOG(INFO) << "trace:" << std::endl << trace.DebugString();
- }
- }
- }
- TEST_F(SyntaxNetComponentTest, NoTracingDropsFeatureNames) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- constexpr int kBeamSize = 1;
- const auto test_parser =
- CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
- const auto link_features = test_parser->GetRawLinkFeatures(0);
- // The fml associated with the channel is "stack.focus stack(1).focus".
- // Both features should lack the feature_name field.
- EXPECT_EQ(link_features.size(), 2);
- EXPECT_FALSE(link_features.at(0).has_feature_name());
- EXPECT_FALSE(link_features.at(1).has_feature_name());
- }
- TEST_F(SyntaxNetComponentTest, TracingOutputsFeatureNames) {
- // Create an empty input batch and beam vector to initialize the parser.
- Sentence sentence_0;
- TextFormat::ParseFromString(kSentence0, &sentence_0);
- string sentence_0_str;
- sentence_0.SerializeToString(&sentence_0_str);
- constexpr int kBeamSize = 1;
- auto test_parser = CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str});
- test_parser->InitializeTracing();
- const auto link_features = test_parser->GetRawLinkFeatures(0);
- // The fml associated with the channel is "stack.focus stack(1).focus".
- EXPECT_EQ(link_features.size(), 2);
- EXPECT_EQ(link_features.at(0).feature_name(), "stack.focus");
- EXPECT_EQ(link_features.at(1).feature_name(), "stack(1).focus");
- }
- } // namespace dragnn
- } // namespace syntaxnet
|