sentence_input_batch_test.cc 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/io/sentence_input_batch.h"
  16. #include "dragnn/core/test/generic.h"
  17. #include "syntaxnet/sentence.pb.h"
  18. #include "tensorflow/core/platform/test.h"
  19. namespace syntaxnet {
  20. namespace dragnn {
  21. using syntaxnet::test::EqualsProto;
  22. TEST(SentenceInputBatchTest, ConvertsFromStringifiedProtos) {
  23. // Create some distinct Sentence protos.
  24. Sentence sentence_one;
  25. sentence_one.set_docid("foo");
  26. Sentence sentence_two;
  27. sentence_two.set_docid("bar");
  28. std::vector<Sentence> protos({sentence_one, sentence_two});
  29. // Create stringified versions.
  30. std::vector<string> strings;
  31. for (const auto &sentence : protos) {
  32. string str;
  33. sentence.SerializeToString(&str);
  34. strings.push_back(str);
  35. }
  36. // Create a SentenceInputBatch. The data inside it should match the protos.
  37. SentenceInputBatch set;
  38. set.SetData(strings);
  39. auto converted_data = set.data();
  40. for (int i = 0; i < protos.size(); ++i) {
  41. EXPECT_THAT(*(converted_data->at(i).sentence()), EqualsProto(protos.at(i)));
  42. EXPECT_NE(converted_data->at(i).workspace(), nullptr);
  43. }
  44. // Get the data back out. The strings should be identical.
  45. auto output = set.GetSerializedData();
  46. EXPECT_EQ(output.size(), strings.size());
  47. EXPECT_NE(output.size(), 0);
  48. for (int i = 0; i < output.size(); ++i) {
  49. EXPECT_EQ(strings.at(i), output.at(i));
  50. }
  51. }
  52. TEST(SentenceInputBatchTest, BadlyFormedProtosDie) {
  53. // Create a input batch with malformed data. This should cause a CHECK fail.
  54. SentenceInputBatch set;
  55. EXPECT_DEATH(set.SetData({"BADLY FORMATTED DATA. SHOULD CAUSE A CHECK"}),
  56. "Unable to parse string input");
  57. }
  58. } // namespace dragnn
  59. } // namespace syntaxnet