syntaxnet_link_feature_extractor.cc 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
  16. #include "tensorflow/core/platform/logging.h"
  17. namespace syntaxnet {
  18. namespace dragnn {
  19. void SyntaxNetLinkFeatureExtractor::Setup(TaskContext *context) {
  20. ParserEmbeddingFeatureExtractor::Setup(context);
  21. if (NumEmbeddings() > 0) {
  22. channel_sources_ = utils::Split(
  23. context->Get(
  24. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_components"),
  25. ""),
  26. ';');
  27. channel_layers_ = utils::Split(
  28. context->Get(
  29. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_layers"), ""),
  30. ';');
  31. channel_translators_ = utils::Split(
  32. context->Get(
  33. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_translators"),
  34. ""),
  35. ';');
  36. }
  37. CHECK_EQ(channel_sources_.size(), NumEmbeddings());
  38. CHECK_EQ(channel_layers_.size(), NumEmbeddings());
  39. CHECK_EQ(channel_translators_.size(), NumEmbeddings());
  40. }
  41. void SyntaxNetLinkFeatureExtractor::AddLinkedFeatureChannelProtos(
  42. ComponentSpec *spec) const {
  43. for (int embedding_idx = 0; embedding_idx < NumEmbeddings();
  44. ++embedding_idx) {
  45. LinkedFeatureChannel *channel = spec->add_linked_feature();
  46. channel->set_name(embedding_name(embedding_idx));
  47. channel->set_fml(embedding_fml()[embedding_idx]);
  48. channel->set_embedding_dim(EmbeddingDims(embedding_idx));
  49. channel->set_size(FeatureSize(embedding_idx));
  50. channel->set_source_layer(channel_layers_[embedding_idx]);
  51. channel->set_source_component(channel_sources_[embedding_idx]);
  52. channel->set_source_translator(channel_translators_[embedding_idx]);
  53. }
  54. }
  55. } // namespace dragnn
  56. } // namespace syntaxnet