digraph_ops.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357
  1. # Copyright 2017 Google Inc. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """TensorFlow ops for directed graphs."""
  16. import tensorflow as tf
  17. from syntaxnet.util import check
  18. def ArcPotentialsFromTokens(source_tokens, target_tokens, weights):
  19. r"""Returns arc potentials computed from token activations and weights.
  20. For each batch of source and target token activations, computes a scalar
  21. potential for each arc as the 3-way product between the activation vectors of
  22. the source and target of the arc and the |weights|. Specifically,
  23. arc[b,s,t] =
  24. \sum_{i,j} source_tokens[b,s,i] * weights[i,j] * target_tokens[b,t,j]
  25. Note that the token activations can be extended with bias terms to implement a
  26. "biaffine" model (Dozat and Manning, 2017).
  27. Args:
  28. source_tokens: [B,N,S] tensor of batched activations for the source token in
  29. each arc.
  30. target_tokens: [B,N,T] tensor of batched activations for the target token in
  31. each arc.
  32. weights: [S,T] matrix of weights.
  33. B,N may be statically-unknown, but S,T must be statically-known. The dtype
  34. of all arguments must be compatible.
  35. Returns:
  36. [B,N,N] tensor A of arc potentials where A_{b,s,t} is the potential of the
  37. arc from s to t in batch element b. The dtype of A is the same as that of
  38. the arguments. Note that the diagonal entries (i.e., where s==t) represent
  39. self-loops and may not be meaningful.
  40. """
  41. # All arguments must have statically-known rank.
  42. check.Eq(source_tokens.get_shape().ndims, 3, 'source_tokens must be rank 3')
  43. check.Eq(target_tokens.get_shape().ndims, 3, 'target_tokens must be rank 3')
  44. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  45. # All activation dimensions must be statically-known.
  46. num_source_activations = weights.get_shape().as_list()[0]
  47. num_target_activations = weights.get_shape().as_list()[1]
  48. check.NotNone(num_source_activations, 'unknown source activation dimension')
  49. check.NotNone(num_target_activations, 'unknown target activation dimension')
  50. check.Eq(source_tokens.get_shape().as_list()[2], num_source_activations,
  51. 'dimension mismatch between weights and source_tokens')
  52. check.Eq(target_tokens.get_shape().as_list()[2], num_target_activations,
  53. 'dimension mismatch between weights and target_tokens')
  54. # All arguments must share the same type.
  55. check.Same([weights.dtype.base_dtype,
  56. source_tokens.dtype.base_dtype,
  57. target_tokens.dtype.base_dtype],
  58. 'dtype mismatch')
  59. source_tokens_shape = tf.shape(source_tokens)
  60. target_tokens_shape = tf.shape(target_tokens)
  61. batch_size = source_tokens_shape[0]
  62. num_tokens = source_tokens_shape[1]
  63. with tf.control_dependencies([
  64. tf.assert_equal(batch_size, target_tokens_shape[0]),
  65. tf.assert_equal(num_tokens, target_tokens_shape[1])]):
  66. # Flatten out the batch dimension so we can use one big multiplication.
  67. targets_bnxt = tf.reshape(target_tokens, [-1, num_target_activations])
  68. # Matrices are row-major, so we arrange for the RHS argument of each matmul
  69. # to have its transpose flag set. That way no copying is required to align
  70. # the rows of the LHS with the columns of the RHS.
  71. weights_targets_bnxs = tf.matmul(targets_bnxt, weights, transpose_b=True)
  72. # The next computation is over pairs of tokens within each batch element, so
  73. # restore the batch dimension.
  74. weights_targets_bxnxs = tf.reshape(
  75. weights_targets_bnxs, [batch_size, num_tokens, num_source_activations])
  76. # Note that this multiplication is repeated across the batch dimension,
  77. # instead of being one big multiplication as in the first matmul. There
  78. # doesn't seem to be a way to arrange this as a single multiplication given
  79. # the pairwise nature of this computation.
  80. arcs_bxnxn = tf.matmul(source_tokens, weights_targets_bxnxs,
  81. transpose_b=True)
  82. return arcs_bxnxn
  83. def ArcSourcePotentialsFromTokens(tokens, weights):
  84. r"""Returns arc source potentials computed from tokens and weights.
  85. For each batch of token activations, computes a scalar potential for each arc
  86. as the product between the activations of the source token and the |weights|.
  87. Specifically,
  88. arc[b,s,:] = \sum_{i} weights[i] * tokens[b,s,i]
  89. Args:
  90. tokens: [B,N,S] tensor of batched activations for source tokens.
  91. weights: [S] vector of weights.
  92. B,N may be statically-unknown, but S must be statically-known. The dtype of
  93. all arguments must be compatible.
  94. Returns:
  95. [B,N,N] tensor A of arc potentials as defined above. The dtype of A is the
  96. same as that of the arguments. Note that the diagonal entries (i.e., where
  97. s==t) represent self-loops and may not be meaningful.
  98. """
  99. # All arguments must have statically-known rank.
  100. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  101. check.Eq(weights.get_shape().ndims, 1, 'weights must be a vector')
  102. # All activation dimensions must be statically-known.
  103. num_source_activations = weights.get_shape().as_list()[0]
  104. check.NotNone(num_source_activations, 'unknown source activation dimension')
  105. check.Eq(tokens.get_shape().as_list()[2], num_source_activations,
  106. 'dimension mismatch between weights and tokens')
  107. # All arguments must share the same type.
  108. check.Same([weights.dtype.base_dtype,
  109. tokens.dtype.base_dtype],
  110. 'dtype mismatch')
  111. tokens_shape = tf.shape(tokens)
  112. batch_size = tokens_shape[0]
  113. num_tokens = tokens_shape[1]
  114. # Flatten out the batch dimension so we can use a couple big matmuls.
  115. tokens_bnxs = tf.reshape(tokens, [-1, num_source_activations])
  116. weights_sx1 = tf.expand_dims(weights, 1)
  117. sources_bnx1 = tf.matmul(tokens_bnxs, weights_sx1)
  118. sources_bnxn = tf.tile(sources_bnx1, [1, num_tokens])
  119. # Restore the batch dimension in the output.
  120. sources_bxnxn = tf.reshape(sources_bnxn, [batch_size, num_tokens, num_tokens])
  121. return sources_bxnxn
  122. def RootPotentialsFromTokens(root, tokens, weights):
  123. r"""Returns root selection potentials computed from tokens and weights.
  124. For each batch of token activations, computes a scalar potential for each root
  125. selection as the 3-way product between the activations of the artificial root
  126. token, the token activations, and the |weights|. Specifically,
  127. roots[b,r] = \sum_{i,j} root[i] * weights[i,j] * tokens[b,r,j]
  128. Args:
  129. root: [S] vector of activations for the artificial root token.
  130. tokens: [B,N,T] tensor of batched activations for root tokens.
  131. weights: [S,T] matrix of weights.
  132. B,N may be statically-unknown, but S,T must be statically-known. The dtype
  133. of all arguments must be compatible.
  134. Returns:
  135. [B,N] matrix R of root-selection potentials as defined above. The dtype of
  136. R is the same as that of the arguments.
  137. """
  138. # All arguments must have statically-known rank.
  139. check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
  140. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  141. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  142. # All activation dimensions must be statically-known.
  143. num_source_activations = weights.get_shape().as_list()[0]
  144. num_target_activations = weights.get_shape().as_list()[1]
  145. check.NotNone(num_source_activations, 'unknown source activation dimension')
  146. check.NotNone(num_target_activations, 'unknown target activation dimension')
  147. check.Eq(root.get_shape().as_list()[0], num_source_activations,
  148. 'dimension mismatch between weights and root')
  149. check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
  150. 'dimension mismatch between weights and tokens')
  151. # All arguments must share the same type.
  152. check.Same([weights.dtype.base_dtype,
  153. root.dtype.base_dtype,
  154. tokens.dtype.base_dtype],
  155. 'dtype mismatch')
  156. root_1xs = tf.expand_dims(root, 0)
  157. tokens_shape = tf.shape(tokens)
  158. batch_size = tokens_shape[0]
  159. num_tokens = tokens_shape[1]
  160. # Flatten out the batch dimension so we can use a couple big matmuls.
  161. tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
  162. weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  163. roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)
  164. # Restore the batch dimension in the output.
  165. roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
  166. return roots_bxn
  167. def CombineArcAndRootPotentials(arcs, roots):
  168. """Combines arc and root potentials into a single set of potentials.
  169. Args:
  170. arcs: [B,N,N] tensor of batched arc potentials.
  171. roots: [B,N] matrix of batched root potentials.
  172. Returns:
  173. [B,N,N] tensor P of combined potentials where
  174. P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  175. """
  176. # All arguments must have statically-known rank.
  177. check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  178. check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')
  179. # All arguments must share the same type.
  180. dtype = arcs.dtype.base_dtype
  181. check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')
  182. roots_shape = tf.shape(roots)
  183. arcs_shape = tf.shape(arcs)
  184. batch_size = roots_shape[0]
  185. num_tokens = roots_shape[1]
  186. with tf.control_dependencies([
  187. tf.assert_equal(batch_size, arcs_shape[0]),
  188. tf.assert_equal(num_tokens, arcs_shape[1]),
  189. tf.assert_equal(num_tokens, arcs_shape[2])]):
  190. return tf.matrix_set_diag(arcs, roots)
  191. def LabelPotentialsFromTokens(tokens, weights):
  192. r"""Computes label potentials from tokens and weights.
  193. For each batch of token activations, computes a scalar potential for each
  194. label as the product between the activations of the source token and the
  195. |weights|. Specifically,
  196. labels[b,t,l] = \sum_{i} weights[l,i] * tokens[b,t,i]
  197. Args:
  198. tokens: [B,N,T] tensor of batched token activations.
  199. weights: [L,T] matrix of weights.
  200. B,N may be dynamic, but L,T must be static. The dtype of all arguments must
  201. be compatible.
  202. Returns:
  203. [B,N,L] tensor of label potentials as defined above, with the same dtype as
  204. the arguments.
  205. """
  206. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  207. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  208. num_labels = weights.get_shape().as_list()[0]
  209. num_activations = weights.get_shape().as_list()[1]
  210. check.NotNone(num_labels, 'unknown number of labels')
  211. check.NotNone(num_activations, 'unknown activation dimension')
  212. check.Eq(tokens.get_shape().as_list()[2], num_activations,
  213. 'activation mismatch between weights and tokens')
  214. tokens_shape = tf.shape(tokens)
  215. batch_size = tokens_shape[0]
  216. num_tokens = tokens_shape[1]
  217. check.Same([tokens.dtype.base_dtype,
  218. weights.dtype.base_dtype],
  219. 'dtype mismatch')
  220. # Flatten out the batch dimension so we can use one big matmul().
  221. tokens_bnxt = tf.reshape(tokens, [-1, num_activations])
  222. labels_bnxl = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  223. # Restore the batch dimension in the output.
  224. labels_bxnxl = tf.reshape(labels_bnxl, [batch_size, num_tokens, num_labels])
  225. return labels_bxnxl
  226. def LabelPotentialsFromTokenPairs(sources, targets, weights):
  227. r"""Computes label potentials from source and target tokens and weights.
  228. For each aligned pair of source and target token activations, computes a
  229. scalar potential for each label on the arc from the source to the target.
  230. Specifically,
  231. labels[b,t,l] = \sum_{i,j} sources[b,t,i] * weights[l,i,j] * targets[b,t,j]
  232. Args:
  233. sources: [B,N,S] tensor of batched source token activations.
  234. targets: [B,N,T] tensor of batched target token activations.
  235. weights: [L,S,T] tensor of weights.
  236. B,N may be dynamic, but L,S,T must be static. The dtype of all arguments
  237. must be compatible.
  238. Returns:
  239. [B,N,L] tensor of label potentials as defined above, with the same dtype as
  240. the arguments.
  241. """
  242. check.Eq(sources.get_shape().ndims, 3, 'sources must be rank 3')
  243. check.Eq(targets.get_shape().ndims, 3, 'targets must be rank 3')
  244. check.Eq(weights.get_shape().ndims, 3, 'weights must be rank 3')
  245. num_labels = weights.get_shape().as_list()[0]
  246. num_source_activations = weights.get_shape().as_list()[1]
  247. num_target_activations = weights.get_shape().as_list()[2]
  248. check.NotNone(num_labels, 'unknown number of labels')
  249. check.NotNone(num_source_activations, 'unknown source activation dimension')
  250. check.NotNone(num_target_activations, 'unknown target activation dimension')
  251. check.Eq(sources.get_shape().as_list()[2], num_source_activations,
  252. 'activation mismatch between weights and source tokens')
  253. check.Eq(targets.get_shape().as_list()[2], num_target_activations,
  254. 'activation mismatch between weights and target tokens')
  255. check.Same([sources.dtype.base_dtype,
  256. targets.dtype.base_dtype,
  257. weights.dtype.base_dtype],
  258. 'dtype mismatch')
  259. sources_shape = tf.shape(sources)
  260. targets_shape = tf.shape(targets)
  261. batch_size = sources_shape[0]
  262. num_tokens = sources_shape[1]
  263. with tf.control_dependencies([tf.assert_equal(batch_size, targets_shape[0]),
  264. tf.assert_equal(num_tokens, targets_shape[1])]):
  265. # For each token, we must compute a vector-3tensor-vector product. There is
  266. # no op for this, but we can use reshape() and matmul() to compute it.
  267. # Reshape |weights| and |targets| so we can use a single matmul().
  268. weights_lsxt = tf.reshape(weights, [num_labels * num_source_activations,
  269. num_target_activations])
  270. targets_bnxt = tf.reshape(targets, [-1, num_target_activations])
  271. weights_targets_bnxls = tf.matmul(targets_bnxt, weights_lsxt,
  272. transpose_b=True)
  273. # Restore all dimensions.
  274. weights_targets_bxnxlxs = tf.reshape(
  275. weights_targets_bnxls,
  276. [batch_size, num_tokens, num_labels, num_source_activations])
  277. # Incorporate the source activations. In this case, we perform a batched
  278. # matmul() between the trailing [L,S] matrices of the current result and the
  279. # trailing [S] vectors of the tokens.
  280. sources_bxnx1xs = tf.expand_dims(sources, 2)
  281. labels_bxnxlx1 = tf.matmul(weights_targets_bxnxlxs, sources_bxnx1xs,
  282. transpose_b=True)
  283. labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
  284. return labels_bxnxl