compute_session_pool.cc 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/core/compute_session_pool.h"
  16. #include <utility>
  17. #include "dragnn/core/component_registry.h"
  18. #include "dragnn/core/compute_session_impl.h"
  19. #include "tensorflow/core/platform/logging.h"
  20. namespace syntaxnet {
  21. namespace dragnn {
  22. using tensorflow::mutex_lock;
  23. ComputeSessionPool::ComputeSessionPool(const MasterSpec &master_spec,
  24. const GridPoint &hyperparams)
  25. : master_spec_(master_spec),
  26. hyperparams_(hyperparams),
  27. num_unique_sessions_(0) {
  28. // Create a default component builder function. This function looks up
  29. // components in the component registry and returns them.
  30. component_builder_ = [](
  31. const string &component_name,
  32. const string &backend_type) -> std::unique_ptr<Component> {
  33. VLOG(2) << "Creating component " << component_name << " with backend "
  34. << backend_type;
  35. std::unique_ptr<Component> component(Component::Create(backend_type));
  36. return component;
  37. };
  38. // Create a default session builder function. This function returns a
  39. // ComputeSessionImpl that uses the currently set component_builder_
  40. // function to create its components.
  41. session_builder_ = [this]() {
  42. return std::unique_ptr<ComputeSession>(
  43. new ComputeSessionImpl(num_unique_sessions_, this->component_builder_));
  44. };
  45. }
  46. ComputeSessionPool::~ComputeSessionPool() {
  47. LOG(INFO) << "Destroying pool: total number of sessions created = "
  48. << num_unique_sessions_;
  49. if (sessions_.size() < num_unique_sessions_) {
  50. LOG(WARNING) << "Destroying pool: number of unreturned sessions = "
  51. << (num_unique_sessions_ - sessions_.size());
  52. }
  53. }
  54. void ComputeSessionPool::SetComputeSessionBuilder(
  55. std::function<std::unique_ptr<ComputeSession>()> session_builder) {
  56. mutex_lock lock(lock_);
  57. session_builder_ = std::move(session_builder);
  58. }
  59. void ComputeSessionPool::SetComponentBuilder(
  60. std::function<std::unique_ptr<Component>(const string &component_name,
  61. const string &backend_type)>
  62. component_builder) {
  63. mutex_lock lock(lock_);
  64. component_builder_ = std::move(component_builder);
  65. }
  66. std::unique_ptr<ComputeSession> ComputeSessionPool::GetSession() {
  67. mutex_lock lock(lock_);
  68. std::unique_ptr<ComputeSession> session_ptr;
  69. if (sessions_.empty()) {
  70. // There are no available sessions, so create and initialize one.
  71. VLOG(2) << "Creating new session.";
  72. session_ptr = session_builder_();
  73. num_unique_sessions_++;
  74. session_ptr->Init(master_spec_, hyperparams_);
  75. } else {
  76. // Get the last free session, and remove it from the free sessions vector.
  77. VLOG(2) << "Reusing session from pool of size " << sessions_.size();
  78. session_ptr = std::move(sessions_.back());
  79. sessions_.pop_back();
  80. session_ptr->ResetSession();
  81. }
  82. return session_ptr;
  83. }
  84. void ComputeSessionPool::ReturnSession(
  85. std::unique_ptr<ComputeSession> session) {
  86. mutex_lock lock(lock_);
  87. sessions_.push_back(std::move(session));
  88. }
  89. } // namespace dragnn
  90. } // namespace syntaxnet