beam.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  16. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  17. #include <algorithm>
  18. #include <cmath>
  19. #include <memory>
  20. #include <vector>
  21. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  22. #include "dragnn/core/interfaces/transition_state.h"
  23. #include "tensorflow/core/platform/logging.h"
  24. namespace syntaxnet {
  25. namespace dragnn {
  26. // The Beam class wraps the logic necessary to advance a set of transition
  27. // states for an arbitrary Component. Because the Beam class is generic, it
  28. // doesn't know how to act on the states it is provided - the instantiating
  29. // Component is expected to provide it the three functions it needs to interact
  30. // with that Component's TransitionState subclasses.
  31. template <typename T>
  32. class Beam {
  33. public:
  34. // Creates a new Beam which can grow up to max_size elements.
  35. explicit Beam(int max_size) : max_size_(max_size), num_steps_(0) {
  36. VLOG(2) << "Creating beam with max size " << max_size_;
  37. static_assert(
  38. std::is_base_of<CloneableTransitionState<T>, T>::value,
  39. "This class must be instantiated to use a CloneableTransitionState");
  40. }
  41. // Sets the Beam functions, as follows:
  42. // bool is_allowed(TransitionState *, int): Return true if transition 'int' is
  43. // allowed for transition state 'TransitionState *'.
  44. // void perform_transition(TransitionState *, int): Performs transition 'int'
  45. // on transition state 'TransitionState *'.
  46. // int oracle_function(TransitionState *): Returns the oracle-specified action
  47. // for transition state 'TransitionState *'.
  48. void SetFunctions(std::function<bool(T *, int)> is_allowed,
  49. std::function<bool(T *)> is_final,
  50. std::function<void(T *, int)> perform_transition,
  51. std::function<int(T *)> oracle_function) {
  52. is_allowed_ = is_allowed;
  53. is_final_ = is_final;
  54. perform_transition_ = perform_transition;
  55. oracle_function_ = oracle_function;
  56. }
  57. // Resets the Beam and initializes it with the given set of states. The Beam
  58. // takes ownership of these TransitionStates.
  59. void Init(std::vector<std::unique_ptr<T>> initial_states) {
  60. VLOG(2) << "Initializing beam. Beam max size is " << max_size_;
  61. CHECK_LE(initial_states.size(), max_size_)
  62. << "Attempted to initialize a beam with more states ("
  63. << initial_states.size() << ") than the max size " << max_size_;
  64. beam_ = std::move(initial_states);
  65. std::vector<int> previous_beam_indices(max_size_, -1);
  66. for (int i = 0; i < beam_.size(); ++i) {
  67. previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
  68. beam_[i]->SetBeamIndex(i);
  69. }
  70. beam_index_history_.emplace_back(previous_beam_indices);
  71. }
  72. // Advances the Beam from the given transition matrix.
  73. void AdvanceFromPrediction(const float transition_matrix[], int matrix_length,
  74. int num_actions) {
  75. // Ensure that the transition matrix is the correct size. All underlying
  76. // states should have the same transition profile, so using the one at 0
  77. // should be safe.
  78. CHECK_EQ(matrix_length, max_size_ * num_actions)
  79. << "Transition matrix size does not match max beam size * number of "
  80. "state transitions!";
  81. if (max_size_ == 1) {
  82. // In the case where beam size is 1, we can advance by simply finding the
  83. // highest score and advancing the beam state in place.
  84. VLOG(2) << "Beam size is 1. Using fast beam path.";
  85. int best_action = -1;
  86. float best_score = -INFINITY;
  87. auto &state = beam_[0];
  88. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  89. if (is_allowed_(state.get(), action_idx) &&
  90. transition_matrix[action_idx] > best_score) {
  91. best_score = transition_matrix[action_idx];
  92. best_action = action_idx;
  93. }
  94. }
  95. CHECK_GE(best_action, 0) << "Num actions: " << num_actions
  96. << " score[0]: " << transition_matrix[0];
  97. perform_transition_(state.get(), best_action);
  98. const float new_score = state->GetScore() + best_score;
  99. state->SetScore(new_score);
  100. state->SetBeamIndex(0);
  101. } else {
  102. // Create the vector of all possible transitions, along with their scores.
  103. std::vector<Transition> candidates;
  104. // Iterate through all beams, examining all actions for each beam.
  105. for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
  106. const auto &state = beam_[beam_idx];
  107. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  108. // If the action is allowed, calculate the proposed new score and add
  109. // the candidate action to the vector of all actions at this state.
  110. if (is_allowed_(state.get(), action_idx)) {
  111. Transition candidate;
  112. // The matrix is laid out by beam index, with a linear set of
  113. // actions for that index - so beam N's actions start at [nr. of
  114. // actions]*[N].
  115. const int matrix_idx = action_idx + beam_idx * num_actions;
  116. CHECK_LT(matrix_idx, matrix_length)
  117. << "Matrix index out of bounds!";
  118. const double score_delta = transition_matrix[matrix_idx];
  119. CHECK(!std::isnan(score_delta));
  120. candidate.source_idx = beam_idx;
  121. candidate.action = action_idx;
  122. candidate.resulting_score = state->GetScore() + score_delta;
  123. candidates.emplace_back(candidate);
  124. }
  125. }
  126. }
  127. // Sort the vector of all possible transitions and scores.
  128. const auto comparator = [](const Transition &a, const Transition &b) {
  129. return a.resulting_score > b.resulting_score;
  130. };
  131. std::stable_sort(candidates.begin(), candidates.end(), comparator);
  132. // Apply the top transitions, up to a maximum of 'max_size_'.
  133. std::vector<std::unique_ptr<T>> new_beam;
  134. std::vector<int> previous_beam_indices(max_size_, -1);
  135. const int beam_size =
  136. std::min(max_size_, static_cast<int>(candidates.size()));
  137. VLOG(2) << "Previous beam size = " << beam_.size();
  138. VLOG(2) << "New beam size = " << beam_size;
  139. VLOG(2) << "Maximum beam size = " << max_size_;
  140. for (int i = 0; i < beam_size; ++i) {
  141. // Get the source of the i'th transition.
  142. const auto &transition = candidates[i];
  143. VLOG(2) << "Taking transition with score: "
  144. << transition.resulting_score
  145. << " and action: " << transition.action;
  146. VLOG(2) << "transition.source_idx = " << transition.source_idx;
  147. const auto &source = beam_[transition.source_idx];
  148. // Put the new transition on the new state beam.
  149. auto new_state = source->Clone();
  150. perform_transition_(new_state.get(), transition.action);
  151. new_state->SetScore(transition.resulting_score);
  152. new_state->SetBeamIndex(i);
  153. previous_beam_indices.at(i) = transition.source_idx;
  154. new_beam.emplace_back(std::move(new_state));
  155. }
  156. beam_ = std::move(new_beam);
  157. beam_index_history_.emplace_back(previous_beam_indices);
  158. }
  159. ++num_steps_;
  160. }
  161. // Advances the Beam from the state oracles.
  162. void AdvanceFromOracle() {
  163. std::vector<int> previous_beam_indices(max_size_, -1);
  164. for (int i = 0; i < beam_.size(); ++i) {
  165. previous_beam_indices.at(i) = i;
  166. if (is_final_(beam_[i].get())) continue;
  167. const auto oracle_label = oracle_function_(beam_[i].get());
  168. VLOG(2) << "AdvanceFromOracle beam_index:" << i
  169. << " oracle_label:" << oracle_label;
  170. perform_transition_(beam_[i].get(), oracle_label);
  171. beam_[i]->SetScore(0.0);
  172. beam_[i]->SetBeamIndex(i);
  173. }
  174. if (max_size_ > 1) {
  175. beam_index_history_.emplace_back(previous_beam_indices);
  176. }
  177. num_steps_++;
  178. }
  179. // Returns true if all states in the beam are final.
  180. bool IsTerminal() {
  181. for (auto &state : beam_) {
  182. if (!is_final_(state.get())) {
  183. return false;
  184. }
  185. }
  186. return true;
  187. }
  188. // Destroys the states held by this beam and resets its history.
  189. void Reset() {
  190. beam_.clear();
  191. beam_index_history_.clear();
  192. num_steps_ = 0;
  193. }
  194. // Given an index into the current beam, determine the index of the item's
  195. // parent at beam step "step", which should be less than the total number
  196. // of steps taken by this beam.
  197. int FindPreviousIndex(int current_index, int step) const {
  198. VLOG(2) << "FindPreviousIndex requested for current_index:" << current_index
  199. << " at step:" << step;
  200. if (VLOG_IS_ON(2)) {
  201. int step_index = 0;
  202. for (const auto &step : beam_index_history_) {
  203. string row =
  204. "Step " + std::to_string(step_index) + " element source slot: ";
  205. for (const auto &index : step) {
  206. if (index == -1) {
  207. row += " X";
  208. } else {
  209. row += " " + std::to_string(index);
  210. }
  211. }
  212. VLOG(2) << row;
  213. ++step_index;
  214. }
  215. }
  216. // If the max size of the beam is 1, make sure the steps are in sync with
  217. // the size.
  218. if (max_size_ > 1) {
  219. CHECK(num_steps_ == beam_index_history_.size() - 1);
  220. }
  221. // Check if the step is too far into the past or future.
  222. if (step < 0 || step > num_steps_) {
  223. return -1;
  224. }
  225. // Check that the index is within the beam.
  226. if (current_index < 0 || current_index >= max_size_) {
  227. return -1;
  228. }
  229. // If the max size of the beam is 1, always return 0.
  230. if (max_size_ == 1) {
  231. return 0;
  232. }
  233. // Check that the start index isn't -1; -1 means that we don't have an
  234. // actual transition state in that beam slot.
  235. if (beam_index_history_.back().at(current_index) == -1) {
  236. return -1;
  237. }
  238. int beam_index = current_index;
  239. for (int i = beam_index_history_.size() - 1; i >= step; --i) {
  240. beam_index = beam_index_history_.at(i).at(beam_index);
  241. }
  242. CHECK_GE(beam_index, 0);
  243. VLOG(2) << "Index is " << beam_index;
  244. return beam_index;
  245. }
  246. // Returns the current state of the beam.
  247. std::vector<const TransitionState *> beam() const {
  248. std::vector<const TransitionState *> state_ptrs;
  249. for (const auto &beam_state : beam_) {
  250. state_ptrs.emplace_back(beam_state.get());
  251. }
  252. return state_ptrs;
  253. }
  254. // Returns the beam at the current state index.
  255. T *beam_state(int beam_index) { return beam_.at(beam_index).get(); }
  256. // Returns the raw history vectors for this beam.
  257. const std::vector<std::vector<int>> &history() {
  258. if (max_size_ == 1) {
  259. // If max size is 1, we haven't been keeping track of the beam. Quick
  260. // create it.
  261. beam_index_history_.clear();
  262. beam_index_history_.push_back({beam_[0]->ParentBeamIndex()});
  263. for (int i = 0; i < num_steps_; ++i) {
  264. beam_index_history_.push_back({0});
  265. }
  266. }
  267. return beam_index_history_;
  268. }
  269. // Sets the max size of the beam.
  270. void SetMaxSize(int max_size) {
  271. max_size_ = max_size;
  272. Reset();
  273. }
  274. // Returns the number of steps taken so far.
  275. const int num_steps() const { return num_steps_; }
  276. // Returns the max size of this beam.
  277. const int max_size() const { return max_size_; }
  278. // Returns the current size of the beam.
  279. const int size() const { return beam_.size(); }
  280. private:
  281. // Associates an action taken on an index into current_state_ with a score.
  282. struct Transition {
  283. // The index of the source item.
  284. int source_idx;
  285. // The index of the action being taken.
  286. int action;
  287. // The score of the full derivation.
  288. double resulting_score;
  289. };
  290. // The maximum beam size.
  291. int max_size_;
  292. // The current beam.
  293. std::vector<std::unique_ptr<T>> beam_;
  294. // Function to check if a transition is allowed for a given state.
  295. std::function<bool(T *, int)> is_allowed_;
  296. // Function to check if a state is final.
  297. std::function<int(T *)> is_final_;
  298. // Function to perform a transition on a given state.
  299. std::function<void(T *, int)> perform_transition_;
  300. // Function to provide the oracle action for a given state.
  301. std::function<int(T *)> oracle_function_;
  302. // The history of the states in this beam. The vector indexes across steps.
  303. // For every step, there is a vector in the vector. This inner vector denotes
  304. // the state of the beam at that step, and contains the beam index that
  305. // was transitioned to create the transition state at that index (so,
  306. // if at step 2 the transition state at beam index 4 was created by applying
  307. // a transition to the state in beam index 3 during step 1, the query would
  308. // be "beam_index_history_.at(2).at(4)" and the value would be 3. Empty beam
  309. // states will return -1.
  310. std::vector<std::vector<int>> beam_index_history_;
  311. // The number of steps taken so far.
  312. int num_steps_;
  313. };
  314. } // namespace dragnn
  315. } // namespace syntaxnet
  316. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_