syntaxnet_component.cc 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/components/syntaxnet/syntaxnet_component.h"
  16. #include <vector>
  17. #include "dragnn/components/util/bulk_feature_extractor.h"
  18. #include "dragnn/core/component_registry.h"
  19. #include "dragnn/core/input_batch_cache.h"
  20. #include "dragnn/core/interfaces/component.h"
  21. #include "dragnn/core/interfaces/transition_state.h"
  22. #include "dragnn/io/sentence_input_batch.h"
  23. #include "dragnn/io/syntaxnet_sentence.h"
  24. #include "syntaxnet/parser_state.h"
  25. #include "syntaxnet/sparse.pb.h"
  26. #include "syntaxnet/task_spec.pb.h"
  27. #include "syntaxnet/utils.h"
  28. #include "tensorflow/core/platform/logging.h"
  29. namespace syntaxnet {
  30. namespace dragnn {
  31. using tensorflow::strings::StrCat;
  32. namespace {
  33. // Returns a new step in a trace based on a ComponentSpec.
  34. ComponentStepTrace GetNewStepTrace(const ComponentSpec &spec,
  35. const TransitionState &state) {
  36. ComponentStepTrace step;
  37. for (auto &linked_spec : spec.linked_feature()) {
  38. auto &channel_trace = *step.add_linked_feature_trace();
  39. channel_trace.set_name(linked_spec.name());
  40. channel_trace.set_source_component(linked_spec.source_component());
  41. channel_trace.set_source_translator(linked_spec.source_translator());
  42. channel_trace.set_source_layer(linked_spec.source_layer());
  43. }
  44. for (auto &fixed_spec : spec.fixed_feature()) {
  45. step.add_fixed_feature_trace()->set_name(fixed_spec.name());
  46. }
  47. step.set_html_representation(state.HTMLRepresentation());
  48. return step;
  49. }
  50. // Returns the last step in the trace.
  51. ComponentStepTrace *GetLastStepInTrace(ComponentTrace *trace) {
  52. CHECK_GT(trace->step_trace_size(), 0) << "Trace has no steps added yet";
  53. return trace->mutable_step_trace(trace->step_trace_size() - 1);
  54. }
  55. } // anonymous namespace
  56. SyntaxNetComponent::SyntaxNetComponent()
  57. : feature_extractor_("brain_parser"),
  58. rewrite_root_labels_(false),
  59. max_beam_size_(1),
  60. input_data_(nullptr) {}
  61. void SyntaxNetComponent::InitializeComponent(const ComponentSpec &spec) {
  62. // Save off the passed spec for future reference.
  63. spec_ = spec;
  64. // Create and populate a TaskContext for the underlying parser.
  65. TaskContext context;
  66. // Add the specified resources.
  67. for (const Resource &resource : spec_.resource()) {
  68. auto *input = context.GetInput(resource.name());
  69. for (const Part &part : resource.part()) {
  70. auto *input_part = input->add_part();
  71. input_part->set_file_pattern(part.file_pattern());
  72. input_part->set_file_format(part.file_format());
  73. input_part->set_record_format(part.record_format());
  74. }
  75. }
  76. // Add the specified task args to the transition system.
  77. for (const auto &param : spec_.transition_system().parameters()) {
  78. context.SetParameter(param.first, param.second);
  79. }
  80. // Set the arguments for the feature extractor.
  81. std::vector<string> names;
  82. std::vector<string> dims;
  83. std::vector<string> fml;
  84. std::vector<string> predicate_maps;
  85. for (const FixedFeatureChannel &channel : spec.fixed_feature()) {
  86. names.push_back(channel.name());
  87. fml.push_back(channel.fml());
  88. predicate_maps.push_back(channel.predicate_map());
  89. dims.push_back(StrCat(channel.embedding_dim()));
  90. }
  91. context.SetParameter("neurosis_feature_syntax_version", "2");
  92. context.SetParameter("brain_parser_embedding_dims", utils::Join(dims, ";"));
  93. context.SetParameter("brain_parser_predicate_maps",
  94. utils::Join(predicate_maps, ";"));
  95. context.SetParameter("brain_parser_features", utils::Join(fml, ";"));
  96. context.SetParameter("brain_parser_embedding_names", utils::Join(names, ";"));
  97. names.clear();
  98. dims.clear();
  99. fml.clear();
  100. predicate_maps.clear();
  101. std::vector<string> source_components;
  102. std::vector<string> source_layers;
  103. std::vector<string> source_translators;
  104. for (const LinkedFeatureChannel &channel : spec.linked_feature()) {
  105. names.push_back(channel.name());
  106. fml.push_back(channel.fml());
  107. dims.push_back(StrCat(channel.embedding_dim()));
  108. source_components.push_back(channel.source_component());
  109. source_layers.push_back(channel.source_layer());
  110. source_translators.push_back(channel.source_translator());
  111. predicate_maps.push_back("none");
  112. }
  113. context.SetParameter("link_embedding_dims", utils::Join(dims, ";"));
  114. context.SetParameter("link_predicate_maps", utils::Join(predicate_maps, ";"));
  115. context.SetParameter("link_features", utils::Join(fml, ";"));
  116. context.SetParameter("link_embedding_names", utils::Join(names, ";"));
  117. context.SetParameter("link_source_layers", utils::Join(source_layers, ";"));
  118. context.SetParameter("link_source_translators",
  119. utils::Join(source_translators, ";"));
  120. context.SetParameter("link_source_components",
  121. utils::Join(source_components, ";"));
  122. context.SetParameter("parser_transition_system",
  123. spec.transition_system().registered_name());
  124. // Set up the fixed feature extractor.
  125. feature_extractor_.Setup(&context);
  126. feature_extractor_.Init(&context);
  127. feature_extractor_.RequestWorkspaces(&workspace_registry_);
  128. // Set up the underlying transition system.
  129. transition_system_.reset(ParserTransitionSystem::Create(
  130. context.Get("parser_transition_system", "arc-standard")));
  131. transition_system_->Setup(&context);
  132. transition_system_->Init(&context);
  133. // Create label map.
  134. string path = TaskContext::InputFile(*context.GetInput("label-map"));
  135. label_map_ =
  136. SharedStoreUtils::GetWithDefaultName<TermFrequencyMap>(path, 0, 0);
  137. // Set up link feature extractors.
  138. if (spec.linked_feature_size() > 0) {
  139. link_feature_extractor_.Setup(&context);
  140. link_feature_extractor_.Init(&context);
  141. link_feature_extractor_.RequestWorkspaces(&workspace_registry_);
  142. }
  143. // Get the legacy flag for simulating old parser processor behavior. If the
  144. // flag is not set, default to 'false'.
  145. rewrite_root_labels_ = context.Get("rewrite_root_labels", false);
  146. }
  147. std::unique_ptr<Beam<SyntaxNetTransitionState>> SyntaxNetComponent::CreateBeam(
  148. int max_size) {
  149. std::unique_ptr<Beam<SyntaxNetTransitionState>> beam(
  150. new Beam<SyntaxNetTransitionState>(max_size));
  151. auto permission_function = [this](SyntaxNetTransitionState *state,
  152. int action) {
  153. VLOG(3) << "permission_function action:" << action
  154. << " is_allowed:" << this->IsAllowed(state, action);
  155. return this->IsAllowed(state, action);
  156. };
  157. auto finality_function = [this](SyntaxNetTransitionState *state) {
  158. VLOG(2) << "finality_function is_final:" << this->IsFinal(state);
  159. return this->IsFinal(state);
  160. };
  161. auto oracle_function = [this](SyntaxNetTransitionState *state) {
  162. VLOG(2) << "oracle_function action:" << this->GetOracleLabel(state);
  163. return this->GetOracleLabel(state);
  164. };
  165. auto beam_ptr = beam.get();
  166. auto advance_function = [this, beam_ptr](SyntaxNetTransitionState *state,
  167. int action) {
  168. VLOG(2) << "advance_function beam ptr:" << beam_ptr << " action:" << action;
  169. this->Advance(state, action, beam_ptr);
  170. };
  171. beam->SetFunctions(permission_function, finality_function, advance_function,
  172. oracle_function);
  173. return beam;
  174. }
  175. void SyntaxNetComponent::InitializeData(
  176. const std::vector<std::vector<const TransitionState *>> &parent_states,
  177. int max_beam_size, InputBatchCache *input_data) {
  178. // Save off the input data object.
  179. input_data_ = input_data;
  180. // If beam size has changed, change all beam sizes for existing beams.
  181. if (max_beam_size_ != max_beam_size) {
  182. CHECK_GT(max_beam_size, 0)
  183. << "Requested max beam size must be greater than 0.";
  184. VLOG(2) << "Adjusting max beam size from " << max_beam_size_ << " to "
  185. << max_beam_size;
  186. max_beam_size_ = max_beam_size;
  187. for (auto &beam : batch_) {
  188. beam->SetMaxSize(max_beam_size_);
  189. }
  190. }
  191. SentenceInputBatch *sentences = input_data->GetAs<SentenceInputBatch>();
  192. // Expect that the sentence data is the same size as the input states batch.
  193. if (!parent_states.empty()) {
  194. CHECK_EQ(parent_states.size(), sentences->data()->size());
  195. }
  196. // Adjust the beam vector so that it is the correct size for this batch.
  197. if (batch_.size() < sentences->data()->size()) {
  198. VLOG(1) << "Batch size is increased to " << sentences->data()->size()
  199. << " from " << batch_.size();
  200. for (int i = batch_.size(); i < sentences->data()->size(); ++i) {
  201. batch_.push_back(CreateBeam(max_beam_size));
  202. }
  203. } else if (batch_.size() > sentences->data()->size()) {
  204. VLOG(1) << "Batch size is decreased to " << sentences->data()->size()
  205. << " from " << batch_.size();
  206. batch_.erase(batch_.begin() + sentences->data()->size(), batch_.end());
  207. } else {
  208. VLOG(1) << "Batch size is constant at " << sentences->data()->size();
  209. }
  210. CHECK_EQ(batch_.size(), sentences->data()->size());
  211. // Fill the beams with the relevant data for that batch.
  212. for (int batch_index = 0; batch_index < sentences->data()->size();
  213. ++batch_index) {
  214. // Create a vector of states for this component's beam.
  215. std::vector<std::unique_ptr<SyntaxNetTransitionState>> initial_states;
  216. if (parent_states.empty()) {
  217. // If no states have been passed in, create a single state to seed the
  218. // beam.
  219. initial_states.push_back(
  220. CreateState(&(sentences->data()->at(batch_index))));
  221. } else {
  222. // If states have been passed in, seed the beam with them up to the max
  223. // beam size.
  224. int num_states =
  225. std::min(batch_.at(batch_index)->max_size(),
  226. static_cast<int>(parent_states.at(batch_index).size()));
  227. VLOG(2) << "Creating a beam using " << num_states << " initial states";
  228. for (int i = 0; i < num_states; ++i) {
  229. std::unique_ptr<SyntaxNetTransitionState> state(
  230. CreateState(&(sentences->data()->at(batch_index))));
  231. state->Init(*parent_states.at(batch_index).at(i));
  232. initial_states.push_back(std::move(state));
  233. }
  234. }
  235. batch_.at(batch_index)->Init(std::move(initial_states));
  236. }
  237. }
  238. bool SyntaxNetComponent::IsReady() const { return input_data_ != nullptr; }
  239. string SyntaxNetComponent::Name() const {
  240. return "SyntaxNet-backed beam parser";
  241. }
  242. int SyntaxNetComponent::BatchSize() const { return batch_.size(); }
  243. int SyntaxNetComponent::BeamSize() const { return max_beam_size_; }
  244. int SyntaxNetComponent::StepsTaken(int batch_index) const {
  245. return batch_.at(batch_index)->num_steps();
  246. }
  247. int SyntaxNetComponent::GetBeamIndexAtStep(int step, int current_index,
  248. int batch) const {
  249. return batch_.at(batch)->FindPreviousIndex(current_index, step);
  250. }
  251. int SyntaxNetComponent::GetSourceBeamIndex(int current_index, int batch) const {
  252. return batch_.at(batch)->FindPreviousIndex(current_index, 0);
  253. }
  254. std::function<int(int, int, int)> SyntaxNetComponent::GetStepLookupFunction(
  255. const string &method) {
  256. if (method == "shift-reduce-step") {
  257. // TODO(googleuser): Describe this function.
  258. return [this](int batch_index, int beam_index, int value) {
  259. SyntaxNetTransitionState *state =
  260. batch_.at(batch_index)->beam_state(beam_index);
  261. return state->step_for_token(value);
  262. };
  263. } else if (method == "reduce-step") {
  264. // TODO(googleuser): Describe this function.
  265. return [this](int batch_index, int beam_index, int value) {
  266. SyntaxNetTransitionState *state =
  267. batch_.at(batch_index)->beam_state(beam_index);
  268. return state->parent_step_for_token(value);
  269. };
  270. } else if (method == "parent-shift-reduce-step") {
  271. // TODO(googleuser): Describe this function.
  272. return [this](int batch_index, int beam_index, int value) {
  273. SyntaxNetTransitionState *state =
  274. batch_.at(batch_index)->beam_state(beam_index);
  275. return state->step_for_token(state->parent_step_for_token(value));
  276. };
  277. } else if (method == "reverse-token") {
  278. // TODO(googleuser): Describe this function.
  279. return [this](int batch_index, int beam_index, int value) {
  280. SyntaxNetTransitionState *state =
  281. batch_.at(batch_index)->beam_state(beam_index);
  282. int result = state->sentence()->sentence()->token_size() - value - 1;
  283. if (result >= 0 && result < state->sentence()->sentence()->token_size()) {
  284. return result;
  285. } else {
  286. return -1;
  287. }
  288. };
  289. } else {
  290. LOG(FATAL) << "Unable to find step lookup function " << method;
  291. }
  292. }
  293. void SyntaxNetComponent::AdvanceFromPrediction(const float transition_matrix[],
  294. int transition_matrix_length) {
  295. VLOG(2) << "Advancing from prediction.";
  296. int matrix_index = 0;
  297. int num_labels = transition_system_->NumActions(label_map_->Size());
  298. for (int i = 0; i < batch_.size(); ++i) {
  299. int max_beam_size = batch_.at(i)->max_size();
  300. int matrix_size = num_labels * max_beam_size;
  301. CHECK_LE(matrix_index + matrix_size, transition_matrix_length);
  302. if (!batch_.at(i)->IsTerminal()) {
  303. batch_.at(i)->AdvanceFromPrediction(&transition_matrix[matrix_index],
  304. matrix_size, num_labels);
  305. }
  306. matrix_index += num_labels * max_beam_size;
  307. }
  308. }
  309. void SyntaxNetComponent::AdvanceFromOracle() {
  310. VLOG(2) << "Advancing from oracle.";
  311. for (auto &beam : batch_) {
  312. beam->AdvanceFromOracle();
  313. }
  314. }
  315. bool SyntaxNetComponent::IsTerminal() const {
  316. VLOG(2) << "Checking terminal status.";
  317. for (const auto &beam : batch_) {
  318. if (!beam->IsTerminal()) {
  319. return false;
  320. }
  321. }
  322. return true;
  323. }
  324. std::vector<std::vector<const TransitionState *>>
  325. SyntaxNetComponent::GetBeam() {
  326. std::vector<std::vector<const TransitionState *>> state_beam;
  327. for (auto &beam : batch_) {
  328. // Because this component only finalizes the data of the highest ranked
  329. // component in each beam, the next component should only be initialized
  330. // from the highest ranked component in that beam.
  331. state_beam.push_back({beam->beam().at(0)});
  332. }
  333. return state_beam;
  334. }
  335. int SyntaxNetComponent::GetFixedFeatures(
  336. std::function<int32 *(int)> allocate_indices,
  337. std::function<int64 *(int)> allocate_ids,
  338. std::function<float *(int)> allocate_weights, int channel_id) const {
  339. std::vector<SparseFeatures> features;
  340. const int channel_size = spec_.fixed_feature(channel_id).size();
  341. // For every beam in the batch...
  342. for (const auto &beam : batch_) {
  343. // For every element in the beam...
  344. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  345. // Get the SparseFeatures from the feature extractor.
  346. auto state = beam->beam_state(beam_idx);
  347. const std::vector<std::vector<SparseFeatures>> sparse_features =
  348. feature_extractor_.ExtractSparseFeatures(
  349. *(state->sentence()->workspace()), *(state->parser_state()));
  350. // Hold the SparseFeatures for later processing.
  351. for (const SparseFeatures &f : sparse_features[channel_id]) {
  352. features.emplace_back(f);
  353. if (do_tracing_) {
  354. FixedFeatures fixed_features;
  355. for (const string &name : f.description()) {
  356. fixed_features.add_value_name(name);
  357. }
  358. fixed_features.set_feature_name("");
  359. auto *trace = GetLastStepInTrace(state->mutable_trace());
  360. auto *fixed_trace = trace->mutable_fixed_feature_trace(channel_id);
  361. *fixed_trace->add_value_trace() = fixed_features;
  362. }
  363. }
  364. }
  365. const int pad_amount = max_beam_size_ - beam->size();
  366. features.resize(features.size() + pad_amount * channel_size);
  367. }
  368. int feature_count = 0;
  369. for (const auto &feature : features) {
  370. feature_count += feature.id_size();
  371. }
  372. VLOG(2) << "Feature count is " << feature_count;
  373. int32 *indices_tensor = allocate_indices(feature_count);
  374. int64 *ids_tensor = allocate_ids(feature_count);
  375. float *weights_tensor = allocate_weights(feature_count);
  376. int array_index = 0;
  377. for (int feature_index = 0; feature_index < features.size();
  378. ++feature_index) {
  379. VLOG(2) << "Extracting for feature_index " << feature_index;
  380. const auto feature = features[feature_index];
  381. for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
  382. indices_tensor[array_index] = feature_index;
  383. ids_tensor[array_index] = feature.id(sub_idx);
  384. if (sub_idx < feature.weight_size()) {
  385. weights_tensor[array_index] = feature.weight(sub_idx);
  386. } else {
  387. weights_tensor[array_index] = 1.0;
  388. }
  389. VLOG(2) << "Feature index: " << indices_tensor[array_index]
  390. << " id: " << ids_tensor[array_index]
  391. << " weight: " << weights_tensor[array_index];
  392. ++array_index;
  393. }
  394. }
  395. return feature_count;
  396. }
  397. int SyntaxNetComponent::BulkGetFixedFeatures(
  398. const BulkFeatureExtractor &extractor) {
  399. // Allocate a vector of SparseFeatures per channel.
  400. const int num_channels = spec_.fixed_feature_size();
  401. std::vector<int> channel_size(num_channels);
  402. for (int i = 0; i < num_channels; ++i) {
  403. channel_size[i] = spec_.fixed_feature(i).size();
  404. }
  405. std::vector<std::vector<SparseFeatures>> features(num_channels);
  406. std::vector<std::vector<int>> feature_indices(num_channels);
  407. std::vector<std::vector<int>> step_indices(num_channels);
  408. std::vector<std::vector<int>> element_indices(num_channels);
  409. std::vector<int> feature_counts(num_channels);
  410. int step_count = 0;
  411. while (!IsTerminal()) {
  412. int current_element = 0;
  413. // For every beam in the batch...
  414. for (const auto &beam : batch_) {
  415. // For every element in the beam...
  416. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  417. // Get the SparseFeatures from the parser.
  418. auto state = beam->beam_state(beam_idx);
  419. const std::vector<std::vector<SparseFeatures>> sparse_features =
  420. feature_extractor_.ExtractSparseFeatures(
  421. *(state->sentence()->workspace()), *(state->parser_state()));
  422. for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
  423. int feature_count = 0;
  424. for (const SparseFeatures &f : sparse_features[channel_id]) {
  425. // Trace, if requested.
  426. if (do_tracing_) {
  427. FixedFeatures fixed_features;
  428. for (const string &name : f.description()) {
  429. fixed_features.add_value_name(name);
  430. }
  431. fixed_features.set_feature_name("");
  432. auto *trace = GetLastStepInTrace(state->mutable_trace());
  433. auto *fixed_trace =
  434. trace->mutable_fixed_feature_trace(channel_id);
  435. *fixed_trace->add_value_trace() = fixed_features;
  436. }
  437. // Hold the SparseFeatures for later processing.
  438. features[channel_id].emplace_back(f);
  439. element_indices[channel_id].emplace_back(current_element);
  440. step_indices[channel_id].emplace_back(step_count);
  441. feature_indices[channel_id].emplace_back(feature_count);
  442. feature_counts[channel_id] += f.id_size();
  443. ++feature_count;
  444. }
  445. }
  446. ++current_element;
  447. }
  448. // Advance the current element to skip unused beam slots.
  449. // Pad the beam out to max_beam_size.
  450. int pad_amount = max_beam_size_ - beam->size();
  451. current_element += pad_amount;
  452. }
  453. AdvanceFromOracle();
  454. ++step_count;
  455. }
  456. const int total_steps = step_count;
  457. const int num_elements = batch_.size() * max_beam_size_;
  458. // This would be a good place to add threading.
  459. for (int channel_id = 0; channel_id < num_channels; ++channel_id) {
  460. int feature_count = feature_counts[channel_id];
  461. LOG(INFO) << "Feature count is " << feature_count << " for channel "
  462. << channel_id;
  463. int32 *indices_tensor =
  464. extractor.AllocateIndexMemory(channel_id, feature_count);
  465. int64 *ids_tensor = extractor.AllocateIdMemory(channel_id, feature_count);
  466. float *weights_tensor =
  467. extractor.AllocateWeightMemory(channel_id, feature_count);
  468. int array_index = 0;
  469. for (int feat_idx = 0; feat_idx < features[channel_id].size(); ++feat_idx) {
  470. const auto &feature = features[channel_id][feat_idx];
  471. int element_index = element_indices[channel_id][feat_idx];
  472. int step_index = step_indices[channel_id][feat_idx];
  473. int feature_index = feature_indices[channel_id][feat_idx];
  474. for (int sub_idx = 0; sub_idx < feature.id_size(); ++sub_idx) {
  475. indices_tensor[array_index] =
  476. extractor.GetIndex(total_steps, num_elements, feature_index,
  477. element_index, step_index);
  478. ids_tensor[array_index] = feature.id(sub_idx);
  479. if (sub_idx < feature.weight_size()) {
  480. weights_tensor[array_index] = feature.weight(sub_idx);
  481. } else {
  482. weights_tensor[array_index] = 1.0;
  483. }
  484. ++array_index;
  485. }
  486. }
  487. }
  488. return step_count;
  489. }
  490. std::vector<LinkFeatures> SyntaxNetComponent::GetRawLinkFeatures(
  491. int channel_id) const {
  492. std::vector<LinkFeatures> features;
  493. const int channel_size = spec_.linked_feature(channel_id).size();
  494. std::unique_ptr<std::vector<string>> feature_names;
  495. if (do_tracing_) {
  496. feature_names.reset(new std::vector<string>);
  497. *feature_names = utils::Split(spec_.linked_feature(channel_id).fml(), ' ');
  498. }
  499. // For every beam in the batch...
  500. for (int batch_idx = 0; batch_idx < batch_.size(); ++batch_idx) {
  501. // For every element in the beam...
  502. const auto &beam = batch_[batch_idx];
  503. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  504. // Get the raw link features from the linked feature extractor.
  505. auto state = beam->beam_state(beam_idx);
  506. std::vector<FeatureVector> raw_features(
  507. link_feature_extractor_.NumEmbeddings());
  508. link_feature_extractor_.ExtractFeatures(*(state->sentence()->workspace()),
  509. *(state->parser_state()),
  510. &raw_features);
  511. // Add the raw feature values to the LinkFeatures proto.
  512. CHECK_LT(channel_id, raw_features.size());
  513. for (int i = 0; i < raw_features[channel_id].size(); ++i) {
  514. features.emplace_back();
  515. features.back().set_feature_value(raw_features[channel_id].value(i));
  516. features.back().set_batch_idx(batch_idx);
  517. features.back().set_beam_idx(beam_idx);
  518. if (do_tracing_) {
  519. features.back().set_feature_name(feature_names->at(i));
  520. }
  521. }
  522. }
  523. // Pad the beam out to max_beam_size.
  524. int pad_amount = max_beam_size_ - beam->size();
  525. features.resize(features.size() + pad_amount * channel_size);
  526. }
  527. return features;
  528. }
  529. std::vector<std::vector<int>> SyntaxNetComponent::GetOracleLabels() const {
  530. std::vector<std::vector<int>> oracle_labels;
  531. for (const auto &beam : batch_) {
  532. oracle_labels.emplace_back();
  533. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  534. // Get the raw link features from the linked feature extractor.
  535. auto state = beam->beam_state(beam_idx);
  536. oracle_labels.back().push_back(GetOracleLabel(state));
  537. }
  538. }
  539. return oracle_labels;
  540. }
  541. void SyntaxNetComponent::FinalizeData() {
  542. // This chooses the top-scoring member of the beam to annotate the underlying
  543. // document.
  544. VLOG(2) << "Finalizing data.";
  545. for (auto &beam : batch_) {
  546. if (beam->size() != 0) {
  547. auto top_state = beam->beam_state(0);
  548. VLOG(3) << "Finalizing for sentence: "
  549. << top_state->sentence()->sentence()->ShortDebugString();
  550. top_state->parser_state()->AddParseToDocument(
  551. top_state->sentence()->sentence(), rewrite_root_labels_);
  552. VLOG(3) << "Sentence is now: "
  553. << top_state->sentence()->sentence()->ShortDebugString();
  554. } else {
  555. LOG(WARNING) << "Attempting to finalize an empty beam for component "
  556. << spec_.name();
  557. }
  558. }
  559. }
  560. void SyntaxNetComponent::ResetComponent() {
  561. for (auto &beam : batch_) {
  562. beam->Reset();
  563. }
  564. input_data_ = nullptr;
  565. max_beam_size_ = 0;
  566. }
  567. std::unique_ptr<SyntaxNetTransitionState> SyntaxNetComponent::CreateState(
  568. SyntaxNetSentence *sentence) {
  569. VLOG(3) << "Creating state for sentence "
  570. << sentence->sentence()->DebugString();
  571. std::unique_ptr<ParserState> parser_state(new ParserState(
  572. sentence->sentence(), transition_system_->NewTransitionState(false),
  573. label_map_));
  574. sentence->workspace()->Reset(workspace_registry_);
  575. feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
  576. link_feature_extractor_.Preprocess(sentence->workspace(), parser_state.get());
  577. std::unique_ptr<SyntaxNetTransitionState> transition_state(
  578. new SyntaxNetTransitionState(std::move(parser_state), sentence));
  579. return transition_state;
  580. }
  581. bool SyntaxNetComponent::IsAllowed(SyntaxNetTransitionState *state,
  582. int action) const {
  583. return transition_system_->IsAllowedAction(action, *(state->parser_state()));
  584. }
  585. bool SyntaxNetComponent::IsFinal(SyntaxNetTransitionState *state) const {
  586. return transition_system_->IsFinalState(*(state->parser_state()));
  587. }
  588. int SyntaxNetComponent::GetOracleLabel(SyntaxNetTransitionState *state) const {
  589. if (IsFinal(state)) {
  590. // It is not permitted to request an oracle label from a sentence that is
  591. // in a final state.
  592. return -1;
  593. } else {
  594. return transition_system_->GetNextGoldAction(*(state->parser_state()));
  595. }
  596. }
  597. void SyntaxNetComponent::Advance(SyntaxNetTransitionState *state, int action,
  598. Beam<SyntaxNetTransitionState> *beam) {
  599. auto parser_state = state->parser_state();
  600. auto sentence_size = state->sentence()->sentence()->token_size();
  601. const int num_steps = beam->num_steps();
  602. if (transition_system_->SupportsActionMetaData()) {
  603. const int parent_idx =
  604. transition_system_->ParentIndex(*parser_state, action);
  605. constexpr int kShiftAction = -1;
  606. if (parent_idx == kShiftAction) {
  607. if (parser_state->Next() < sentence_size && parser_state->Next() >= 0) {
  608. // if we have already consumed all the input then it is not a shift
  609. // action. We just skip it.
  610. state->set_step_for_token(parser_state->Next(), num_steps);
  611. }
  612. } else if (parent_idx >= 0) {
  613. VLOG(2) << spec_.name() << ": Updating pointer: " << parent_idx << " -> "
  614. << num_steps;
  615. state->set_step_for_token(parent_idx, num_steps);
  616. const int child_idx =
  617. transition_system_->ChildIndex(*parser_state, action);
  618. assert(child_idx >= 0 && child_idx < sentence_size);
  619. state->set_parent_for_token(child_idx, parent_idx);
  620. VLOG(2) << spec_.name() << ": Updating parent for child: " << parent_idx
  621. << " -> " << child_idx;
  622. state->set_parent_step_for_token(child_idx, num_steps);
  623. } else {
  624. VLOG(2) << spec_.name() << ": Invalid parent index: " << parent_idx;
  625. }
  626. }
  627. if (do_tracing_) {
  628. auto *trace = state->mutable_trace();
  629. auto *last_step = GetLastStepInTrace(trace);
  630. // Add action to the prior step.
  631. last_step->set_caption(
  632. transition_system_->ActionAsString(action, *parser_state));
  633. last_step->set_step_finished(true);
  634. }
  635. transition_system_->PerformAction(action, parser_state);
  636. if (do_tracing_) {
  637. // Add info for the next step.
  638. *state->mutable_trace()->add_step_trace() = GetNewStepTrace(spec_, *state);
  639. }
  640. }
  641. void SyntaxNetComponent::InitializeTracing() {
  642. do_tracing_ = true;
  643. CHECK(IsReady()) << "Cannot initialize trace before InitializeData().";
  644. // Initialize each element of the beam with a new trace.
  645. for (auto &beam : batch_) {
  646. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  647. SyntaxNetTransitionState *state = beam->beam_state(beam_idx);
  648. std::unique_ptr<ComponentTrace> trace(new ComponentTrace());
  649. trace->set_name(spec_.name());
  650. *trace->add_step_trace() = GetNewStepTrace(spec_, *state);
  651. state->set_trace(std::move(trace));
  652. }
  653. }
  654. feature_extractor_.set_add_strings(true);
  655. }
  656. void SyntaxNetComponent::DisableTracing() {
  657. do_tracing_ = false;
  658. feature_extractor_.set_add_strings(false);
  659. }
  660. void SyntaxNetComponent::AddTranslatedLinkFeaturesToTrace(
  661. const std::vector<LinkFeatures> &features, int channel_id) {
  662. CHECK(do_tracing_) << "Tracing is not enabled.";
  663. int linear_idx = 0;
  664. const int channel_size = spec_.linked_feature(channel_id).size();
  665. // For every beam in the batch...
  666. for (const auto &beam : batch_) {
  667. // For every element in the beam...
  668. for (int beam_idx = 0; beam_idx < max_beam_size_; ++beam_idx) {
  669. for (int feature_idx = 0; feature_idx < channel_size; ++feature_idx) {
  670. if (beam_idx < beam->size()) {
  671. auto state = beam->beam_state(beam_idx);
  672. auto *trace = GetLastStepInTrace(state->mutable_trace());
  673. auto *link_trace = trace->mutable_linked_feature_trace(channel_id);
  674. if (features[linear_idx].feature_value() >= 0 &&
  675. features[linear_idx].step_idx() >= 0) {
  676. *link_trace->add_value_trace() = features[linear_idx];
  677. }
  678. }
  679. ++linear_idx;
  680. }
  681. }
  682. }
  683. }
  684. std::vector<std::vector<ComponentTrace>> SyntaxNetComponent::GetTraceProtos()
  685. const {
  686. std::vector<std::vector<ComponentTrace>> traces;
  687. // For every beam in the batch...
  688. for (const auto &beam : batch_) {
  689. std::vector<ComponentTrace> beam_trace;
  690. // For every element in the beam...
  691. for (int beam_idx = 0; beam_idx < beam->size(); ++beam_idx) {
  692. auto state = beam->beam_state(beam_idx);
  693. beam_trace.push_back(*state->mutable_trace());
  694. }
  695. traces.push_back(beam_trace);
  696. }
  697. return traces;
  698. };
  699. REGISTER_DRAGNN_COMPONENT(SyntaxNetComponent);
  700. } // namespace dragnn
  701. } // namespace syntaxnet