dragnn_ops.cc 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. // Copyright 2017 Google Inc. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. // =============================================================================
  15. #include "tensorflow/core/framework/op.h"
  16. namespace syntaxnet {
  17. namespace dragnn {
  18. REGISTER_OP("GetSession")
  19. .Input("container: string")
  20. .Attr("master_spec: string")
  21. .Attr("grid_point: string")
  22. .Output("handle: string")
  23. .SetIsStateful()
  24. .Doc(R"doc(
  25. Given MasterSpec and GridPoint protos, outputs a handle to a ComputeSession.
  26. container: A unique identifier for the ComputeSessionPool from which a
  27. ComputeSession will be allocated.
  28. master_spec: A serialized syntaxnet.dragnn.MasterSpec proto.
  29. grid_point: A serialized syntaxnet.dragnn.GridPoint proto.
  30. handle: A string handle to a ComputeSession.
  31. )doc");
  32. REGISTER_OP("ReleaseSession").Input("handle: string").SetIsStateful().Doc(R"doc(
  33. Given a ComputeSession, return it to the ComputeSession pool.
  34. This ComputeSession will no longer be available after this op returns.
  35. handle: A handle to a ComputeSession that will be returned to the backing pool.
  36. )doc");
  37. REGISTER_OP("InitComponentData")
  38. .Input("handle: string")
  39. .Input("beam_size: int32")
  40. .Attr("component: string")
  41. .Output("output_handle: string")
  42. .Doc(R"doc(
  43. Initialize a component with the given beam size for a given ComputeSession.
  44. handle: A handle to a ComputeSession.
  45. beam_size: The size of the beam to use on the component.
  46. component: The name of a Component instance, matching the ComponentSpec.name.
  47. output_handle: The handle to the same ComputeSession after initialization.
  48. )doc");
  49. REGISTER_OP("BatchSize")
  50. .Input("handle: string")
  51. .Attr("component: string")
  52. .Output("batch_size: int32")
  53. .Doc(R"doc(
  54. Given a ComputeSession and a component name,return the component batch size.
  55. handle: A handle to a ComputeSession.
  56. component: The name of a Component instance, matching the ComponentSpec.name.
  57. batch_size: The size of the given component's batch.
  58. )doc");
  59. REGISTER_OP("SetTracing")
  60. .Input("handle: string")
  61. .Input("tracing_on: bool")
  62. .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
  63. .Output("output_handle: string")
  64. .Doc(R"doc(
  65. Given a ComputeSession, turns on or off tracing for all components.
  66. handle: A handle to a ComputeSession.
  67. tracing_on: Whether or not to record traces.
  68. output_handle: The handle to the same ComputeSession, with the tracing status changed.
  69. )doc");
  70. REGISTER_OP("AttachDataReader")
  71. .Input("handle: string")
  72. .Input("input_spec: string")
  73. .Attr("component: string = 'NOT_USED_FOR_THIS_OP'")
  74. .Output("output_handle: string")
  75. .Doc(R"doc(
  76. Given a ComputeSession, attach a data source.
  77. This op is agnostic to the type of input data. The vector of input strings is
  78. interpreted by the backend.
  79. handle: A handle to a ComputeSession.
  80. input_spec: A vector of strings, where each string represents one batch item.
  81. output_handle: The handle to the same ComputeSession after attachment.
  82. )doc");
  83. REGISTER_OP("AdvanceFromOracle")
  84. .Input("handle: string")
  85. .Attr("component: string")
  86. .Output("output_handle: string")
  87. .Doc(R"doc(
  88. Given a ComputeSession and a Component name, advance the component via oracle.
  89. handle: A handle to a ComputeSession.
  90. component: The name of a Component instance, matching the ComponentSpec.name.
  91. output_handle: The handle to the same ComputeSession after advancement.
  92. )doc");
  93. REGISTER_OP("AdvanceFromPrediction")
  94. .Input("handle: string")
  95. .Input("scores: float")
  96. .Attr("component: string")
  97. .Output("output_handle: string")
  98. .Doc(R"doc(
  99. Given a ComputeSession, a Component name, and a score tensor, advance the state.
  100. handle: A handle to a ComputeSession.
  101. scores: A tensor of scores, ordered by {batch_size, beam_size, num_actions}.
  102. component: The name of a Component instance, matching the ComponentSpec.name.
  103. output_handle: A handle to the same ComputeSession after advancement.
  104. )doc");
  105. REGISTER_OP("DragnnEmbeddingInitializer")
  106. .Output("embeddings: float")
  107. .Attr("embedding_input: string")
  108. .Attr("vocab: string")
  109. .Attr("scaling_coefficient: float = 1.0")
  110. .Attr("seed: int = 0")
  111. .Attr("seed2: int = 0")
  112. .Doc(R"doc(
  113. *** PLACEHOLDER OP - FUNCTIONALITY NOT YET IMPLEMENTED ***
  114. Read embeddings from an an input for every key specified in a text vocab file.
  115. embeddings: A tensor containing embeddings from the specified sstable.
  116. embedding_input: Path to location with embedding vectors.
  117. vocab: Path to list of keys corresponding to the input.
  118. scaling_coefficient: A scaling coefficient for the embedding matrix.
  119. seed: If either `seed` or `seed2` are set to be non-zero, the random number
  120. generator is seeded by the given seed. Otherwise, it is seeded by a
  121. random seed.
  122. seed2: A second seed to avoid seed collision.
  123. )doc");
  124. REGISTER_OP("ExtractFixedFeatures")
  125. .Input("handle: string")
  126. .Output("indices: int32")
  127. .Output("ids: int64")
  128. .Output("weights: float")
  129. .Attr("component: string")
  130. .Attr("channel_id: int")
  131. .Doc(R"doc(
  132. Given a ComputeSession, Component, and channel index, output fixed features.
  133. Fixed features returned as 3 vectors, 'indices', 'ids', and 'weights' of equal
  134. length. 'ids' specifies which rows should be looked up in the embedding
  135. matrix. 'weights' specifies a scale for each embedding vector. 'indices' is a
  136. sorted vector that assigns the same index to embedding vectors that should be
  137. summed together.
  138. handle: A handle to a ComputeSession.
  139. indices: The row to add the feature to.
  140. ids: The indices into embedding matrices for each feature.
  141. weights: The weight for each looked up feature.
  142. component: The name of a Component instance, matching the ComponentSpec.name.
  143. channel_id: The feature channel to extract features for.
  144. )doc");
  145. REGISTER_OP("ExtractLinkFeatures")
  146. .Input("handle: string")
  147. .Output("step_idx: int32")
  148. .Output("idx: int32")
  149. .Attr("component: string")
  150. .Attr("channel_id: int")
  151. .Doc(R"doc(
  152. Given a ComputeSession, Component, and a channel index, outputs link features.
  153. Output indices have shape {batch_size * beam_size * channel_size}.
  154. handle: A handle to a ComputeSession.
  155. step_idx: The step indices to read activations from.
  156. idx: indices The index within a step to read the activations from.
  157. component: The name of a Component instance, matching the ComponentSpec.name.
  158. channel_id: The feature channel to extract features for.
  159. )doc");
  160. REGISTER_OP("EmitOracleLabels")
  161. .Input("handle: string")
  162. .Output("gold_labels: int32")
  163. .Attr("component: string")
  164. .Doc(R"doc(
  165. Given a ComputeSession and Component, emit a vector of gold labels.
  166. handle: A handle to a ComputeSession.
  167. gold_labels: A [batch_size * beam_size] vector of gold labels for the current
  168. ComputeSession.
  169. component: The name of a Component instance, matching the ComponentSpec.name.
  170. )doc");
  171. REGISTER_OP("EmitAllFinal")
  172. .Input("handle: string")
  173. .Output("all_final: bool")
  174. .Attr("component: string")
  175. .Doc(R"doc(
  176. Given a ComputeSession and Component, returns whether the Component is final.
  177. A component is considered final when all elements in the batch have beams
  178. containing all final states.
  179. handle: A handle to a ComputeSession.
  180. all_final: Whether every element in the specified component is 'final'.
  181. component: The name of a Component instance, matching the ComponentSpec.name.
  182. )doc");
  183. REGISTER_OP("WriteAnnotations")
  184. .Input("handle: string")
  185. .Output("output_handle: string")
  186. .Attr("component: string")
  187. .Doc(R"doc(
  188. Given a ComputeSession, has the given component write out its annotations.
  189. The annotations are written to the underlying data objects passed in at the
  190. beginning of the computation.
  191. handle: A handle to a ComputeSession.
  192. output_handle: A handle to the same ComputeSession after writing.
  193. component: The name of a Component instance, matching the ComponentSpec.name.
  194. )doc");
  195. REGISTER_OP("EmitAnnotations")
  196. .Input("handle: string")
  197. .Output("annotations: string")
  198. .Attr("component: string")
  199. .Doc(R"doc(
  200. Given a ComputeSession, emits strings with final predictions for the model.
  201. Predictions are given for each element in the final component's batch.
  202. handle: A handle to a ComputeSession.
  203. annotations: A vector of strings representing the annotated data.
  204. component: The name of a Component instance, matching the ComponentSpec.name.
  205. )doc");
  206. REGISTER_OP("GetComponentTrace")
  207. .Input("handle: string")
  208. .Output("trace: string")
  209. .Attr("component: string")
  210. .Doc(R"doc(
  211. Gets the raw MasterTrace proto for each batch, state, and beam slot.
  212. handle: A handle to a ComputeSession.
  213. trace: A vector of MasterTrace protos.
  214. component: The name of a Component instance, matching the ComponentSpec.name.
  215. )doc");
  216. } // namespace dragnn
  217. } // namespace syntaxnet