dragnn_ops.cc 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. #include "tensorflow/core/framework/op.h"
  2. namespace syntaxnet {
  3. namespace dragnn {
  4. REGISTER_OP("GetSession")
  5. .Input("container: string")
  6. .Attr("master_spec: string")
  7. .Attr("grid_point: string")
  8. .Output("handle: string")
  9. .SetIsStateful()
  10. .Doc(R"doc(
  11. Given MasterSpec and GridPoint protos, outputs a handle to a ComputeSession.
  12. container: A unique identifier for the ComputeSessionPool from which a
  13. ComputeSession will be allocated.
  14. master_spec: A serialized syntaxnet.dragnn.MasterSpec proto.
  15. grid_point: A serialized syntaxnet.dragnn.GridPoint proto.
  16. handle: A string handle to a ComputeSession.
  17. )doc");
  18. REGISTER_OP("ReleaseSession").Input("handle: string").SetIsStateful().Doc(R"doc(
  19. Given a ComputeSession, return it to the ComputeSession pool.
  20. This ComputeSession will no longer be available after this op returns.
  21. handle: A handle to a ComputeSession that will be returned to the backing pool.
  22. )doc");
  23. REGISTER_OP("InitComponentData")
  24. .Input("handle: string")
  25. .Input("beam_size: int32")
  26. .Attr("component: string")
  27. .Output("output_handle: string")
  28. .Doc(R"doc(
  29. Initialize a component with the given beam size for a given ComputeSession.
  30. handle: A handle to a ComputeSession.
  31. beam_size: The size of the beam to use on the component.
  32. component: The name of a Component instance, matching the ComponentSpec.name.
  33. output_handle: The handle to the same ComputeSession after initialization.
  34. )doc");
  35. REGISTER_OP("BatchSize")
  36. .Input("handle: string")
  37. .Attr("component: string")
  38. .Output("batch_size: int32")
  39. .Doc(R"doc(
  40. Given a ComputeSession and a component name,return the component batch size.
  41. handle: A handle to a ComputeSession.
  42. component: The name of a Component instance, matching the ComponentSpec.name.
  43. batch_size: The size of the given component's batch.
  44. )doc");
  45. REGISTER_OP("SetTracing")
  46. .Input("handle: string")
  47. .Input("tracing_on: bool")
  48. .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
  49. .Output("output_handle: string")
  50. .Doc(R"doc(
  51. Given a ComputeSession, turns on or off tracing for all components.
  52. handle: A handle to a ComputeSession.
  53. tracing_on: Whether or not to record traces.
  54. output_handle: The handle to the same ComputeSession, with the tracing status changed.
  55. )doc");
  56. REGISTER_OP("AttachDataReader")
  57. .Input("handle: string")
  58. .Input("input_spec: string")
  59. .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
  60. .Output("output_handle: string")
  61. .Doc(R"doc(
  62. Given a ComputeSession, attach a data source.
  63. This op is agnostic to the type of input data. The vector of input strings is
  64. interpreted by the backend.
  65. handle: A handle to a ComputeSession.
  66. input_spec: A vector of strings, where each string represents one batch item.
  67. output_handle: The handle to the same ComputeSession after attachment.
  68. )doc");
  69. REGISTER_OP("AdvanceFromOracle")
  70. .Input("handle: string")
  71. .Attr("component: string")
  72. .Output("output_handle: string")
  73. .Doc(R"doc(
  74. Given a ComputeSession and a Component name, advance the component via oracle.
  75. handle: A handle to a ComputeSession.
  76. component: The name of a Component instance, matching the ComponentSpec.name.
  77. output_handle: The handle to the same ComputeSession after advancement.
  78. )doc");
  79. REGISTER_OP("AdvanceFromPrediction")
  80. .Input("handle: string")
  81. .Input("scores: float")
  82. .Attr("component: string")
  83. .Output("output_handle: string")
  84. .Doc(R"doc(
  85. Given a ComputeSession, a Component name, and a score tensor, advance the state.
  86. handle: A handle to a ComputeSession.
  87. scores: A tensor of scores, ordered by {batch_size, beam_size, num_actions}.
  88. component: The name of a Component instance, matching the ComponentSpec.name.
  89. output_handle: A handle to the same ComputeSession after advancement.
  90. )doc");
  91. REGISTER_OP("DragnnEmbeddingInitializer")
  92. .Output("embeddings: float")
  93. .Attr("embedding_input: string")
  94. .Attr("vocab: string")
  95. .Attr("scaling_coefficient: float = 1.0")
  96. .Doc(R"doc(
  97. *** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
  98. Read embeddings from an an input for every key specified in a text vocab file.
  99. embeddings: A tensor containing embeddings from the specified sstable.
  100. embedding_input: Path to location with embedding vectors.
  101. vocab: Path to list of keys corresponding to the input.
  102. scaling_coefficient: A scaling coefficient for the embedding matrix.
  103. )doc");
  104. REGISTER_OP("ExtractFixedFeatures")
  105. .Input("handle: string")
  106. .Output("indices: int32")
  107. .Output("ids: int64")
  108. .Output("weights: float")
  109. .Attr("component: string")
  110. .Attr("channel_id: int")
  111. .Doc(R"doc(
  112. Given a ComputeSession, Component, and channel index, output fixed features.
  113. Fixed features returned as 3 vectors, 'indices', 'ids', and 'weights' of equal
  114. length. 'ids' specifies which rows should be looked up in the embedding
  115. matrix. 'weights' specifies a scale for each embedding vector. 'indices' is a
  116. sorted vector that assigns the same index to embedding vectors that should be
  117. summed together.
  118. handle: A handle to a ComputeSession.
  119. indices: The row to add the feature to.
  120. ids: The indices into embedding matrices for each feature.
  121. weights: The weight for each looked up feature.
  122. component: The name of a Component instance, matching the ComponentSpec.name.
  123. channel_id: The feature channel to extract features for.
  124. )doc");
  125. REGISTER_OP("ExtractLinkFeatures")
  126. .Input("handle: string")
  127. .Output("step_idx: int32")
  128. .Output("idx: int32")
  129. .Attr("component: string")
  130. .Attr("channel_id: int")
  131. .Doc(R"doc(
  132. Given a ComputeSession, Component, and a channel index, outputs link features.
  133. Output indices have shape {batch_size * beam_size * channel_size}.
  134. handle: A handle to a ComputeSession.
  135. step_idx: The step indices to read activations from.
  136. idx: indices The index within a step to read the activations from.
  137. component: The name of a Component instance, matching the ComponentSpec.name.
  138. channel_id: The feature channel to extract features for.
  139. )doc");
  140. REGISTER_OP("EmitOracleLabels")
  141. .Input("handle: string")
  142. .Output("gold_labels: int32")
  143. .Attr("component: string")
  144. .Doc(R"doc(
  145. Given a ComputeSession and Component, emit a vector of gold labels.
  146. handle: A handle to a ComputeSession.
  147. gold_labels: A [batch_size * beam_size] vector of gold labels for the current
  148. ComputeSession.
  149. component: The name of a Component instance, matching the ComponentSpec.name.
  150. )doc");
  151. REGISTER_OP("EmitAllFinal")
  152. .Input("handle: string")
  153. .Output("all_final: bool")
  154. .Attr("component: string")
  155. .Doc(R"doc(
  156. Given a ComputeSession and Component, returns whether the Component is final.
  157. A component is considered final when all elements in the batch have beams
  158. containing all final states.
  159. handle: A handle to a ComputeSession.
  160. all_final: Whether every element in the specified component is 'final'.
  161. component: The name of a Component instance, matching the ComponentSpec.name.
  162. )doc");
  163. REGISTER_OP("WriteAnnotations")
  164. .Input("handle: string")
  165. .Output("output_handle: string")
  166. .Attr("component: string")
  167. .Doc(R"doc(
  168. Given a ComputeSession, has the given component write out its annotations.
  169. The annotations are written to the underlying data objects passed in at the
  170. beginning of the computation.
  171. handle: A handle to a ComputeSession.
  172. output_handle: A handle to the same ComputeSession after writing.
  173. component: The name of a Component instance, matching the ComponentSpec.name.
  174. )doc");
  175. REGISTER_OP("EmitAnnotations")
  176. .Input("handle: string")
  177. .Output("annotations: string")
  178. .Attr("component: string")
  179. .Doc(R"doc(
  180. Given a ComputeSession, emits strings with final predictions for the model.
  181. Predictions are given for each element in the final component's batch.
  182. handle: A handle to a ComputeSession.
  183. annotations: A vector of strings representing the annotated data.
  184. component: The name of a Component instance, matching the ComponentSpec.name.
  185. )doc");
  186. REGISTER_OP("GetComponentTrace")
  187. .Input("handle: string")
  188. .Output("trace: string")
  189. .Attr("component: string")
  190. .Doc(R"doc(
  191. Gets the raw MasterTrace proto for each batch, state, and beam slot.
  192. handle: A handle to a ComputeSession.
  193. trace: A vector of MasterTrace protos.
  194. component: The name of a Component instance, matching the ComponentSpec.name.
  195. )doc");
  196. } // namespace dragnn
  197. } // namespace syntaxnet