spec.proto 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. // DRAGNN Configuration proto. See go/dragnn-design for more information.
  2. syntax = "proto2";
  3. package syntaxnet.dragnn;
  4. // Proto to specify a set of DRAGNN components (transition systems) that are
  5. // trained and evaluated jointly. Each component gets one ComponentSpec.
  6. //
  7. // The order of component is important: a component can only link to components
  8. // that come before (for now.)
  9. // NEXT ID: 6
  10. message MasterSpec {
  11. repeated ComponentSpec component = 1;
  12. // Whether to extract debug traces.
  13. optional bool debug_tracing = 4 [default = false];
  14. reserved 2, 3, 5;
  15. }
  16. // Complete specification for a single task.
  17. message ComponentSpec {
  18. // Name for this component: this is used in linked features via the
  19. // "source_component" field.
  20. optional string name = 1;
  21. // TransitionSystem to use.
  22. optional RegisteredModuleSpec transition_system = 2;
  23. // Resources that this component depends on. These are copied to TaskInputs
  24. // when calling SAFT code.
  25. repeated Resource resource = 3;
  26. // Feature space configurations.
  27. repeated FixedFeatureChannel fixed_feature = 4;
  28. repeated LinkedFeatureChannel linked_feature = 5;
  29. // Neural Network builder specification.
  30. optional RegisteredModuleSpec network_unit = 6;
  31. // The registered C++ implementation of the dragnn::Component class; e.g.
  32. // "SyntaxNetComponent".
  33. optional RegisteredModuleSpec backend = 7;
  34. // Number of possible actions from every state.
  35. optional int32 num_actions = 8;
  36. // Specify the name of the lower level component on which it has attention.
  37. optional string attention_component = 9 [default = ""];
  38. // Options for the ComponentBuilder. If this is empty, the regular
  39. // tf.while_loop based builder is assumed.
  40. optional RegisteredModuleSpec component_builder = 10;
  41. // Default max number of active states for beam training.
  42. optional int32 training_beam_size = 11 [default = 1];
  43. // Default max number of active states for beam inference.
  44. optional int32 inference_beam_size = 12 [default = 1];
  45. }
  46. // Super generic container for any registered sub-piece of DRAGNN.
  47. message RegisteredModuleSpec {
  48. // Name of the registered class.
  49. optional string registered_name = 1;
  50. // Parameters to set while initializing this system; these are copied to
  51. // Parameters in a TaskSpec when calling SAFT code, or via kwargs in TF Python
  52. // code.
  53. map<string, string> parameters = 2;
  54. }
  55. // Fixed resources that will be converted into TaskInput's when calling SAFT
  56. // code.
  57. message Resource {
  58. optional string name = 1;
  59. repeated Part part = 2;
  60. }
  61. // The Parts here should be more or less compatible with TaskInput.
  62. message Part {
  63. optional string file_pattern = 1;
  64. optional string file_format = 2;
  65. optional string record_format = 3;
  66. }
  67. // ------------------------------------------------------------------------
  68. // Feature specifications.
  69. //
  70. // A *feature channel* is a named collection of feature templates that share an
  71. // embedding matrix. Thus all features in the channel are assumed to use the
  72. // same vocabulary: e.g., words, POS tags, hidden layer activations, etc. These
  73. // are extracted, embedded, and then concatenated together as a group.
  74. // Specification for a feature channel that is a *fixed* function of the input.
  75. // NEXT_ID: 10
  76. message FixedFeatureChannel {
  77. // Interpretable name for this feature channel. NN builders might depend on
  78. // this to determine how to hook different channels up internally.
  79. optional string name = 1;
  80. // String describing the FML for this feature channel.
  81. optional string fml = 2;
  82. // Size of parameters for this space:
  83. // Dimensions of embedding space, or -1 if the feature should not be embedded.
  84. optional int32 embedding_dim = 3;
  85. // No. of possible values returned.
  86. optional int32 vocabulary_size = 4;
  87. // No. of different feature templates in the channel, i.e. the # of features
  88. // that will be concatenated but share the embedding for this channel.
  89. optional int32 size = 5;
  90. // Whether the embeddings for this channel should be held constant at their
  91. // pretrained values, instead of being trained. Pretrained embeddings are
  92. // required when true.
  93. optional bool is_constant = 9;
  94. // Resources for this space:
  95. // Predicate map for compacting feature values.
  96. optional string predicate_map = 6;
  97. // Pointer to a pretrained embedding matrix for this feature set.
  98. optional Resource pretrained_embedding_matrix = 7;
  99. // Vocab file, containing all vocabulary words one per line.
  100. optional Resource vocab = 8;
  101. }
  102. // Specification for a feature channel that *links* to component
  103. // activations. Note that the "vocabulary" of these features is the activations
  104. // that they are linked to, so it is determined by the other components in the
  105. // spec.
  106. message LinkedFeatureChannel {
  107. // Interpretable name for this feature channel. NN builders might depend on
  108. // this to determine how to hook different channels up internally.
  109. optional string name = 1;
  110. // Feature function specification. Note: these should all be of type
  111. // LinkedFeatureType.
  112. optional string fml = 2;
  113. // Embedding dimension, or -1 if the link should not be embedded.
  114. optional int32 embedding_dim = 3;
  115. // No. of different feature templates in the channel, i.e. the # of features
  116. // that will be concatenated but share the embedding for this channel.
  117. optional int32 size = 4;
  118. // Component to use for translation, e.g. "tagger"
  119. optional string source_component = 5;
  120. // Translator target, e.g. "token" or "last_action", to translate raw feature
  121. // values into indices. This must be interpretable by the Component referenced
  122. // by source_component.
  123. optional string source_translator = 6;
  124. // Layer that these features should connect to.
  125. optional string source_layer = 7;
  126. }
  127. // A vector of hyperparameter configurations to search over.
  128. message TrainingGridSpec {
  129. // Grid points to search over.
  130. repeated GridPoint grid_point = 1;
  131. // Training targets to create in the graph builder stage.
  132. repeated TrainTarget target = 2;
  133. }
  134. // A hyperparameter configuration for a training run.
  135. // NEXT ID: 22
  136. message GridPoint {
  137. // Global learning rate initialization point.
  138. optional double learning_rate = 1 [default = 0.1];
  139. // Momentum coefficient when using MomentumOptimizer.
  140. optional double momentum = 2 [default = 0.9];
  141. // Decay rate and base for global learning rate decay. The learning rate is
  142. // reduced by a factor of |decay_base| every |decay_steps|.
  143. optional double decay_base = 16 [default = 0.96];
  144. optional int32 decay_steps = 3 [default = 1000];
  145. // Whether to decay the learning rate in a "staircase" manner. If true, the
  146. // rate is adjusted exactly once every |decay_steps|. Otherwise, the rate is
  147. // adjusted in smaller increments on every step, such that the overall rate of
  148. // decay is still |decay_base| every |decay_steps|.
  149. optional bool decay_staircase = 17 [default = true];
  150. // Random seed to initialize parameters.
  151. optional int32 seed = 4 [default = 0];
  152. // Specify the optimizer used in training, the default is MomentumOptimizer.
  153. optional string learning_method = 7 [default = 'momentum'];
  154. // Whether or not to use a moving average of the weights in inference time.
  155. optional bool use_moving_average = 8 [default = false];
  156. // Rolling average update co-efficient.
  157. optional double average_weight = 9 [default = 0.9999];
  158. // The dropout *keep* probability rate used in the model. 1.0 = no dropout.
  159. optional double dropout_rate = 10 [default = 1.0];
  160. // The dropout *keep* probability rate for recurrent connections. If < 0.0,
  161. // recurrent connections should use |dropout_rate| instead. 1.0 = no dropout.
  162. optional double recurrent_dropout_rate = 20 [default = -1.0];
  163. // Gradient clipping threshold, applied if greater than zero. A value in the
  164. // range 1-20 seems to work well to prevent large learning rates from causing
  165. // problems for updates at the start of training.
  166. optional double gradient_clip_norm = 11 [default = 0.0];
  167. // A spec for using multiple optimization methods.
  168. message CompositeOptimizerSpec {
  169. // First optimizer.
  170. optional GridPoint method1 = 1;
  171. // Second optimizer.
  172. optional GridPoint method2 = 2;
  173. // After this number of steps, switch from first to second.
  174. optional int32 switch_after_steps = 3;
  175. }
  176. optional CompositeOptimizerSpec composite_optimizer_spec = 12;
  177. // Parameters for Adam training.
  178. optional double adam_beta1 = 13 [default = 0.01];
  179. optional double adam_beta2 = 14 [default = 0.9999];
  180. optional double adam_eps = 15 [default = 1e-8];
  181. // Coefficient for global L2 regularization.
  182. optional double l2_regularization_coefficient = 18 [default = 1e-4];
  183. // Coefficient for global self normalization regularization.
  184. // A value of zero turns it off.
  185. optional double self_norm_alpha = 19 [default = 0.0];
  186. // Comma separated list of components to which self_norm_alpha
  187. // should be restricted. If left empty, no filtering will take
  188. // place. Typically a single component.
  189. optional string self_norm_components_filter = 21;
  190. reserved 5, 6;
  191. }
  192. // Training target to be built into the graph.
  193. message TrainTarget {
  194. // Name for this target. This should be unique across all targets.
  195. optional string name = 1;
  196. // Specify the weights for different components. This should be the same size
  197. // as the number of components in the spec, or empty (defaults to equal
  198. // weights). Weights are normalized across the components being trained to sum
  199. // to one.
  200. repeated double component_weights = 2;
  201. // Specify whether to train a component using supervised signal or not. This
  202. // should be the same size as the number of components in the spec, or empty
  203. // (defaults to all true).
  204. repeated bool unroll_using_oracle = 3;
  205. // Maximum length of the pipeline to train. E.g. if max_index is 1, then only
  206. // the first component will be trained via this target.
  207. optional int32 max_index = 4 [default = -1];
  208. }