syntaxnet_link_feature_extractor.cc 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. #include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
  2. #include "tensorflow/core/platform/logging.h"
  3. namespace syntaxnet {
  4. namespace dragnn {
  5. void SyntaxNetLinkFeatureExtractor::Setup(TaskContext *context) {
  6. ParserEmbeddingFeatureExtractor::Setup(context);
  7. if (NumEmbeddings() > 0) {
  8. channel_sources_ = utils::Split(
  9. context->Get(
  10. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_components"),
  11. ""),
  12. ';');
  13. channel_layers_ = utils::Split(
  14. context->Get(
  15. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_layers"), ""),
  16. ';');
  17. channel_translators_ = utils::Split(
  18. context->Get(
  19. tensorflow::strings::StrCat(ArgPrefix(), "_", "source_translators"),
  20. ""),
  21. ';');
  22. }
  23. CHECK_EQ(channel_sources_.size(), NumEmbeddings());
  24. CHECK_EQ(channel_layers_.size(), NumEmbeddings());
  25. CHECK_EQ(channel_translators_.size(), NumEmbeddings());
  26. }
  27. void SyntaxNetLinkFeatureExtractor::AddLinkedFeatureChannelProtos(
  28. ComponentSpec *spec) const {
  29. for (int embedding_idx = 0; embedding_idx < NumEmbeddings();
  30. ++embedding_idx) {
  31. LinkedFeatureChannel *channel = spec->add_linked_feature();
  32. channel->set_name(embedding_name(embedding_idx));
  33. channel->set_fml(embedding_fml()[embedding_idx]);
  34. channel->set_embedding_dim(EmbeddingDims(embedding_idx));
  35. channel->set_size(FeatureSize(embedding_idx));
  36. channel->set_source_layer(channel_layers_[embedding_idx]);
  37. channel->set_source_component(channel_sources_[embedding_idx]);
  38. channel->set_source_translator(channel_translators_[embedding_idx]);
  39. }
  40. }
  41. } // namespace dragnn
  42. } // namespace syntaxnet