compute_session.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
  3. #include <string>
  4. #include "dragnn/components/util/bulk_feature_extractor.h"
  5. #include "dragnn/core/index_translator.h"
  6. #include "dragnn/core/interfaces/component.h"
  7. #include "dragnn/protos/spec.pb.h"
  8. #include "dragnn/protos/trace.pb.h"
  9. namespace syntaxnet {
  10. namespace dragnn {
  11. // This defines the interface for a ComputeSession object. We only ever expect
  12. // ComputeSessionImpl to implement the ComputeSession - this is only used
  13. // to provide a mocking seam.
  14. class ComputeSession {
  15. public:
  16. virtual ~ComputeSession() {}
  17. // Initialize this ComputeSession to compute the graph defined in the given
  18. // MasterSpec with the hyperparameters passed in the GridPoint. This should
  19. // only be called once, when the ComputeSession is created.
  20. virtual void Init(const MasterSpec &master_spec,
  21. const GridPoint &hyperparams) = 0;
  22. // Initialize a component with data and a given maximum beam
  23. // size. Note that attempting to initialize a component that depends on
  24. // another component that has not yet finished will cause a CHECK failure.
  25. virtual void InitializeComponentData(const string &component_name,
  26. int max_beam_size) = 0;
  27. // Return the batch size for the given component.
  28. virtual int BatchSize(const string &component_name) const = 0;
  29. // Return the beam size for the given component.
  30. virtual int BeamSize(const string &component_name) const = 0;
  31. // Returns the spec used to create this ComputeSession.
  32. virtual const ComponentSpec &Spec(const string &component_name) const = 0;
  33. // For a given component and linked feature channel, get the beam size of the
  34. // component that is the source of the linked features.
  35. virtual int SourceComponentBeamSize(const string &component_name,
  36. int channel_id) = 0;
  37. // Advance the given component using the component's oracle.
  38. virtual void AdvanceFromOracle(const string &component_name) = 0;
  39. // Advance the given component using the given score matrix.
  40. virtual void AdvanceFromPrediction(const string &component_name,
  41. const float score_matrix[],
  42. int score_matrix_length) = 0;
  43. // Get the input features for the given component and channel. This passes
  44. // through to the relevant Component's GetFixedFeatures() call.
  45. virtual int GetInputFeatures(
  46. const string &component_name,
  47. std::function<int32 *(int num_items)> allocate_indices,
  48. std::function<int64 *(int num_items)> allocate_ids,
  49. std::function<float *(int num_items)> allocate_weights,
  50. int channel_id) const = 0;
  51. // Get the input features for the given component and channel, advancing via
  52. // the oracle until the state is final. This passes through to the relevant
  53. // Component's BulkGetFixedFeatures() call.
  54. virtual int BulkGetInputFeatures(const string &component_name,
  55. const BulkFeatureExtractor &extractor) = 0;
  56. // Get the input features for the given component and channel. This function
  57. // can return empty LinkFeatures protos, which represent unused padding slots
  58. // in the output weight tensor.
  59. virtual std::vector<LinkFeatures> GetTranslatedLinkFeatures(
  60. const string &component_name, int channel_id) = 0;
  61. // Get the oracle labels for the given component.
  62. virtual std::vector<std::vector<int>> EmitOracleLabels(
  63. const string &component_name) = 0;
  64. // Returns true if the given component is terminal.
  65. virtual bool IsTerminal(const string &component_name) = 0;
  66. // Force the given component to write out its predictions to the backing data.
  67. virtual void FinalizeData(const string &component_name) = 0;
  68. // Return the finalized predictions from this compute session.
  69. virtual std::vector<string> GetSerializedPredictions() = 0;
  70. // Returns the trace protos. This will CHECK fail or be empty if the
  71. // SetTracing() has not been called to initialize the underlying Component
  72. // traces.
  73. virtual std::vector<MasterTrace> GetTraceProtos() = 0;
  74. // Provides the ComputeSession with a batch of data to compute.
  75. virtual void SetInputData(const std::vector<string> &data) = 0;
  76. // Resets all components owned by this ComputeSession.
  77. virtual void ResetSession() = 0;
  78. // Set the tracing for this ComputeSession.
  79. virtual void SetTracing(bool tracing_on) = 0;
  80. // Returns a unique identifier for this ComputeSession.
  81. virtual int Id() const = 0;
  82. // Returns a string describing the given component.
  83. virtual string GetDescription(const string &component_name) const = 0;
  84. // Get all the translators for the given component. Should only be used to
  85. // validate correct construction of translators in tests.
  86. virtual const std::vector<const IndexTranslator *> Translators(
  87. const string &component_name) const = 0;
  88. };
  89. } // namespace dragnn
  90. } // namespace syntaxnet
  91. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_