compute_session_pool.cc 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. #include "dragnn/core/compute_session_pool.h"
  2. #include <utility>
  3. #include "dragnn/core/component_registry.h"
  4. #include "dragnn/core/compute_session_impl.h"
  5. #include "tensorflow/core/platform/logging.h"
  6. namespace syntaxnet {
  7. namespace dragnn {
  8. using tensorflow::mutex_lock;
  9. ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
  10. const GridPoint &hyperparams)
  11. : master_spec_(master_spec),
  12. hyperparams_(hyperparams),
  13. num_unique_sessions_(0) {
  14. // Create a default component builder function. This function looks up
  15. // components in the component registry and returns them.
  16. component_builder_ = [](
  17. const string &component_name,
  18. const string &backend_type) -> std::unique_ptr<Component> {
  19. VLOG(2) << "Creating component " << component_name << " with backend "
  20. << backend_type;
  21. std::unique_ptr<Component> component(Component::Create(backend_type));
  22. return component;
  23. };
  24. // Create a default session builder function. This function returns a
  25. // ComputeSessionImpl that uses the currently set component_builder_
  26. // function to create its components.
  27. session_builder_ = [this]() {
  28. return std::unique_ptr<ComputeSession>(
  29. new ComputeSessionImpl(num_unique_sessions_, this->component_builder_));
  30. };
  31. }
  32. ComputeSessionPool::~ComputeSessionPool() {
  33. LOG(INFO) << "Destroying pool: total number of sessions created = "
  34. << num_unique_sessions_;
  35. if (sessions_.size() < num_unique_sessions_) {
  36. LOG(WARNING) << "Destroying pool: number of unreturned sessions = "
  37. << (num_unique_sessions_ - sessions_.size());
  38. }
  39. }
  40. void ComputeSessionPool::SetComputeSessionBuilder(
  41. std::function<std::unique_ptr<ComputeSession>()> session_builder) {
  42. mutex_lock lock(lock_);
  43. session_builder_ = std::move(session_builder);
  44. }
  45. void ComputeSessionPool::SetComponentBuilder(
  46. std::function<std::unique_ptr<Component>(const string &component_name,
  47. const string &backend_type)>
  48. component_builder) {
  49. mutex_lock lock(lock_);
  50. component_builder_ = std::move(component_builder);
  51. }
  52. std::unique_ptr<ComputeSession> ComputeSessionPool::GetSession() {
  53. mutex_lock lock(lock_);
  54. std::unique_ptr<ComputeSession> session_ptr;
  55. if (sessions_.empty()) {
  56. // There are no available sessions, so create and initialize one.
  57. VLOG(2) << "Creating new session.";
  58. session_ptr = session_builder_();
  59. num_unique_sessions_++;
  60. session_ptr->Init(master_spec_, hyperparams_);
  61. } else {
  62. // Get the last free session, and remove it from the free sessions vector.
  63. VLOG(2) << "Reusing session from pool of size " << sessions_.size();
  64. session_ptr = std::move(sessions_.back());
  65. sessions_.pop_back();
  66. session_ptr->ResetSession();
  67. }
  68. return session_ptr;
  69. }
  70. void ComputeSessionPool::ReturnSession(
  71. std::unique_ptr<ComputeSession> session) {
  72. mutex_lock lock(lock_);
  73. sessions_.push_back(std::move(session));
  74. }
  75. } // namespace dragnn
  76. } // namespace syntaxnet