shared_store_test.cc 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. #include "syntaxnet/shared_store.h"
  13. #include <string>
  14. #include "syntaxnet/utils.h"
  15. #include <gmock/gmock.h>
  16. #include "tensorflow/core/lib/core/threadpool.h"
  17. using ::testing::_;
  18. namespace syntaxnet {
  19. struct NoArgs {
  20. NoArgs() {
  21. LOG(INFO) << "Calling NoArgs()";
  22. }
  23. };
  24. struct OneArg {
  25. string name;
  26. explicit OneArg(const string &n) : name(n) {
  27. LOG(INFO) << "Calling OneArg(" << name << ")";
  28. }
  29. };
  30. struct TwoArgs {
  31. string name;
  32. int age;
  33. TwoArgs(const string &n, int a) : name(n), age(a) {
  34. LOG(INFO) << "Calling TwoArgs(" << name << ", " << age << ")";
  35. }
  36. };
  37. struct Slow {
  38. string lengthy;
  39. Slow() {
  40. LOG(INFO) << "Calling Slow()";
  41. lengthy.assign(50 << 20, 'L'); // 50MB of the letter 'L'
  42. }
  43. };
  44. struct CountCalls {
  45. CountCalls() {
  46. LOG(INFO) << "Calling CountCalls()";
  47. ++constructor_calls;
  48. }
  49. ~CountCalls() {
  50. LOG(INFO) << "Calling ~CountCalls()";
  51. ++destructor_calls;
  52. }
  53. static void Reset() {
  54. constructor_calls = 0;
  55. destructor_calls = 0;
  56. }
  57. static int constructor_calls;
  58. static int destructor_calls;
  59. };
  60. int CountCalls::constructor_calls = 0;
  61. int CountCalls::destructor_calls = 0;
  62. class PointerSet {
  63. public:
  64. PointerSet() { }
  65. void Add(const void *p) {
  66. mutex_lock l(mu_);
  67. pointers_.insert(p);
  68. }
  69. int size() {
  70. mutex_lock l(mu_);
  71. return pointers_.size();
  72. }
  73. private:
  74. mutex mu_;
  75. unordered_set<const void *> pointers_;
  76. };
  77. class SharedStoreTest : public testing::Test {
  78. protected:
  79. ~SharedStoreTest() override {
  80. // Clear the shared store after each test, otherwise objects created
  81. // in one test may interfere with other tests.
  82. SharedStore::Clear();
  83. }
  84. };
  85. // Verify that we can call constructors with varying numbers and types of args.
  86. TEST_F(SharedStoreTest, ConstructorArgs) {
  87. SharedStore::Get<NoArgs>("no args");
  88. SharedStore::Get<OneArg>("one arg", "Fred");
  89. SharedStore::Get<TwoArgs>("two args", "Pebbles", 2);
  90. }
  91. // Verify that an object with a given key is created only once.
  92. TEST_F(SharedStoreTest, Shared) {
  93. const NoArgs *ob1 = SharedStore::Get<NoArgs>("first");
  94. const NoArgs *ob2 = SharedStore::Get<NoArgs>("second");
  95. const NoArgs *ob3 = SharedStore::Get<NoArgs>("first");
  96. EXPECT_EQ(ob1, ob3);
  97. EXPECT_NE(ob1, ob2);
  98. EXPECT_NE(ob2, ob3);
  99. }
  100. // Verify that objects with the same name but different types do not collide.
  101. TEST_F(SharedStoreTest, DifferentTypes) {
  102. const NoArgs *ob1 = SharedStore::Get<NoArgs>("same");
  103. const OneArg *ob2 = SharedStore::Get<OneArg>("same", "foo");
  104. const TwoArgs *ob3 = SharedStore::Get<TwoArgs>("same", "bar", 5);
  105. EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob2));
  106. EXPECT_NE(static_cast<const void *>(ob1), static_cast<const void *>(ob3));
  107. EXPECT_NE(static_cast<const void *>(ob2), static_cast<const void *>(ob3));
  108. }
  109. // Factory method to make a OneArg.
  110. OneArg *MakeOneArg(const string &n) {
  111. return new OneArg(n);
  112. }
  113. TEST_F(SharedStoreTest, ClosureGet) {
  114. std::function<OneArg *()> closure1 = std::bind(MakeOneArg, "Al");
  115. std::function<OneArg *()> closure2 = std::bind(MakeOneArg, "Al");
  116. const OneArg *ob1 = SharedStore::ClosureGet("first", &closure1);
  117. const OneArg *ob2 = SharedStore::ClosureGet("first", &closure2);
  118. EXPECT_EQ("Al", ob1->name);
  119. EXPECT_EQ(ob1, ob2);
  120. }
  121. TEST_F(SharedStoreTest, PermanentCallback) {
  122. std::function<OneArg *()> closure = std::bind(MakeOneArg, "Al");
  123. const OneArg *ob1 = SharedStore::ClosureGet("first", &closure);
  124. const OneArg *ob2 = SharedStore::ClosureGet("first", &closure);
  125. EXPECT_EQ("Al", ob1->name);
  126. EXPECT_EQ(ob1, ob2);
  127. }
  128. // Factory method to "make" a NoArgs by simply returning an input pointer.
  129. NoArgs *BogusMakeNoArgs(NoArgs *ob) {
  130. return ob;
  131. }
  132. // Create a CountCalls object, pretend it failed, and return null.
  133. CountCalls *MakeFailedCountCalls() {
  134. CountCalls *ob = new CountCalls;
  135. delete ob;
  136. return nullptr;
  137. }
  138. // Verify that ClosureGet() only calls the closure for a given key once,
  139. // even if the closure fails.
  140. TEST_F(SharedStoreTest, FailedClosureGet) {
  141. CountCalls::Reset();
  142. std::function<CountCalls *()> closure1(MakeFailedCountCalls);
  143. std::function<CountCalls *()> closure2(MakeFailedCountCalls);
  144. const CountCalls *ob1 = SharedStore::ClosureGet("first", &closure1);
  145. const CountCalls *ob2 = SharedStore::ClosureGet("first", &closure2);
  146. EXPECT_EQ(nullptr, ob1);
  147. EXPECT_EQ(nullptr, ob2);
  148. EXPECT_EQ(1, CountCalls::constructor_calls);
  149. }
  150. typedef SharedStoreTest SharedStoreDeathTest;
  151. TEST_F(SharedStoreDeathTest, ClosureGetOrDie) {
  152. NoArgs *empty = nullptr;
  153. std::function<NoArgs *()> closure = std::bind(BogusMakeNoArgs, empty);
  154. EXPECT_DEATH(SharedStore::ClosureGetOrDie("first", &closure), "nullptr");
  155. }
  156. TEST_F(SharedStoreTest, Release) {
  157. const OneArg *ob1 = SharedStore::Get<OneArg>("first", "Fred");
  158. const OneArg *ob2 = SharedStore::Get<OneArg>("first", "Fred");
  159. EXPECT_EQ(ob1, ob2);
  160. EXPECT_TRUE(SharedStore::Release(ob1)); // now refcount = 1
  161. EXPECT_TRUE(SharedStore::Release(ob1)); // now object is deleted
  162. EXPECT_FALSE(SharedStore::Release(ob1)); // now object is not found
  163. EXPECT_TRUE(SharedStore::Release(nullptr)); // release(nullptr) returns true
  164. }
  165. TEST_F(SharedStoreTest, Clear) {
  166. CountCalls::Reset();
  167. SharedStore::Get<CountCalls>("first");
  168. SharedStore::Get<CountCalls>("second");
  169. SharedStore::Get<CountCalls>("first");
  170. // Test that the constructor and destructor are each called exactly once
  171. // for each key in the shared store.
  172. SharedStore::Clear();
  173. EXPECT_EQ(2, CountCalls::constructor_calls);
  174. EXPECT_EQ(2, CountCalls::destructor_calls);
  175. }
  176. void GetSharedObject(PointerSet *ps) {
  177. // Gets a shared object whose constructor takes a long time.
  178. const Slow *ob = SharedStore::Get<Slow>("first");
  179. // Collects the pointer we got. Later, we'll check whether SharedStore
  180. // mistakenly called the constructor more than once.
  181. ps->Add(static_cast<const void *>(ob));
  182. }
  183. // If multiple parallel threads all access an object with the same key,
  184. // only one object is created.
  185. TEST_F(SharedStoreTest, ThreadSafety) {
  186. const int kNumThreads = 20;
  187. tensorflow::thread::ThreadPool *pool = new tensorflow::thread::ThreadPool(
  188. tensorflow::Env::Default(), "ThreadSafetyPool", kNumThreads);
  189. PointerSet ps;
  190. for (int i = 0; i < kNumThreads; ++i) {
  191. std::function<void()> closure = std::bind(GetSharedObject, &ps);
  192. pool->Schedule(closure);
  193. }
  194. // Waits for closures to finish, then delete the pool.
  195. delete pool;
  196. // Expects only one object to have been created across all threads.
  197. EXPECT_EQ(1, ps.size());
  198. }
  199. } // namespace syntaxnet