component.h 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
  3. #include <vector>
  4. #include "dragnn/components/util/bulk_feature_extractor.h"
  5. #include "dragnn/core/input_batch_cache.h"
  6. #include "dragnn/core/interfaces/transition_state.h"
  7. #include "dragnn/protos/spec.pb.h"
  8. #include "dragnn/protos/trace.pb.h"
  9. #include "syntaxnet/registry.h"
  10. namespace syntaxnet {
  11. namespace dragnn {
  12. class Component : public RegisterableClass<Component> {
  13. public:
  14. virtual ~Component() {}
  15. // Initializes this component from the spec.
  16. virtual void InitializeComponent(const ComponentSpec &spec) = 0;
  17. // Provides the previous beam to the component.
  18. virtual void InitializeData(
  19. const std::vector<std::vector<const TransitionState *>> &states,
  20. int max_beam_size, InputBatchCache *input_data) = 0;
  21. // Returns true if the component has had InitializeData called on it since
  22. // the last time it was reset.
  23. virtual bool IsReady() const = 0;
  24. // Initializes the component for tracing execution, resetting any existing
  25. // traces. This will typically have the side effect of slowing down all
  26. // subsequent Component calculations and storing a trace in memory that can be
  27. // returned by GetTraceProtos().
  28. virtual void InitializeTracing() = 0;
  29. // Disables tracing, freeing any associated traces and avoiding triggering
  30. // additional computation in the future.
  31. virtual void DisableTracing() = 0;
  32. // Returns the string name of this component.
  33. virtual string Name() const = 0;
  34. // Returns the current batch size of the component's underlying data.
  35. virtual int BatchSize() const = 0;
  36. // Returns the maximum beam size of this component.
  37. virtual int BeamSize() const = 0;
  38. // Returns the number of steps taken by this component so far.
  39. virtual int StepsTaken(int batch_index) const = 0;
  40. // Return the beam index of the item which is currently at index
  41. // 'index', when the beam was at step 'step', for batch element 'batch'.
  42. virtual int GetBeamIndexAtStep(int step, int current_index,
  43. int batch) const = 0;
  44. // Return the source index of the item which is currently at index 'index'
  45. // for batch element 'batch'. This index is into the final beam of the
  46. // Component that this Component was initialized from.
  47. virtual int GetSourceBeamIndex(int current_index, int batch) const = 0;
  48. // Request a translation function based on the given method string.
  49. // The translation function will be called with arguments (beam, batch, value)
  50. // and should return the step index corresponding to the given value, for the
  51. // data in the given beam and batch.
  52. virtual std::function<int(int, int, int)> GetStepLookupFunction(
  53. const string &method) = 0;
  54. // Advances this component from the given transition matrix.
  55. virtual void AdvanceFromPrediction(const float transition_matrix[],
  56. int transition_matrix_length) = 0;
  57. // Advances this component from the state oracles.
  58. virtual void AdvanceFromOracle() = 0;
  59. // Returns true if all states within this component are terminal.
  60. virtual bool IsTerminal() const = 0;
  61. // Returns the current batch of beams for this component.
  62. virtual std::vector<std::vector<const TransitionState *>> GetBeam() = 0;
  63. // Extracts and populates the vector of FixedFeatures for the specified
  64. // channel. Each functor allocates storage space for the indices, the IDs, and
  65. // the weights (respectively).
  66. virtual int GetFixedFeatures(
  67. std::function<int32 *(int num_elements)> allocate_indices,
  68. std::function<int64 *(int num_elements)> allocate_ids,
  69. std::function<float *(int num_elements)> allocate_weights,
  70. int channel_id) const = 0;
  71. // Extracts and populates all FixedFeatures for all channels, advancing this
  72. // component via the oracle until it is terminal. This call uses a
  73. // BulkFeatureExtractor object to contain the functors and other information.
  74. virtual int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) = 0;
  75. // Extracts and returns the vector of LinkFeatures for the specified
  76. // channel. Note: these are NOT translated.
  77. virtual std::vector<LinkFeatures> GetRawLinkFeatures(
  78. int channel_id) const = 0;
  79. // Returns a vector of oracle labels for each element in the beam and
  80. // batch.
  81. virtual std::vector<std::vector<int>> GetOracleLabels() const = 0;
  82. // Annotate the underlying data object with the results of this Component's
  83. // calculation.
  84. virtual void FinalizeData() = 0;
  85. // Reset this component.
  86. virtual void ResetComponent() = 0;
  87. // Get a vector of all traces managed by this component.
  88. virtual std::vector<std::vector<ComponentTrace>> GetTraceProtos() const = 0;
  89. // Add the translated link features (done outside the component) to the traces
  90. // managed by this component.
  91. virtual void AddTranslatedLinkFeaturesToTrace(
  92. const std::vector<LinkFeatures> &features, int channel_id) = 0;
  93. };
  94. } // namespace dragnn
  95. } // namespace syntaxnet
  96. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_