beam.h 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  2. #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
  3. #include <algorithm>
  4. #include <memory>
  5. #include <vector>
  6. #include "dragnn/core/interfaces/cloneable_transition_state.h"
  7. #include "dragnn/core/interfaces/transition_state.h"
  8. #include "tensorflow/core/platform/logging.h"
  9. namespace syntaxnet {
  10. namespace dragnn {
  11. // The Beam class wraps the logic necessary to advance a set of transition
  12. // states for an arbitrary Component. Because the Beam class is generic, it
  13. // doesn't know how to act on the states it is provided - the instantiating
  14. // Component is expected to provide it the three functions it needs to interact
  15. // with that Component's TransitionState subclasses.
  16. template <typename T>
  17. class Beam {
  18. public:
  19. // Creates a new Beam which can grow up to max_size elements.
  20. explicit Beam(int max_size) : max_size_(max_size), num_steps_(0) {
  21. VLOG(2) << "Creating beam with max size " << max_size_;
  22. static_assert(
  23. std::is_base_of<CloneableTransitionState<T>, T>::value,
  24. "This class must be instantiated to use a CloneableTransitionState");
  25. }
  26. // Sets the Beam functions, as follows:
  27. // bool is_allowed(TransitionState *, int): Return true if transition 'int' is
  28. // allowed for transition state 'TransitionState *'.
  29. // void perform_transition(TransitionState *, int): Performs transition 'int'
  30. // on transition state 'TransitionState *'.
  31. // int oracle_function(TransitionState *): Returns the oracle-specified action
  32. // for transition state 'TransitionState *'.
  33. void SetFunctions(std::function<bool(T *, int)> is_allowed,
  34. std::function<bool(T *)> is_final,
  35. std::function<void(T *, int)> perform_transition,
  36. std::function<int(T *)> oracle_function) {
  37. is_allowed_ = is_allowed;
  38. is_final_ = is_final;
  39. perform_transition_ = perform_transition;
  40. oracle_function_ = oracle_function;
  41. }
  42. // Resets the Beam and initializes it with the given set of states. The Beam
  43. // takes ownership of these TransitionStates.
  44. void Init(std::vector<std::unique_ptr<T>> initial_states) {
  45. VLOG(2) << "Initializing beam. Beam max size is " << max_size_;
  46. CHECK_LE(initial_states.size(), max_size_)
  47. << "Attempted to initialize a beam with more states ("
  48. << initial_states.size() << ") than the max size " << max_size_;
  49. beam_ = std::move(initial_states);
  50. std::vector<int> previous_beam_indices(max_size_, -1);
  51. for (int i = 0; i < beam_.size(); ++i) {
  52. previous_beam_indices.at(i) = beam_[i]->ParentBeamIndex();
  53. beam_[i]->SetBeamIndex(i);
  54. }
  55. beam_index_history_.emplace_back(previous_beam_indices);
  56. }
  57. // Advances the Beam from the given transition matrix.
  58. void AdvanceFromPrediction(const float transition_matrix[], int matrix_length,
  59. int num_actions) {
  60. // Ensure that the transition matrix is the correct size. All underlying
  61. // states should have the same transition profile, so using the one at 0
  62. // should be safe.
  63. CHECK_EQ(matrix_length, max_size_ * num_actions)
  64. << "Transition matrix size does not match max beam size * number of "
  65. "state transitions!";
  66. if (max_size_ == 1) {
  67. // In the case where beam size is 1, we can advance by simply finding the
  68. // highest score and advancing the beam state in place.
  69. VLOG(2) << "Beam size is 1. Using fast beam path.";
  70. int best_action = -1;
  71. float best_score = -INFINITY;
  72. auto &state = beam_[0];
  73. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  74. if (is_allowed_(state.get(), action_idx) &&
  75. transition_matrix[action_idx] > best_score) {
  76. best_score = transition_matrix[action_idx];
  77. best_action = action_idx;
  78. }
  79. }
  80. CHECK_GE(best_action, 0) << "Num actions: " << num_actions
  81. << " score[0]: " << transition_matrix[0];
  82. perform_transition_(state.get(), best_action);
  83. const float new_score = state->GetScore() + best_score;
  84. state->SetScore(new_score);
  85. state->SetBeamIndex(0);
  86. } else {
  87. // Create the vector of all possible transitions, along with their scores.
  88. std::vector<Transition> candidates;
  89. // Iterate through all beams, examining all actions for each beam.
  90. for (int beam_idx = 0; beam_idx < beam_.size(); ++beam_idx) {
  91. const auto &state = beam_[beam_idx];
  92. for (int action_idx = 0; action_idx < num_actions; ++action_idx) {
  93. // If the action is allowed, calculate the proposed new score and add
  94. // the candidate action to the vector of all actions at this state.
  95. if (is_allowed_(state.get(), action_idx)) {
  96. Transition candidate;
  97. // The matrix is laid out by beam index, with a linear set of
  98. // actions for that index - so beam N's actions start at [nr. of
  99. // actions]*[N].
  100. const int matrix_idx = action_idx + beam_idx * num_actions;
  101. CHECK_LT(matrix_idx, matrix_length)
  102. << "Matrix index out of bounds!";
  103. const double score_delta = transition_matrix[matrix_idx];
  104. CHECK(!isnan(score_delta));
  105. candidate.source_idx = beam_idx;
  106. candidate.action = action_idx;
  107. candidate.resulting_score = state->GetScore() + score_delta;
  108. candidates.emplace_back(candidate);
  109. }
  110. }
  111. }
  112. // Sort the vector of all possible transitions and scores.
  113. const auto comparator = [](const Transition &a, const Transition &b) {
  114. return a.resulting_score > b.resulting_score;
  115. };
  116. std::sort(candidates.begin(), candidates.end(), comparator);
  117. // Apply the top transitions, up to a maximum of 'max_size_'.
  118. std::vector<std::unique_ptr<T>> new_beam;
  119. std::vector<int> previous_beam_indices(max_size_, -1);
  120. const int beam_size =
  121. std::min(max_size_, static_cast<int>(candidates.size()));
  122. VLOG(2) << "Previous beam size = " << beam_.size();
  123. VLOG(2) << "New beam size = " << beam_size;
  124. VLOG(2) << "Maximum beam size = " << max_size_;
  125. for (int i = 0; i < beam_size; ++i) {
  126. // Get the source of the i'th transition.
  127. const auto &transition = candidates[i];
  128. VLOG(2) << "Taking transition with score: "
  129. << transition.resulting_score
  130. << " and action: " << transition.action;
  131. VLOG(2) << "transition.source_idx = " << transition.source_idx;
  132. const auto &source = beam_[transition.source_idx];
  133. // Put the new transition on the new state beam.
  134. auto new_state = source->Clone();
  135. perform_transition_(new_state.get(), transition.action);
  136. new_state->SetScore(transition.resulting_score);
  137. new_state->SetBeamIndex(i);
  138. previous_beam_indices.at(i) = transition.source_idx;
  139. new_beam.emplace_back(std::move(new_state));
  140. }
  141. beam_ = std::move(new_beam);
  142. beam_index_history_.emplace_back(previous_beam_indices);
  143. }
  144. ++num_steps_;
  145. }
  146. // Advances the Beam from the state oracles.
  147. void AdvanceFromOracle() {
  148. std::vector<int> previous_beam_indices(max_size_, -1);
  149. for (int i = 0; i < beam_.size(); ++i) {
  150. previous_beam_indices.at(i) = i;
  151. if (is_final_(beam_[i].get())) continue;
  152. const auto oracle_label = oracle_function_(beam_[i].get());
  153. VLOG(2) << "AdvanceFromOracle beam_index:" << i
  154. << " oracle_label:" << oracle_label;
  155. perform_transition_(beam_[i].get(), oracle_label);
  156. beam_[i]->SetScore(0.0);
  157. beam_[i]->SetBeamIndex(i);
  158. }
  159. if (max_size_ > 1) {
  160. beam_index_history_.emplace_back(previous_beam_indices);
  161. }
  162. num_steps_++;
  163. }
  164. // Returns true if all states in the beam are final.
  165. bool IsTerminal() {
  166. for (auto &state : beam_) {
  167. if (!is_final_(state.get())) {
  168. return false;
  169. }
  170. }
  171. return true;
  172. }
  173. // Destroys the states held by this beam and resets its history.
  174. void Reset() {
  175. beam_.clear();
  176. beam_index_history_.clear();
  177. num_steps_ = 0;
  178. }
  179. // Given an index into the current beam, determine the index of the item's
  180. // parent at beam step "step", which should be less than the total number
  181. // of steps taken by this beam.
  182. int FindPreviousIndex(int current_index, int step) const {
  183. VLOG(2) << "FindPreviousIndex requested for current_index:" << current_index
  184. << " at step:" << step;
  185. if (VLOG_IS_ON(2)) {
  186. int step_index = 0;
  187. for (const auto &step : beam_index_history_) {
  188. string row =
  189. "Step " + std::to_string(step_index) + " element source slot: ";
  190. for (const auto &index : step) {
  191. if (index == -1) {
  192. row += " X";
  193. } else {
  194. row += " " + std::to_string(index);
  195. }
  196. }
  197. VLOG(2) << row;
  198. ++step_index;
  199. }
  200. }
  201. // If the max size of the beam is 1, make sure the steps are in sync with
  202. // the size.
  203. if (max_size_ > 1) {
  204. CHECK(num_steps_ == beam_index_history_.size() - 1);
  205. }
  206. // Check if the step is too far into the past or future.
  207. if (step < 0 || step > num_steps_) {
  208. return -1;
  209. }
  210. // Check that the index is within the beam.
  211. if (current_index < 0 || current_index >= max_size_) {
  212. return -1;
  213. }
  214. // If the max size of the beam is 1, always return 0.
  215. if (max_size_ == 1) {
  216. return 0;
  217. }
  218. // Check that the start index isn't -1; -1 means that we don't have an
  219. // actual transition state in that beam slot.
  220. if (beam_index_history_.back().at(current_index) == -1) {
  221. return -1;
  222. }
  223. int beam_index = current_index;
  224. for (int i = beam_index_history_.size() - 1; i >= step; --i) {
  225. beam_index = beam_index_history_.at(i).at(beam_index);
  226. }
  227. CHECK_GE(beam_index, 0);
  228. VLOG(2) << "Index is " << beam_index;
  229. return beam_index;
  230. }
  231. // Returns the current state of the beam.
  232. std::vector<const TransitionState *> beam() const {
  233. std::vector<const TransitionState *> state_ptrs;
  234. for (const auto &beam_state : beam_) {
  235. state_ptrs.emplace_back(beam_state.get());
  236. }
  237. return state_ptrs;
  238. }
  239. // Returns the beam at the current state index.
  240. T *beam_state(int beam_index) { return beam_.at(beam_index).get(); }
  241. // Returns the raw history vectors for this beam.
  242. const std::vector<std::vector<int>> &history() {
  243. if (max_size_ == 1) {
  244. // If max size is 1, we haven't been keeping track of the beam. Quick
  245. // create it.
  246. beam_index_history_.clear();
  247. beam_index_history_.push_back({beam_[0]->ParentBeamIndex()});
  248. for (int i = 0; i < num_steps_; ++i) {
  249. beam_index_history_.push_back({0});
  250. }
  251. }
  252. return beam_index_history_;
  253. }
  254. // Sets the max size of the beam.
  255. void SetMaxSize(int max_size) {
  256. max_size_ = max_size;
  257. Reset();
  258. }
  259. // Returns the number of steps taken so far.
  260. const int num_steps() const { return num_steps_; }
  261. // Returns the max size of this beam.
  262. const int max_size() const { return max_size_; }
  263. // Returns the current size of the beam.
  264. const int size() const { return beam_.size(); }
  265. private:
  266. // Associates an action taken on an index into current_state_ with a score.
  267. struct Transition {
  268. // The index of the source item.
  269. int source_idx;
  270. // The index of the action being taken.
  271. int action;
  272. // The score of the full derivation.
  273. double resulting_score;
  274. };
  275. // The maximum beam size.
  276. int max_size_;
  277. // The current beam.
  278. std::vector<std::unique_ptr<T>> beam_;
  279. // Function to check if a transition is allowed for a given state.
  280. std::function<bool(T *, int)> is_allowed_;
  281. // Function to check if a state is final.
  282. std::function<int(T *)> is_final_;
  283. // Function to perform a transition on a given state.
  284. std::function<void(T *, int)> perform_transition_;
  285. // Function to provide the oracle action for a given state.
  286. std::function<int(T *)> oracle_function_;
  287. // The history of the states in this beam. The vector indexes across steps.
  288. // For every step, there is a vector in the vector. This inner vector denotes
  289. // the state of the beam at that step, and contains the beam index that
  290. // was transitioned to create the transition state at that index (so,
  291. // if at step 2 the transition state at beam index 4 was created by applying
  292. // a transition to the state in beam index 3 during step 1, the query would
  293. // be "beam_index_history_.at(2).at(4)" and the value would be 3. Empty beam
  294. // states will return -1.
  295. std::vector<std::vector<int>> beam_index_history_;
  296. // The number of steps taken so far.
  297. int num_steps_;
  298. };
  299. } // namespace dragnn
  300. } // namespace syntaxnet
  301. #endif // NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_