shared_store.h 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Utility for creating read-only objects once and sharing them across threads.
  13. #ifndef SYNTAXNET_SHARED_STORE_H_
  14. #define SYNTAXNET_SHARED_STORE_H_
  15. #include <functional>
  16. #include <string>
  17. #include <typeindex>
  18. #include <unordered_map>
  19. #include <utility>
  20. #include "syntaxnet/utils.h"
  21. #include "tensorflow/core/lib/core/stringpiece.h"
  22. #include "tensorflow/core/lib/strings/strcat.h"
  23. namespace syntaxnet {
  24. class SharedStore {
  25. public:
  26. // Returns an existing object with type T and name 'name' if it exists, else
  27. // creates one with "new T(args...)". Note: Objects will be indexed under
  28. // their typeid + name, so names only have to be unique within a given type.
  29. template <typename T, typename ...Args>
  30. static const T *Get(const string &name,
  31. Args &&...args); // NOLINT(build/c++11)
  32. // Like Get(), but creates the object with "closure->Run()". If the closure
  33. // returns null, we store a null in the SharedStore, but note that Release()
  34. // cannot be used to remove it. This is because Release() finds the object
  35. // by associative lookup, and there may be more than one null value, so we
  36. // don't know which one to release. If the closure returns a duplicate value
  37. // (one that is pointer-equal to an object already in the SharedStore),
  38. // we disregard it and store null instead -- otherwise associative lookup
  39. // would again fail (and the reference counts would be wrong).
  40. template <typename T>
  41. static const T *ClosureGet(const string &name, std::function<T *()> *closure);
  42. // Like ClosureGet(), but check-fails if ClosureGet() would return null.
  43. template <typename T>
  44. static const T *ClosureGetOrDie(const string &name,
  45. std::function<T *()> *closure);
  46. // Release an object that was acquired by Get(). When its reference count
  47. // hits 0, the object will be deleted. Returns true if the object was found.
  48. // Does nothing and returns true if the object is null.
  49. static bool Release(const void *object);
  50. // Delete all objects in the shared store.
  51. static void Clear();
  52. private:
  53. // A shared object.
  54. struct SharedObject {
  55. void *object;
  56. std::function<void()> delete_callback;
  57. int refcount;
  58. SharedObject(void *o, std::function<void()> d)
  59. : object(o), delete_callback(std::move(d)), refcount(1) {}
  60. };
  61. // A map from keys to shared objects.
  62. typedef std::unordered_map<string, SharedObject> SharedObjectMap;
  63. // Return the shared object map.
  64. static SharedObjectMap *shared_object_map();
  65. // Return the string to use for indexing an object in the shared store.
  66. template <typename T>
  67. static string GetSharedKey(const string &name);
  68. // Delete an object of type T.
  69. template <typename T>
  70. static void DeleteObject(T *object);
  71. // Add an object to the shared object map. Return the object.
  72. template <typename T>
  73. static T *StoreObject(const string &key, T *object);
  74. // Increment the reference count of an object in the map. Return the object.
  75. template <typename T>
  76. static T *IncrementRefCountOfObject(SharedObjectMap::iterator it);
  77. // Map from keys to shared objects.
  78. static SharedObjectMap *shared_object_map_;
  79. static mutex shared_object_map_mutex_;
  80. TF_DISALLOW_COPY_AND_ASSIGN(SharedStore);
  81. };
  82. template <typename T>
  83. string SharedStore::GetSharedKey(const string &name) {
  84. const std::type_index id = std::type_index(typeid(T));
  85. return tensorflow::strings::StrCat(id.name(), "_", name);
  86. }
  87. template <typename T>
  88. void SharedStore::DeleteObject(T *object) {
  89. delete object;
  90. }
  91. template <typename T>
  92. T *SharedStore::StoreObject(const string &key, T *object) {
  93. std::function<void()> delete_cb =
  94. std::bind(SharedStore::DeleteObject<T>, object);
  95. SharedObject so(object, delete_cb);
  96. shared_object_map()->insert(std::make_pair(key, so));
  97. return object;
  98. }
  99. template <typename T>
  100. T *SharedStore::IncrementRefCountOfObject(SharedObjectMap::iterator it) {
  101. it->second.refcount++;
  102. return static_cast<T *>(it->second.object);
  103. }
  104. template <typename T, typename ...Args>
  105. const T *SharedStore::Get(const string &name,
  106. Args &&...args) { // NOLINT(build/c++11)
  107. mutex_lock l(shared_object_map_mutex_);
  108. const string key = GetSharedKey<T>(name);
  109. SharedObjectMap::iterator it = shared_object_map()->find(key);
  110. return (it == shared_object_map()->end()) ?
  111. StoreObject<T>(key, new T(std::forward<Args>(args)...)) :
  112. IncrementRefCountOfObject<T>(it);
  113. }
  114. template <typename T>
  115. const T *SharedStore::ClosureGet(const string &name,
  116. std::function<T *()> *closure) {
  117. mutex_lock l(shared_object_map_mutex_);
  118. const string key = GetSharedKey<T>(name);
  119. SharedObjectMap::iterator it = shared_object_map()->find(key);
  120. if (it == shared_object_map()->end()) {
  121. // Creates a new object by calling the closure.
  122. T *object = (*closure)();
  123. if (object == nullptr) {
  124. LOG(ERROR) << "Closure returned a null pointer";
  125. } else {
  126. for (SharedObjectMap::iterator it = shared_object_map()->begin();
  127. it != shared_object_map()->end(); ++it) {
  128. if (it->second.object == object) {
  129. LOG(ERROR)
  130. << "Closure returned duplicate pointer: "
  131. << "keys " << it->first << " and " << key;
  132. // Not a memory leak to discard pointer, since we have another copy.
  133. object = nullptr;
  134. break;
  135. }
  136. }
  137. }
  138. return StoreObject<T>(key, object);
  139. } else {
  140. return IncrementRefCountOfObject<T>(it);
  141. }
  142. }
  143. template <typename T>
  144. const T *SharedStore::ClosureGetOrDie(const string &name,
  145. std::function<T *()> *closure) {
  146. const T *object = ClosureGet<T>(name, closure);
  147. CHECK(object != nullptr);
  148. return object;
  149. }
  150. // A collection of utility functions for working with the shared store.
  151. class SharedStoreUtils {
  152. public:
  153. // Returns a shared object registered using a default name that is created
  154. // from the constructor args.
  155. //
  156. // NB: This function does not guarantee a one-to-one relationship between
  157. // sets of constructor args and names. See warnings on CreateDefaultName().
  158. // It is the caller's responsibility to ensure that the args provided will
  159. // result in unique names.
  160. template <class T, class... Args>
  161. static const T *GetWithDefaultName(Args &&... args) { // NOLINT(build/c++11)
  162. return SharedStore::Get<T>(CreateDefaultName(std::forward<Args>(args)...),
  163. std::forward<Args>(args)...);
  164. }
  165. // Returns a string name representing the args. Implemented via a pair of
  166. // overloaded functions to achieve compile-time recursion.
  167. //
  168. // WARNING: It is possible for instances of different types to have the same
  169. // string representation. For example,
  170. //
  171. // CreateDefaultName(1) == CreateDefaultName(1ULL)
  172. //
  173. template <class First, class... Rest>
  174. static string CreateDefaultName(First &&first,
  175. Rest &&... rest) { // NOLINT(build/c++11)
  176. return tensorflow::strings::StrCat(
  177. ToString<First>(std::forward<First>(first)), ",",
  178. CreateDefaultName(std::forward<Rest>(rest)...));
  179. }
  180. static string CreateDefaultName();
  181. private:
  182. // Returns a string representing the input. The generic implementation uses
  183. // StrCat(), and overloads are provided for selected types.
  184. template <class T>
  185. static string ToString(T input) {
  186. return tensorflow::strings::StrCat(input);
  187. }
  188. static string ToString(const string &input);
  189. static string ToString(const char *input);
  190. static string ToString(tensorflow::StringPiece input);
  191. static string ToString(bool input);
  192. static string ToString(float input);
  193. static string ToString(double input);
  194. TF_DISALLOW_COPY_AND_ASSIGN(SharedStoreUtils);
  195. };
  196. } // namespace syntaxnet
  197. #endif // SYNTAXNET_SHARED_STORE_H_