index_translator.h 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
  17. #include <memory>
  18. #include <vector>
  19. #include "dragnn/core/interfaces/component.h"
  20. #include "dragnn/core/interfaces/transition_state.h"
  21. namespace syntaxnet {
  22. namespace dragnn {
  23. // A IndexTranslator provides an interface into the data of another component.
  24. // It allows one component to look up a translated array index from the history
  25. // or state of another component.
  26. //
  27. // When it is created, it is passed a pointer to the source component (that is,
  28. // the component whose data it will be accessing) and a string representing the
  29. // type of data access it will perform. There are two universal data access
  30. // methods - "identity" and "history" - and components can declare more via
  31. // their GetStepLookupFunction function.
  32. class IndexTranslator {
  33. public:
  34. // Index into a TensorArray. Provides a given step, and the beam index within
  35. // that step, for TensorArray access to data in the given batch.
  36. struct Index {
  37. int batch_index = -1;
  38. int beam_index = -1;
  39. int step_index = -1;
  40. };
  41. // Creates a new IndexTranslator with access method as determined by the
  42. // passed string. The Translator will walk the path "path" in order, and will
  43. // translate from the last Component in the path.
  44. IndexTranslator(const std::vector<Component *> &path, const string &method);
  45. // Returns an index in (step, beam, batch) index space as computed from the
  46. // given feature value.
  47. Index Translate(int batch_index, int beam_index, int feature_value);
  48. // Returns the path to be walked by this translator.
  49. const std::vector<Component *> &path() const { return path_; }
  50. // Returns the method to be used by this translator.
  51. const string &method() const { return method_; }
  52. private:
  53. // The ordered list of components that must be walked to get from the
  54. // requesting component to the source component. This vector has the
  55. // requesting component at index 0 and the source component at the end. If
  56. // the requesting component is the source component, this vector has only one
  57. // entry.
  58. const std::vector<Component *> path_;
  59. // The function this translator will use to look up the step in the source
  60. // component. The function is invoked as:
  61. // step_lookup_(batch_index, beam_index, feature).
  62. std::function<int(int, int, int)> step_lookup_;
  63. // This translator's method.
  64. string method_;
  65. };
  66. } // namespace dragnn
  67. } // namespace syntaxnet
  68. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_