stateless_component.cc 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/core/component_registry.h"
  16. #include "dragnn/core/interfaces/component.h"
  17. #include "dragnn/core/interfaces/transition_state.h"
  18. #include "dragnn/io/sentence_input_batch.h"
  19. #include "dragnn/protos/data.pb.h"
  20. #include "syntaxnet/base.h"
  21. namespace syntaxnet {
  22. namespace dragnn {
  23. namespace {
  24. // A component that does not create its own transition states; instead, it
  25. // simply forwards the states of the previous component. Does not support all
  26. // methods. Intended for "compute-only" bulk components that only use linked
  27. // features, which use only a small subset of DRAGNN functionality.
  28. class StatelessComponent : public Component {
  29. public:
  30. void InitializeComponent(const ComponentSpec &spec) override {
  31. name_ = spec.name();
  32. }
  33. // Stores the |parent_states| for forwarding to downstream components.
  34. void InitializeData(
  35. const std::vector<std::vector<const TransitionState *>> &parent_states,
  36. int max_beam_size, InputBatchCache *input_data) override {
  37. // Must use SentenceInputBatch to match SyntaxNetComponent.
  38. batch_size_ = input_data->GetAs<SentenceInputBatch>()->data()->size();
  39. beam_size_ = max_beam_size;
  40. parent_states_ = parent_states;
  41. // The beam should be wide enough for the previous component.
  42. for (const auto &beam : parent_states) {
  43. CHECK_LE(beam.size(), beam_size_);
  44. }
  45. }
  46. // Forwards the states of the previous component.
  47. std::vector<std::vector<const TransitionState *>> GetBeam() override {
  48. return parent_states_;
  49. }
  50. // Forwards the |current_index| to the previous component.
  51. int GetSourceBeamIndex(int current_index, int batch) const override {
  52. return current_index;
  53. }
  54. string Name() const override { return name_; }
  55. int BeamSize() const override { return beam_size_; }
  56. int BatchSize() const override { return batch_size_; }
  57. int StepsTaken(int batch_index) const override { return 0; }
  58. bool IsReady() const override { return true; }
  59. bool IsTerminal() const override { return true; }
  60. void FinalizeData() override {}
  61. void ResetComponent() override {}
  62. void InitializeTracing() override {}
  63. void DisableTracing() override {}
  64. std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
  65. return {};
  66. }
  67. // Unsupported methods.
  68. int GetBeamIndexAtStep(int step, int current_index,
  69. int batch) const override {
  70. LOG(FATAL) << "[" << name_ << "] Method not supported";
  71. return 0;
  72. }
  73. std::function<int(int, int, int)> GetStepLookupFunction(
  74. const string &method) override {
  75. LOG(FATAL) << "[" << name_ << "] Method not supported";
  76. return nullptr;
  77. }
  78. void AdvanceFromPrediction(const float transition_matrix[],
  79. int matrix_length) override {
  80. LOG(FATAL) << "[" << name_ << "] Method not supported";
  81. }
  82. void AdvanceFromOracle() override {
  83. LOG(FATAL) << "[" << name_ << "] Method not supported";
  84. }
  85. std::vector<std::vector<int>> GetOracleLabels() const override {
  86. LOG(FATAL) << "[" << name_ << "] Method not supported";
  87. return {};
  88. }
  89. int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
  90. std::function<int64 *(int)> allocate_ids,
  91. std::function<float *(int)> allocate_weights,
  92. int channel_id) const override {
  93. LOG(FATAL) << "[" << name_ << "] Method not supported";
  94. return 0;
  95. }
  96. int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
  97. LOG(FATAL) << "[" << name_ << "] Method not supported";
  98. return 0;
  99. }
  100. std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
  101. LOG(FATAL) << "[" << name_ << "] Method not supported";
  102. return {};
  103. }
  104. void AddTranslatedLinkFeaturesToTrace(
  105. const std::vector<LinkFeatures> &features, int channel_id) override {
  106. LOG(FATAL) << "[" << name_ << "] Method not supported";
  107. }
  108. private:
  109. string name_; // component name
  110. int batch_size_ = 1; // number of sentences in current batch
  111. int beam_size_ = 1; // maximum beam size
  112. // Parent states passed to InitializeData(), and passed along in GetBeam().
  113. std::vector<std::vector<const TransitionState *>> parent_states_;
  114. };
  115. REGISTER_DRAGNN_COMPONENT(StatelessComponent);
  116. } // namespace
  117. } // namespace dragnn
  118. } // namespace syntaxnet