compute_session_pool.h 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
  3. #include <memory>
  4. #include "dragnn/core/compute_session.h"
  5. #include "dragnn/protos/spec.pb.h"
  6. #include "tensorflow/core/platform/mutex.h"
  7. namespace syntaxnet {
  8. namespace dragnn {
  9. // This pool creates and manages the reuse of ComputeSession objects.
  10. class ComputeSessionPool {
  11. public:
  12. // Create a ComputeSessionPool that creates ComputeSessions for the given
  13. // MasterSpec and hyperparameters.
  14. ComputeSessionPool(const MasterSpec &master_spec,
  15. const GridPoint &hyperparams);
  16. virtual ~ComputeSessionPool();
  17. // Get a ComputeSession. This function will attempt to use an already-created
  18. // ComputeSession, but if none are available a new one will be created.
  19. std::unique_ptr<ComputeSession> GetSession();
  20. // Returns a ComputeSession to the backing pool.
  21. void ReturnSession(std::unique_ptr<ComputeSession> session);
  22. // Returns the count of outstanding unique sessions.
  23. int num_outstanding_sessions() {
  24. tensorflow::mutex_lock lock(lock_);
  25. return num_unique_sessions_ - sessions_.size();
  26. }
  27. private:
  28. friend class ComputeSessionImplTestPoolAccessor;
  29. friend class ComputeSessionPoolTestPoolAccessor;
  30. // This is a creational injection setter. It should be used for tests
  31. // where we want our ComputeSessionPool to prepare and return
  32. // MockComputeSessions instead of actual ComputeSessionImpls.
  33. void SetComputeSessionBuilder(
  34. std::function<std::unique_ptr<ComputeSession>()> session_builder);
  35. // This injector will cause ComputeSessions built in this pool to use the
  36. // passed function to create Components. This is useful when you want a
  37. // ComputeSession to create MockComponents instead of real ones.
  38. void SetComponentBuilder(
  39. std::function<std::unique_ptr<Component>(const string &component_name,
  40. const string &backend_type)>
  41. component_builder);
  42. // The MasterSpec that will be used to initialize ComputeSessions from this
  43. // pool.
  44. const MasterSpec master_spec_;
  45. // The hyperparameters that will be used to initialize ComputeSessions from
  46. // this pool.
  47. const GridPoint hyperparams_;
  48. // The function that is used to create ComputeSessions.
  49. std::function<std::unique_ptr<ComputeSession>()> session_builder_;
  50. // The function passed to ComputeSessions that will be used by that session
  51. // to create components.
  52. std::function<std::unique_ptr<Component>(const string &component_name,
  53. const string &backend_type)>
  54. component_builder_;
  55. // ComputeSessions that are not currently being used. These sessions are not
  56. // reset until they are requested by another thread.
  57. std::vector<std::unique_ptr<ComputeSession>> sessions_;
  58. // Count of the number of unique ComputeSession objects that have been
  59. // created. Used to assign IDs to new Sessions.
  60. int num_unique_sessions_;
  61. // Mutex that protects accesses to all members of this object.
  62. tensorflow::mutex lock_;
  63. };
  64. } // namespace dragnn
  65. } // namespace syntaxnet
  66. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_