index_translator.cc 3.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/core/index_translator.h"
  16. #include "tensorflow/core/platform/logging.h"
  17. namespace syntaxnet {
  18. namespace dragnn {
  19. using Index = IndexTranslator::Index;
  20. IndexTranslator::IndexTranslator(const std::vector<Component *> &path,
  21. const string &method)
  22. : path_(path), method_(method) {
  23. if (method_ == "identity") {
  24. // Identity lookup: Return the feature index.
  25. step_lookup_ = [](int batch_index, int beam_index, int feature) {
  26. return feature;
  27. };
  28. } else if (method_ == "history") {
  29. // History lookup: Return the number of steps taken less the feature.
  30. step_lookup_ = [this](int batch_index, int beam_index, int feature) {
  31. if (feature > path_.back()->StepsTaken(batch_index) - 1) {
  32. VLOG(2) << "Translation to outside: feature is " << feature
  33. << " and steps_taken is "
  34. << path_.back()->StepsTaken(batch_index);
  35. return -1;
  36. }
  37. return ((path_.back()->StepsTaken(batch_index) - 1) - feature);
  38. };
  39. } else {
  40. // Component defined lookup: Get the lookup function from the component.
  41. // If the lookup function is not defined, this function will CHECK.
  42. step_lookup_ = path_.back()->GetStepLookupFunction(method_);
  43. }
  44. }
  45. Index IndexTranslator::Translate(int batch_index, int beam_index,
  46. int feature_value) {
  47. Index translated_index;
  48. translated_index.batch_index = batch_index;
  49. VLOG(2) << "Translation requested (type: " << method_ << ") for batch "
  50. << batch_index << " beam " << beam_index << " feature "
  51. << feature_value;
  52. // For all save the last item in the path, get the source index for the
  53. // previous component.
  54. int current_beam_index = beam_index;
  55. VLOG(2) << "Beam index before walk is " << current_beam_index;
  56. for (int i = 0; i < path_.size() - 1; ++i) {
  57. // Backtrack through previous components. For each non-final component,
  58. // figure out what state in the prior component was used to initialize the
  59. // state at the current beam index.
  60. current_beam_index =
  61. path_.at(i)->GetSourceBeamIndex(current_beam_index, batch_index);
  62. VLOG(2) << "Beam index updated to " << current_beam_index;
  63. }
  64. VLOG(2) << "Beam index after walk is " << current_beam_index;
  65. translated_index.step_index =
  66. step_lookup_(batch_index, current_beam_index, feature_value);
  67. VLOG(2) << "Translated step index is " << translated_index.step_index;
  68. translated_index.beam_index = path_.back()->GetBeamIndexAtStep(
  69. translated_index.step_index, current_beam_index, batch_index);
  70. VLOG(2) << "Translated beam index is " << translated_index.beam_index;
  71. return translated_index;
  72. }
  73. } // namespace dragnn
  74. } // namespace syntaxnet