index_translator.h 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
  3. #include <memory>
  4. #include <vector>
  5. #include "dragnn/core/interfaces/component.h"
  6. #include "dragnn/core/interfaces/transition_state.h"
  7. namespace syntaxnet {
  8. namespace dragnn {
  9. // A IndexTranslator provides an interface into the data of another component.
  10. // It allows one component to look up a translated array index from the history
  11. // or state of another component.
  12. //
  13. // When it is created, it is passed a pointer to the source component (that is,
  14. // the component whose data it will be accessing) and a string representing the
  15. // type of data access it will perform. There are two universal data access
  16. // methods - "identity" and "history" - and components can declare more via
  17. // their GetStepLookupFunction function.
  18. class IndexTranslator {
  19. public:
  20. // Index into a TensorArray. Provides a given step, and the beam index within
  21. // that step, for TensorArray access to data in the given batch.
  22. struct Index {
  23. int batch_index = -1;
  24. int beam_index = -1;
  25. int step_index = -1;
  26. };
  27. // Creates a new IndexTranslator with access method as determined by the
  28. // passed string. The Translator will walk the path "path" in order, and will
  29. // translate from the last Component in the path.
  30. IndexTranslator(const std::vector<Component *> &path, const string &method);
  31. // Returns an index in (step, beam, batch) index space as computed from the
  32. // given feature value.
  33. Index Translate(int batch_index, int beam_index, int feature_value);
  34. // Returns the path to be walked by this translator.
  35. const std::vector<Component *> &path() const { return path_; }
  36. // Returns the method to be used by this translator.
  37. const string &method() const { return method_; }
  38. private:
  39. // The ordered list of components that must be walked to get from the
  40. // requesting component to the source component. This vector has the
  41. // requesting component at index 0 and the source component at the end. If
  42. // the requesting component is the source component, this vector has only one
  43. // entry.
  44. const std::vector<Component *> path_;
  45. // The function this translator will use to look up the step in the source
  46. // component. The function is invoked as:
  47. // step_lookup_(batch_index, beam_index, feature).
  48. std::function<int(int, int, int)> step_lookup_;
  49. // This translator's method.
  50. string method_;
  51. };
  52. } // namespace dragnn
  53. } // namespace syntaxnet
  54. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_