Browse Source

Update DRAGNN, fix some macOS issues

Ivan Bogatyy 8 years ago
parent
commit
ea3fa4a338
100 changed files with 3345 additions and 169 deletions
  1. 34 0
      syntaxnet/dragnn/components/stateless/BUILD
  2. 131 0
      syntaxnet/dragnn/components/stateless/stateless_component.cc
  3. 171 0
      syntaxnet/dragnn/components/stateless/stateless_component_test.cc
  4. 8 9
      syntaxnet/dragnn/components/syntaxnet/BUILD
  5. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc
  6. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h
  7. 104 5
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc
  8. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc
  9. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h
  10. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc
  11. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc
  12. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h
  13. 15 0
      syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc
  14. 5 2
      syntaxnet/dragnn/components/util/BUILD
  15. 15 0
      syntaxnet/dragnn/components/util/bulk_feature_extractor.h
  16. 23 31
      syntaxnet/dragnn/core/BUILD
  17. 18 2
      syntaxnet/dragnn/core/beam.h
  18. 15 0
      syntaxnet/dragnn/core/beam_test.cc
  19. 15 0
      syntaxnet/dragnn/core/component_registry.cc
  20. 15 0
      syntaxnet/dragnn/core/component_registry.h
  21. 15 0
      syntaxnet/dragnn/core/compute_session.h
  22. 15 0
      syntaxnet/dragnn/core/compute_session_impl.cc
  23. 15 0
      syntaxnet/dragnn/core/compute_session_impl.h
  24. 15 0
      syntaxnet/dragnn/core/compute_session_impl_test.cc
  25. 15 0
      syntaxnet/dragnn/core/compute_session_pool.cc
  26. 15 0
      syntaxnet/dragnn/core/compute_session_pool.h
  27. 15 0
      syntaxnet/dragnn/core/compute_session_pool_test.cc
  28. 15 0
      syntaxnet/dragnn/core/index_translator.cc
  29. 15 0
      syntaxnet/dragnn/core/index_translator.h
  30. 15 0
      syntaxnet/dragnn/core/index_translator_test.cc
  31. 20 5
      syntaxnet/dragnn/core/input_batch_cache.h
  32. 15 0
      syntaxnet/dragnn/core/input_batch_cache_test.cc
  33. 4 1
      syntaxnet/dragnn/core/interfaces/BUILD
  34. 15 0
      syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h
  35. 15 0
      syntaxnet/dragnn/core/interfaces/component.h
  36. 15 0
      syntaxnet/dragnn/core/interfaces/input_batch.h
  37. 15 0
      syntaxnet/dragnn/core/interfaces/transition_state.h
  38. 15 0
      syntaxnet/dragnn/core/interfaces/transition_state_starter_test.cc
  39. 15 0
      syntaxnet/dragnn/core/ops/compute_session_op.cc
  40. 15 0
      syntaxnet/dragnn/core/ops/compute_session_op.h
  41. 15 0
      syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc
  42. 15 0
      syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc
  43. 22 3
      syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc
  44. 15 0
      syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc
  45. 15 0
      syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc
  46. 21 0
      syntaxnet/dragnn/core/ops/dragnn_ops.cc
  47. 15 0
      syntaxnet/dragnn/core/resource_container.h
  48. 15 0
      syntaxnet/dragnn/core/resource_container_test.cc
  49. 6 7
      syntaxnet/dragnn/core/test/BUILD
  50. 15 0
      syntaxnet/dragnn/core/test/generic.cc
  51. 15 0
      syntaxnet/dragnn/core/test/generic.h
  52. 15 0
      syntaxnet/dragnn/core/test/mock_component.h
  53. 15 0
      syntaxnet/dragnn/core/test/mock_compute_session.h
  54. 15 0
      syntaxnet/dragnn/core/test/mock_transition_state.h
  55. 15 0
      syntaxnet/dragnn/io/sentence_input_batch.cc
  56. 15 0
      syntaxnet/dragnn/io/sentence_input_batch.h
  57. 15 0
      syntaxnet/dragnn/io/sentence_input_batch_test.cc
  58. 15 0
      syntaxnet/dragnn/io/syntaxnet_sentence.h
  59. 1 0
      syntaxnet/dragnn/python/BUILD
  60. 22 7
      syntaxnet/dragnn/python/biaffine_units.py
  61. 19 2
      syntaxnet/dragnn/python/bulk_component.py
  62. 15 0
      syntaxnet/dragnn/python/bulk_component_test.py
  63. 43 3
      syntaxnet/dragnn/python/component.py
  64. 15 0
      syntaxnet/dragnn/python/composite_optimizer.py
  65. 16 2
      syntaxnet/dragnn/python/composite_optimizer_test.py
  66. 15 0
      syntaxnet/dragnn/python/digraph_ops.py
  67. 15 0
      syntaxnet/dragnn/python/digraph_ops_test.py
  68. 15 0
      syntaxnet/dragnn/python/dragnn_ops.py
  69. 49 27
      syntaxnet/dragnn/python/graph_builder.py
  70. 32 0
      syntaxnet/dragnn/python/graph_builder_test.py
  71. 41 40
      syntaxnet/dragnn/python/network_units.py
  72. 15 0
      syntaxnet/dragnn/python/network_units_test.py
  73. 15 0
      syntaxnet/dragnn/python/render_parse_tree_graphviz.py
  74. 15 0
      syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py
  75. 15 0
      syntaxnet/dragnn/python/render_spec_with_graphviz.py
  76. 15 0
      syntaxnet/dragnn/python/render_spec_with_graphviz_test.py
  77. 15 0
      syntaxnet/dragnn/python/sentence_io.py
  78. 15 0
      syntaxnet/dragnn/python/sentence_io_test.py
  79. 6 2
      syntaxnet/dragnn/python/spec_builder.py
  80. 17 2
      syntaxnet/dragnn/python/trainer_lib.py
  81. 15 0
      syntaxnet/dragnn/python/visualization.py
  82. 15 0
      syntaxnet/dragnn/python/visualization_test.py
  83. 41 17
      syntaxnet/dragnn/python/wrapped_units.py
  84. 38 1
      syntaxnet/dragnn/tools/BUILD
  85. 15 0
      syntaxnet/dragnn/tools/build_pip_package.py
  86. 21 0
      syntaxnet/dragnn/tools/evaluator.py
  87. 197 0
      syntaxnet/dragnn/tools/model_trainer.py
  88. 54 0
      syntaxnet/dragnn/tools/model_trainer_test.sh
  89. 15 0
      syntaxnet/dragnn/tools/oss_notebook_launcher.py
  90. 15 0
      syntaxnet/dragnn/tools/parse-to-conll.py
  91. 0 1
      syntaxnet/dragnn/tools/parser_trainer.py
  92. 15 0
      syntaxnet/dragnn/tools/segmenter-evaluator.py
  93. 4 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt
  94. 18 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt
  95. 1135 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt
  96. 7 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/resources/category-map
  97. 18 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-map
  98. 46 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-ngram-map
  99. 8 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/resources/label-map
  100. 0 0
      syntaxnet/dragnn/tools/testdata/biaffine.model/resources/lcword-map

+ 34 - 0
syntaxnet/dragnn/components/stateless/BUILD

@@ -0,0 +1,34 @@
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
+
+cc_library(
+    name = "stateless_component",
+    srcs = ["stateless_component.cc"],
+    deps = [
+        "//dragnn/core:component_registry",
+        "//dragnn/core/interfaces:component",
+        "//dragnn/core/interfaces:transition_state",
+        "//dragnn/io:sentence_input_batch",
+        "//dragnn/protos:data_proto",
+        "//syntaxnet:base",
+    ],
+    alwayslink = 1,
+)
+
+cc_test(
+    name = "stateless_component_test",
+    srcs = ["stateless_component_test.cc"],
+    deps = [
+        ":stateless_component",
+        "//dragnn/core:component_registry",
+        "//dragnn/core:input_batch_cache",
+        "//dragnn/core/test:generic",
+        "//dragnn/core/test:mock_transition_state",
+        "//dragnn/io:sentence_input_batch",
+        "//syntaxnet:base",
+        "//syntaxnet:sentence_proto",
+        "//syntaxnet:test_main",
+    ],
+)

+ 131 - 0
syntaxnet/dragnn/components/stateless/stateless_component.cc

@@ -0,0 +1,131 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/interfaces/component.h"
+#include "dragnn/core/interfaces/transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "dragnn/protos/data.pb.h"
+#include "syntaxnet/base.h"
+
+namespace syntaxnet {
+namespace dragnn {
+namespace {
+
+// A component that does not create its own transition states; instead, it
+// simply forwards the states of the previous component.  Does not support all
+// methods.  Intended for "compute-only" bulk components that only use linked
+// features, which use only a small subset of DRAGNN functionality.
+class StatelessComponent : public Component {
+ public:
+  void InitializeComponent(const ComponentSpec &spec) override {
+    name_ = spec.name();
+  }
+
+  // Stores the |parent_states| for forwarding to downstream components.
+  void InitializeData(
+      const std::vector<std::vector<const TransitionState *>> &parent_states,
+      int max_beam_size, InputBatchCache *input_data) override {
+    // Must use SentenceInputBatch to match SyntaxNetComponent.
+    batch_size_ = input_data->GetAs<SentenceInputBatch>()->data()->size();
+    beam_size_ = max_beam_size;
+    parent_states_ = parent_states;
+
+    // The beam should be wide enough for the previous component.
+    for (const auto &beam : parent_states) {
+      CHECK_LE(beam.size(), beam_size_);
+    }
+  }
+
+  // Forwards the states of the previous component.
+  std::vector<std::vector<const TransitionState *>> GetBeam() override {
+    return parent_states_;
+  }
+
+  // Forwards the |current_index| to the previous component.
+  int GetSourceBeamIndex(int current_index, int batch) const override {
+    return current_index;
+  }
+
+  string Name() const override { return name_; }
+  int BeamSize() const override { return beam_size_; }
+  int BatchSize() const override { return batch_size_; }
+  int StepsTaken(int batch_index) const override { return 0; }
+  bool IsReady() const override { return true; }
+  bool IsTerminal() const override { return true; }
+  void FinalizeData() override {}
+  void ResetComponent() override {}
+  void InitializeTracing() override {}
+  void DisableTracing() override {}
+  std::vector<std::vector<ComponentTrace>> GetTraceProtos() const override {
+    return {};
+  }
+
+  // Unsupported methods.
+  int GetBeamIndexAtStep(int step, int current_index,
+                         int batch) const override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return 0;
+  }
+  std::function<int(int, int, int)> GetStepLookupFunction(
+      const string &method) override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return nullptr;
+  }
+  void AdvanceFromPrediction(const float transition_matrix[],
+                             int matrix_length) override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+  }
+  void AdvanceFromOracle() override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+  }
+  std::vector<std::vector<int>> GetOracleLabels() const override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return {};
+  }
+  int GetFixedFeatures(std::function<int32 *(int)> allocate_indices,
+                       std::function<int64 *(int)> allocate_ids,
+                       std::function<float *(int)> allocate_weights,
+                       int channel_id) const override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return 0;
+  }
+  int BulkGetFixedFeatures(const BulkFeatureExtractor &extractor) override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return 0;
+  }
+  std::vector<LinkFeatures> GetRawLinkFeatures(int channel_id) const override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+    return {};
+  }
+  void AddTranslatedLinkFeaturesToTrace(
+      const std::vector<LinkFeatures> &features, int channel_id) override {
+    LOG(FATAL) << "[" << name_ << "] Method not supported";
+  }
+
+ private:
+  string name_;  // component name
+  int batch_size_ = 1;  // number of sentences in current batch
+  int beam_size_ = 1;  // maximum beam size
+
+  // Parent states passed to InitializeData(), and passed along in GetBeam().
+  std::vector<std::vector<const TransitionState *>> parent_states_;
+};
+
+REGISTER_DRAGNN_COMPONENT(StatelessComponent);
+
+}  // namespace
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 171 - 0
syntaxnet/dragnn/components/stateless/stateless_component_test.cc

@@ -0,0 +1,171 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "dragnn/core/component_registry.h"
+#include "dragnn/core/input_batch_cache.h"
+#include "dragnn/core/test/generic.h"
+#include "dragnn/core/test/mock_transition_state.h"
+#include "dragnn/io/sentence_input_batch.h"
+#include "syntaxnet/base.h"
+#include "syntaxnet/sentence.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace syntaxnet {
+namespace dragnn {
+namespace {
+
+const char kSentence0[] = R"(
+token {
+  word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+  break_level: NO_BREAK
+}
+token {
+  word: "0" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
+  break_level: NO_BREAK
+}
+)";
+
+const char kSentence1[] = R"(
+token {
+  word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+  break_level: NO_BREAK
+}
+token {
+  word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "." start: 10 end: 10 head: 0 tag: "." category: "." label: "punct"
+  break_level: NO_BREAK
+}
+)";
+
+const char kLongSentence[] = R"(
+token {
+  word: "Sentence" start: 0 end: 7 tag: "NN" category: "NOUN" label: "ROOT"
+  break_level: NO_BREAK
+}
+token {
+  word: "1" start: 9 end: 9 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "2" start: 10 end: 10 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "3" start: 11 end: 11 head: 0 tag: "CD" category: "NUM" label: "num"
+  break_level: SPACE_BREAK
+}
+token {
+  word: "." start: 12 end: 12 head: 0 tag: "." category: "." label: "punct"
+  break_level: NO_BREAK
+}
+)";
+
+const char kMasterSpec[] = R"(
+component {
+  name: "test"
+  transition_system {
+    registered_name: "shift-only"
+  }
+  linked_feature {
+    name: "prev"
+    fml: "input.focus"
+    embedding_dim: 32
+    size: 1
+    source_component: "prev"
+    source_translator: "identity"
+    source_layer: "last_layer"
+  }
+  backend {
+    registered_name: "StatelessComponent"
+  }
+}
+)";
+
+}  // namespace
+
+using testing::Return;
+
+class StatelessComponentTest : public ::testing::Test {
+ public:
+  std::unique_ptr<Component> CreateParser(
+      int beam_size,
+      const std::vector<std::vector<const TransitionState *>> &states,
+      const std::vector<string> &data) {
+    MasterSpec master_spec;
+    CHECK(TextFormat::ParseFromString(kMasterSpec, &master_spec));
+    data_.reset(new InputBatchCache(data));
+
+    // Create a parser component with the specified beam size.
+    std::unique_ptr<Component> parser_component(
+        Component::Create("StatelessComponent"));
+    parser_component->InitializeComponent(master_spec.component(0));
+    parser_component->InitializeData(states, beam_size, data_.get());
+    return parser_component;
+  }
+
+  std::unique_ptr<InputBatchCache> data_;
+};
+
+TEST_F(StatelessComponentTest, ForwardsTransitionStates) {
+  const MockTransitionState mock_state_1, mock_state_2, mock_state_3;
+  const std::vector<std::vector<const TransitionState *>> parent_states = {
+      {}, {&mock_state_1}, {&mock_state_2, &mock_state_3}};
+
+  std::vector<string> data;
+  for (const string &textproto : {kSentence0, kSentence1, kLongSentence}) {
+    Sentence sentence;
+    CHECK(TextFormat::ParseFromString(textproto, &sentence));
+    data.emplace_back();
+    CHECK(sentence.SerializeToString(&data.back()));
+  }
+  CHECK_EQ(parent_states.size(), data.size());
+
+  const int kBeamSize = 2;
+  auto test_parser = CreateParser(kBeamSize, parent_states, data);
+
+  EXPECT_TRUE(test_parser->IsReady());
+  EXPECT_TRUE(test_parser->IsTerminal());
+  EXPECT_EQ(kBeamSize, test_parser->BeamSize());
+  EXPECT_EQ(data.size(), test_parser->BatchSize());
+  EXPECT_TRUE(test_parser->GetTraceProtos().empty());
+
+  for (int batch_index = 0; batch_index < parent_states.size(); ++batch_index) {
+    EXPECT_EQ(0, test_parser->StepsTaken(batch_index));
+    const auto &beam = parent_states[batch_index];
+    for (int beam_index = 0; beam_index < beam.size(); ++beam_index) {
+      // Expect an identity mapping.
+      EXPECT_EQ(beam_index,
+                test_parser->GetSourceBeamIndex(beam_index, batch_index));
+    }
+  }
+
+  const auto forwarded_states = test_parser->GetBeam();
+  EXPECT_EQ(parent_states, forwarded_states);
+}
+
+}  // namespace dragnn
+}  // namespace syntaxnet

+ 8 - 9
syntaxnet/dragnn/components/syntaxnet/BUILD

@@ -1,4 +1,7 @@
-package(default_visibility = ["//visibility:public"])
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
 
 cc_library(
     name = "syntaxnet_component",
@@ -25,7 +28,6 @@ cc_library(
         "//syntaxnet:task_context",
         "//syntaxnet:task_spec_proto",
         "//syntaxnet:utils",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
     ],
     alwayslink = 1,
 )
@@ -36,10 +38,10 @@ cc_library(
     hdrs = ["syntaxnet_link_feature_extractor.h"],
     deps = [
         "//dragnn/protos:spec_proto",
+        "//syntaxnet:base",
         "//syntaxnet:embedding_feature_extractor",
         "//syntaxnet:parser_transitions",
         "//syntaxnet:task_context",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
     ],
 )
 
@@ -54,7 +56,6 @@ cc_library(
         "//dragnn/protos:trace_proto",
         "//syntaxnet:base",
         "//syntaxnet:parser_transitions",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
     ],
 )
 
@@ -75,9 +76,9 @@ cc_test(
         "//dragnn/core/test:generic",
         "//dragnn/core/test:mock_transition_state",
         "//dragnn/io:sentence_input_batch",
+        "//syntaxnet:base",
         "//syntaxnet:sentence_proto",
-        "@org_tensorflow//tensorflow/core:lib",
-        "@org_tensorflow//tensorflow/core:test",
+        "//syntaxnet:test_main",
     ],
 )
 
@@ -90,7 +91,6 @@ cc_test(
         "//dragnn/protos:spec_proto",
         "//syntaxnet:task_context",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
         "@org_tensorflow//tensorflow/core:testlib",
     ],
 )
@@ -107,10 +107,9 @@ cc_test(
         "//dragnn/core/test:mock_transition_state",
         "//dragnn/io:sentence_input_batch",
         "//dragnn/protos:spec_proto",
+        "//syntaxnet:base",
         "//syntaxnet:sentence_proto",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:lib",
-        "@org_tensorflow//tensorflow/core:test",
         "@org_tensorflow//tensorflow/core:testlib",
     ],
 )

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_component.h"
 
 #include <vector>

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_COMPONENT_H_
 

+ 104 - 5
syntaxnet/dragnn/components/syntaxnet/syntaxnet_component_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_component.h"
 
 #include "dragnn/core/input_batch_cache.h"
@@ -833,10 +848,94 @@ TEST_F(SyntaxNetComponentTest, ExportsFixedFeatures) {
   const int num_features =
       test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
 
+  constexpr int kExpectedOutputSize = 12;
+  const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
+  const vector<int64> expected_ids({7, 50, 12, 7, 12, 7, 7, 50, 12, 7, 12, 7});
+  const vector<float> expected_weights(
+      {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
+
+  EXPECT_EQ(expected_indices.size(), kExpectedOutputSize);
+  EXPECT_EQ(expected_ids.size(), kExpectedOutputSize);
+  EXPECT_EQ(expected_weights.size(), kExpectedOutputSize);
+  EXPECT_EQ(num_features, kExpectedOutputSize);
+
+  EXPECT_EQ(expected_indices, indices);
+  EXPECT_EQ(expected_ids, ids);
+  EXPECT_EQ(expected_weights, weights);
+}
+
+TEST_F(SyntaxNetComponentTest, AdvancesAccordingToHighestWeightedInputOption) {
+  // Create an empty input batch and beam vector to initialize the parser.
+  Sentence sentence_0;
+  TextFormat::ParseFromString(kSentence0, &sentence_0);
+  string sentence_0_str;
+  sentence_0.SerializeToString(&sentence_0_str);
+
+  Sentence sentence_1;
+  TextFormat::ParseFromString(kSentence1, &sentence_1);
+  string sentence_1_str;
+  sentence_1.SerializeToString(&sentence_1_str);
+
+  constexpr int kBeamSize = 3;
+
+  auto test_parser =
+      CreateParserWithBeamSize(kBeamSize, {}, {sentence_0_str, sentence_1_str});
+
+  // There are 93 possible transitions for any given state. Create a transition
+  // array with a score of 10.0 for each transition.
+  constexpr int kBatchSize = 2;
+  constexpr int kNumPossibleTransitions = 93;
+  constexpr float kTransitionValue = 10.0;
+  float transition_matrix[kNumPossibleTransitions * kBeamSize * kBatchSize];
+  for (int i = 0; i < kNumPossibleTransitions * kBeamSize * kBatchSize; ++i) {
+    transition_matrix[i] = kTransitionValue;
+  }
+
+  // Replace the first several options with varying scores to test sorting.
+  constexpr int kBatchOffset = kNumPossibleTransitions * kBeamSize;
+  transition_matrix[0] = 3 * kTransitionValue;
+  transition_matrix[1] = 3 * kTransitionValue;
+  transition_matrix[2] = 4 * kTransitionValue;
+  transition_matrix[3] = 4 * kTransitionValue;
+  transition_matrix[4] = 2 * kTransitionValue;
+  transition_matrix[5] = 2 * kTransitionValue;
+  transition_matrix[kBatchOffset + 0] = 3 * kTransitionValue;
+  transition_matrix[kBatchOffset + 1] = 3 * kTransitionValue;
+  transition_matrix[kBatchOffset + 2] = 4 * kTransitionValue;
+  transition_matrix[kBatchOffset + 3] = 4 * kTransitionValue;
+  transition_matrix[kBatchOffset + 4] = 2 * kTransitionValue;
+  transition_matrix[kBatchOffset + 5] = 2 * kTransitionValue;
+
+  // Advance twice, so that the underlying parser fills the beam.
+  test_parser->AdvanceFromPrediction(
+      transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
+  test_parser->AdvanceFromPrediction(
+      transition_matrix, kNumPossibleTransitions * kBeamSize * kBatchSize);
+
+  // Get and check the raw link features.
+  vector<int32> indices;
+  auto indices_fn = [&indices](int size) {
+    indices.resize(size);
+    return indices.data();
+  };
+  vector<int64> ids;
+  auto ids_fn = [&ids](int size) {
+    ids.resize(size);
+    return ids.data();
+  };
+  vector<float> weights;
+  auto weights_fn = [&weights](int size) {
+    weights.resize(size);
+    return weights.data();
+  };
+  constexpr int kChannelId = 0;
+  const int num_features =
+      test_parser->GetFixedFeatures(indices_fn, ids_fn, weights_fn, kChannelId);
+
   // In this case, all even features and all odd features are identical.
   constexpr int kExpectedOutputSize = 12;
   const vector<int32> expected_indices({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11});
-  const vector<int64> expected_ids({12, 7, 12, 7, 12, 7, 12, 7, 12, 7, 12, 7});
+  const vector<int64> expected_ids({12, 7, 7, 50, 12, 7, 12, 7, 7, 50, 12, 7});
   const vector<float> expected_weights(
       {1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0});
 
@@ -1024,11 +1123,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
   EXPECT_EQ(link_features.size(), kBeamSize * kBatchSize * kNumLinkFeatures);
 
   // These should index into batch 0.
-  EXPECT_EQ(link_features.at(0).feature_value(), -1);
+  EXPECT_EQ(link_features.at(0).feature_value(), 1);
   EXPECT_EQ(link_features.at(0).batch_idx(), 0);
   EXPECT_EQ(link_features.at(0).beam_idx(), 0);
 
-  EXPECT_EQ(link_features.at(1).feature_value(), -2);
+  EXPECT_EQ(link_features.at(1).feature_value(), 0);
   EXPECT_EQ(link_features.at(1).batch_idx(), 0);
   EXPECT_EQ(link_features.at(1).beam_idx(), 0);
 
@@ -1049,11 +1148,11 @@ TEST_F(SyntaxNetComponentTest, ExportsRawLinkFeatures) {
   EXPECT_EQ(link_features.at(5).beam_idx(), 2);
 
   // These should index into batch 1.
-  EXPECT_EQ(link_features.at(6).feature_value(), -1);
+  EXPECT_EQ(link_features.at(6).feature_value(), 1);
   EXPECT_EQ(link_features.at(6).batch_idx(), 1);
   EXPECT_EQ(link_features.at(6).beam_idx(), 0);
 
-  EXPECT_EQ(link_features.at(7).feature_value(), -2);
+  EXPECT_EQ(link_features.at(7).feature_value(), 0);
   EXPECT_EQ(link_features.at(7).batch_idx(), 1);
   EXPECT_EQ(link_features.at(7).beam_idx(), 0);
 

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
 
 #include "tensorflow/core/platform/logging.h"

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_LINK_FEATURE_EXTRACTOR_H_
 

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_link_feature_extractor_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_link_feature_extractor.h"
 
 #include <string>

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
 
 #include "tensorflow/core/lib/strings/strcat.h"

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_SYNTAXNET_SYNTAXNET_TRANSITION_STATE_H_
 

+ 15 - 0
syntaxnet/dragnn/components/syntaxnet/syntaxnet_transition_state_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/syntaxnet/syntaxnet_transition_state.h"
 
 #include "dragnn/components/syntaxnet/syntaxnet_component.h"

+ 5 - 2
syntaxnet/dragnn/components/util/BUILD

@@ -1,9 +1,12 @@
-package(default_visibility = ["//visibility:public"])
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
 
 cc_library(
     name = "bulk_feature_extractor",
     hdrs = ["bulk_feature_extractor.h"],
     deps = [
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
     ],
 )

+ 15 - 0
syntaxnet/dragnn/components/util/bulk_feature_extractor.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_COMPONENTS_UTIL_BULK_FEATURE_EXTRACTOR_H_
 

+ 23 - 31
syntaxnet/dragnn/core/BUILD

@@ -1,4 +1,7 @@
-package(default_visibility = ["//visibility:public"])
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
 
 # Test data.
 filegroup(
@@ -12,7 +15,7 @@ cc_library(
     deps = [
         "//dragnn/core/interfaces:cloneable_transition_state",
         "//dragnn/core/interfaces:transition_state",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+        "//syntaxnet:base",
     ],
 )
 
@@ -50,8 +53,8 @@ cc_library(
         "//dragnn/protos:data_proto",
         "//dragnn/protos:spec_proto",
         "//dragnn/protos:trace_proto",
+        "//syntaxnet:base",
         "//syntaxnet:registry",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
     ],
 )
 
@@ -64,7 +67,7 @@ cc_library(
         ":compute_session",
         ":compute_session_impl",
         "//dragnn/protos:spec_proto",
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
     ],
 )
 
@@ -75,7 +78,7 @@ cc_library(
     deps = [
         "//dragnn/core/interfaces:component",
         "//dragnn/core/interfaces:transition_state",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+        "//syntaxnet:base",
     ],
 )
 
@@ -84,17 +87,14 @@ cc_library(
     hdrs = ["input_batch_cache.h"],
     deps = [
         "//dragnn/core/interfaces:input_batch",
-        "@org_tensorflow//tensorflow/core:lib",  # For tf/core/platform/logging.h
+        "//syntaxnet:base",
     ],
 )
 
 cc_library(
     name = "resource_container",
     hdrs = ["resource_container.h"],
-    deps = [
-        "//syntaxnet:base",
-        "@org_tensorflow//tensorflow/core:framework",
-    ],
+    deps = ["//syntaxnet:base"],
 )
 
 # Tests
@@ -107,8 +107,8 @@ cc_test(
         "//dragnn/core/interfaces:cloneable_transition_state",
         "//dragnn/core/interfaces:transition_state",
         "//dragnn/core/test:mock_transition_state",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -125,7 +125,7 @@ cc_test(
         "//dragnn/core/test:generic",
         "//dragnn/core/test:mock_component",
         "//dragnn/core/test:mock_transition_state",
-        "@org_tensorflow//tensorflow/core:test",
+        "//syntaxnet:base",
     ],
 )
 
@@ -138,9 +138,8 @@ cc_test(
         "//dragnn/core/test:generic",
         "//dragnn/core/test:mock_component",
         "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:lib",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -151,8 +150,8 @@ cc_test(
         ":index_translator",
         "//dragnn/core/test:mock_component",
         "//dragnn/core/test:mock_transition_state",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -162,8 +161,8 @@ cc_test(
     deps = [
         ":input_batch_cache",
         "//dragnn/core/interfaces:input_batch",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -172,8 +171,8 @@ cc_test(
     srcs = ["resource_container_test.cc"],
     deps = [
         ":resource_container",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -213,7 +212,7 @@ cc_library(
     deps = [
         ":compute_session",
         ":resource_container",
-        "@org_tensorflow//tensorflow/core:framework",
+        "//syntaxnet:base",
         "@org_tensorflow//third_party/eigen3",
     ],
 )
@@ -231,8 +230,7 @@ cc_library(
         ":resource_container",
         "//dragnn/protos:data_proto",
         "//dragnn/protos:spec_proto",
-        "@org_tensorflow//tensorflow/core:framework",
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
         "@org_tensorflow//third_party/eigen3",
     ],
     alwayslink = 1,
@@ -247,8 +245,7 @@ cc_library(
     deps = [
         ":compute_session_op",
         ":resource_container",
-        "@org_tensorflow//tensorflow/core:framework",
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
         "@org_tensorflow//tensorflow/core:protos_all_cc",
         "@org_tensorflow//third_party/eigen3",
     ],
@@ -271,8 +268,7 @@ tf_kernel_library(
         ":resource_container",
         "//dragnn/protos:data_proto",
         "//dragnn/protos:spec_proto",
-        "@org_tensorflow//tensorflow/core:framework",
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
         "@org_tensorflow//third_party/eigen3",
     ],
 )
@@ -292,8 +288,7 @@ tf_kernel_library(
         ":resource_container",
         "//dragnn/components/util:bulk_feature_extractor",
         "//dragnn/protos:spec_proto",
-        "@org_tensorflow//tensorflow/core:framework",
-        "@org_tensorflow//tensorflow/core:lib",
+        "//syntaxnet:base",
         "@org_tensorflow//tensorflow/core:protos_all_cc",
         "@org_tensorflow//third_party/eigen3",
     ],
@@ -311,11 +306,9 @@ cc_test(
         ":resource_container",
         "//dragnn/core/test:generic",
         "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:framework",
         "@org_tensorflow//tensorflow/core:protos_all_cc",
-        "@org_tensorflow//tensorflow/core:test",
-        "@org_tensorflow//tensorflow/core:testlib",
         "@org_tensorflow//tensorflow/core/kernels:ops_testutil",
         "@org_tensorflow//tensorflow/core/kernels:ops_util",
         "@org_tensorflow//tensorflow/core/kernels:quantized_ops",
@@ -331,9 +324,8 @@ cc_test(
         ":resource_container",
         "//dragnn/components/util:bulk_feature_extractor",
         "//dragnn/core/test:mock_compute_session",
+        "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:framework",
-        "@org_tensorflow//tensorflow/core:testlib",
         "@org_tensorflow//tensorflow/core/kernels:ops_testutil",
         "@org_tensorflow//tensorflow/core/kernels:quantized_ops",
     ],

+ 18 - 2
syntaxnet/dragnn/core/beam.h

@@ -1,7 +1,23 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_BEAM_H_
 
 #include <algorithm>
+#include <cmath>
 #include <memory>
 #include <vector>
 
@@ -112,7 +128,7 @@ class Beam {
             CHECK_LT(matrix_idx, matrix_length)
                 << "Matrix index out of bounds!";
             const double score_delta = transition_matrix[matrix_idx];
-            CHECK(!isnan(score_delta));
+            CHECK(!std::isnan(score_delta));
             candidate.source_idx = beam_idx;
             candidate.action = action_idx;
             candidate.resulting_score = state->GetScore() + score_delta;
@@ -125,7 +141,7 @@ class Beam {
       const auto comparator = [](const Transition &a, const Transition &b) {
         return a.resulting_score > b.resulting_score;
       };
-      std::sort(candidates.begin(), candidates.end(), comparator);
+      std::stable_sort(candidates.begin(), candidates.end(), comparator);
 
       // Apply the top transitions, up to a maximum of 'max_size_'.
       std::vector<std::unique_ptr<T>> new_beam;

+ 15 - 0
syntaxnet/dragnn/core/beam_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/beam.h"
 
 #include "dragnn/core/interfaces/cloneable_transition_state.h"

+ 15 - 0
syntaxnet/dragnn/core/component_registry.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/component_registry.h"
 
 namespace syntaxnet {

+ 15 - 0
syntaxnet/dragnn/core/component_registry.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPONENT_REGISTRY_H_
 

+ 15 - 0
syntaxnet/dragnn/core/compute_session.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_H_
 

+ 15 - 0
syntaxnet/dragnn/core/compute_session_impl.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/compute_session_impl.h"
 
 #include <algorithm>

+ 15 - 0
syntaxnet/dragnn/core/compute_session_impl.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_IMPL_H_
 

+ 15 - 0
syntaxnet/dragnn/core/compute_session_impl_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/compute_session_impl.h"
 
 #include <memory>

+ 15 - 0
syntaxnet/dragnn/core/compute_session_pool.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/compute_session_pool.h"
 
 #include <utility>

+ 15 - 0
syntaxnet/dragnn/core/compute_session_pool.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_COMPUTE_SESSION_POOL_H_
 

+ 15 - 0
syntaxnet/dragnn/core/compute_session_pool_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/compute_session_pool.h"
 
 #include <memory>

+ 15 - 0
syntaxnet/dragnn/core/index_translator.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/index_translator.h"
 
 #include "tensorflow/core/platform/logging.h"

+ 15 - 0
syntaxnet/dragnn/core/index_translator.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INDEX_TRANSLATOR_H_
 

+ 15 - 0
syntaxnet/dragnn/core/index_translator_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/index_translator.h"
 
 #include "dragnn/core/test/mock_component.h"

+ 20 - 5
syntaxnet/dragnn/core/input_batch_cache.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INPUT_BATCH_CACHE_H_
 
@@ -16,14 +31,14 @@ namespace dragnn {
 
 class InputBatchCache {
  public:
-  // Create an empty cache..
+  // Creates an empty cache.
   InputBatchCache() : stored_type_(std::type_index(typeid(void))) {}
 
-  // Create a InputBatchCache from a single example. This copies the string.
+  // Creates a InputBatchCache from a single example. This copies the string.
   explicit InputBatchCache(const string &data)
       : stored_type_(std::type_index(typeid(void))), source_data_({data}) {}
 
-  // Create a InputBatchCache from a vector of examples. The vector is copied.
+  // Creates a InputBatchCache from a vector of examples. The vector is copied.
   explicit InputBatchCache(const std::vector<string> &data)
       : stored_type_(std::type_index(typeid(void))), source_data_(data) {}
 
@@ -36,7 +51,7 @@ class InputBatchCache {
     source_data_.emplace_back(data);
   }
 
-  // Convert the stored strings into protos and return them in a specific
+  // Converts the stored strings into protos and return them in a specific
   // InputBatch subclass. T should always be of type InputBatch. After this
   // method is called once, all further calls must be of the same data type.
   template <class T>
@@ -54,7 +69,7 @@ class InputBatchCache {
     return dynamic_cast<T *>(converted_data_.get());
   }
 
-  // Return the serialized representation of the data held in the input batch
+  // Returns the serialized representation of the data held in the input batch
   // object within this cache.
   const std::vector<string> SerializedData() const {
     CHECK(converted_data_) << "Cannot return batch without data.";

+ 15 - 0
syntaxnet/dragnn/core/input_batch_cache_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/input_batch_cache.h"
 
 #include "dragnn/core/interfaces/input_batch.h"

+ 4 - 1
syntaxnet/dragnn/core/interfaces/BUILD

@@ -1,4 +1,7 @@
-package(default_visibility = ["//visibility:public"])
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
 
 cc_library(
     name = "cloneable_transition_state",

+ 15 - 0
syntaxnet/dragnn/core/interfaces/cloneable_transition_state.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_CLONEABLE_TRANSITION_STATE_H_
 

+ 15 - 0
syntaxnet/dragnn/core/interfaces/component.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_COMPONENT_H_
 

+ 15 - 0
syntaxnet/dragnn/core/interfaces/input_batch.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_INPUT_BATCH_H_
 

+ 15 - 0
syntaxnet/dragnn/core/interfaces/transition_state.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_INTERFACES_TRANSITION_STATE_H_
 

+ 15 - 0
syntaxnet/dragnn/core/interfaces/transition_state_starter_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/test/mock_transition_state.h"
 #include <gmock/gmock.h>
 #include "testing/base/public/googletest.h"

+ 15 - 0
syntaxnet/dragnn/core/ops/compute_session_op.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/ops/compute_session_op.h"
 
 #include "dragnn/core/compute_session.h"

+ 15 - 0
syntaxnet/dragnn/core/ops/compute_session_op.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_OPS_COMPUTE_SESSION_OP_H_
 

+ 15 - 0
syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include <math.h>
 #include <algorithm>
 #include <utility>

+ 15 - 0
syntaxnet/dragnn/core/ops/dragnn_bulk_op_kernels_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/components/util/bulk_feature_extractor.h"
 #include "dragnn/core/compute_session_pool.h"
 #include "dragnn/core/resource_container.h"

+ 22 - 3
syntaxnet/dragnn/core/ops/dragnn_bulk_ops.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "tensorflow/core/framework/op.h"
 #include "tensorflow/core/framework/shape_inference.h"
 
@@ -93,9 +108,13 @@ REGISTER_OP("BulkAdvanceFromPrediction")
     .Output("output_handle: string")
     .Attr("component: string")
     .Attr("T: type")
-    .SetShapeFn([](tensorflow::shape_inference::InferenceContext *context) {
-      auto scores = context->input(1);
-      TF_RETURN_IF_ERROR(context->WithRank(scores, 2, &scores));
+    .SetShapeFn([](tensorflow::shape_inference::InferenceContext *c) {
+      tensorflow::shape_inference::ShapeHandle handle;
+      TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->Vector(2), &handle));
+      c->set_output(0, handle);
+
+      auto scores = c->input(1);
+      TF_RETURN_IF_ERROR(c->WithRank(scores, 2, &scores));
       return tensorflow::Status::OK();
     })
     .Doc(R"doc(

+ 15 - 0
syntaxnet/dragnn/core/ops/dragnn_op_kernels.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include <memory>
 #include <string>
 #include <vector>

+ 15 - 0
syntaxnet/dragnn/core/ops/dragnn_op_kernels_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include <functional>
 #include <memory>
 #include <vector>

+ 21 - 0
syntaxnet/dragnn/core/ops/dragnn_ops.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "tensorflow/core/framework/op.h"
 
 namespace syntaxnet {
@@ -113,6 +128,8 @@ REGISTER_OP("DragnnEmbeddingInitializer")
     .Attr("embedding_input: string")
     .Attr("vocab: string")
     .Attr("scaling_coefficient: float = 1.0")
+    .Attr("seed: int = 0")
+    .Attr("seed2: int = 0")
     .Doc(R"doc(
 *** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
 
@@ -122,6 +139,10 @@ embeddings: A tensor containing embeddings from the specified sstable.
 embedding_input: Path to location with embedding vectors.
 vocab: Path to list of keys corresponding to the input.
 scaling_coefficient: A scaling coefficient for the embedding matrix.
+seed: If either `seed` or `seed2` are set to be non-zero, the random number
+      generator is seeded by the given seed.  Otherwise, it is seeded by a
+      random seed.
+seed2: A second seed to avoid seed collision.
 )doc");
 
 REGISTER_OP("ExtractFixedFeatures")

+ 15 - 0
syntaxnet/dragnn/core/resource_container.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_RESOURCE_CONTAINER_H_
 

+ 15 - 0
syntaxnet/dragnn/core/resource_container_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 // Tests the methods of ResourceContainer.
 //
 // NOTE(danielandor): For all tests: ResourceContainer is derived from

+ 6 - 7
syntaxnet/dragnn/core/test/BUILD

@@ -1,4 +1,7 @@
-package(default_visibility = ["//visibility:public"])
+package(
+    default_visibility = ["//visibility:public"],
+    features = ["-layering_check"],
+)
 
 cc_library(
     name = "mock_component",
@@ -13,7 +16,6 @@ cc_library(
         "//dragnn/protos:spec_proto",
         "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -27,7 +29,7 @@ cc_library(
         "//dragnn/protos:data_proto",
         "//dragnn/protos:spec_proto",
         "//syntaxnet:base",
-        "@org_tensorflow//tensorflow/core:test",
+        "//syntaxnet:test_main",
     ],
 )
 
@@ -39,7 +41,6 @@ cc_library(
         "//dragnn/core/interfaces:transition_state",
         "//syntaxnet:base",
         "//syntaxnet:test_main",
-        "@org_tensorflow//tensorflow/core:test",
     ],
 )
 
@@ -50,8 +51,6 @@ cc_library(
     hdrs = ["generic.h"],
     deps = [
         "//syntaxnet:base",
-        "@org_tensorflow//tensorflow/core:lib",
-        "@org_tensorflow//tensorflow/core:test",
-        "@org_tensorflow//tensorflow/core:testlib",
+        "//syntaxnet:test_main",
     ],
 )

+ 15 - 0
syntaxnet/dragnn/core/test/generic.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/core/test/generic.h"
 
 #include "tensorflow/core/lib/io/path.h"

+ 15 - 0
syntaxnet/dragnn/core/test/generic.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_GENERIC_H_
 

+ 15 - 0
syntaxnet/dragnn/core/test/mock_component.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPONENT_H_
 

+ 15 - 0
syntaxnet/dragnn/core/test/mock_compute_session.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_COMPUTE_SESSION_H_
 

+ 15 - 0
syntaxnet/dragnn/core/test/mock_transition_state.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_CORE_TEST_MOCK_TRANSITION_STATE_H_
 

+ 15 - 0
syntaxnet/dragnn/io/sentence_input_batch.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/io/sentence_input_batch.h"
 
 #include "syntaxnet/sentence.pb.h"

+ 15 - 0
syntaxnet/dragnn/io/sentence_input_batch.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SENTENCE_INPUT_BATCH_H_
 

+ 15 - 0
syntaxnet/dragnn/io/sentence_input_batch_test.cc

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #include "dragnn/io/sentence_input_batch.h"
 
 #include "dragnn/core/test/generic.h"

+ 15 - 0
syntaxnet/dragnn/io/syntaxnet_sentence.h

@@ -1,3 +1,18 @@
+// Copyright 2017 Google Inc. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
 #ifndef NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
 #define NLP_SAFT_OPENSOURCE_DRAGNN_IO_SYNTAXNET_SENTENCE_H_
 

+ 1 - 0
syntaxnet/dragnn/python/BUILD

@@ -9,6 +9,7 @@ cc_binary(
     linkshared = 1,
     linkstatic = 1,
     deps = [
+        "//dragnn/components/stateless:stateless_component",
         "//dragnn/components/syntaxnet:syntaxnet_component",
         "//dragnn/core:dragnn_bulk_ops_cc",
         "//dragnn/core:dragnn_ops_cc",

+ 22 - 7
syntaxnet/dragnn/python/biaffine_units.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Network units used in the Dozat and Manning (2017) biaffine parser."""
 
 from __future__ import absolute_import
@@ -68,13 +83,13 @@ class BiaffineDigraphNetwork(network_units.NetworkUnitInterface):
     self._weights = []
     self._weights.append(tf.get_variable(
         'weights_arc', [self._source_dim, self._target_dim], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
     self._weights.append(tf.get_variable(
         'weights_source', [self._source_dim], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
     self._weights.append(tf.get_variable(
         'root', [self._source_dim], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
 
     self._params.extend(self._weights)
     self._regularized_weights.extend(self._weights)
@@ -178,18 +193,18 @@ class BiaffineLabelNetwork(network_units.NetworkUnitInterface):
     self._weights = []
     self._weights.append(tf.get_variable(
         'weights_pair', [self._num_labels, self._source_dim, self._target_dim],
-        tf.float32, tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.float32, tf.random_normal_initializer(stddev=1e-4)))
     self._weights.append(tf.get_variable(
         'weights_source', [self._num_labels, self._source_dim], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
     self._weights.append(tf.get_variable(
         'weights_target', [self._num_labels, self._target_dim], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
 
     self._biases = []
     self._biases.append(tf.get_variable(
         'biases', [self._num_labels], tf.float32,
-        tf.random_normal_initializer(stddev=1e-4, seed=self._seed)))
+        tf.random_normal_initializer(stddev=1e-4)))
 
     self._params.extend(self._weights + self._biases)
     self._regularized_weights.extend(self._weights)

+ 19 - 2
syntaxnet/dragnn/python/bulk_component.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Component builders for non-recurrent networks in DRAGNN."""
 
 
@@ -249,7 +264,8 @@ class BulkFeatureExtractorComponentBuilder(component.ComponentBuilderBase):
     update_network_states(self, tensors, network_states, stride)
     cost = self.add_regularizer(tf.constant(0.))
 
-    return state.handle, cost, 0, 0
+    correct, total = tf.constant(0), tf.constant(0)
+    return state.handle, cost, correct, total
 
   def build_greedy_inference(self, state, network_states,
                              during_training=False):
@@ -327,7 +343,8 @@ class BulkFeatureIdExtractorComponentBuilder(component.ComponentBuilderBase):
     """See base class."""
     state.handle = self._extract_feature_ids(state, network_states, True)
     cost = self.add_regularizer(tf.constant(0.))
-    return state.handle, cost, 0, 0
+    correct, total = tf.constant(0), tf.constant(0)
+    return state.handle, cost, correct, total
 
   def build_greedy_inference(self, state, network_states,
                              during_training=False):

+ 15 - 0
syntaxnet/dragnn/python/bulk_component_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for bulk_component.
 
 Verifies that:

+ 43 - 3
syntaxnet/dragnn/python/component.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Builds a DRAGNN graph for local training."""
 
 from abc import ABCMeta
@@ -147,6 +162,32 @@ class ComponentBuilderBase(object):
     """
     pass
 
+  def build_structured_training(self, state, network_states):
+    """Builds a beam search based training loop for this component.
+
+    The default implementation builds a dummy graph and raises a
+    TensorFlow runtime exception to indicate that structured training
+    is not implemented.
+
+    Args:
+      state: MasterState from the 'AdvanceMaster' op that advances the
+        underlying master to this component.
+      network_states: dictionary of component NetworkState objects.
+
+    Returns:
+      (handle, cost, correct, total) -- These are TF ops corresponding
+      to the final handle after unrolling, the total cost, and the
+      total number of actions. Since the number of correctly predicted
+      actions is not applicable in the structured training setting, a
+      dummy value should returned.
+    """
+    del network_states  # Unused.
+    with tf.control_dependencies([tf.Assert(False, ['Not implemented.'])]):
+      handle = tf.identity(state.handle)
+    cost = tf.constant(0.)
+    correct, total = tf.constant(0), tf.constant(0)
+    return handle, cost, correct, total
+
   @abstractmethod
   def build_greedy_inference(self, state, network_states,
                              during_training=False):
@@ -349,14 +390,13 @@ class DynamicComponentBuilder(ComponentBuilderBase):
       correctly predicted actions, and the total number of actions.
     """
     logging.info('Building component: %s', self.spec.name)
-    stride = state.current_batch_size * self.training_beam_size
+    with tf.control_dependencies([tf.assert_equal(self.training_beam_size, 1)]):
+      stride = state.current_batch_size * self.training_beam_size
 
     cost = tf.constant(0.)
     correct = tf.constant(0)
     total = tf.constant(0)
 
-    # Create the TensorArray's to store activations for downstream/recurrent
-    # connections.
     def cond(handle, *_):
       all_final = dragnn_ops.emit_all_final(handle, component=self.name)
       return tf.logical_not(tf.reduce_all(all_final))

+ 15 - 0
syntaxnet/dragnn/python/composite_optimizer.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """An optimizer that switches between several methods."""
 
 import tensorflow as tf

+ 16 - 2
syntaxnet/dragnn/python/composite_optimizer_test.py

@@ -1,5 +1,19 @@
-"""Tests for CompositeOptimizer.
-"""
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for CompositeOptimizer."""
 
 
 import numpy as np

+ 15 - 0
syntaxnet/dragnn/python/digraph_ops.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """TensorFlow ops for directed graphs."""
 
 import tensorflow as tf

+ 15 - 0
syntaxnet/dragnn/python/digraph_ops_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for digraph ops."""
 
 import tensorflow as tf

+ 15 - 0
syntaxnet/dragnn/python/dragnn_ops.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Groups the DRAGNN TensorFlow ops in one module."""
 
 

+ 49 - 27
syntaxnet/dragnn/python/graph_builder.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Builds a DRAGNN graph for local training."""
 
 
@@ -65,6 +80,13 @@ def _create_optimizer(hyperparams, learning_rate_var, step_var=None):
         beta2=hyperparams.adam_beta2,
         epsilon=hyperparams.adam_eps,
         use_locking=True)
+  elif hyperparams.learning_method == 'lazyadam':
+    return tf.contrib.opt.LazyAdamOptimizer(
+        learning_rate_var,
+        beta1=hyperparams.adam_beta1,
+        beta2=hyperparams.adam_beta2,
+        epsilon=hyperparams.adam_eps,
+        use_locking=True)
   elif hyperparams.learning_method == 'momentum':
     return tf.train.MomentumOptimizer(
         learning_rate_var, hyperparams.momentum, use_locking=True)
@@ -138,6 +160,10 @@ class MasterBuilder(object):
                         if hyperparam_config is None else hyperparam_config)
     self.pool_scope = pool_scope
 
+    # Set the graph-level random seed before creating the Components so the ops
+    # they create will use this seed.
+    tf.set_random_seed(hyperparam_config.seed)
+
     # Construct all utility class and variables for each Component.
     self.components = []
     self.lookup_component = {}
@@ -318,15 +344,18 @@ class MasterBuilder(object):
                                            dragnn_ops.batch_size(
                                                handle, component=comp.name))
       with tf.control_dependencies([handle, cost]):
-        component_cost = tf.constant(0.)
-        component_correct = tf.constant(0)
-        component_total = tf.constant(0)
+        args = (master_state, network_states)
         if unroll_using_oracle[component_index]:
-          handle, component_cost, component_correct, component_total = (
-              comp.build_greedy_training(master_state, network_states))
+
+          handle, component_cost, component_correct, component_total = (tf.cond(
+              comp.training_beam_size > 1,
+              lambda: comp.build_structured_training(*args),
+              lambda: comp.build_greedy_training(*args)))
+
         else:
-          handle = comp.build_greedy_inference(
-              master_state, network_states, during_training=True)
+          handle = comp.build_greedy_inference(*args, during_training=True)
+          component_cost = tf.constant(0.)
+          component_correct, component_total = tf.constant(0), tf.constant(0)
 
         weighted_component_cost = tf.multiply(
             component_cost,
@@ -497,30 +526,23 @@ class MasterBuilder(object):
     with tf.name_scope(scope_id):
       # Construct training targets. Disable tracing during training.
       handle, input_batch = self._get_session_with_reader(trace_only)
+
+      # If `trace_only` is True, the training graph shouldn't have any
+      # side effects. Otherwise, the standard training scenario should
+      # generate gradients and update counters.
+      handle, outputs = self.build_training(
+          handle,
+          compute_gradients=not trace_only,
+          advance_counters=not trace_only,
+          component_weights=target_config.component_weights,
+          unroll_using_oracle=target_config.unroll_using_oracle,
+          max_index=target_config.max_index,
+          **kwargs)
       if trace_only:
-        # Build a training graph that doesn't have any side effects.
-        handle, outputs = self.build_training(
-            handle,
-            compute_gradients=False,
-            advance_counters=False,
-            component_weights=target_config.component_weights,
-            unroll_using_oracle=target_config.unroll_using_oracle,
-            max_index=target_config.max_index,
-            **kwargs)
         outputs['traces'] = dragnn_ops.get_component_trace(
             handle, component=self.spec.component[-1].name)
       else:
-        # The standard training scenario has gradients and updates counters.
-        handle, outputs = self.build_training(
-            handle,
-            compute_gradients=True,
-            advance_counters=True,
-            component_weights=target_config.component_weights,
-            unroll_using_oracle=target_config.unroll_using_oracle,
-            max_index=target_config.max_index,
-            **kwargs)
-
-        # In addition, it keeps track of the number of training steps.
+        # Standard training keeps track of the number of training steps.
         outputs['target_step'] = tf.get_variable(
             scope_id + '/TargetStep', [],
             initializer=tf.zeros_initializer(),

+ 32 - 0
syntaxnet/dragnn/python/graph_builder_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for graph_builder."""
 
 
@@ -517,6 +532,23 @@ class GraphBuilderTest(test_util.TensorFlowTestCase):
           expected_num_actions=9,
           expected=_TAGGER_PARSER_EXPECTED_SENTENCES)
 
+  def testStructuredTrainingNotImplementedDeath(self):
+    spec = self.LoadSpec('simple_parser_master_spec.textproto')
+
+    # Make the 'parser' component have a beam at training time.
+    self.assertEqual('parser', spec.component[0].name)
+    spec.component[0].training_beam_size = 8
+
+    # The training run should fail at runtime rather than build time.
+    with self.assertRaisesRegexp(tf.errors.InvalidArgumentError,
+                                 r'\[Not implemented.\]'):
+      self.RunFullTrainingAndInference(
+          'simple-parser',
+          master_spec=spec,
+          expected_num_actions=8,
+          component_weights=[1],
+          expected=_LABELED_PARSER_EXPECTED_SENTENCES)
+
   def testSimpleParser(self):
     self.RunFullTrainingAndInference(
         'simple-parser',

+ 41 - 40
syntaxnet/dragnn/python/network_units.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Basic network units used in assembling DRAGNN graphs."""
 
 from abc import ABCMeta
@@ -88,7 +103,7 @@ class NamedTensor(object):
     self.dim = dim
 
 
-def add_embeddings(channel_id, feature_spec, seed):
+def add_embeddings(channel_id, feature_spec, seed=None):
   """Adds a variable for the embedding of a given fixed feature.
 
   Supports pre-trained or randomly initialized embeddings In both cases, extra
@@ -119,11 +134,14 @@ def add_embeddings(channel_id, feature_spec, seed):
     if len(feature_spec.vocab.part) > 1:
       raise RuntimeError('vocab resource contains more than one part:\n%s',
                          str(feature_spec.vocab))
+    seed1, seed2 = tf.get_seed(seed)
     embeddings = dragnn_ops.dragnn_embedding_initializer(
         embedding_input=feature_spec.pretrained_embedding_matrix.part[0]
         .file_pattern,
         vocab=feature_spec.vocab.part[0].file_pattern,
-        scaling_coefficient=1.0)
+        scaling_coefficient=1.0,
+        seed=seed1,
+        seed2=seed2)
     return tf.get_variable(name, initializer=tf.reshape(embeddings, shape))
   else:
     return tf.get_variable(
@@ -622,7 +640,6 @@ class NetworkUnitInterface(object):
       init_layers: optional initial layers.
       init_context_layers: optional initial context layers.
     """
-    self._seed = component.master.hyperparams.seed
     self._component = component
     self._params = []
     self._layers = init_layers if init_layers else []
@@ -640,7 +657,7 @@ class NetworkUnitInterface(object):
       check.Gt(spec.size, 0, 'Invalid fixed feature size')
       if spec.embedding_dim > 0:
         fixed_dim = spec.embedding_dim
-        self._params.append(add_embeddings(channel_id, spec, self._seed))
+        self._params.append(add_embeddings(channel_id, spec))
       else:
         fixed_dim = 1  # assume feature ID extraction; only one ID per step
       self._fixed_feature_dims[spec.name] = spec.size * fixed_dim
@@ -663,7 +680,7 @@ class NetworkUnitInterface(object):
                 linked_embeddings_name(channel_id),
                 [source_array_dim + 1, spec.embedding_dim],
                 initializer=tf.random_normal_initializer(
-                    stddev=1 / spec.embedding_dim**.5, seed=self._seed)))
+                    stddev=1 / spec.embedding_dim**.5)))
 
         self._linked_feature_dims[spec.name] = spec.size * spec.embedding_dim
       else:
@@ -698,14 +715,12 @@ class NetworkUnitInterface(object):
           tf.get_variable(
               'attention_weights_pm_0',
               [attention_hidden_layer_size, hidden_layer_size],
-              initializer=tf.random_normal_initializer(
-                  stddev=1e-4, seed=self._seed)))
+              initializer=tf.random_normal_initializer(stddev=1e-4)))
 
       self._params.append(
           tf.get_variable(
               'attention_weights_hm_0', [hidden_layer_size, hidden_layer_size],
-              initializer=tf.random_normal_initializer(
-                  stddev=1e-4, seed=self._seed)))
+              initializer=tf.random_normal_initializer(stddev=1e-4)))
 
       self._params.append(
           tf.get_variable(
@@ -721,8 +736,7 @@ class NetworkUnitInterface(object):
           tf.get_variable(
               'attention_weights_pu',
               [attention_hidden_layer_size, component.num_actions],
-              initializer=tf.random_normal_initializer(
-                  stddev=1e-4, seed=self._seed)))
+              initializer=tf.random_normal_initializer(stddev=1e-4)))
 
   @abstractmethod
   def create(self,
@@ -961,8 +975,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
     for index, hidden_layer_size in enumerate(self._hidden_layer_sizes):
       weights = tf.get_variable(
           'weights_%d' % index, [last_layer_dim, hidden_layer_size],
-          initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                   seed=self._seed))
+          initializer=tf.random_normal_initializer(stddev=1e-4))
       self._params.append(weights)
       if index > 0 or self._layer_norm_hidden is None:
         self._params.append(
@@ -988,8 +1001,7 @@ class FeedForwardNetwork(NetworkUnitInterface):
       self._params.append(
           tf.get_variable(
               'weights_softmax', [last_layer_dim, component.num_actions],
-              initializer=tf.random_normal_initializer(
-                  stddev=1e-4, seed=self._seed)))
+              initializer=tf.random_normal_initializer(stddev=1e-4)))
       self._params.append(
           tf.get_variable(
               'bias_softmax', [component.num_actions],
@@ -1106,47 +1118,39 @@ class LSTMNetwork(NetworkUnitInterface):
     # e.g. truncated_normal_initializer?
     self._x2i = tf.get_variable(
         'x2i', [layer_input_dim, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._h2i = tf.get_variable(
         'h2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._c2i = tf.get_variable(
         'c2i', [self._hidden_layer_sizes, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._bi = tf.get_variable(
         'bi', [self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
 
     self._x2o = tf.get_variable(
         'x2o', [layer_input_dim, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._h2o = tf.get_variable(
         'h2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._c2o = tf.get_variable(
         'c2o', [self._hidden_layer_sizes, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._bo = tf.get_variable(
         'bo', [self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
 
     self._x2c = tf.get_variable(
         'x2c', [layer_input_dim, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._h2c = tf.get_variable(
         'h2c', [self._hidden_layer_sizes, self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
     self._bc = tf.get_variable(
         'bc', [self._hidden_layer_sizes],
-        initializer=tf.random_normal_initializer(stddev=1e-4, seed=self._seed))
+        initializer=tf.random_normal_initializer(stddev=1e-4))
 
     self._params.extend([
         self._x2i, self._h2i, self._c2i, self._bi, self._x2o, self._h2o,
@@ -1166,8 +1170,7 @@ class LSTMNetwork(NetworkUnitInterface):
 
     self.params.append(tf.get_variable(
         'weights_softmax', [self._hidden_layer_sizes, component.num_actions],
-        initializer=tf.random_normal_initializer(stddev=1e-4,
-                                                 seed=self._seed)))
+        initializer=tf.random_normal_initializer(stddev=1e-4)))
     self.params.append(
         tf.get_variable(
             'bias_softmax', [component.num_actions],
@@ -1324,8 +1327,7 @@ class ConvNetwork(NetworkUnitInterface):
             tf.get_variable(
                 'weights',
                 self.kernel_shapes[i],
-                initializer=tf.random_normal_initializer(
-                    stddev=1e-4, seed=self._seed),
+                initializer=tf.random_normal_initializer(stddev=1e-4),
                 dtype=tf.float32))
         bias_init = 0.0 if (i == len(self._widths) - 1) else 0.2
         self._biases.append(
@@ -1473,8 +1475,7 @@ class PairwiseConvNetwork(NetworkUnitInterface):
             tf.get_variable(
                 'weights',
                 kernel_shape,
-                initializer=tf.random_normal_initializer(
-                    stddev=1e-4, seed=self._seed),
+                initializer=tf.random_normal_initializer(stddev=1e-4),
                 dtype=tf.float32))
         bias_init = 0.0 if i in self._relu_layers else 0.2
         self._biases.append(

+ 15 - 0
syntaxnet/dragnn/python/network_units_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for network_units."""
 
 

+ 15 - 0
syntaxnet/dragnn/python/render_parse_tree_graphviz.py

@@ -1,4 +1,19 @@
 # -*- coding: utf-8 -*-
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Renders parse trees with Graphviz."""
 from __future__ import absolute_import
 from __future__ import division

+ 15 - 0
syntaxnet/dragnn/python/render_parse_tree_graphviz_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for ....dragnn.python.render_parse_tree_graphviz."""
 
 from __future__ import absolute_import

+ 15 - 0
syntaxnet/dragnn/python/render_spec_with_graphviz.py

@@ -1,4 +1,19 @@
 # -*- coding: utf-8 -*-
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Renders DRAGNN specs with Graphviz."""
 from __future__ import absolute_import
 from __future__ import division

+ 15 - 0
syntaxnet/dragnn/python/render_spec_with_graphviz_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for render_spec_with_graphviz."""
 
 from __future__ import absolute_import

+ 15 - 0
syntaxnet/dragnn/python/sentence_io.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Utilities for reading and writing sentences in dragnn."""
 import tensorflow as tf
 from syntaxnet.ops import gen_parser_ops

+ 15 - 0
syntaxnet/dragnn/python/sentence_io_test.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 import os
 import tensorflow as tf
 

+ 6 - 2
syntaxnet/dragnn/python/spec_builder.py

@@ -36,16 +36,20 @@ class ComponentSpecBuilder(object):
     spec: The dragnn.ComponentSpec proto.
   """
 
-  def __init__(self, name, builder='DynamicComponentBuilder'):
+  def __init__(self,
+               name,
+               builder='DynamicComponentBuilder',
+               backend='SyntaxNetComponent'):
     """Initializes the ComponentSpec with some defaults for SyntaxNet.
 
     Args:
       name: The name of this Component in the pipeline.
       builder: The component builder type.
+      backend: The component backend type.
     """
     self.spec = spec_pb2.ComponentSpec(
         name=name,
-        backend=self.make_module('SyntaxNetComponent'),
+        backend=self.make_module(backend),
         component_builder=self.make_module(builder))
 
   def make_module(self, name, **kwargs):

+ 17 - 2
syntaxnet/dragnn/python/trainer_lib.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Utility functions to build DRAGNN MasterSpecs and schedule model training.
 
 Provides functions to finish a MasterSpec, building required lexicons for it and
@@ -27,7 +42,7 @@ def calculate_component_accuracies(eval_res_values):
   ]
 
 
-def _write_summary(summary_writer, label, value, step):
+def write_summary(summary_writer, label, value, step):
   """Write a summary for a certain evaluation."""
   summary = Summary(value=[Summary.Value(tag=label, simple_value=float(value))])
   summary_writer.add_summary(summary, step)
@@ -135,7 +150,7 @@ def run_training(sess, trainers, annotator, evaluator, pretrain_steps,
       annotated = annotate_dataset(sess, annotator, eval_corpus)
       summaries = evaluator(eval_gold, annotated)
       for label, metric in summaries.iteritems():
-        _write_summary(summary_writer, label, metric, actual_step + step)
+        write_summary(summary_writer, label, metric, actual_step + step)
       eval_metric = summaries['eval_metric']
       if best_eval_metric < eval_metric:
         tf.logging.info('Updating best eval to %.2f%%, saving checkpoint.',

+ 15 - 0
syntaxnet/dragnn/python/visualization.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Helper library for visualizations.
 
 TODO(googleuser): Find a more reliable way to serve stuff from IPython

+ 15 - 0
syntaxnet/dragnn/python/visualization_test.py

@@ -1,4 +1,19 @@
 # -*- coding: utf-8 -*-
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Tests for dragnn.python.visualization."""
 
 from __future__ import absolute_import

+ 41 - 17
syntaxnet/dragnn/python/wrapped_units.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Network units wrapping TensorFlows' tf.contrib.rnn cells.
 
 Please put all wrapping logic for tf.contrib.rnn in this module; this will help
@@ -25,7 +40,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
     logits: Logits associated with component actions.
   """
 
-  def __init__(self, component):
+  def __init__(self, component, additional_attr_defaults=None):
     """Initializes the LSTM base class.
 
     Parameters used:
@@ -42,15 +57,18 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
 
     Args:
       component: parent ComponentBuilderBase object.
+      additional_attr_defaults: Additional attributes for use by derived class.
     """
+    attr_defaults = additional_attr_defaults or {}
+    attr_defaults.update({
+        'layer_norm': True,
+        'input_dropout_rate': -1.0,
+        'recurrent_dropout_rate': 0.8,
+        'hidden_layer_sizes': '256',
+    })
     self._attrs = dragnn.get_attrs_with_defaults(
         component.spec.network_unit.parameters,
-        defaults={
-            'layer_norm': True,
-            'input_dropout_rate': -1.0,
-            'recurrent_dropout_rate': 0.8,
-            'hidden_layer_sizes': '256',
-        })
+        defaults=attr_defaults)
 
     self._hidden_layer_sizes = map(int,
                                    self._attrs['hidden_layer_sizes'].split(','))
@@ -87,8 +105,7 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
     self._params.append(
         tf.get_variable(
             'weights_softmax', [last_layer_dim, component.num_actions],
-            initializer=tf.random_normal_initializer(
-                stddev=1e-4, seed=self._seed)))
+            initializer=tf.random_normal_initializer(stddev=1e-4)))
     self._params.append(
         tf.get_variable(
             'bias_softmax', [component.num_actions],
@@ -116,14 +133,9 @@ class BaseLSTMNetwork(dragnn.NetworkUnitInterface):
     """Appends layers defined by the base class to the |hidden_layers|."""
     last_layer = hidden_layers[-1]
 
-    # TODO(googleuser): Uncomment the version that uses component.get_variable()
-    # and delete the uses of tf.get_variable().
-    # logits = tf.nn.xw_plus_b(last_layer,
-    #                          self._component.get_variable('weights_softmax'),
-    #                          self._component.get_variable('bias_softmax'))
     logits = tf.nn.xw_plus_b(last_layer,
-                             tf.get_variable('weights_softmax'),
-                             tf.get_variable('bias_softmax'))
+                             self._component.get_variable('weights_softmax'),
+                             self._component.get_variable('bias_softmax'))
     return hidden_layers + [last_layer, logits]
 
   def _create_cell(self, num_units, during_training):
@@ -321,7 +333,18 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
   """
 
   def __init__(self, component):
-    super(BulkBiLSTMNetwork, self).__init__(component)
+    """Initializes the bulk bi-LSTM.
+
+    Parameters used:
+      parallel_iterations (1): Parallelism of the underlying tf.while_loop().
+        Defaults to 1 thread to encourage deterministic behavior, but can be
+        increased to trade memory for speed.
+
+    Args:
+      component: parent ComponentBuilderBase object.
+    """
+    super(BulkBiLSTMNetwork, self).__init__(
+        component, additional_attr_defaults={'parallel_iterations': 1})
 
     check.In('lengths', self._linked_feature_dims,
              'Missing required linked feature')
@@ -426,6 +449,7 @@ class BulkBiLSTMNetwork(BaseLSTMNetwork):
           initial_states_fw=initial_states_forward,
           initial_states_bw=initial_states_backward,
           sequence_length=lengths_s,
+          parallel_iterations=self._attrs['parallel_iterations'],
           scope=scope)
       return outputs_sxnxd
 

+ 38 - 1
syntaxnet/dragnn/tools/BUILD

@@ -1,5 +1,10 @@
 package(default_visibility = ["//visibility:public"])
 
+filegroup(
+    name = "testdata",
+    srcs = glob(["testdata/**"]),
+)
+
 py_binary(
     name = "evaluator",
     srcs = ["evaluator.py"],
@@ -78,10 +83,29 @@ py_binary(
         "//dragnn/core:dragnn_bulk_ops",
         "//dragnn/core:dragnn_ops",
         "//dragnn/protos:spec_py_pb2",
+        "//dragnn/python:evaluation",
+        "//dragnn/python:graph_builder",
+        "//dragnn/python:load_dragnn_cc_impl_py",
+        "//dragnn/python:sentence_io",
+        "//dragnn/python:spec_builder",
+        "//dragnn/python:trainer_lib",
+        "//syntaxnet:load_parser_ops_py",
+        "//syntaxnet:parser_ops",
+        "@org_tensorflow//tensorflow:tensorflow_py",
+        "@org_tensorflow//tensorflow/core:protos_all_py",
+    ],
+)
+
+py_binary(
+    name = "model_trainer",
+    srcs = ["model_trainer.py"],
+    deps = [
+        "//dragnn/core:dragnn_bulk_ops",
+        "//dragnn/core:dragnn_ops",
+        "//dragnn/protos:spec_py_pb2",
         "//dragnn/python:dragnn_ops",
         "//dragnn/python:evaluation",
         "//dragnn/python:graph_builder",
-        "//dragnn/python:lexicon",
         "//dragnn/python:load_dragnn_cc_impl_py",
         "//dragnn/python:sentence_io",
         "//dragnn/python:spec_builder",
@@ -90,11 +114,24 @@ py_binary(
         "//syntaxnet:parser_ops",
         "//syntaxnet:sentence_py_pb2",
         "//syntaxnet:task_spec_py_pb2",
+        "//syntaxnet/util:check",
         "@org_tensorflow//tensorflow:tensorflow_py",
         "@org_tensorflow//tensorflow/core:protos_all_py",
     ],
 )
 
+sh_test(
+    name = "model_trainer_test",
+    size = "medium",
+    srcs = ["model_trainer_test.sh"],
+    data = [
+        ":model_trainer",
+        ":testdata",
+    ],
+    deps = [
+    ],
+)
+
 # This is meant to be run inside the Docker image. In the OSS directory, run,
 #
 #     ./build_devel.sh bazel run //dragnn/python:oss_notebook_launcher

+ 15 - 0
syntaxnet/dragnn/tools/build_pip_package.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Builds a pip package suitable for redistribution.
 
 Adapted from tensorflow/tools/pip_package/build_pip_package.sh. This might have

+ 21 - 0
syntaxnet/dragnn/tools/evaluator.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 r"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
 
 Sample invocation:
@@ -51,6 +66,9 @@ flags.DEFINE_integer('threads', 10, 'Number of threads used for intra- and '
 flags.DEFINE_string('timeline_output_file', '', 'Path to save timeline to. '
                     'If specified, the final iteration of the evaluation loop '
                     'will capture and save a TensorFlow timeline.')
+flags.DEFINE_string('log_file', '', 'File path to write parser eval results.')
+flags.DEFINE_string('language_name', '_', 'Name of language being parsed, '
+                    'for logging.')
 
 
 def main(unused_argv):
@@ -134,6 +152,9 @@ def main(unused_argv):
     tf.logging.info('Processed %d documents in %.2f seconds.',
                     len(input_corpus), time.time() - start_time)
     pos, uas, las = evaluation.calculate_parse_metrics(input_corpus, processed)
+    if FLAGS.log_file:
+      with gfile.GFile(FLAGS.log_file, 'w') as f:
+        f.write('%s\t%f\t%f\t%f\n' % (FLAGS.language_name, pos, uas, las))
 
     if FLAGS.output_file:
       with gfile.GFile(FLAGS.output_file, 'w') as f:

+ 197 - 0
syntaxnet/dragnn/tools/model_trainer.py

@@ -0,0 +1,197 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Trainer for generic DRAGNN models.
+
+This trainer uses a "model directory" for both input and output.  When invoked,
+the model directory should contain the following inputs:
+
+  <model_dir>/config.txt: A stringified dict that defines high-level
+    configuration parameters.  Unset parameters default to False.
+  <model_dir>/master.pbtxt: A text-format MasterSpec proto that defines
+    the DRAGNN network to train.
+  <model_dir>/hyperparameters.pbtxt: A text-format GridPoint proto that
+    defines training hyper-parameters.
+  <model_dir>/targets.pbtxt: (Optional) A text-format TrainingGridSpec whose
+    "target" field defines the training targets.  If missing, then default
+    training targets are used instead.
+
+On success, the model directory will contain the following outputs:
+
+  <model_dir>/checkpoints/best: The best checkpoint seen during training, as
+    measured by accuracy on the eval corpus.
+  <model_dir>/tensorboard: TensorBoard log directory.
+
+Outside of the files and subdirectories named above, the model directory should
+contain any other necessary files (e.g., pretrained embeddings).  See the model
+builders in dragnn/examples.
+"""
+
+import ast
+import collections
+import os
+import os.path
+import tensorflow as tf
+
+from google.protobuf import text_format
+
+from dragnn.protos import spec_pb2
+from dragnn.python import evaluation
+from dragnn.python import graph_builder
+from dragnn.python import sentence_io
+from dragnn.python import spec_builder
+from dragnn.python import trainer_lib
+from syntaxnet.ops import gen_parser_ops
+from syntaxnet.util import check
+
+import dragnn.python.load_dragnn_cc_impl
+import syntaxnet.load_parser_ops
+
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('tf_master', '',
+                    'TensorFlow execution engine to connect to.')
+flags.DEFINE_string('model_dir', None, 'Path to a prepared model directory.')
+
+flags.DEFINE_string(
+    'pretrain_steps', None,
+    'Comma-delimited list of pre-training steps per training target.')
+flags.DEFINE_string(
+    'pretrain_epochs', None,
+    'Comma-delimited list of pre-training epochs per training target.')
+flags.DEFINE_string(
+    'train_steps', None,
+    'Comma-delimited list of training steps per training target.')
+flags.DEFINE_string(
+    'train_epochs', None,
+    'Comma-delimited list of training epochs per training target.')
+
+flags.DEFINE_integer('batch_size', 4, 'Batch size.')
+flags.DEFINE_integer('report_every', 200,
+                     'Report cost and training accuracy every this many steps.')
+
+
+def _read_text_proto(path, proto_type):
+  """Reads a text-format instance of |proto_type| from the |path|."""
+  proto = proto_type()
+  with tf.gfile.FastGFile(path) as proto_file:
+    text_format.Parse(proto_file.read(), proto)
+  return proto
+
+
+def _convert_to_char_corpus(corpus):
+  """Converts the word-based |corpus| into a char-based corpus."""
+  with tf.Session(graph=tf.Graph()) as tmp_session:
+    conversion_op = gen_parser_ops.segmenter_training_data_constructor(corpus)
+    return tmp_session.run(conversion_op)
+
+
+def _get_steps(steps_flag, epochs_flag, corpus_length):
+  """Converts the |steps_flag| or |epochs_flag| into a list of step counts."""
+  if steps_flag:
+    return map(int, steps_flag.split(','))
+  return [corpus_length * int(epochs) for epochs in epochs_flag.split(',')]
+
+
+def main(unused_argv):
+  tf.logging.set_verbosity(tf.logging.INFO)
+
+  check.NotNone(FLAGS.model_dir, '--model_dir is required')
+  check.Ne(FLAGS.pretrain_steps is None, FLAGS.pretrain_epochs is None,
+           'Exactly one of --pretrain_steps or --pretrain_epochs is required')
+  check.Ne(FLAGS.train_steps is None, FLAGS.train_epochs is None,
+           'Exactly one of --train_steps or --train_epochs is required')
+
+  config_path = os.path.join(FLAGS.model_dir, 'config.txt')
+  master_path = os.path.join(FLAGS.model_dir, 'master.pbtxt')
+  hyperparameters_path = os.path.join(FLAGS.model_dir, 'hyperparameters.pbtxt')
+  targets_path = os.path.join(FLAGS.model_dir, 'targets.pbtxt')
+  checkpoint_path = os.path.join(FLAGS.model_dir, 'checkpoints/best')
+  tensorboard_dir = os.path.join(FLAGS.model_dir, 'tensorboard')
+
+  with tf.gfile.FastGFile(config_path) as config_file:
+    config = collections.defaultdict(bool, ast.literal_eval(config_file.read()))
+  train_corpus_path = config['train_corpus_path']
+  tune_corpus_path = config['tune_corpus_path']
+  projectivize_train_corpus = config['projectivize_train_corpus']
+
+  master = _read_text_proto(master_path, spec_pb2.MasterSpec)
+  hyperparameters = _read_text_proto(hyperparameters_path, spec_pb2.GridPoint)
+  targets = spec_builder.default_targets_from_spec(master)
+  if tf.gfile.Exists(targets_path):
+    targets = _read_text_proto(targets_path, spec_pb2.TrainingGridSpec).target
+
+  # Build the TensorFlow graph.
+  graph = tf.Graph()
+  with graph.as_default():
+    tf.set_random_seed(hyperparameters.seed)
+    builder = graph_builder.MasterBuilder(master, hyperparameters)
+    trainers = [
+        builder.add_training_from_config(target) for target in targets
+    ]
+    annotator = builder.add_annotation()
+    builder.add_saver()
+
+  # Read in serialized protos from training data.
+  train_corpus = sentence_io.ConllSentenceReader(
+      train_corpus_path, projectivize=projectivize_train_corpus).corpus()
+  tune_corpus = sentence_io.ConllSentenceReader(
+      tune_corpus_path, projectivize=False).corpus()
+  gold_tune_corpus = tune_corpus
+
+  # Convert to char-based corpora, if requested.
+  if config['convert_to_char_corpora']:
+    # NB: Do not convert the |gold_tune_corpus|, which should remain word-based
+    # for segmentation evaluation purposes.
+    train_corpus = _convert_to_char_corpus(train_corpus)
+    tune_corpus = _convert_to_char_corpus(tune_corpus)
+
+  pretrain_steps = _get_steps(FLAGS.pretrain_steps, FLAGS.pretrain_epochs,
+                              len(train_corpus))
+  train_steps = _get_steps(FLAGS.train_steps, FLAGS.train_epochs,
+                           len(train_corpus))
+  check.Eq(len(targets), len(pretrain_steps),
+           'Length mismatch between training targets and --pretrain_steps')
+  check.Eq(len(targets), len(train_steps),
+           'Length mismatch between training targets and --train_steps')
+
+  # Ready to train!
+  tf.logging.info('Training on %d sentences.', len(train_corpus))
+  tf.logging.info('Tuning on %d sentences.', len(tune_corpus))
+
+  tf.logging.info('Creating TensorFlow checkpoint dir...')
+  summary_writer = trainer_lib.get_summary_writer(tensorboard_dir)
+
+  checkpoint_dir = os.path.dirname(checkpoint_path)
+  if tf.gfile.IsDirectory(checkpoint_dir):
+    tf.gfile.DeleteRecursively(checkpoint_dir)
+  elif tf.gfile.Exists(checkpoint_dir):
+    tf.gfile.Remove(checkpoint_dir)
+  tf.gfile.MakeDirs(checkpoint_dir)
+
+  with tf.Session(FLAGS.tf_master, graph=graph) as sess:
+    # Make sure to re-initialize all underlying state.
+    sess.run(tf.global_variables_initializer())
+    trainer_lib.run_training(sess, trainers, annotator,
+                             evaluation.parser_summaries, pretrain_steps,
+                             train_steps, train_corpus, tune_corpus,
+                             gold_tune_corpus, FLAGS.batch_size, summary_writer,
+                             FLAGS.report_every, builder.saver, checkpoint_path)
+
+  tf.logging.info('Best checkpoint written to:\n%s', checkpoint_path)
+
+
+if __name__ == '__main__':
+  tf.app.run()

+ 54 - 0
syntaxnet/dragnn/tools/model_trainer_test.sh

@@ -0,0 +1,54 @@
+#!/bin/bash
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# This test runs the model trainer on a snapshotted model directory.  This is a
+# "don't crash" test, so it does not evaluate the trained model.
+
+
+
+
+set -eu
+
+readonly DRAGNN_DIR="${TEST_SRCDIR}/${TEST_WORKSPACE}/dragnn"
+readonly MODEL_TRAINER="${DRAGNN_DIR}/tools/model_trainer"
+readonly MODEL_DIR="${DRAGNN_DIR}/tools/testdata/biaffine.model"
+readonly CORPUS="${DRAGNN_DIR}/tools/testdata/small.conll"
+readonly TMP_DIR="/tmp/model_trainer_test.$$"
+readonly TMP_MODEL_DIR="${TMP_DIR}/biaffine.model"
+
+rm -rf "${TMP_DIR}"
+mkdir -p "${TMP_DIR}"
+
+# Copy all testdata files to a temp dir, so they can be modified (see below).
+cp "${CORPUS}" "${TMP_DIR}"
+mkdir -p "${TMP_MODEL_DIR}"
+for name in hyperparameters.pbtxt targets.pbtxt resources; do
+  cp -r "${MODEL_DIR}/${name}" "${TMP_MODEL_DIR}/${name}"
+done
+
+# Replace "TESTDATA" with the temp dir path in config files that contain paths.
+for name in config.txt master.pbtxt; do
+  sed "s=TESTDATA=${TMP_DIR}=" "${MODEL_DIR}/${name}" \
+    > "${TMP_MODEL_DIR}/${name}"
+done
+
+"${MODEL_TRAINER}" \
+  --model_dir="${TMP_MODEL_DIR}" \
+  --pretrain_steps='1' \
+  --train_epochs='10' \
+  --alsologtostderr
+
+echo "PASS"

+ 15 - 0
syntaxnet/dragnn/tools/oss_notebook_launcher.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 """Mini OSS launcher so we can build a py_binary for OSS."""
 
 from __future__ import absolute_import

+ 15 - 0
syntaxnet/dragnn/tools/parse-to-conll.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 r"""Runs a both a segmentation and parsing model on a CoNLL dataset.
 """
 

+ 0 - 1
syntaxnet/dragnn/tools/parser_trainer.py

@@ -60,7 +60,6 @@ flags.DEFINE_string('dev_corpus_path', '', 'Path to development set data.')
 flags.DEFINE_bool('compute_lexicon', False, '')
 flags.DEFINE_bool('projectivize_training_set', True, '')
 
-flags.DEFINE_integer('num_epochs', 10, 'Number of epochs to train for.')
 flags.DEFINE_integer('batch_size', 4, 'Batch size.')
 flags.DEFINE_integer('report_every', 200,
                      'Report cost and training accuracy every this many steps.')

+ 15 - 0
syntaxnet/dragnn/tools/segmenter-evaluator.py

@@ -1,3 +1,18 @@
+# Copyright 2017 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
 r"""Runs a DRAGNN model on a given set of CoNLL-formatted sentences.
 
 Sample invocation:

+ 4 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/config.txt

@@ -0,0 +1,4 @@
+{
+  'train_corpus_path': 'TESTDATA/small.conll',
+  'tune_corpus_path': 'TESTDATA/small.conll',
+}

+ 18 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/hyperparameters.pbtxt

@@ -0,0 +1,18 @@
+learning_method: "adam"
+adam_beta1: 0.9
+adam_beta2: 0.9
+adam_eps: 1e-12
+
+learning_rate: 0.002
+decay_base: 0.75
+decay_staircase: false
+decay_steps: 2500
+
+dropout_rate: 0.67
+recurrent_dropout_rate: 0.75
+
+gradient_clip_norm: 15
+l2_regularization_coefficient: 0
+use_moving_average: false
+
+seed: 1

File diff suppressed because it is too large
+ 1135 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/master.pbtxt


+ 7 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/category-map

@@ -0,0 +1,7 @@
+6
+VERB 6
+NOUN 5
+PRON 5
+PUNCT 5
+DET 2
+CONJ 1

+ 18 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-map

@@ -0,0 +1,18 @@
+17
+o 10
+e 8
+b 6
+s 6
+. 5
+h 5
+l 5
+y 5
+k 4
+T 3
+a 3
+n 3
+u 3
+I 2
+v 2
+c 1
+d 1

+ 46 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/char-ngram-map

@@ -0,0 +1,46 @@
+45
+o 8
+^ b 6
+^ . $ 5
+e 5
+y $ 5
+^ bo 4
+k 4
+ks $ 4
+ok 4
+oo 4
+s $ 4
+^ T 3
+^ Th 3
+e $ 3
+ey $ 3
+h 3
+he 3
+l 3
+u 3
+^ I $ 2
+^ bu 2
+^ h 2
+^ ha 2
+^ n 2
+^ no $ 2
+^ s 2
+^ se 2
+a 2
+av 2
+el 2
+l $ 2
+ll $ 2
+o $ 2
+uy $ 2
+v 2
+ve $ 2
+^ a 1
+^ an 1
+^ c 1
+^ cl 1
+d $ 1
+lu 1
+n 1
+nd $ 1
+ue $ 1

+ 8 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/label-map

@@ -0,0 +1,8 @@
+7
+ROOT 5
+nsubj 5
+obj 5
+punct 5
+det 2
+cc 1
+conj 1

+ 0 - 0
syntaxnet/dragnn/tools/testdata/biaffine.model/resources/lcword-map


Some files were not shown because too many files changed in this diff