compute_session_impl.cc 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/core/compute_session_impl.h"
  16. #include <algorithm>
  17. #include <utility>
  18. #include "dragnn/protos/data.pb.h"
  19. #include "dragnn/protos/spec.pb.h"
  20. #include "dragnn/protos/trace.pb.h"
  21. #include "syntaxnet/registry.h"
  22. #include "tensorflow/core/platform/logging.h"
  23. namespace syntaxnet {
  24. namespace dragnn {
  25. ComputeSessionImpl::ComputeSessionImpl(
  26. int id,
  27. std::function<std::unique_ptr<Component>(const string &component_name,
  28. const string &backend_type)>
  29. component_builder)
  30. : component_builder_(std::move(component_builder)), id_(id) {}
  31. void ComputeSessionImpl::Init(const MasterSpec &master_spec,
  32. const GridPoint &hyperparams) {
  33. spec_ = master_spec;
  34. grid_point_ = hyperparams;
  35. VLOG(2) << "Creating components.";
  36. bool is_input = true;
  37. Component *predecessor;
  38. for (const ComponentSpec &spec : master_spec.component()) {
  39. // Construct the component using the specified backend.
  40. VLOG(2) << "Creating component '" << spec.name()
  41. << "' with backend: " << spec.backend().registered_name();
  42. auto component =
  43. component_builder_(spec.name(), spec.backend().registered_name());
  44. // Initializes the component.
  45. component->InitializeComponent(spec);
  46. // Adds a predecessor to non-input components.
  47. if (!is_input) {
  48. predecessors_.insert(
  49. std::pair<Component *, Component *>(component.get(), predecessor));
  50. }
  51. // The current component will be the predecessor component next time around.
  52. predecessor = component.get();
  53. // All components after the first are non-input components.
  54. is_input = false;
  55. // Move into components list.
  56. components_.insert(std::pair<string, std::unique_ptr<Component>>(
  57. spec.name(), std::move(component)));
  58. }
  59. VLOG(2) << "Done creating components.";
  60. VLOG(2) << "Adding translators.";
  61. for (const ComponentSpec &spec : master_spec.component()) {
  62. // First, get the component object for this spec.
  63. VLOG(2) << "Examining component: " << spec.name();
  64. auto map_result = components_.find(spec.name());
  65. CHECK(map_result != components_.end()) << "Unable to find component.";
  66. Component *start_component = map_result->second.get();
  67. if (spec.linked_feature_size() > 0) {
  68. VLOG(2) << "Adding " << spec.linked_feature_size() << " translators for "
  69. << spec.name();
  70. // Attach all the translators described in the spec.
  71. std::vector<IndexTranslator *> translator_set;
  72. for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
  73. // For every translator, save off a non-unique ptr in the component name
  74. // to translator map, then push the unique ptr onto the management
  75. // vector.
  76. auto translator = CreateTranslator(channel, start_component);
  77. translator_set.push_back(translator.get());
  78. owned_translators_.push_back(std::move(translator));
  79. }
  80. // Once all translators have been created, associate this group of
  81. // translators with a component.
  82. translators_.insert(std::pair<string, std::vector<IndexTranslator *>>(
  83. spec.name(), std::move(translator_set)));
  84. } else {
  85. VLOG(2) << "No translators found for " << spec.name();
  86. }
  87. }
  88. VLOG(2) << "Done adding translators.";
  89. VLOG(2) << "Initialization complete.";
  90. }
  91. void ComputeSessionImpl::InitializeComponentData(const string &component_name,
  92. int max_beam_size) {
  93. CHECK(input_data_ != nullptr) << "Attempted to access a component without "
  94. "providing input data for this session.";
  95. Component *component = GetComponent(component_name);
  96. // Try and find the source component. If one exists, check that it is terminal
  97. // and get its data; if not, pass in an empty vector for source data.
  98. auto source_result = predecessors_.find(component);
  99. if (source_result == predecessors_.end()) {
  100. VLOG(1) << "Source result not found. Using empty initialization vector for "
  101. << component_name;
  102. component->InitializeData({}, max_beam_size, input_data_.get());
  103. } else {
  104. VLOG(1) << "Source result found. Using prior initialization vector for "
  105. << component_name;
  106. auto source = source_result->second;
  107. CHECK(source->IsTerminal()) << "Source is not terminal for component '"
  108. << component_name << "'. Exiting.";
  109. component->InitializeData(source->GetBeam(), max_beam_size,
  110. input_data_.get());
  111. }
  112. if (do_tracing_) {
  113. component->InitializeTracing();
  114. }
  115. }
  116. int ComputeSessionImpl::BatchSize(const string &component_name) const {
  117. return GetReadiedComponent(component_name)->BatchSize();
  118. }
  119. int ComputeSessionImpl::BeamSize(const string &component_name) const {
  120. return GetReadiedComponent(component_name)->BeamSize();
  121. }
  122. const ComponentSpec &ComputeSessionImpl::Spec(
  123. const string &component_name) const {
  124. for (const auto &component : spec_.component()) {
  125. if (component.name() == component_name) {
  126. return component;
  127. }
  128. }
  129. LOG(FATAL) << "Missing component '" << component_name << "'. Exiting.";
  130. }
  131. int ComputeSessionImpl::SourceComponentBeamSize(const string &component_name,
  132. int channel_id) {
  133. const auto &translators = GetTranslators(component_name);
  134. return translators.at(channel_id)->path().back()->BeamSize();
  135. }
  136. void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
  137. GetReadiedComponent(component_name)->AdvanceFromOracle();
  138. }
  139. void ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
  140. const float score_matrix[],
  141. int score_matrix_length) {
  142. GetReadiedComponent(component_name)
  143. ->AdvanceFromPrediction(score_matrix, score_matrix_length);
  144. }
  145. int ComputeSessionImpl::GetInputFeatures(
  146. const string &component_name, std::function<int32 *(int)> allocate_indices,
  147. std::function<int64 *(int)> allocate_ids,
  148. std::function<float *(int)> allocate_weights, int channel_id) const {
  149. return GetReadiedComponent(component_name)
  150. ->GetFixedFeatures(allocate_indices, allocate_ids, allocate_weights,
  151. channel_id);
  152. }
  153. int ComputeSessionImpl::BulkGetInputFeatures(
  154. const string &component_name, const BulkFeatureExtractor &extractor) {
  155. return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor);
  156. }
  157. std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
  158. const string &component_name, int channel_id) {
  159. auto *component = GetReadiedComponent(component_name);
  160. auto features = component->GetRawLinkFeatures(channel_id);
  161. IndexTranslator *translator = GetTranslators(component_name).at(channel_id);
  162. for (int i = 0; i < features.size(); ++i) {
  163. LinkFeatures &feature = features[i];
  164. if (feature.has_feature_value()) {
  165. VLOG(2) << "Raw feature[" << i << "]: " << feature.ShortDebugString();
  166. IndexTranslator::Index index = translator->Translate(
  167. feature.batch_idx(), feature.beam_idx(), feature.feature_value());
  168. feature.set_step_idx(index.step_index);
  169. feature.set_batch_idx(index.batch_index);
  170. feature.set_beam_idx(index.beam_index);
  171. } else {
  172. VLOG(2) << "Raw feature[" << i << "]: PADDING (empty proto)";
  173. }
  174. }
  175. // Add the translated link features to the component's trace.
  176. if (do_tracing_) {
  177. component->AddTranslatedLinkFeaturesToTrace(features, channel_id);
  178. }
  179. return features;
  180. }
  181. std::vector<std::vector<int>> ComputeSessionImpl::EmitOracleLabels(
  182. const string &component_name) {
  183. return GetReadiedComponent(component_name)->GetOracleLabels();
  184. }
  185. bool ComputeSessionImpl::IsTerminal(const string &component_name) {
  186. return GetReadiedComponent(component_name)->IsTerminal();
  187. }
  188. void ComputeSessionImpl::SetTracing(bool tracing_on) {
  189. do_tracing_ = tracing_on;
  190. for (auto &component_pair : components_) {
  191. if (!tracing_on) {
  192. component_pair.second->DisableTracing();
  193. }
  194. }
  195. }
  196. void ComputeSessionImpl::FinalizeData(const string &component_name) {
  197. VLOG(2) << "Finalizing data for " << component_name;
  198. GetReadiedComponent(component_name)->FinalizeData();
  199. }
  200. std::vector<string> ComputeSessionImpl::GetSerializedPredictions() {
  201. VLOG(2) << "Geting serialized predictions.";
  202. return input_data_->SerializedData();
  203. }
  204. std::vector<MasterTrace> ComputeSessionImpl::GetTraceProtos() {
  205. std::vector<MasterTrace> traces;
  206. // First compute all possible traces for each component.
  207. std::map<string, std::vector<std::vector<ComponentTrace>>> component_traces;
  208. std::vector<string> pipeline;
  209. for (auto &component_spec : spec_.component()) {
  210. pipeline.push_back(component_spec.name());
  211. component_traces.insert(
  212. {component_spec.name(),
  213. GetComponent(component_spec.name())->GetTraceProtos()});
  214. }
  215. // Only output for the actual number of states in each beam.
  216. auto final_beam = GetComponent(pipeline.back())->GetBeam();
  217. for (int batch_idx = 0; batch_idx < final_beam.size(); ++batch_idx) {
  218. for (int beam_idx = 0; beam_idx < final_beam[batch_idx].size();
  219. ++beam_idx) {
  220. std::vector<int> beam_path;
  221. beam_path.push_back(beam_idx);
  222. // Trace components backwards, finding the source of each state in the
  223. // prior component.
  224. VLOG(2) << "Start trace: " << beam_idx;
  225. for (int i = pipeline.size() - 1; i > 0; --i) {
  226. const auto *component = GetComponent(pipeline[i]);
  227. int source_beam_idx =
  228. component->GetSourceBeamIndex(beam_path.back(), batch_idx);
  229. beam_path.push_back(source_beam_idx);
  230. VLOG(2) << "Tracing path: " << pipeline[i] << " = " << source_beam_idx;
  231. }
  232. // Trace the path from the *start* to the end.
  233. std::reverse(beam_path.begin(), beam_path.end());
  234. MasterTrace master_trace;
  235. for (int i = 0; i < pipeline.size(); ++i) {
  236. *master_trace.add_component_trace() =
  237. component_traces[pipeline[i]][batch_idx][beam_path[i]];
  238. }
  239. traces.push_back(master_trace);
  240. }
  241. }
  242. return traces;
  243. }
  244. void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
  245. input_data_.reset(new InputBatchCache(data));
  246. }
  247. void ComputeSessionImpl::ResetSession() {
  248. // Reset all component states.
  249. for (auto &component_pair : components_) {
  250. component_pair.second->ResetComponent();
  251. }
  252. // Reset the input data pointer.
  253. input_data_.reset();
  254. }
  255. int ComputeSessionImpl::Id() const { return id_; }
  256. string ComputeSessionImpl::GetDescription(const string &component_name) const {
  257. return GetComponent(component_name)->Name();
  258. }
  259. const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
  260. const string &component_name) const {
  261. auto translators = GetTranslators(component_name);
  262. std::vector<const IndexTranslator *> const_translators;
  263. for (const auto &translator : translators) {
  264. const_translators.push_back(translator);
  265. }
  266. return const_translators;
  267. }
  268. Component *ComputeSessionImpl::GetReadiedComponent(
  269. const string &component_name) const {
  270. auto component = GetComponent(component_name);
  271. CHECK(component->IsReady())
  272. << "Attempted to access component " << component_name
  273. << " without first initializing it.";
  274. return component;
  275. }
  276. Component *ComputeSessionImpl::GetComponent(
  277. const string &component_name) const {
  278. auto result = components_.find(component_name);
  279. if (result == components_.end()) {
  280. LOG(ERROR) << "Could not find component \"" << component_name
  281. << "\" in the component set. Current components are: ";
  282. for (const auto &component_pair : components_) {
  283. LOG(ERROR) << component_pair.first;
  284. }
  285. LOG(FATAL) << "Missing component. Exiting.";
  286. }
  287. auto component = result->second.get();
  288. return component;
  289. }
  290. const std::vector<IndexTranslator *> &ComputeSessionImpl::GetTranslators(
  291. const string &component_name) const {
  292. auto result = translators_.find(component_name);
  293. if (result == translators_.end()) {
  294. LOG(ERROR) << "Could not find component " << component_name
  295. << " in the translator set. Current components are: ";
  296. for (const auto &component_pair : translators_) {
  297. LOG(ERROR) << component_pair.first;
  298. }
  299. LOG(FATAL) << "Missing component. Exiting.";
  300. }
  301. return result->second;
  302. }
  303. std::unique_ptr<IndexTranslator> ComputeSessionImpl::CreateTranslator(
  304. const LinkedFeatureChannel &channel, Component *start_component) {
  305. const int num_components = spec_.component_size();
  306. VLOG(2) << "Channel spec: " << channel.ShortDebugString();
  307. // Find the linked feature's source component, if it exists.
  308. auto source_map_result = components_.find(channel.source_component());
  309. CHECK(source_map_result != components_.end())
  310. << "Unable to find source component " << channel.source_component();
  311. const Component *end_component = source_map_result->second.get();
  312. // Our goal here is to iterate up the source map from the
  313. // start_component to the end_component.
  314. Component *current_component = start_component;
  315. std::vector<Component *> path;
  316. path.push_back(current_component);
  317. while (current_component != end_component) {
  318. // Try to find the next link upwards in the source chain.
  319. auto source_result = predecessors_.find(current_component);
  320. // If this component doesn't have a source to find, that's an error.
  321. CHECK(source_result != predecessors_.end())
  322. << "No link to source " << channel.source_component();
  323. // If we jump more times than there are components in the graph, that
  324. // is an error state.
  325. CHECK_LT(path.size(), num_components) << "Too many jumps. Is there a "
  326. "loop in the MasterSpec "
  327. "component definition?";
  328. // Add the source to the vector and repeat.
  329. path.push_back(source_result->second);
  330. current_component = source_result->second;
  331. }
  332. // At this point, we have the source chain for the traslator and can
  333. // build it.
  334. std::unique_ptr<IndexTranslator> translator(
  335. new IndexTranslator(path, channel.source_translator()));
  336. return translator;
  337. }
  338. } // namespace dragnn
  339. } // namespace syntaxnet