syntaxnet_transition_state_test.cc 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. #include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
  2. #include "dragnn/components/syntaxnet/syntaxnet_component.h"
  3. #include "dragnn/core/input_batch_cache.h"
  4. #include "dragnn/core/test/generic.h"
  5. #include "dragnn/core/test/mock_transition_state.h"
  6. #include "dragnn/io/sentence_input_batch.h"
  7. #include "dragnn/protos/spec.pb.h"
  8. #include "syntaxnet/sentence.pb.h"
  9. #include "tensorflow/core/lib/core/errors.h"
  10. #include "tensorflow/core/lib/core/status.h"
  11. #include "tensorflow/core/lib/io/path.h"
  12. #include "tensorflow/core/platform/env.h"
  13. #include "tensorflow/core/platform/protobuf.h"
  14. #include "tensorflow/core/platform/test.h"
  15. // This test suite is intended to validate the contracts that the DRAGNN
  16. // system expects from all transition state subclasses. Developers creating
  17. // new TransitionStates should copy this test and modify it as necessary,
  18. // using it to ensure their state conforms to DRAGNN expectations.
  19. namespace syntaxnet {
  20. namespace dragnn {
  21. namespace {
  22. const char kSentence0[] = R"(
  23. token {
  24. word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
  25. break_level: NO_BREAK
  26. }
  27. token {
  28. word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
  29. break_level: SPACE_BREAK
  30. }
  31. token {
  32. word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
  33. break_level: NO_BREAK
  34. }
  35. )";
  36. } // namespace
  37. using testing::Return;
  38. class SyntaxNetTransitionStateTest : public ::testing::Test {
  39. public:
  40. std::unique_ptr<SyntaxNetTransitionState> CreateState() {
  41. // Get the master spec proto from the test data directory.
  42. MasterSpec master_spec;
  43. string file_name = tensorflow::io::JoinPath(
  44. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  45. "master_spec.textproto");
  46. TF_CHECK_OK(tensorflow::ReadTextProto(tensorflow::Env::Default(), file_name,
  47. &master_spec));
  48. // Get all the resource protos from the test data directory.
  49. for (Resource &resource :
  50. *(master_spec.mutable_component(0)->mutable_resource())) {
  51. resource.mutable_part(0)->set_file_pattern(tensorflow::io::JoinPath(
  52. test::GetTestDataPrefix(), "dragnn/components/syntaxnet/testdata",
  53. resource.part(0).file_pattern()));
  54. }
  55. // Create an empty input batch and beam vector to initialize the parser.
  56. Sentence sentence_0;
  57. TextFormat::ParseFromString(kSentence0, &sentence_0);
  58. string sentence_0_str;
  59. sentence_0.SerializeToString(&sentence_0_str);
  60. data_.reset(new InputBatchCache(sentence_0_str));
  61. SentenceInputBatch *sentences = data_->GetAs<SentenceInputBatch>();
  62. // Create a parser comoponent that will generate a parser state for this
  63. // test.
  64. SyntaxNetComponent component;
  65. component.InitializeComponent(*(master_spec.mutable_component(0)));
  66. std::vector<std::vector<const TransitionState *>> states;
  67. constexpr int kBeamSize = 1;
  68. component.InitializeData(states, kBeamSize, data_.get());
  69. // Get a transition state from the component.
  70. std::unique_ptr<SyntaxNetTransitionState> test_state =
  71. component.CreateState(&(sentences->data()->at(0)));
  72. return test_state;
  73. }
  74. std::unique_ptr<InputBatchCache> data_;
  75. };
  76. // Validates the consistency of the beam index setter and getter.
  77. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetBeamIndex) {
  78. // Create and initialize a test state.
  79. MockTransitionState mock_state;
  80. auto test_state = CreateState();
  81. test_state->Init(mock_state);
  82. constexpr int kOldBeamIndex = 12;
  83. test_state->SetBeamIndex(kOldBeamIndex);
  84. EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
  85. constexpr int kNewBeamIndex = 7;
  86. test_state->SetBeamIndex(kNewBeamIndex);
  87. EXPECT_EQ(test_state->GetBeamIndex(), kNewBeamIndex);
  88. }
  89. // Validates the consistency of the score setter and getter.
  90. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetScore) {
  91. // Create and initialize a test state.
  92. MockTransitionState mock_state;
  93. auto test_state = CreateState();
  94. test_state->Init(mock_state);
  95. constexpr float kOldScore = 12.1;
  96. test_state->SetScore(kOldScore);
  97. EXPECT_EQ(test_state->GetScore(), kOldScore);
  98. constexpr float kNewScore = 7.2;
  99. test_state->SetScore(kNewScore);
  100. EXPECT_EQ(test_state->GetScore(), kNewScore);
  101. }
  102. // This test ensures that the initializing state's current index is saved
  103. // as the parent beam index of the state being initialized.
  104. TEST_F(SyntaxNetTransitionStateTest, ReportsParentBeamIndex) {
  105. // Create a mock transition state that wil report a specific current index.
  106. // This index should become the parent state index for the test state.
  107. MockTransitionState mock_state;
  108. constexpr int kParentBeamIndex = 1138;
  109. EXPECT_CALL(mock_state, GetBeamIndex())
  110. .WillRepeatedly(Return(kParentBeamIndex));
  111. auto test_state = CreateState();
  112. test_state->Init(mock_state);
  113. EXPECT_EQ(test_state->ParentBeamIndex(), kParentBeamIndex);
  114. }
  115. // This test ensures that the initializing state's current score is saved
  116. // as the current score of the state being initialized.
  117. TEST_F(SyntaxNetTransitionStateTest, InitializationCopiesParentScore) {
  118. // Create a mock transition state that wil report a specific current index.
  119. // This index should become the parent state index for the test state.
  120. MockTransitionState mock_state;
  121. constexpr float kParentScore = 24.12;
  122. EXPECT_CALL(mock_state, GetScore()).WillRepeatedly(Return(kParentScore));
  123. auto test_state = CreateState();
  124. test_state->Init(mock_state);
  125. EXPECT_EQ(test_state->GetScore(), kParentScore);
  126. }
  127. // This test ensures that calling Clone maintains the state data (parent beam
  128. // index, beam index, score, etc.) of the state that was cloned.
  129. TEST_F(SyntaxNetTransitionStateTest, CloningMaintainsState) {
  130. // Create and initialize the state->
  131. MockTransitionState mock_state;
  132. constexpr int kParentBeamIndex = 1138;
  133. EXPECT_CALL(mock_state, GetBeamIndex())
  134. .WillRepeatedly(Return(kParentBeamIndex));
  135. auto test_state = CreateState();
  136. test_state->Init(mock_state);
  137. // Validate the internal state of the test state.
  138. constexpr float kOldScore = 20.0;
  139. test_state->SetScore(kOldScore);
  140. EXPECT_EQ(test_state->GetScore(), kOldScore);
  141. constexpr int kOldBeamIndex = 12;
  142. test_state->SetBeamIndex(kOldBeamIndex);
  143. EXPECT_EQ(test_state->GetBeamIndex(), kOldBeamIndex);
  144. auto clone = test_state->Clone();
  145. // The clone should have identical state to the old state.
  146. EXPECT_EQ(clone->ParentBeamIndex(), kParentBeamIndex);
  147. EXPECT_EQ(clone->GetScore(), kOldScore);
  148. EXPECT_EQ(clone->GetBeamIndex(), kOldBeamIndex);
  149. }
  150. // Validates the consistency of the step_for_token setter and getter.
  151. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetStepForToken) {
  152. // Create and initialize a test state.
  153. MockTransitionState mock_state;
  154. auto test_state = CreateState();
  155. test_state->Init(mock_state);
  156. constexpr int kStepForTokenZero = 12;
  157. constexpr int kStepForTokenTwo = 34;
  158. test_state->set_step_for_token(0, kStepForTokenZero);
  159. test_state->set_step_for_token(2, kStepForTokenTwo);
  160. // Expect that the set tokens return values and the unset steps return the
  161. // default.
  162. constexpr int kDefaultValue = -1;
  163. EXPECT_EQ(kStepForTokenZero, test_state->step_for_token(0));
  164. EXPECT_EQ(kDefaultValue, test_state->step_for_token(1));
  165. EXPECT_EQ(kStepForTokenTwo, test_state->step_for_token(2));
  166. // Expect that out of bound accesses will return the default. (There are only
  167. // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
  168. EXPECT_EQ(kDefaultValue, test_state->step_for_token(-1));
  169. EXPECT_EQ(kDefaultValue, test_state->step_for_token(3));
  170. }
  171. // Validates the consistency of the parent_step_for_token setter and getter.
  172. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentStepForToken) {
  173. // Create and initialize a test state.
  174. MockTransitionState mock_state;
  175. auto test_state = CreateState();
  176. test_state->Init(mock_state);
  177. constexpr int kStepForTokenZero = 12;
  178. constexpr int kStepForTokenTwo = 34;
  179. test_state->set_parent_step_for_token(0, kStepForTokenZero);
  180. test_state->set_parent_step_for_token(2, kStepForTokenTwo);
  181. // Expect that the set tokens return values and the unset steps return the
  182. // default.
  183. constexpr int kDefaultValue = -1;
  184. EXPECT_EQ(kStepForTokenZero, test_state->parent_step_for_token(0));
  185. EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(1));
  186. EXPECT_EQ(kStepForTokenTwo, test_state->parent_step_for_token(2));
  187. // Expect that out of bound accesses will return the default. (There are only
  188. // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
  189. EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(-1));
  190. EXPECT_EQ(kDefaultValue, test_state->parent_step_for_token(3));
  191. }
  192. // Validates the consistency of the parent_for_token setter and getter.
  193. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetParentForToken) {
  194. // Create and initialize a test state.
  195. MockTransitionState mock_state;
  196. auto test_state = CreateState();
  197. test_state->Init(mock_state);
  198. constexpr int kParentForTokenZero = 12;
  199. constexpr int kParentForTokenTwo = 34;
  200. test_state->set_parent_for_token(0, kParentForTokenZero);
  201. test_state->set_parent_for_token(2, kParentForTokenTwo);
  202. // Expect that the set tokens return values and the unset steps return the
  203. // default.
  204. constexpr int kDefaultValue = -1;
  205. EXPECT_EQ(kParentForTokenZero, test_state->parent_for_token(0));
  206. EXPECT_EQ(kDefaultValue, test_state->parent_for_token(1));
  207. EXPECT_EQ(kParentForTokenTwo, test_state->parent_for_token(2));
  208. // Expect that out of bound accesses will return the default. (There are only
  209. // 3 tokens in the backing sentence, so token 3 and greater are out of bound.)
  210. EXPECT_EQ(kDefaultValue, test_state->parent_for_token(-1));
  211. EXPECT_EQ(kDefaultValue, test_state->parent_for_token(3));
  212. }
  213. // Validates the consistency of trace proto setter/getter.
  214. TEST_F(SyntaxNetTransitionStateTest, CanSetAndGetTrace) {
  215. // Create and initialize a test state.
  216. MockTransitionState mock_state;
  217. auto test_state = CreateState();
  218. test_state->Init(mock_state);
  219. const string kTestComponentName = "test";
  220. std::unique_ptr<ComponentTrace> trace;
  221. trace.reset(new ComponentTrace());
  222. trace->set_name(kTestComponentName);
  223. test_state->set_trace(std::move(trace));
  224. EXPECT_EQ(trace.get(), nullptr);
  225. EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
  226. // Should be preserved when cloing.
  227. auto cloned_state = test_state->Clone();
  228. EXPECT_EQ(cloned_state->mutable_trace()->name(), kTestComponentName);
  229. EXPECT_EQ(test_state->mutable_trace()->name(), kTestComponentName);
  230. }
  231. } // namespace dragnn
  232. } // namespace syntaxnet