index_translator.cc 2.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. #include "dragnn/core/index_translator.h"
  2. #include "tensorflow/core/platform/logging.h"
  3. namespace syntaxnet {
  4. namespace dragnn {
  5. using Index = IndexTranslator::Index;
  6. IndexTranslator::IndexTranslator(const std::vector<Component *> &path,
  7. const string &method)
  8. : path_(path), method_(method) {
  9. if (method_ == "identity") {
  10. // Identity lookup: Return the feature index.
  11. step_lookup_ = [](int batch_index, int beam_index, int feature) {
  12. return feature;
  13. };
  14. } else if (method_ == "history") {
  15. // History lookup: Return the number of steps taken less the feature.
  16. step_lookup_ = [this](int batch_index, int beam_index, int feature) {
  17. if (feature > path_.back()->StepsTaken(batch_index) - 1) {
  18. VLOG(2) << "Translation to outside: feature is " << feature
  19. << " and steps_taken is "
  20. << path_.back()->StepsTaken(batch_index);
  21. return -1;
  22. }
  23. return ((path_.back()->StepsTaken(batch_index) - 1) - feature);
  24. };
  25. } else {
  26. // Component defined lookup: Get the lookup function from the component.
  27. // If the lookup function is not defined, this function will CHECK.
  28. step_lookup_ = path_.back()->GetStepLookupFunction(method_);
  29. }
  30. }
  31. Index IndexTranslator::Translate(int batch_index, int beam_index,
  32. int feature_value) {
  33. Index translated_index;
  34. translated_index.batch_index = batch_index;
  35. VLOG(2) << "Translation requested (type: " << method_ << ") for batch "
  36. << batch_index << " beam " << beam_index << " feature "
  37. << feature_value;
  38. // For all save the last item in the path, get the source index for the
  39. // previous component.
  40. int current_beam_index = beam_index;
  41. VLOG(2) << "Beam index before walk is " << current_beam_index;
  42. for (int i = 0; i < path_.size() - 1; ++i) {
  43. // Backtrack through previous components. For each non-final component,
  44. // figure out what state in the prior component was used to initialize the
  45. // state at the current beam index.
  46. current_beam_index =
  47. path_.at(i)->GetSourceBeamIndex(current_beam_index, batch_index);
  48. VLOG(2) << "Beam index updated to " << current_beam_index;
  49. }
  50. VLOG(2) << "Beam index after walk is " << current_beam_index;
  51. translated_index.step_index =
  52. step_lookup_(batch_index, current_beam_index, feature_value);
  53. VLOG(2) << "Translated step index is " << translated_index.step_index;
  54. translated_index.beam_index = path_.back()->GetBeamIndexAtStep(
  55. translated_index.step_index, current_beam_index, batch_index);
  56. VLOG(2) << "Translated beam index is " << translated_index.beam_index;
  57. return translated_index;
  58. }
  59. } // namespace dragnn
  60. } // namespace syntaxnet