beam.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  3. #include <algorithm>
  4. #include <cmath>
  5. #include <memory>
  6. #include <vector>
  7. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  8. #include "dragnn/core/interfaces/transition_state.h"
  9. #include "tensorflow/core/platform/logging.h"
  10. namespace syntaxnet {
  11. namespace dragnn {
  12. // The Beam class wraps the logic necessary to advance a set of transition
  13. // states for an arbitrary Component. Because the Beam class is generic, it
  14. // doesn't know how to act on the states it is provided - the instantiating
  15. // Component is expected to provide it the three functions it needs to interact
  16. // with that Component's TransitionState subclasses.
  17. template <typename T>
  18. class Beam {
  19. public:
  20. // Creates a new Beam which can grow up to max_size elements.
  21. explicit Beam(int max_size) : max_size_(max_size), num_steps_(0) {
  22. VLOG(2) << "Creating beam with max size " << max_size_;
  23. static_assert(
  24. std::is_base_of<CloneableTransitionState<T>, T>::value,
  25. "This class must be instantiated to use a CloneableTransitionState");
  26. }
  27. // Sets the Beam functions, as follows:
  28. // bool is_allowed(TransitionState *, int): Return true if transition 'int' is
  29. // allowed for transition state 'TransitionState *'.
  30. // void perform_transition(TransitionState *, int): Performs transition 'int'
  31. // on transition state 'TransitionState *'.
  32. // int oracle_function(TransitionState *): Returns the oracle-specified action
  33. // for transition state 'TransitionState *'.
  34. void SetFunctions(std::function<bool(T *, int)> is_allowed,
  35. std::function<bool(T *)> is_final,
  36. std::function<void(T *, int)> perform_transition,
  37. std::function<int(T *)> oracle_function) {
  38. is_allowed_ = is_allowed;
  39. is_final_ = is_final;
  40. perform_transition_ = perform_transition;
  41. oracle_function_ = oracle_function;
  42. }
  43. // Resets the Beam and initializes it with the given set of states. The Beam
  44. // takes ownership of these TransitionStates.
  45. void Init(std::vector<std::unique_ptr<T>> initial_states) {
  46. VLOG(2) << "Initializing beam. Beam max size is " << max_size_;
  47. CHECK_LE(initial_states.size(), max_size_)
  48. << "Attempted to initialize a beam with more states ("
  49. << initial_states.size() << ") than the max size " << max_size_;
  50. beam_ = std::move(initial_states);
  51. std::vector<int> previous_beam_indices(max_size_, -1);
  52. for (int i = 0; i < beam_.size(); ++i) {
  53. previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
  54. beam_[i]->SetBeamIndex(i);
  55. }
  56. beam_index_history_.emplace_back(previous_beam_indices);
  57. }
  58. // Advances the Beam from the given transition matrix.
  59. void AdvanceFromPrediction(const float transition_matrix[], int matrix_length,
  60. int num_actions) {
  61. // Ensure that the transition matrix is the correct size. All underlying
  62. // states should have the same transition profile, so using the one at 0
  63. // should be safe.
  64. CHECK_EQ(matrix_length, max_size_ * num_actions)
  65. << "Transition matrix size does not match max beam size * number of "
  66. "state transitions!";
  67. if (max_size_ == 1) {
  68. // In the case where beam size is 1, we can advance by simply finding the
  69. // highest score and advancing the beam state in place.
  70. VLOG(2) << "Beam size is 1. Using fast beam path.";
  71. int best_action = -1;
  72. float best_score = -INFINITY;
  73. auto &state = beam_[0];
  74. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  75. if (is_allowed_(state.get(), action_idx) &&
  76. transition_matrix[action_idx] > best_score) {
  77. best_score = transition_matrix[action_idx];
  78. best_action = action_idx;
  79. }
  80. }
  81. CHECK_GE(best_action, 0) << "Num actions: " << num_actions
  82. << " score[0]: " << transition_matrix[0];
  83. perform_transition_(state.get(), best_action);
  84. const float new_score = state->GetScore() + best_score;
  85. state->SetScore(new_score);
  86. state->SetBeamIndex(0);
  87. } else {
  88. // Create the vector of all possible transitions, along with their scores.
  89. std::vector<Transition> candidates;
  90. // Iterate through all beams, examining all actions for each beam.
  91. for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
  92. const auto &state = beam_[beam_idx];
  93. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  94. // If the action is allowed, calculate the proposed new score and add
  95. // the candidate action to the vector of all actions at this state.
  96. if (is_allowed_(state.get(), action_idx)) {
  97. Transition candidate;
  98. // The matrix is laid out by beam index, with a linear set of
  99. // actions for that index - so beam N's actions start at [nr. of
  100. // actions]*[N].
  101. const int matrix_idx = action_idx + beam_idx * num_actions;
  102. CHECK_LT(matrix_idx, matrix_length)
  103. << "Matrix index out of bounds!";
  104. const double score_delta = transition_matrix[matrix_idx];
  105. CHECK(!std::isnan(score_delta));
  106. candidate.source_idx = beam_idx;
  107. candidate.action = action_idx;
  108. candidate.resulting_score = state->GetScore() + score_delta;
  109. candidates.emplace_back(candidate);
  110. }
  111. }
  112. }
  113. // Sort the vector of all possible transitions and scores.
  114. const auto comparator = [](const Transition &a, const Transition &b) {
  115. return a.resulting_score > b.resulting_score;
  116. };
  117. std::sort(candidates.begin(), candidates.end(), comparator);
  118. // Apply the top transitions, up to a maximum of 'max_size_'.
  119. std::vector<std::unique_ptr<T>> new_beam;
  120. std::vector<int> previous_beam_indices(max_size_, -1);
  121. const int beam_size =
  122. std::min(max_size_, static_cast<int>(candidates.size()));
  123. VLOG(2) << "Previous beam size = " << beam_.size();
  124. VLOG(2) << "New beam size = " << beam_size;
  125. VLOG(2) << "Maximum beam size = " << max_size_;
  126. for (int i = 0; i < beam_size; ++i) {
  127. // Get the source of the i'th transition.
  128. const auto &transition = candidates[i];
  129. VLOG(2) << "Taking transition with score: "
  130. << transition.resulting_score
  131. << " and action: " << transition.action;
  132. VLOG(2) << "transition.source_idx = " << transition.source_idx;
  133. const auto &source = beam_[transition.source_idx];
  134. // Put the new transition on the new state beam.
  135. auto new_state = source->Clone();
  136. perform_transition_(new_state.get(), transition.action);
  137. new_state->SetScore(transition.resulting_score);
  138. new_state->SetBeamIndex(i);
  139. previous_beam_indices.at(i) = transition.source_idx;
  140. new_beam.emplace_back(std::move(new_state));
  141. }
  142. beam_ = std::move(new_beam);
  143. beam_index_history_.emplace_back(previous_beam_indices);
  144. }
  145. ++num_steps_;
  146. }
  147. // Advances the Beam from the state oracles.
  148. void AdvanceFromOracle() {
  149. std::vector<int> previous_beam_indices(max_size_, -1);
  150. for (int i = 0; i < beam_.size(); ++i) {
  151. previous_beam_indices.at(i) = i;
  152. if (is_final_(beam_[i].get())) continue;
  153. const auto oracle_label = oracle_function_(beam_[i].get());
  154. VLOG(2) << "AdvanceFromOracle beam_index:" << i
  155. << " oracle_label:" << oracle_label;
  156. perform_transition_(beam_[i].get(), oracle_label);
  157. beam_[i]->SetScore(0.0);
  158. beam_[i]->SetBeamIndex(i);
  159. }
  160. if (max_size_ > 1) {
  161. beam_index_history_.emplace_back(previous_beam_indices);
  162. }
  163. num_steps_++;
  164. }
  165. // Returns true if all states in the beam are final.
  166. bool IsTerminal() {
  167. for (auto &state : beam_) {
  168. if (!is_final_(state.get())) {
  169. return false;
  170. }
  171. }
  172. return true;
  173. }
  174. // Destroys the states held by this beam and resets its history.
  175. void Reset() {
  176. beam_.clear();
  177. beam_index_history_.clear();
  178. num_steps_ = 0;
  179. }
  180. // Given an index into the current beam, determine the index of the item's
  181. // parent at beam step "step", which should be less than the total number
  182. // of steps taken by this beam.
  183. int FindPreviousIndex(int current_index, int step) const {
  184. VLOG(2) << "FindPreviousIndex requested for current_index:" << current_index
  185. << " at step:" << step;
  186. if (VLOG_IS_ON(2)) {
  187. int step_index = 0;
  188. for (const auto &step : beam_index_history_) {
  189. string row =
  190. "Step " + std::to_string(step_index) + " element source slot: ";
  191. for (const auto &index : step) {
  192. if (index == -1) {
  193. row += " X";
  194. } else {
  195. row += " " + std::to_string(index);
  196. }
  197. }
  198. VLOG(2) << row;
  199. ++step_index;
  200. }
  201. }
  202. // If the max size of the beam is 1, make sure the steps are in sync with
  203. // the size.
  204. if (max_size_ > 1) {
  205. CHECK(num_steps_ == beam_index_history_.size() - 1);
  206. }
  207. // Check if the step is too far into the past or future.
  208. if (step < 0 || step > num_steps_) {
  209. return -1;
  210. }
  211. // Check that the index is within the beam.
  212. if (current_index < 0 || current_index >= max_size_) {
  213. return -1;
  214. }
  215. // If the max size of the beam is 1, always return 0.
  216. if (max_size_ == 1) {
  217. return 0;
  218. }
  219. // Check that the start index isn't -1; -1 means that we don't have an
  220. // actual transition state in that beam slot.
  221. if (beam_index_history_.back().at(current_index) == -1) {
  222. return -1;
  223. }
  224. int beam_index = current_index;
  225. for (int i = beam_index_history_.size() - 1; i >= step; --i) {
  226. beam_index = beam_index_history_.at(i).at(beam_index);
  227. }
  228. CHECK_GE(beam_index, 0);
  229. VLOG(2) << "Index is " << beam_index;
  230. return beam_index;
  231. }
  232. // Returns the current state of the beam.
  233. std::vector<const TransitionState *> beam() const {
  234. std::vector<const TransitionState *> state_ptrs;
  235. for (const auto &beam_state : beam_) {
  236. state_ptrs.emplace_back(beam_state.get());
  237. }
  238. return state_ptrs;
  239. }
  240. // Returns the beam at the current state index.
  241. T *beam_state(int beam_index) { return beam_.at(beam_index).get(); }
  242. // Returns the raw history vectors for this beam.
  243. const std::vector<std::vector<int>> &history() {
  244. if (max_size_ == 1) {
  245. // If max size is 1, we haven't been keeping track of the beam. Quick
  246. // create it.
  247. beam_index_history_.clear();
  248. beam_index_history_.push_back({beam_[0]->ParentBeamIndex()});
  249. for (int i = 0; i < num_steps_; ++i) {
  250. beam_index_history_.push_back({0});
  251. }
  252. }
  253. return beam_index_history_;
  254. }
  255. // Sets the max size of the beam.
  256. void SetMaxSize(int max_size) {
  257. max_size_ = max_size;
  258. Reset();
  259. }
  260. // Returns the number of steps taken so far.
  261. const int num_steps() const { return num_steps_; }
  262. // Returns the max size of this beam.
  263. const int max_size() const { return max_size_; }
  264. // Returns the current size of the beam.
  265. const int size() const { return beam_.size(); }
  266. private:
  267. // Associates an action taken on an index into current_state_ with a score.
  268. struct Transition {
  269. // The index of the source item.
  270. int source_idx;
  271. // The index of the action being taken.
  272. int action;
  273. // The score of the full derivation.
  274. double resulting_score;
  275. };
  276. // The maximum beam size.
  277. int max_size_;
  278. // The current beam.
  279. std::vector<std::unique_ptr<T>> beam_;
  280. // Function to check if a transition is allowed for a given state.
  281. std::function<bool(T *, int)> is_allowed_;
  282. // Function to check if a state is final.
  283. std::function<int(T *)> is_final_;
  284. // Function to perform a transition on a given state.
  285. std::function<void(T *, int)> perform_transition_;
  286. // Function to provide the oracle action for a given state.
  287. std::function<int(T *)> oracle_function_;
  288. // The history of the states in this beam. The vector indexes across steps.
  289. // For every step, there is a vector in the vector. This inner vector denotes
  290. // the state of the beam at that step, and contains the beam index that
  291. // was transitioned to create the transition state at that index (so,
  292. // if at step 2 the transition state at beam index 4 was created by applying
  293. // a transition to the state in beam index 3 during step 1, the query would
  294. // be "beam_index_history_.at(2).at(4)" and the value would be 3. Empty beam
  295. // states will return -1.
  296. std::vector<std::vector<int>> beam_index_history_;
  297. // The number of steps taken so far.
  298. int num_steps_;
  299. };
  300. } // namespace dragnn
  301. } // namespace syntaxnet
  302. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_