|
@@ -14,9 +14,12 @@ limitations under the License.
|
|
|
==============================================================================*/
|
|
|
|
|
|
#include "syntaxnet/binary_segment_state.h"
|
|
|
+#include "syntaxnet/parser_features.h"
|
|
|
#include "syntaxnet/parser_state.h"
|
|
|
#include "syntaxnet/parser_transitions.h"
|
|
|
+#include "syntaxnet/task_context.h"
|
|
|
#include "syntaxnet/term_frequency_map.h"
|
|
|
+#include "syntaxnet/workspace.h"
|
|
|
#include "tensorflow/core/platform/test.h"
|
|
|
|
|
|
namespace syntaxnet {
|
|
@@ -38,6 +41,58 @@ class SegmentationTransitionTest : public ::testing::Test {
|
|
|
"token { word: '样' start: 14 end: 16 break_level: NO_BREAK } ";
|
|
|
sentence_ = std::unique_ptr<Sentence>(new Sentence());
|
|
|
TextFormat::ParseFromString(str_sentence, sentence_.get());
|
|
|
+
|
|
|
+ word_map_.Increment("因为");
|
|
|
+ word_map_.Increment("因为");
|
|
|
+ word_map_.Increment("有");
|
|
|
+ word_map_.Increment("这");
|
|
|
+ word_map_.Increment("这");
|
|
|
+ word_map_.Increment("样");
|
|
|
+ word_map_.Increment("样");
|
|
|
+ word_map_.Increment("这样");
|
|
|
+ word_map_.Increment("这样");
|
|
|
+ string filename = tensorflow::strings::StrCat(
|
|
|
+ tensorflow::testing::TmpDir(), "word-map");
|
|
|
+ word_map_.Save(filename);
|
|
|
+
|
|
|
+ // Re-load in sorted order, ignore words that only occurs once.
|
|
|
+ word_map_.Load(filename, 2, -1);
|
|
|
+
|
|
|
+ // Prepare task context.
|
|
|
+ context_ = std::unique_ptr<TaskContext>(new TaskContext());
|
|
|
+ AddInputToContext("word-map", filename, "text", "");
|
|
|
+ registry_ = std::unique_ptr<WorkspaceRegistry>( new WorkspaceRegistry());
|
|
|
+ }
|
|
|
+
|
|
|
+ // Adds an input to the task context.
|
|
|
+ void AddInputToContext(const string &name,
|
|
|
+ const string &file_pattern,
|
|
|
+ const string &file_format,
|
|
|
+ const string &record_format) {
|
|
|
+ TaskInput *input = context_->GetInput(name);
|
|
|
+ TaskInput::Part *part = input->add_part();
|
|
|
+ part->set_file_pattern(file_pattern);
|
|
|
+ part->set_file_format(file_format);
|
|
|
+ part->set_record_format(record_format);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Prepares a feature for computations.
|
|
|
+ void PrepareFeature(const string &feature_name, ParserState *state) {
|
|
|
+ feature_extractor_ = std::unique_ptr<ParserFeatureExtractor>(
|
|
|
+ new ParserFeatureExtractor());
|
|
|
+ feature_extractor_->Parse(feature_name);
|
|
|
+ feature_extractor_->Setup(context_.get());
|
|
|
+ feature_extractor_->Init(context_.get());
|
|
|
+ feature_extractor_->RequestWorkspaces(registry_.get());
|
|
|
+ workspace_.Reset(*registry_);
|
|
|
+ feature_extractor_->Preprocess(&workspace_, state);
|
|
|
+ }
|
|
|
+
|
|
|
+ // Computes the feature value for the parser state.
|
|
|
+ FeatureValue ComputeFeature(const ParserState &state) const {
|
|
|
+ FeatureVector result;
|
|
|
+ feature_extractor_->ExtractFeatures(workspace_, state, &result);
|
|
|
+ return result.size() > 0 ? result.value(0) : -1;
|
|
|
}
|
|
|
|
|
|
void CheckStarts(const ParserState &state, const vector<int> &target) {
|
|
@@ -48,10 +103,18 @@ class SegmentationTransitionTest : public ::testing::Test {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- // The test document, parse tree, and sentence with tags and partial parses.
|
|
|
+ // The test sentence.
|
|
|
std::unique_ptr<Sentence> sentence_;
|
|
|
+
|
|
|
+ // Members for testing features.
|
|
|
+ std::unique_ptr<ParserFeatureExtractor> feature_extractor_;
|
|
|
+ std::unique_ptr<TaskContext> context_;
|
|
|
+ std::unique_ptr<WorkspaceRegistry> registry_;
|
|
|
+ WorkspaceSet workspace_;
|
|
|
+
|
|
|
std::unique_ptr<ParserTransitionSystem> transition_system_;
|
|
|
TermFrequencyMap label_map_;
|
|
|
+ TermFrequencyMap word_map_;
|
|
|
};
|
|
|
|
|
|
TEST_F(SegmentationTransitionTest, GoldNextActionTest) {
|
|
@@ -108,4 +171,62 @@ TEST_F(SegmentationTransitionTest, DefaultActionTest) {
|
|
|
EXPECT_EQ(sentence_->token(4).word(), "样");
|
|
|
}
|
|
|
|
|
|
+TEST_F(SegmentationTransitionTest, LastWordFeatureTest) {
|
|
|
+ const int unk_id = word_map_.Size();
|
|
|
+ const int outside_id = unk_id + 1;
|
|
|
+
|
|
|
+ // Prepare a parser state.
|
|
|
+ BinarySegmentState *segment_state = new BinarySegmentState();
|
|
|
+ auto state = std::unique_ptr<ParserState>(new ParserState(
|
|
|
+ sentence_.get(), segment_state, &label_map_));
|
|
|
+
|
|
|
+ // Test initial state which contains no words.
|
|
|
+ PrepareFeature("last-word(1,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+ PrepareFeature("last-word(2,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+ PrepareFeature("last-word(3,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+
|
|
|
+ // Test when the state contains only one start.
|
|
|
+ segment_state->AddStart(0, state.get());
|
|
|
+ PrepareFeature("last-word(1,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+ PrepareFeature("last-word(2,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+
|
|
|
+ // Test when the state contains two starts, which forms a complete word and
|
|
|
+ // the start of another new word.
|
|
|
+ segment_state->AddStart(2, state.get());
|
|
|
+ EXPECT_NE(word_map_.LookupIndex("因为", unk_id), unk_id);
|
|
|
+ PrepareFeature("last-word(1)", state.get());
|
|
|
+ EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
|
|
|
+
|
|
|
+ // The last-word still points to outside.
|
|
|
+ PrepareFeature("last-word(2,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(outside_id, ComputeFeature(*state));
|
|
|
+
|
|
|
+ // Adding more starts that leads to the following words:
|
|
|
+ // 因为 ‘ ’ 有 ‘ ’
|
|
|
+ segment_state->AddStart(3, state.get());
|
|
|
+ segment_state->AddStart(4, state.get());
|
|
|
+
|
|
|
+ // Note 有 is pruned from the map since its frequency is less than 2.
|
|
|
+ EXPECT_EQ(word_map_.LookupIndex("有", unk_id), unk_id);
|
|
|
+ PrepareFeature("last-word(1,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(unk_id, ComputeFeature(*state));
|
|
|
+
|
|
|
+ // Note that last-word(2) points to ' ' which is also a unk.
|
|
|
+ PrepareFeature("last-word(2,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(unk_id, ComputeFeature(*state));
|
|
|
+ PrepareFeature("last-word(3,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(word_map_.LookupIndex("因为", unk_id), ComputeFeature(*state));
|
|
|
+
|
|
|
+ // Adding two words: "这" and "样".
|
|
|
+ segment_state->AddStart(5, state.get());
|
|
|
+ segment_state->AddStart(6, state.get());
|
|
|
+ PrepareFeature("last-word(1,min-freq=2)", state.get());
|
|
|
+ EXPECT_EQ(word_map_.LookupIndex("这", unk_id), ComputeFeature(*state));
|
|
|
+}
|
|
|
+
|
|
|
} // namespace syntaxnet
|