bulk_feature_extractor.h 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
  3. #include <functional>
  4. #include <utility>
  5. #include "tensorflow/core/platform/types.h"
  6. namespace syntaxnet {
  7. namespace dragnn {
  8. // Provides a wrapper for allocator functions and padding data for the Bulk
  9. // ExtractFixedFeatures operation.
  10. class BulkFeatureExtractor {
  11. public:
  12. // Create a BulkFeatureExtractor with the given allocator functions and
  13. // padding. The allocator functions should take a channel and an element
  14. // count and return a contigous block of memory that is associated with that
  15. // channel (the caller can decide what that means). If use_padding is true,
  16. // the provided pad_to_step and pad_to_element will be used to calculate
  17. // the ID size.
  18. BulkFeatureExtractor(
  19. std::function<tensorflow::int32 *(int channel, int num_elements)>
  20. allocate_indices_by_channel,
  21. std::function<tensorflow::int64 *(int channel, int num_elements)>
  22. allocate_ids_by_channel,
  23. std::function<float *(int channel, int num_elements)>
  24. allocate_weights_by_channel,
  25. bool use_padding, int pad_to_step, int pad_to_element)
  26. : use_padding_(use_padding),
  27. pad_to_step_(pad_to_step),
  28. pad_to_element_(pad_to_element),
  29. allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
  30. allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
  31. allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
  32. // Create a BulkFeatureExtractor with allocator functions as above, but with
  33. // use_padding set to False. Useful when you know your caller will never
  34. // need to pad.
  35. BulkFeatureExtractor(
  36. std::function<tensorflow::int32 *(int channel, int num_elements)>
  37. allocate_indices_by_channel,
  38. std::function<tensorflow::int64 *(int channel, int num_elements)>
  39. allocate_ids_by_channel,
  40. std::function<float *(int channel, int num_elements)>
  41. allocate_weights_by_channel)
  42. : use_padding_(false),
  43. pad_to_step_(-1),
  44. pad_to_element_(-1),
  45. allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
  46. allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
  47. allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
  48. // Invoke the index memory allocator.
  49. tensorflow::int32 *AllocateIndexMemory(int channel, int num_elements) const {
  50. return allocate_indices_by_channel_(channel, num_elements);
  51. }
  52. // Invoke the ID memory allocator.
  53. tensorflow::int64 *AllocateIdMemory(int channel, int num_elements) const {
  54. return allocate_ids_by_channel_(channel, num_elements);
  55. }
  56. // Invoke the weight memory allocator.
  57. float *AllocateWeightMemory(int channel, int num_elements) const {
  58. return allocate_weights_by_channel_(channel, num_elements);
  59. }
  60. // Given the total number of steps and total number of elements for a given
  61. // feature, calculate the index (not ID) of that feature. Based on how the
  62. // BulkFeatureExtractor was constructed, it may use the given number of steps
  63. // and number of elements, or it may use the passed padded number.
  64. int GetIndex(int total_steps, int num_elements, int feature_idx,
  65. int element_idx, int step_idx) const {
  66. const int steps = (use_padding_) ? pad_to_step_ : total_steps;
  67. const int elements = (use_padding_) ? pad_to_element_ : num_elements;
  68. const int feature_offset = elements * steps;
  69. const int element_offset = steps;
  70. return (feature_idx * feature_offset) + (element_idx * element_offset) +
  71. step_idx;
  72. }
  73. private:
  74. const bool use_padding_;
  75. const int pad_to_step_;
  76. const int pad_to_element_;
  77. const std::function<tensorflow::int32 *(int, int)>
  78. allocate_indices_by_channel_;
  79. const std::function<tensorflow::int64 *(int, int)> allocate_ids_by_channel_;
  80. const std::function<float *(int, int)> allocate_weights_by_channel_;
  81. };
  82. } // namespace dragnn
  83. } // namespace syntaxnet
  84. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_