bulk_feature_extractor.h 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
  17. #include <functional>
  18. #include <utility>
  19. #include "tensorflow/core/platform/types.h"
  20. namespace syntaxnet {
  21. namespace dragnn {
  22. // Provides a wrapper for allocator functions and padding data for the Bulk
  23. // ExtractFixedFeatures operation.
  24. class BulkFeatureExtractor {
  25. public:
  26. // Create a BulkFeatureExtractor with the given allocator functions and
  27. // padding. The allocator functions should take a channel and an element
  28. // count and return a contigous block of memory that is associated with that
  29. // channel (the caller can decide what that means). If use_padding is true,
  30. // the provided pad_to_step and pad_to_element will be used to calculate
  31. // the ID size.
  32. BulkFeatureExtractor(
  33. std::function<tensorflow::int32 *(int channel, int num_elements)>
  34. allocate_indices_by_channel,
  35. std::function<tensorflow::int64 *(int channel, int num_elements)>
  36. allocate_ids_by_channel,
  37. std::function<float *(int channel, int num_elements)>
  38. allocate_weights_by_channel,
  39. bool use_padding, int pad_to_step, int pad_to_element)
  40. : use_padding_(use_padding),
  41. pad_to_step_(pad_to_step),
  42. pad_to_element_(pad_to_element),
  43. allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
  44. allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
  45. allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
  46. // Create a BulkFeatureExtractor with allocator functions as above, but with
  47. // use_padding set to False. Useful when you know your caller will never
  48. // need to pad.
  49. BulkFeatureExtractor(
  50. std::function<tensorflow::int32 *(int channel, int num_elements)>
  51. allocate_indices_by_channel,
  52. std::function<tensorflow::int64 *(int channel, int num_elements)>
  53. allocate_ids_by_channel,
  54. std::function<float *(int channel, int num_elements)>
  55. allocate_weights_by_channel)
  56. : use_padding_(false),
  57. pad_to_step_(-1),
  58. pad_to_element_(-1),
  59. allocate_indices_by_channel_(std::move(allocate_indices_by_channel)),
  60. allocate_ids_by_channel_(std::move(allocate_ids_by_channel)),
  61. allocate_weights_by_channel_(std::move(allocate_weights_by_channel)) {}
  62. // Invoke the index memory allocator.
  63. tensorflow::int32 *AllocateIndexMemory(int channel, int num_elements) const {
  64. return allocate_indices_by_channel_(channel, num_elements);
  65. }
  66. // Invoke the ID memory allocator.
  67. tensorflow::int64 *AllocateIdMemory(int channel, int num_elements) const {
  68. return allocate_ids_by_channel_(channel, num_elements);
  69. }
  70. // Invoke the weight memory allocator.
  71. float *AllocateWeightMemory(int channel, int num_elements) const {
  72. return allocate_weights_by_channel_(channel, num_elements);
  73. }
  74. // Given the total number of steps and total number of elements for a given
  75. // feature, calculate the index (not ID) of that feature. Based on how the
  76. // BulkFeatureExtractor was constructed, it may use the given number of steps
  77. // and number of elements, or it may use the passed padded number.
  78. int GetIndex(int total_steps, int num_elements, int feature_idx,
  79. int element_idx, int step_idx) const {
  80. const int steps = (use_padding_) ? pad_to_step_ : total_steps;
  81. const int elements = (use_padding_) ? pad_to_element_ : num_elements;
  82. const int feature_offset = elements * steps;
  83. const int element_offset = steps;
  84. return (feature_idx * feature_offset) + (element_idx * element_offset) +
  85. step_idx;
  86. }
  87. private:
  88. const bool use_padding_;
  89. const int pad_to_step_;
  90. const int pad_to_element_;
  91. const std::function<tensorflow::int32 *(int, int)>
  92. allocate_indices_by_channel_;
  93. const std::function<tensorflow::int64 *(int, int)> allocate_ids_by_channel_;
  94. const std::function<float *(int, int)> allocate_weights_by_channel_;
  95. };
  96. } // namespace dragnn
  97. } // namespace syntaxnet
  98. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_