compute_session_impl.cc 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #include "dragnn/core/compute_session_impl.h"
  2. #include <algorithm>
  3. #include <utility>
  4. #include "dragnn/protos/data.pb.h"
  5. #include "dragnn/protos/spec.pb.h"
  6. #include "dragnn/protos/trace.pb.h"
  7. #include "syntaxnet/registry.h"
  8. #include "tensorflow/core/platform/logging.h"
  9. namespace syntaxnet {
  10. namespace dragnn {
  11. ComputeSessionImpl::ComputeSessionImpl(
  12. int id,
  13. std::function<std::unique_ptr<Component>(const string &component_name,
  14. const string &backend_type)>
  15. component_builder)
  16. : component_builder_(std::move(component_builder)), id_(id) {}
  17. void ComputeSessionImpl::Init(const MasterSpec &master_spec,
  18. const GridPoint &hyperparams) {
  19. spec_ = master_spec;
  20. grid_point_ = hyperparams;
  21. VLOG(2) << "Creating components.";
  22. bool is_input = true;
  23. Component *predecessor;
  24. for (const ComponentSpec &spec : master_spec.component()) {
  25. // Construct the component using the specified backend.
  26. VLOG(2) << "Creating component '" << spec.name()
  27. << "' with backend: " << spec.backend().registered_name();
  28. auto component =
  29. component_builder_(spec.name(), spec.backend().registered_name());
  30. // Initializes the component.
  31. component->InitializeComponent(spec);
  32. // Adds a predecessor to non-input components.
  33. if (!is_input) {
  34. predecessors_.insert(
  35. std::pair<Component *, Component *>(component.get(), predecessor));
  36. }
  37. // The current component will be the predecessor component next time around.
  38. predecessor = component.get();
  39. // All components after the first are non-input components.
  40. is_input = false;
  41. // Move into components list.
  42. components_.insert(std::pair<string, std::unique_ptr<Component>>(
  43. spec.name(), std::move(component)));
  44. }
  45. VLOG(2) << "Done creating components.";
  46. VLOG(2) << "Adding translators.";
  47. for (const ComponentSpec &spec : master_spec.component()) {
  48. // First, get the component object for this spec.
  49. VLOG(2) << "Examining component: " << spec.name();
  50. auto map_result = components_.find(spec.name());
  51. CHECK(map_result != components_.end()) << "Unable to find component.";
  52. Component *start_component = map_result->second.get();
  53. if (spec.linked_feature_size() > 0) {
  54. VLOG(2) << "Adding " << spec.linked_feature_size() << " translators for "
  55. << spec.name();
  56. // Attach all the translators described in the spec.
  57. std::vector<IndexTranslator *> translator_set;
  58. for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
  59. // For every translator, save off a non-unique ptr in the component name
  60. // to translator map, then push the unique ptr onto the management
  61. // vector.
  62. auto translator = CreateTranslator(channel, start_component);
  63. translator_set.push_back(translator.get());
  64. owned_translators_.push_back(std::move(translator));
  65. }
  66. // Once all translators have been created, associate this group of
  67. // translators with a component.
  68. translators_.insert(std::pair<string, std::vector<IndexTranslator *>>(
  69. spec.name(), std::move(translator_set)));
  70. } else {
  71. VLOG(2) << "No translators found for " << spec.name();
  72. }
  73. }
  74. VLOG(2) << "Done adding translators.";
  75. VLOG(2) << "Initialization complete.";
  76. }
  77. void ComputeSessionImpl::InitializeComponentData(const string &component_name,
  78. int max_beam_size) {
  79. CHECK(input_data_ != nullptr) << "Attempted to access a component without "
  80. "providing input data for this session.";
  81. Component *component = GetComponent(component_name);
  82. // Try and find the source component. If one exists, check that it is terminal
  83. // and get its data; if not, pass in an empty vector for source data.
  84. auto source_result = predecessors_.find(component);
  85. if (source_result == predecessors_.end()) {
  86. VLOG(1) << "Source result not found. Using empty initialization vector for "
  87. << component_name;
  88. component->InitializeData({}, max_beam_size, input_data_.get());
  89. } else {
  90. VLOG(1) << "Source result found. Using prior initialization vector for "
  91. << component_name;
  92. auto source = source_result->second;
  93. CHECK(source->IsTerminal()) << "Source is not terminal for component '"
  94. << component_name << "'. Exiting.";
  95. component->InitializeData(source->GetBeam(), max_beam_size,
  96. input_data_.get());
  97. }
  98. if (do_tracing_) {
  99. component->InitializeTracing();
  100. }
  101. }
  102. int ComputeSessionImpl::BatchSize(const string &component_name) const {
  103. return GetReadiedComponent(component_name)->BatchSize();
  104. }
  105. int ComputeSessionImpl::BeamSize(const string &component_name) const {
  106. return GetReadiedComponent(component_name)->BeamSize();
  107. }
  108. const ComponentSpec &ComputeSessionImpl::Spec(
  109. const string &component_name) const {
  110. for (const auto &component : spec_.component()) {
  111. if (component.name() == component_name) {
  112. return component;
  113. }
  114. }
  115. LOG(FATAL) << "Missing component '" << component_name << "'. Exiting.";
  116. }
  117. int ComputeSessionImpl::SourceComponentBeamSize(const string &component_name,
  118. int channel_id) {
  119. const auto &translators = GetTranslators(component_name);
  120. return translators.at(channel_id)->path().back()->BeamSize();
  121. }
  122. void ComputeSessionImpl::AdvanceFromOracle(const string &component_name) {
  123. GetReadiedComponent(component_name)->AdvanceFromOracle();
  124. }
  125. void ComputeSessionImpl::AdvanceFromPrediction(const string &component_name,
  126. const float score_matrix[],
  127. int score_matrix_length) {
  128. GetReadiedComponent(component_name)
  129. ->AdvanceFromPrediction(score_matrix, score_matrix_length);
  130. }
  131. int ComputeSessionImpl::GetInputFeatures(
  132. const string &component_name, std::function<int32 *(int)> allocate_indices,
  133. std::function<int64 *(int)> allocate_ids,
  134. std::function<float *(int)> allocate_weights, int channel_id) const {
  135. return GetReadiedComponent(component_name)
  136. ->GetFixedFeatures(allocate_indices, allocate_ids, allocate_weights,
  137. channel_id);
  138. }
  139. int ComputeSessionImpl::BulkGetInputFeatures(
  140. const string &component_name, const BulkFeatureExtractor &extractor) {
  141. return GetReadiedComponent(component_name)->BulkGetFixedFeatures(extractor);
  142. }
  143. std::vector<LinkFeatures> ComputeSessionImpl::GetTranslatedLinkFeatures(
  144. const string &component_name, int channel_id) {
  145. auto *component = GetReadiedComponent(component_name);
  146. auto features = component->GetRawLinkFeatures(channel_id);
  147. IndexTranslator *translator = GetTranslators(component_name).at(channel_id);
  148. for (int i = 0; i < features.size(); ++i) {
  149. LinkFeatures &feature = features[i];
  150. if (feature.has_feature_value()) {
  151. VLOG(2) << "Raw feature[" << i << "]: " << feature.ShortDebugString();
  152. IndexTranslator::Index index = translator->Translate(
  153. feature.batch_idx(), feature.beam_idx(), feature.feature_value());
  154. feature.set_step_idx(index.step_index);
  155. feature.set_batch_idx(index.batch_index);
  156. feature.set_beam_idx(index.beam_index);
  157. } else {
  158. VLOG(2) << "Raw feature[" << i << "]: PADDING (empty proto)";
  159. }
  160. }
  161. // Add the translated link features to the component's trace.
  162. if (do_tracing_) {
  163. component->AddTranslatedLinkFeaturesToTrace(features, channel_id);
  164. }
  165. return features;
  166. }
  167. std::vector<std::vector<int>> ComputeSessionImpl::EmitOracleLabels(
  168. const string &component_name) {
  169. return GetReadiedComponent(component_name)->GetOracleLabels();
  170. }
  171. bool ComputeSessionImpl::IsTerminal(const string &component_name) {
  172. return GetReadiedComponent(component_name)->IsTerminal();
  173. }
  174. void ComputeSessionImpl::SetTracing(bool tracing_on) {
  175. do_tracing_ = tracing_on;
  176. for (auto &component_pair : components_) {
  177. if (!tracing_on) {
  178. component_pair.second->DisableTracing();
  179. }
  180. }
  181. }
  182. void ComputeSessionImpl::FinalizeData(const string &component_name) {
  183. VLOG(2) << "Finalizing data for " << component_name;
  184. GetReadiedComponent(component_name)->FinalizeData();
  185. }
  186. std::vector<string> ComputeSessionImpl::GetSerializedPredictions() {
  187. VLOG(2) << "Geting serialized predictions.";
  188. return input_data_->SerializedData();
  189. }
  190. std::vector<MasterTrace> ComputeSessionImpl::GetTraceProtos() {
  191. std::vector<MasterTrace> traces;
  192. // First compute all possible traces for each component.
  193. std::map<string, std::vector<std::vector<ComponentTrace>>> component_traces;
  194. std::vector<string> pipeline;
  195. for (auto &component_spec : spec_.component()) {
  196. pipeline.push_back(component_spec.name());
  197. component_traces.insert(
  198. {component_spec.name(),
  199. GetComponent(component_spec.name())->GetTraceProtos()});
  200. }
  201. // Only output for the actual number of states in each beam.
  202. auto final_beam = GetComponent(pipeline.back())->GetBeam();
  203. for (int batch_idx = 0; batch_idx < final_beam.size(); ++batch_idx) {
  204. for (int beam_idx = 0; beam_idx < final_beam[batch_idx].size();
  205. ++beam_idx) {
  206. std::vector<int> beam_path;
  207. beam_path.push_back(beam_idx);
  208. // Trace components backwards, finding the source of each state in the
  209. // prior component.
  210. VLOG(2) << "Start trace: " << beam_idx;
  211. for (int i = pipeline.size() - 1; i > 0; --i) {
  212. const auto *component = GetComponent(pipeline[i]);
  213. int source_beam_idx =
  214. component->GetSourceBeamIndex(beam_path.back(), batch_idx);
  215. beam_path.push_back(source_beam_idx);
  216. VLOG(2) << "Tracing path: " << pipeline[i] << " = " << source_beam_idx;
  217. }
  218. // Trace the path from the *start* to the end.
  219. std::reverse(beam_path.begin(), beam_path.end());
  220. MasterTrace master_trace;
  221. for (int i = 0; i < pipeline.size(); ++i) {
  222. *master_trace.add_component_trace() =
  223. component_traces[pipeline[i]][batch_idx][beam_path[i]];
  224. }
  225. traces.push_back(master_trace);
  226. }
  227. }
  228. return traces;
  229. }
  230. void ComputeSessionImpl::SetInputData(const std::vector<string> &data) {
  231. input_data_.reset(new InputBatchCache(data));
  232. }
  233. void ComputeSessionImpl::ResetSession() {
  234. // Reset all component states.
  235. for (auto &component_pair : components_) {
  236. component_pair.second->ResetComponent();
  237. }
  238. // Reset the input data pointer.
  239. input_data_.reset();
  240. }
  241. int ComputeSessionImpl::Id() const { return id_; }
  242. string ComputeSessionImpl::GetDescription(const string &component_name) const {
  243. return GetComponent(component_name)->Name();
  244. }
  245. const std::vector<const IndexTranslator *> ComputeSessionImpl::Translators(
  246. const string &component_name) const {
  247. auto translators = GetTranslators(component_name);
  248. std::vector<const IndexTranslator *> const_translators;
  249. for (const auto &translator : translators) {
  250. const_translators.push_back(translator);
  251. }
  252. return const_translators;
  253. }
  254. Component *ComputeSessionImpl::GetReadiedComponent(
  255. const string &component_name) const {
  256. auto component = GetComponent(component_name);
  257. CHECK(component->IsReady())
  258. << "Attempted to access component " << component_name
  259. << " without first initializing it.";
  260. return component;
  261. }
  262. Component *ComputeSessionImpl::GetComponent(
  263. const string &component_name) const {
  264. auto result = components_.find(component_name);
  265. if (result == components_.end()) {
  266. LOG(ERROR) << "Could not find component \"" << component_name
  267. << "\" in the component set. Current components are: ";
  268. for (const auto &component_pair : components_) {
  269. LOG(ERROR) << component_pair.first;
  270. }
  271. LOG(FATAL) << "Missing component. Exiting.";
  272. }
  273. auto component = result->second.get();
  274. return component;
  275. }
  276. const std::vector<IndexTranslator *> &ComputeSessionImpl::GetTranslators(
  277. const string &component_name) const {
  278. auto result = translators_.find(component_name);
  279. if (result == translators_.end()) {
  280. LOG(ERROR) << "Could not find component " << component_name
  281. << " in the translator set. Current components are: ";
  282. for (const auto &component_pair : translators_) {
  283. LOG(ERROR) << component_pair.first;
  284. }
  285. LOG(FATAL) << "Missing component. Exiting.";
  286. }
  287. return result->second;
  288. }
  289. std::unique_ptr<IndexTranslator> ComputeSessionImpl::CreateTranslator(
  290. const LinkedFeatureChannel &channel, Component *start_component) {
  291. const int num_components = spec_.component_size();
  292. VLOG(2) << "Channel spec: " << channel.ShortDebugString();
  293. // Find the linked feature's source component, if it exists.
  294. auto source_map_result = components_.find(channel.source_component());
  295. CHECK(source_map_result != components_.end())
  296. << "Unable to find source component " << channel.source_component();
  297. const Component *end_component = source_map_result->second.get();
  298. // Our goal here is to iterate up the source map from the
  299. // start_component to the end_component.
  300. Component *current_component = start_component;
  301. std::vector<Component *> path;
  302. path.push_back(current_component);
  303. while (current_component != end_component) {
  304. // Try to find the next link upwards in the source chain.
  305. auto source_result = predecessors_.find(current_component);
  306. // If this component doesn't have a source to find, that's an error.
  307. CHECK(source_result != predecessors_.end())
  308. << "No link to source " << channel.source_component();
  309. // If we jump more times than there are components in the graph, that
  310. // is an error state.
  311. CHECK_LT(path.size(), num_components) << "Too many jumps. Is there a "
  312. "loop in the MasterSpec "
  313. "component definition?";
  314. // Add the source to the vector and repeat.
  315. path.push_back(source_result->second);
  316. current_component = source_result->second;
  317. }
  318. // At this point, we have the source chain for the traslator and can
  319. // build it.
  320. std::unique_ptr<IndexTranslator> translator(
  321. new IndexTranslator(path, channel.source_translator()));
  322. return translator;
  323. }
  324. } // namespace dragnn
  325. } // namespace syntaxnet