compute_session.h 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
  17. #include <string>
  18. #include "dragnn/components/util/bulk_feature_extractor.h"
  19. #include "dragnn/core/index_translator.h"
  20. #include "dragnn/core/interfaces/component.h"
  21. #include "dragnn/protos/spec.pb.h"
  22. #include "dragnn/protos/trace.pb.h"
  23. namespace syntaxnet {
  24. namespace dragnn {
  25. // This defines the interface for a ComputeSession object. We only ever expect
  26. // ComputeSessionImpl to implement the ComputeSession - this is only used
  27. // to provide a mocking seam.
  28. class ComputeSession {
  29. public:
  30. virtual ~ComputeSession() {}
  31. // Initialize this ComputeSession to compute the graph defined in the given
  32. // MasterSpec with the hyperparameters passed in the GridPoint. This should
  33. // only be called once, when the ComputeSession is created.
  34. virtual void Init(const MasterSpec &master_spec,
  35. const GridPoint &hyperparams) = 0;
  36. // Initialize a component with data and a given maximum beam
  37. // size. Note that attempting to initialize a component that depends on
  38. // another component that has not yet finished will cause a CHECK failure.
  39. virtual void InitializeComponentData(const string &component_name,
  40. int max_beam_size) = 0;
  41. // Return the batch size for the given component.
  42. virtual int BatchSize(const string &component_name) const = 0;
  43. // Return the beam size for the given component.
  44. virtual int BeamSize(const string &component_name) const = 0;
  45. // Returns the spec used to create this ComputeSession.
  46. virtual const ComponentSpec &Spec(const string &component_name) const = 0;
  47. // For a given component and linked feature channel, get the beam size of the
  48. // component that is the source of the linked features.
  49. virtual int SourceComponentBeamSize(const string &component_name,
  50. int channel_id) = 0;
  51. // Advance the given component using the component's oracle.
  52. virtual void AdvanceFromOracle(const string &component_name) = 0;
  53. // Advance the given component using the given score matrix.
  54. virtual void AdvanceFromPrediction(const string &component_name,
  55. const float score_matrix[],
  56. int score_matrix_length) = 0;
  57. // Get the input features for the given component and channel. This passes
  58. // through to the relevant Component's GetFixedFeatures() call.
  59. virtual int GetInputFeatures(
  60. const string &component_name,
  61. std::function<int32 *(int num_items)> allocate_indices,
  62. std::function<int64 *(int num_items)> allocate_ids,
  63. std::function<float *(int num_items)> allocate_weights,
  64. int channel_id) const = 0;
  65. // Get the input features for the given component and channel, advancing via
  66. // the oracle until the state is final. This passes through to the relevant
  67. // Component's BulkGetFixedFeatures() call.
  68. virtual int BulkGetInputFeatures(const string &component_name,
  69. const BulkFeatureExtractor &extractor) = 0;
  70. // Get the input features for the given component and channel. This function
  71. // can return empty LinkFeatures protos, which represent unused padding slots
  72. // in the output weight tensor.
  73. virtual std::vector<LinkFeatures> GetTranslatedLinkFeatures(
  74. const string &component_name, int channel_id) = 0;
  75. // Get the oracle labels for the given component.
  76. virtual std::vector<std::vector<int>> EmitOracleLabels(
  77. const string &component_name) = 0;
  78. // Returns true if the given component is terminal.
  79. virtual bool IsTerminal(const string &component_name) = 0;
  80. // Force the given component to write out its predictions to the backing data.
  81. virtual void FinalizeData(const string &component_name) = 0;
  82. // Return the finalized predictions from this compute session.
  83. virtual std::vector<string> GetSerializedPredictions() = 0;
  84. // Returns the trace protos. This will CHECK fail or be empty if the
  85. // SetTracing() has not been called to initialize the underlying Component
  86. // traces.
  87. virtual std::vector<MasterTrace> GetTraceProtos() = 0;
  88. // Provides the ComputeSession with a batch of data to compute.
  89. virtual void SetInputData(const std::vector<string> &data) = 0;
  90. // Resets all components owned by this ComputeSession.
  91. virtual void ResetSession() = 0;
  92. // Set the tracing for this ComputeSession.
  93. virtual void SetTracing(bool tracing_on) = 0;
  94. // Returns a unique identifier for this ComputeSession.
  95. virtual int Id() const = 0;
  96. // Returns a string describing the given component.
  97. virtual string GetDescription(const string &component_name) const = 0;
  98. // Get all the translators for the given component. Should only be used to
  99. // validate correct construction of translators in tests.
  100. virtual const std::vector<const IndexTranslator *> Translators(
  101. const string &component_name) const = 0;
  102. };
  103. } // namespace dragnn
  104. } // namespace syntaxnet
  105. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_