sentence_input_batch_test.cc 1.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #include "dragnn/io/sentence_input_batch.h"
  2. #include "dragnn/core/test/generic.h"
  3. #include "syntaxnet/sentence.pb.h"
  4. #include "tensorflow/core/platform/test.h"
  5. namespace syntaxnet {
  6. namespace dragnn {
  7. using syntaxnet::test::EqualsProto;
  8. TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) {
  9. // Create some distinct Sentence protos.
  10. Sentence sentence_one;
  11. sentence_one.set_docid("foo");
  12. Sentence sentence_two;
  13. sentence_two.set_docid("bar");
  14. std::vector<Sentence> protos({sentence_one, sentence_two});
  15. // Create stringified versions.
  16. std::vector<string> strings;
  17. for (const auto &sentence : protos) {
  18. string str;
  19. sentence.SerializeToString(&str);
  20. strings.push_back(str);
  21. }
  22. // Create a SentenceInputBatch. The data inside it should match the protos.
  23. SentenceInputBatch set;
  24. set.SetData(strings);
  25. auto converted_data = set.data();
  26. for (int i = 0; i < protos.size(); ++i) {
  27. EXPECT_THAT(*(converted_data->at(i).sentence()), EqualsProto(protos.at(i)));
  28. EXPECT_NE(converted_data->at(i).workspace(), nullptr);
  29. }
  30. // Get the data back out. The strings should be identical.
  31. auto output = set.GetSerializedData();
  32. EXPECT_EQ(output.size(), strings.size());
  33. EXPECT_NE(output.size(), 0);
  34. for (int i = 0; i < output.size(); ++i) {
  35. EXPECT_EQ(strings.at(i), output.at(i));
  36. }
  37. }
  38. TEST(SentenceInputBatchTest, BadlyFormedProtosDie) {
  39. // Create a input batch with malformed data. This should cause a CHECK fail.
  40. SentenceInputBatch set;
  41. EXPECT_DEATH(set.SetData({"BADLY FORMATTED DATA. SHOULD CAUSE A CHECK"}),
  42. "Unable to parse string input");
  43. }
  44. } // namespace dragnn
  45. } // namespace syntaxnet