compute_session_impl.h 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
  17. #include <memory>
  18. #include "dragnn/components/util/bulk_feature_extractor.h"
  19. #include "dragnn/core/compute_session.h"
  20. #include "dragnn/core/index_translator.h"
  21. #include "dragnn/core/input_batch_cache.h"
  22. #include "dragnn/protos/data.pb.h"
  23. #include "dragnn/protos/spec.pb.h"
  24. #include "dragnn/protos/trace.pb.h"
  25. namespace syntaxnet {
  26. namespace dragnn {
  27. class ComputeSessionImpl : public ComputeSession {
  28. public:
  29. // Creates a ComputeSessionImpl with the provided component builder function.
  30. ComputeSessionImpl(
  31. int id,
  32. std::function<std::unique_ptr<Component>(const string &component_name,
  33. const string &backend_type)>
  34. component_builder);
  35. void Init(const MasterSpec &master_spec,
  36. const GridPoint &hyperparams) override;
  37. void InitializeComponentData(const string &component_name,
  38. int max_beam_size) override;
  39. int BatchSize(const string &component_name) const override;
  40. int BeamSize(const string &component_name) const override;
  41. const ComponentSpec &Spec(const string &component_name) const override;
  42. int SourceComponentBeamSize(const string &component_name,
  43. int channel_id) override;
  44. void AdvanceFromOracle(const string &component_name) override;
  45. void AdvanceFromPrediction(const string &component_name,
  46. const float score_matrix[],
  47. int score_matrix_length) override;
  48. int GetInputFeatures(const string &component_name,
  49. std::function<int32 *(int)> allocate_indices,
  50. std::function<int64 *(int)> allocate_ids,
  51. std::function<float *(int)> allocate_weights,
  52. int channel_id) const override;
  53. int BulkGetInputFeatures(const string &component_name,
  54. const BulkFeatureExtractor &extractor) override;
  55. std::vector<LinkFeatures> GetTranslatedLinkFeatures(
  56. const string &component_name, int channel_id) override;
  57. std::vector<std::vector<int>> EmitOracleLabels(
  58. const string &component_name) override;
  59. bool IsTerminal(const string &component_name) override;
  60. void FinalizeData(const string &component_name) override;
  61. std::vector<string> GetSerializedPredictions() override;
  62. std::vector<MasterTrace> GetTraceProtos() override;
  63. void SetInputData(const std::vector<string> &data) override;
  64. void ResetSession() override;
  65. void SetTracing(bool tracing_on) override;
  66. int Id() const override;
  67. string GetDescription(const string &component_name) const override;
  68. const std::vector<const IndexTranslator *> Translators(
  69. const string &component_name) const override;
  70. private:
  71. // Get a given component. Fails if the component is not found.
  72. Component *GetComponent(const string &component_name) const;
  73. // Get a given component. CHECK-fail if the component's IsReady method
  74. // returns false.
  75. Component *GetReadiedComponent(const string &component_name) const;
  76. // Get the index translators for the given component.
  77. const std::vector<IndexTranslator *> &GetTranslators(
  78. const string &component_name) const;
  79. // Create an index translator.
  80. std::unique_ptr<IndexTranslator> CreateTranslator(
  81. const LinkedFeatureChannel &channel, Component *start_component);
  82. // Perform initialization on the given Component.
  83. void InitComponent(Component *component);
  84. // Holds all of the components owned by this ComputeSession, associated with
  85. // their names in the MasterSpec.
  86. std::map<string, std::unique_ptr<Component>> components_;
  87. // Holds a vector of translators for each component, indexed by the name
  88. // of the component they belong to.
  89. std::map<string, std::vector<IndexTranslator *>> translators_;
  90. // Holds ownership of all the IndexTranslators for this compute session.
  91. std::vector<std::unique_ptr<IndexTranslator>> owned_translators_;
  92. // The predecessor component for every component.
  93. // If a component is not in this map, it has no predecessor component and
  94. // will have its beam initialized without any data from other components.
  95. std::map<Component *, Component *> predecessors_;
  96. // Holds the current input data for this ComputeSession.
  97. std::unique_ptr<InputBatchCache> input_data_;
  98. // Function that, given a string, will return a Component.
  99. std::function<std::unique_ptr<Component>(const string &component_name,
  100. const string &backend_type)>
  101. component_builder_;
  102. // The master spec for this compute session.
  103. MasterSpec spec_;
  104. // The hyperparameters for this compute session.
  105. GridPoint grid_point_;
  106. // Unique identifier, assigned at construction.
  107. int id_;
  108. // Whether or not to perform tracing.
  109. bool do_tracing_ = false;
  110. };
  111. } // namespace dragnn
  112. } // namespace syntaxnet
  113. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_