workspace.h 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218
  1. /* Copyright 2016 Google Inc. All Rights Reserved.
  2. Licensed under the Apache License, Version 2.0 (the "License");
  3. you may not use this file except in compliance with the License.
  4. You may obtain a copy of the License at
  5. http://www.apache.org/licenses/LICENSE-2.0
  6. Unless required by applicable law or agreed to in writing, software
  7. distributed under the License is distributed on an "AS IS" BASIS,
  8. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. See the License for the specific language governing permissions and
  10. limitations under the License.
  11. ==============================================================================*/
  12. // Notes on thread-safety: All of the classes here are thread-compatible. More
  13. // specifically, the registry machinery is thread-safe, as long as each thread
  14. // performs feature extraction on a different Sentence object.
  15. #ifndef SYNTAXNET_WORKSPACE_H_
  16. #define SYNTAXNET_WORKSPACE_H_
  17. #include <string>
  18. #include <typeindex>
  19. #include <unordered_map>
  20. #include <utility>
  21. #include <vector>
  22. #include "syntaxnet/utils.h"
  23. namespace syntaxnet {
  24. // A base class for shared workspaces. Derived classes implement a static member
  25. // function TypeName() which returns a human readable string name for the class.
  26. class Workspace {
  27. public:
  28. // Polymorphic destructor.
  29. virtual ~Workspace() {}
  30. protected:
  31. // Create an empty workspace.
  32. Workspace() {}
  33. private:
  34. TF_DISALLOW_COPY_AND_ASSIGN(Workspace);
  35. };
  36. // A registry that keeps track of workspaces.
  37. class WorkspaceRegistry {
  38. public:
  39. // Create an empty registry.
  40. WorkspaceRegistry() {}
  41. // Returns the index of a named workspace, adding it to the registry first
  42. // if necessary.
  43. template <class W>
  44. int Request(const string &name) {
  45. const std::type_index id = std::type_index(typeid(W));
  46. workspace_types_[id] = W::TypeName();
  47. std::vector<string> &names = workspace_names_[id];
  48. for (int i = 0; i < names.size(); ++i) {
  49. if (names[i] == name) return i;
  50. }
  51. names.push_back(name);
  52. return names.size() - 1;
  53. }
  54. const std::unordered_map<std::type_index, std::vector<string> >
  55. &WorkspaceNames() const {
  56. return workspace_names_;
  57. }
  58. // Returns a string describing the registered workspaces.
  59. string DebugString() const;
  60. private:
  61. // Workspace type names, indexed as workspace_types_[typeid].
  62. std::unordered_map<std::type_index, string> workspace_types_;
  63. // Workspace names, indexed as workspace_names_[typeid][workspace].
  64. std::unordered_map<std::type_index, std::vector<string> > workspace_names_;
  65. TF_DISALLOW_COPY_AND_ASSIGN(WorkspaceRegistry);
  66. };
  67. // A typed collected of workspaces. The workspaces are indexed according to an
  68. // external WorkspaceRegistry. If the WorkspaceSet is const, the contents are
  69. // also immutable.
  70. class WorkspaceSet {
  71. public:
  72. ~WorkspaceSet() { Reset(WorkspaceRegistry()); }
  73. // Returns true if a workspace has been set.
  74. template <class W>
  75. bool Has(int index) const {
  76. const std::type_index id = std::type_index(typeid(W));
  77. DCHECK(workspaces_.find(id) != workspaces_.end());
  78. DCHECK_LT(index, workspaces_.find(id)->second.size());
  79. return workspaces_.find(id)->second[index] != nullptr;
  80. }
  81. // Returns an indexed workspace; the workspace must have been set.
  82. template <class W>
  83. const W &Get(int index) const {
  84. DCHECK(Has<W>(index));
  85. const Workspace *w =
  86. workspaces_.find(std::type_index(typeid(W)))->second[index];
  87. return reinterpret_cast<const W &>(*w);
  88. }
  89. // Sets an indexed workspace; this takes ownership of the workspace, which
  90. // must have been new-allocated. It is an error to set a workspace twice.
  91. template <class W>
  92. void Set(int index, W *workspace) {
  93. const std::type_index id = std::type_index(typeid(W));
  94. DCHECK(workspaces_.find(id) != workspaces_.end());
  95. DCHECK_LT(index, workspaces_[id].size());
  96. DCHECK(workspaces_[id][index] == nullptr);
  97. DCHECK(workspace != nullptr);
  98. workspaces_[id][index] = workspace;
  99. }
  100. void Reset(const WorkspaceRegistry &registry) {
  101. // Deallocate current workspaces.
  102. for (auto &it : workspaces_) {
  103. for (size_t index = 0; index < it.second.size(); ++index) {
  104. delete it.second[index];
  105. }
  106. }
  107. workspaces_.clear();
  108. // Allocate space for new workspaces.
  109. for (auto &it : registry.WorkspaceNames()) {
  110. workspaces_[it.first].resize(it.second.size());
  111. }
  112. }
  113. private:
  114. // The set of workspaces, indexed as workspaces_[typeid][index].
  115. std::unordered_map<std::type_index, std::vector<Workspace *> > workspaces_;
  116. };
  117. // A workspace that wraps around a single int.
  118. class SingletonIntWorkspace : public Workspace {
  119. public:
  120. // Default-initializes the int value.
  121. SingletonIntWorkspace() {}
  122. // Initializes the int with the given value.
  123. explicit SingletonIntWorkspace(int value) : value_(value) {}
  124. // Returns the name of this type of workspace.
  125. static string TypeName() { return "SingletonInt"; }
  126. // Returns the int value.
  127. int get() const { return value_; }
  128. // Sets the int value.
  129. void set(int value) { value_ = value; }
  130. private:
  131. // The enclosed int.
  132. int value_ = 0;
  133. };
  134. // A workspace that wraps around a vector of int.
  135. class VectorIntWorkspace : public Workspace {
  136. public:
  137. // Creates a vector of the given size.
  138. explicit VectorIntWorkspace(int size);
  139. // Creates a vector initialized with the given array.
  140. explicit VectorIntWorkspace(const std::vector<int> &elements);
  141. // Creates a vector of the given size, with each element initialized to the
  142. // given value.
  143. VectorIntWorkspace(int size, int value);
  144. // Returns the name of this type of workspace.
  145. static string TypeName();
  146. // Returns the i'th element.
  147. int element(int i) const { return elements_[i]; }
  148. // Sets the i'th element.
  149. void set_element(int i, int value) { elements_[i] = value; }
  150. int size() const { return elements_.size(); }
  151. private:
  152. // The enclosed vector.
  153. std::vector<int> elements_;
  154. };
  155. // A workspace that wraps around a vector of vector of int.
  156. class VectorVectorIntWorkspace : public Workspace {
  157. public:
  158. // Creates a vector of empty vectors of the given size.
  159. explicit VectorVectorIntWorkspace(int size);
  160. // Returns the name of this type of workspace.
  161. static string TypeName();
  162. // Returns the i'th vector of elements.
  163. const std::vector<int> &elements(int i) const { return elements_[i]; }
  164. // Mutable access to the i'th vector of elements.
  165. std::vector<int> *mutable_elements(int i) { return &(elements_[i]); }
  166. private:
  167. // The enclosed vector of vector of elements.
  168. std::vector<std::vector<int> > elements_;
  169. };
  170. } // namespace syntaxnet
  171. #endif // SYNTAXNET_WORKSPACE_H_