spec.proto 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. // DRAGNN Configuration proto. See go/dragnn-design for more information.
  2. syntax = "proto2";
  3. package syntaxnet.dragnn;
  4. // Proto to specify a set of DRAGNN components (transition systems) that are
  5. // trained and evaluated jointly. Each component gets one ComponentSpec.
  6. //
  7. // The order of component is important: a component can only link to components
  8. // that come before (for now.)
  9. // NEXT ID: 6
  10. message MasterSpec {
  11. repeated ComponentSpec component = 1;
  12. // DEPRECATED: Use the "batch_size" param of DragnnTensorFlowTrainer instead.
  13. optional int32 deprecated_batch_size = 2 [default = 1, deprecated = true];
  14. // DEPRECATED: Use ComponentSpec.*_beam_size instead.
  15. optional int32 deprecated_beam_size = 3 [default = 1, deprecated = true];
  16. // Whether to extract debug traces.
  17. optional bool debug_tracing = 4 [default = false];
  18. }
  19. // Complete specification for a single task.
  20. message ComponentSpec {
  21. // Name for this component: this is used in linked features via the
  22. // "source_component" field.
  23. optional string name = 1;
  24. // TransitionSystem to use.
  25. optional RegisteredModuleSpec transition_system = 2;
  26. // Resources that this component depends on. These are copied to TaskInputs
  27. // when calling SAFT code.
  28. repeated Resource resource = 3;
  29. // Feature space configurations.
  30. repeated FixedFeatureChannel fixed_feature = 4;
  31. repeated LinkedFeatureChannel linked_feature = 5;
  32. // Neural Network builder specification.
  33. optional RegisteredModuleSpec network_unit = 6;
  34. // The registered C++ implementation of the dragnn::Component class; e.g.
  35. // "SyntaxNetComponent".
  36. optional RegisteredModuleSpec backend = 7;
  37. // Number of possible actions from every state.
  38. optional int32 num_actions = 8;
  39. // Specify the name of the lower level component on which it has attention.
  40. optional string attention_component = 9 [default = ""];
  41. // Options for the ComponentBuilder. If this is empty, the regular
  42. // tf.while_loop based builder is assumed.
  43. optional RegisteredModuleSpec component_builder = 10;
  44. // Default max number of active states for beam training.
  45. optional int32 training_beam_size = 11 [default = 1];
  46. // Default max number of active states for beam inference.
  47. optional int32 inference_beam_size = 12 [default = 1];
  48. }
  49. // Super generic container for any registered sub-piece of DRAGNN.
  50. message RegisteredModuleSpec {
  51. // Name of the registered class.
  52. optional string registered_name = 1;
  53. // Parameters to set while initializing this system; these are copied to
  54. // Parameters in a TaskSpec when calling SAFT code, or via kwargs in TF Python
  55. // code.
  56. map<string, string> parameters = 2;
  57. }
  58. // Fixed resources that will be converted into TaskInput's when calling SAFT
  59. // code.
  60. message Resource {
  61. optional string name = 1;
  62. repeated Part part = 2;
  63. }
  64. // The Parts here should be more or less compatible with TaskInput.
  65. message Part {
  66. optional string file_pattern = 1;
  67. optional string file_format = 2;
  68. optional string record_format = 3;
  69. }
  70. // ------------------------------------------------------------------------
  71. // Feature specifications.
  72. //
  73. // A *feature channel* is a named collection of feature templates that share an
  74. // embedding matrix. Thus all features in the channel are assumed to use the
  75. // same vocabulary: e.g., words, POS tags, hidden layer activations, etc. These
  76. // are extracted, embedded, and then concatenated together as a group.
  77. // Specification for a feature channel that is a *fixed* function of the input.
  78. // NEXT_ID: 10
  79. message FixedFeatureChannel {
  80. // Interpretable name for this feature channel. NN builders might depend on
  81. // this to determine how to hook different channels up internally.
  82. optional string name = 1;
  83. // String describing the FML for this feature channel.
  84. optional string fml = 2;
  85. // Size of parameters for this space:
  86. // Dimensions of embedding space, or -1 if the feature should not be embedded.
  87. optional int32 embedding_dim = 3;
  88. // No. of possible values returned.
  89. optional int32 vocabulary_size = 4;
  90. // No. of different feature templates in the channel, i.e. the # of features
  91. // that will be concatenated but share the embedding for this channel.
  92. optional int32 size = 5;
  93. // Whether the embeddings for this channel should be held constant at their
  94. // pretrained values, instead of being trained. Pretrained embeddings are
  95. // required when true.
  96. optional bool is_constant = 9;
  97. // Resources for this space:
  98. // Predicate map for compacting feature values.
  99. optional string predicate_map = 6;
  100. // Pointer to a pretrained embedding matrix for this feature set.
  101. optional Resource pretrained_embedding_matrix = 7;
  102. // Vocab file, containing all vocabulary words one per line.
  103. optional Resource vocab = 8;
  104. }
  105. // Specification for a feature channel that *links* to component
  106. // activations. Note that the "vocabulary" of these features is the activations
  107. // that they are linked to, so it is determined by the other components in the
  108. // spec.
  109. message LinkedFeatureChannel {
  110. // Interpretable name for this feature channel. NN builders might depend on
  111. // this to determine how to hook different channels up internally.
  112. optional string name = 1;
  113. // Feature function specification. Note: these should all be of type
  114. // LinkedFeatureType.
  115. optional string fml = 2;
  116. // Embedding dimension, or -1 if the link should not be embedded.
  117. optional int32 embedding_dim = 3;
  118. // No. of different feature templates in the channel, i.e. the # of features
  119. // that will be concatenated but share the embedding for this channel.
  120. optional int32 size = 4;
  121. // Component to use for translation, e.g. "tagger"
  122. optional string source_component = 5;
  123. // Translator target, e.g. "token" or "last_action", to translate raw feature
  124. // values into indices. This must be interpretable by the Component referenced
  125. // by source_component.
  126. optional string source_translator = 6;
  127. // Layer that these features should connect to.
  128. optional string source_layer = 7;
  129. }
  130. // A vector of hyperparameter configurations to search over.
  131. message TrainingGridSpec {
  132. // Grid points to search over.
  133. repeated GridPoint grid_point = 1;
  134. // Training targets to create in the graph builder stage.
  135. repeated TrainTarget target = 2;
  136. }
  137. // A hyperparameter configuration for a training run.
  138. // NEXT ID: 22
  139. message GridPoint {
  140. // Global learning rate initialization point.
  141. optional double learning_rate = 1 [default = 0.1];
  142. // Momentum coefficient when using MomentumOptimizer.
  143. optional double momentum = 2 [default = 0.9];
  144. // Decay rate and base for global learning rate decay. The learning rate is
  145. // reduced by a factor of |decay_base| every |decay_steps|.
  146. optional double decay_base = 16 [default = 0.96];
  147. optional int32 decay_steps = 3 [default = 1000];
  148. // Whether to decay the learning rate in a "staircase" manner. If true, the
  149. // rate is adjusted exactly once every |decay_steps|. Otherwise, the rate is
  150. // adjusted in smaller increments on every step, such that the overall rate of
  151. // decay is still |decay_base| every |decay_steps|.
  152. optional bool decay_staircase = 17 [default = true];
  153. // Random seed to initialize parameters.
  154. optional int32 seed = 4 [default = 0];
  155. // Specify the optimizer used in training, the default is MomentumOptimizer.
  156. optional string learning_method = 7 [default = 'momentum'];
  157. // Whether or not to use a moving average of the weights in inference time.
  158. optional bool use_moving_average = 8 [default = false];
  159. // Rolling average update co-efficient.
  160. optional double average_weight = 9 [default = 0.9999];
  161. // The dropout *keep* probability rate used in the model. 1.0 = no dropout.
  162. optional double dropout_rate = 10 [default = 1.0];
  163. // The dropout *keep* probability rate for recurrent connections. If < 0.0,
  164. // recurrent connections should use |dropout_rate| instead. 1.0 = no dropout.
  165. optional double recurrent_dropout_rate = 20 [default = -1.0];
  166. // Gradient clipping threshold, applied if greater than zero. A value in the
  167. // range 1-20 seems to work well to prevent large learning rates from causing
  168. // problems for updates at the start of training.
  169. optional double gradient_clip_norm = 11 [default = 0.0];
  170. // DEPRECATED: Use TrainTarget instead.
  171. repeated double component_weights = 5;
  172. repeated bool unroll_using_oracle = 6;
  173. // A spec for using multiple optimization methods.
  174. message CompositeOptimizerSpec {
  175. // First optimizer.
  176. optional GridPoint method1 = 1;
  177. // Second optimizer.
  178. optional GridPoint method2 = 2;
  179. // After this number of steps, switch from first to second.
  180. optional int32 switch_after_steps = 3;
  181. }
  182. optional CompositeOptimizerSpec composite_optimizer_spec = 12;
  183. // Parameters for Adam training.
  184. optional double adam_beta1 = 13 [default = 0.01];
  185. optional double adam_beta2 = 14 [default = 0.9999];
  186. optional double adam_eps = 15 [default = 1e-8];
  187. // Coefficient for global L2 regularization.
  188. optional double l2_regularization_coefficient = 18 [default = 1e-4];
  189. // Coefficient for global self normalization regularization.
  190. // A value of zero turns it off.
  191. optional double self_norm_alpha = 19 [default = 0.0];
  192. // Comma separated list of components to which self_norm_alpha
  193. // should be restricted. If left empty, no filtering will take
  194. // place. Typically a single component.
  195. optional string self_norm_components_filter = 21;
  196. }
  197. // Training target to be built into the graph.
  198. message TrainTarget {
  199. // Name for this target. This should be unique across all targets.
  200. optional string name = 1;
  201. // Specify the weights for different components. This should be the same size
  202. // as the number of components in the spec, or empty (defaults to equal
  203. // weights). Weights are normalized across the components being trained to sum
  204. // to one.
  205. repeated double component_weights = 2;
  206. // Specify whether to train a component using supervised signal or not. This
  207. // should be the same size as the number of components in the spec, or empty
  208. // (defaults to all true).
  209. repeated bool unroll_using_oracle = 3;
  210. // Maximum length of the pipeline to train. E.g. if max_index is 1, then only
  211. // the first component will be trained via this target.
  212. optional int32 max_index = 4 [default = -1];
  213. }