input_batch_cache.h 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
  3. #include <memory>
  4. #include <string>
  5. #include <typeindex>
  6. #include "dragnn/core/interfaces/input_batch.h"
  7. #include "tensorflow/core/platform/logging.h"
  8. namespace syntaxnet {
  9. namespace dragnn {
  10. // A InputBatchCache holds data converted to a DRAGNN internal representation.
  11. // It performs the conversion lazily via Data objects and caches the result.
  12. class InputBatchCache {
  13. public:
  14. // Create an empty cache..
  15. InputBatchCache() : stored_type_(std::type_index(typeid(void))) {}
  16. // Create a InputBatchCache from a single example. This copies the string.
  17. explicit InputBatchCache(const string &data)
  18. : stored_type_(std::type_index(typeid(void))), source_data_({data}) {}
  19. // Create a InputBatchCache from a vector of examples. The vector is copied.
  20. explicit InputBatchCache(const std::vector<string> &data)
  21. : stored_type_(std::type_index(typeid(void))), source_data_(data) {}
  22. // Adds a single string to the cache. Only useable before GetAs() has been
  23. // called.
  24. void AddData(const string &data) {
  25. CHECK(stored_type_ == std::type_index(typeid(void)))
  26. << "You may not add data to an InputBatchCache after the cache has "
  27. "been converted via GetAs().";
  28. source_data_.emplace_back(data);
  29. }
  30. // Convert the stored strings into protos and return them in a specific
  31. // InputBatch subclass. T should always be of type InputBatch. After this
  32. // method is called once, all further calls must be of the same data type.
  33. template <class T>
  34. T *GetAs() {
  35. if (!converted_data_) {
  36. stored_type_ = std::type_index(typeid(T));
  37. converted_data_.reset(new T());
  38. converted_data_->SetData(source_data_);
  39. }
  40. CHECK(std::type_index(typeid(T)) == stored_type_)
  41. << "Attempted to convert to two object types! Existing object type was "
  42. << stored_type_.name() << ", new object type was "
  43. << std::type_index(typeid(T)).name();
  44. return dynamic_cast<T *>(converted_data_.get());
  45. }
  46. // Return the serialized representation of the data held in the input batch
  47. // object within this cache.
  48. const std::vector<string> SerializedData() const {
  49. CHECK(converted_data_) << "Cannot return batch without data.";
  50. return converted_data_->GetSerializedData();
  51. }
  52. private:
  53. // The typeid of the stored data.
  54. std::type_index stored_type_;
  55. // The raw data.
  56. std::vector<string> source_data_;
  57. // The converted data, contained in an InputBatch object.
  58. std::unique_ptr<InputBatch> converted_data_;
  59. };
  60. } // namespace dragnn
  61. } // namespace syntaxnet
  62. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_