input_batch_cache.h 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
  17. #include <memory>
  18. #include <string>
  19. #include <typeindex>
  20. #include "dragnn/core/interfaces/input_batch.h"
  21. #include "tensorflow/core/platform/logging.h"
  22. namespace syntaxnet {
  23. namespace dragnn {
  24. // A InputBatchCache holds data converted to a DRAGNN internal representation.
  25. // It performs the conversion lazily via Data objects and caches the result.
  26. class InputBatchCache {
  27. public:
  28. // Creates an empty cache.
  29. InputBatchCache() : stored_type_(std::type_index(typeid(void))) {}
  30. // Creates a InputBatchCache from a single example. This copies the string.
  31. explicit InputBatchCache(const string &data)
  32. : stored_type_(std::type_index(typeid(void))), source_data_({data}) {}
  33. // Creates a InputBatchCache from a vector of examples. The vector is copied.
  34. explicit InputBatchCache(const std::vector<string> &data)
  35. : stored_type_(std::type_index(typeid(void))), source_data_(data) {}
  36. // Adds a single string to the cache. Only useable before GetAs() has been
  37. // called.
  38. void AddData(const string &data) {
  39. CHECK(stored_type_ == std::type_index(typeid(void)))
  40. << "You may not add data to an InputBatchCache after the cache has "
  41. "been converted via GetAs().";
  42. source_data_.emplace_back(data);
  43. }
  44. // Converts the stored strings into protos and return them in a specific
  45. // InputBatch subclass. T should always be of type InputBatch. After this
  46. // method is called once, all further calls must be of the same data type.
  47. template <class T>
  48. T *GetAs() {
  49. if (!converted_data_) {
  50. stored_type_ = std::type_index(typeid(T));
  51. converted_data_.reset(new T());
  52. converted_data_->SetData(source_data_);
  53. }
  54. CHECK(std::type_index(typeid(T)) == stored_type_)
  55. << "Attempted to convert to two object types! Existing object type was "
  56. << stored_type_.name() << ", new object type was "
  57. << std::type_index(typeid(T)).name();
  58. return dynamic_cast<T *>(converted_data_.get());
  59. }
  60. // Returns the serialized representation of the data held in the input batch
  61. // object within this cache.
  62. const std::vector<string> SerializedData() const {
  63. CHECK(converted_data_) << "Cannot return batch without data.";
  64. return converted_data_->GetSerializedData();
  65. }
  66. private:
  67. // The typeid of the stored data.
  68. std::type_index stored_type_;
  69. // The raw data.
  70. std::vector<string> source_data_;
  71. // The converted data, contained in an InputBatch object.
  72. std::unique_ptr<InputBatch> converted_data_;
  73. };
  74. } // namespace dragnn
  75. } // namespace syntaxnet
  76. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_