digraph_ops.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. """TensorFlow ops for directed graphs."""
  2. import tensorflow as tf
  3. from syntaxnet.util import check
  4. def ArcPotentialsFromTokens(source_tokens, target_tokens, weights):
  5. r"""Returns arc potentials computed from token activations and weights.
  6. For each batch of source and target token activations, computes a scalar
  7. potential for each arc as the 3-way product between the activation vectors of
  8. the source and target of the arc and the |weights|. Specifically,
  9. arc[b,s,t] =
  10. \sum_{i,j} source_tokens[b,s,i] * weights[i,j] * target_tokens[b,t,j]
  11. Note that the token activations can be extended with bias terms to implement a
  12. "biaffine" model (Dozat and Manning, 2017).
  13. Args:
  14. source_tokens: [B,N,S] tensor of batched activations for the source token in
  15. each arc.
  16. target_tokens: [B,N,T] tensor of batched activations for the target token in
  17. each arc.
  18. weights: [S,T] matrix of weights.
  19. B,N may be statically-unknown, but S,T must be statically-known. The dtype
  20. of all arguments must be compatible.
  21. Returns:
  22. [B,N,N] tensor A of arc potentials where A_{b,s,t} is the potential of the
  23. arc from s to t in batch element b. The dtype of A is the same as that of
  24. the arguments. Note that the diagonal entries (i.e., where s==t) represent
  25. self-loops and may not be meaningful.
  26. """
  27. # All arguments must have statically-known rank.
  28. check.Eq(source_tokens.get_shape().ndims, 3, 'source_tokens must be rank 3')
  29. check.Eq(target_tokens.get_shape().ndims, 3, 'target_tokens must be rank 3')
  30. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  31. # All activation dimensions must be statically-known.
  32. num_source_activations = weights.get_shape().as_list()[0]
  33. num_target_activations = weights.get_shape().as_list()[1]
  34. check.NotNone(num_source_activations, 'unknown source activation dimension')
  35. check.NotNone(num_target_activations, 'unknown target activation dimension')
  36. check.Eq(source_tokens.get_shape().as_list()[2], num_source_activations,
  37. 'dimension mismatch between weights and source_tokens')
  38. check.Eq(target_tokens.get_shape().as_list()[2], num_target_activations,
  39. 'dimension mismatch between weights and target_tokens')
  40. # All arguments must share the same type.
  41. check.Same([weights.dtype.base_dtype,
  42. source_tokens.dtype.base_dtype,
  43. target_tokens.dtype.base_dtype],
  44. 'dtype mismatch')
  45. source_tokens_shape = tf.shape(source_tokens)
  46. target_tokens_shape = tf.shape(target_tokens)
  47. batch_size = source_tokens_shape[0]
  48. num_tokens = source_tokens_shape[1]
  49. with tf.control_dependencies([
  50. tf.assert_equal(batch_size, target_tokens_shape[0]),
  51. tf.assert_equal(num_tokens, target_tokens_shape[1])]):
  52. # Flatten out the batch dimension so we can use one big multiplication.
  53. targets_bnxt = tf.reshape(target_tokens, [-1, num_target_activations])
  54. # Matrices are row-major, so we arrange for the RHS argument of each matmul
  55. # to have its transpose flag set. That way no copying is required to align
  56. # the rows of the LHS with the columns of the RHS.
  57. weights_targets_bnxs = tf.matmul(targets_bnxt, weights, transpose_b=True)
  58. # The next computation is over pairs of tokens within each batch element, so
  59. # restore the batch dimension.
  60. weights_targets_bxnxs = tf.reshape(
  61. weights_targets_bnxs, [batch_size, num_tokens, num_source_activations])
  62. # Note that this multiplication is repeated across the batch dimension,
  63. # instead of being one big multiplication as in the first matmul. There
  64. # doesn't seem to be a way to arrange this as a single multiplication given
  65. # the pairwise nature of this computation.
  66. arcs_bxnxn = tf.matmul(source_tokens, weights_targets_bxnxs,
  67. transpose_b=True)
  68. return arcs_bxnxn
  69. def ArcSourcePotentialsFromTokens(tokens, weights):
  70. r"""Returns arc source potentials computed from tokens and weights.
  71. For each batch of token activations, computes a scalar potential for each arc
  72. as the product between the activations of the source token and the |weights|.
  73. Specifically,
  74. arc[b,s,:] = \sum_{i} weights[i] * tokens[b,s,i]
  75. Args:
  76. tokens: [B,N,S] tensor of batched activations for source tokens.
  77. weights: [S] vector of weights.
  78. B,N may be statically-unknown, but S must be statically-known. The dtype of
  79. all arguments must be compatible.
  80. Returns:
  81. [B,N,N] tensor A of arc potentials as defined above. The dtype of A is the
  82. same as that of the arguments. Note that the diagonal entries (i.e., where
  83. s==t) represent self-loops and may not be meaningful.
  84. """
  85. # All arguments must have statically-known rank.
  86. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  87. check.Eq(weights.get_shape().ndims, 1, 'weights must be a vector')
  88. # All activation dimensions must be statically-known.
  89. num_source_activations = weights.get_shape().as_list()[0]
  90. check.NotNone(num_source_activations, 'unknown source activation dimension')
  91. check.Eq(tokens.get_shape().as_list()[2], num_source_activations,
  92. 'dimension mismatch between weights and tokens')
  93. # All arguments must share the same type.
  94. check.Same([weights.dtype.base_dtype,
  95. tokens.dtype.base_dtype],
  96. 'dtype mismatch')
  97. tokens_shape = tf.shape(tokens)
  98. batch_size = tokens_shape[0]
  99. num_tokens = tokens_shape[1]
  100. # Flatten out the batch dimension so we can use a couple big matmuls.
  101. tokens_bnxs = tf.reshape(tokens, [-1, num_source_activations])
  102. weights_sx1 = tf.expand_dims(weights, 1)
  103. sources_bnx1 = tf.matmul(tokens_bnxs, weights_sx1)
  104. sources_bnxn = tf.tile(sources_bnx1, [1, num_tokens])
  105. # Restore the batch dimension in the output.
  106. sources_bxnxn = tf.reshape(sources_bnxn, [batch_size, num_tokens, num_tokens])
  107. return sources_bxnxn
  108. def RootPotentialsFromTokens(root, tokens, weights):
  109. r"""Returns root selection potentials computed from tokens and weights.
  110. For each batch of token activations, computes a scalar potential for each root
  111. selection as the 3-way product between the activations of the artificial root
  112. token, the token activations, and the |weights|. Specifically,
  113. roots[b,r] = \sum_{i,j} root[i] * weights[i,j] * tokens[b,r,j]
  114. Args:
  115. root: [S] vector of activations for the artificial root token.
  116. tokens: [B,N,T] tensor of batched activations for root tokens.
  117. weights: [S,T] matrix of weights.
  118. B,N may be statically-unknown, but S,T must be statically-known. The dtype
  119. of all arguments must be compatible.
  120. Returns:
  121. [B,N] matrix R of root-selection potentials as defined above. The dtype of
  122. R is the same as that of the arguments.
  123. """
  124. # All arguments must have statically-known rank.
  125. check.Eq(root.get_shape().ndims, 1, 'root must be a vector')
  126. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  127. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  128. # All activation dimensions must be statically-known.
  129. num_source_activations = weights.get_shape().as_list()[0]
  130. num_target_activations = weights.get_shape().as_list()[1]
  131. check.NotNone(num_source_activations, 'unknown source activation dimension')
  132. check.NotNone(num_target_activations, 'unknown target activation dimension')
  133. check.Eq(root.get_shape().as_list()[0], num_source_activations,
  134. 'dimension mismatch between weights and root')
  135. check.Eq(tokens.get_shape().as_list()[2], num_target_activations,
  136. 'dimension mismatch between weights and tokens')
  137. # All arguments must share the same type.
  138. check.Same([weights.dtype.base_dtype,
  139. root.dtype.base_dtype,
  140. tokens.dtype.base_dtype],
  141. 'dtype mismatch')
  142. root_1xs = tf.expand_dims(root, 0)
  143. tokens_shape = tf.shape(tokens)
  144. batch_size = tokens_shape[0]
  145. num_tokens = tokens_shape[1]
  146. # Flatten out the batch dimension so we can use a couple big matmuls.
  147. tokens_bnxt = tf.reshape(tokens, [-1, num_target_activations])
  148. weights_targets_bnxs = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  149. roots_1xbn = tf.matmul(root_1xs, weights_targets_bnxs, transpose_b=True)
  150. # Restore the batch dimension in the output.
  151. roots_bxn = tf.reshape(roots_1xbn, [batch_size, num_tokens])
  152. return roots_bxn
  153. def CombineArcAndRootPotentials(arcs, roots):
  154. """Combines arc and root potentials into a single set of potentials.
  155. Args:
  156. arcs: [B,N,N] tensor of batched arc potentials.
  157. roots: [B,N] matrix of batched root potentials.
  158. Returns:
  159. [B,N,N] tensor P of combined potentials where
  160. P_{b,s,t} = s == t ? roots[b,t] : arcs[b,s,t]
  161. """
  162. # All arguments must have statically-known rank.
  163. check.Eq(arcs.get_shape().ndims, 3, 'arcs must be rank 3')
  164. check.Eq(roots.get_shape().ndims, 2, 'roots must be a matrix')
  165. # All arguments must share the same type.
  166. dtype = arcs.dtype.base_dtype
  167. check.Same([dtype, roots.dtype.base_dtype], 'dtype mismatch')
  168. roots_shape = tf.shape(roots)
  169. arcs_shape = tf.shape(arcs)
  170. batch_size = roots_shape[0]
  171. num_tokens = roots_shape[1]
  172. with tf.control_dependencies([
  173. tf.assert_equal(batch_size, arcs_shape[0]),
  174. tf.assert_equal(num_tokens, arcs_shape[1]),
  175. tf.assert_equal(num_tokens, arcs_shape[2])]):
  176. return tf.matrix_set_diag(arcs, roots)
  177. def LabelPotentialsFromTokens(tokens, weights):
  178. r"""Computes label potentials from tokens and weights.
  179. For each batch of token activations, computes a scalar potential for each
  180. label as the product between the activations of the source token and the
  181. |weights|. Specifically,
  182. labels[b,t,l] = \sum_{i} weights[l,i] * tokens[b,t,i]
  183. Args:
  184. tokens: [B,N,T] tensor of batched token activations.
  185. weights: [L,T] matrix of weights.
  186. B,N may be dynamic, but L,T must be static. The dtype of all arguments must
  187. be compatible.
  188. Returns:
  189. [B,N,L] tensor of label potentials as defined above, with the same dtype as
  190. the arguments.
  191. """
  192. check.Eq(tokens.get_shape().ndims, 3, 'tokens must be rank 3')
  193. check.Eq(weights.get_shape().ndims, 2, 'weights must be a matrix')
  194. num_labels = weights.get_shape().as_list()[0]
  195. num_activations = weights.get_shape().as_list()[1]
  196. check.NotNone(num_labels, 'unknown number of labels')
  197. check.NotNone(num_activations, 'unknown activation dimension')
  198. check.Eq(tokens.get_shape().as_list()[2], num_activations,
  199. 'activation mismatch between weights and tokens')
  200. tokens_shape = tf.shape(tokens)
  201. batch_size = tokens_shape[0]
  202. num_tokens = tokens_shape[1]
  203. check.Same([tokens.dtype.base_dtype,
  204. weights.dtype.base_dtype],
  205. 'dtype mismatch')
  206. # Flatten out the batch dimension so we can use one big matmul().
  207. tokens_bnxt = tf.reshape(tokens, [-1, num_activations])
  208. labels_bnxl = tf.matmul(tokens_bnxt, weights, transpose_b=True)
  209. # Restore the batch dimension in the output.
  210. labels_bxnxl = tf.reshape(labels_bnxl, [batch_size, num_tokens, num_labels])
  211. return labels_bxnxl
  212. def LabelPotentialsFromTokenPairs(sources, targets, weights):
  213. r"""Computes label potentials from source and target tokens and weights.
  214. For each aligned pair of source and target token activations, computes a
  215. scalar potential for each label on the arc from the source to the target.
  216. Specifically,
  217. labels[b,t,l] = \sum_{i,j} sources[b,t,i] * weights[l,i,j] * targets[b,t,j]
  218. Args:
  219. sources: [B,N,S] tensor of batched source token activations.
  220. targets: [B,N,T] tensor of batched target token activations.
  221. weights: [L,S,T] tensor of weights.
  222. B,N may be dynamic, but L,S,T must be static. The dtype of all arguments
  223. must be compatible.
  224. Returns:
  225. [B,N,L] tensor of label potentials as defined above, with the same dtype as
  226. the arguments.
  227. """
  228. check.Eq(sources.get_shape().ndims, 3, 'sources must be rank 3')
  229. check.Eq(targets.get_shape().ndims, 3, 'targets must be rank 3')
  230. check.Eq(weights.get_shape().ndims, 3, 'weights must be rank 3')
  231. num_labels = weights.get_shape().as_list()[0]
  232. num_source_activations = weights.get_shape().as_list()[1]
  233. num_target_activations = weights.get_shape().as_list()[2]
  234. check.NotNone(num_labels, 'unknown number of labels')
  235. check.NotNone(num_source_activations, 'unknown source activation dimension')
  236. check.NotNone(num_target_activations, 'unknown target activation dimension')
  237. check.Eq(sources.get_shape().as_list()[2], num_source_activations,
  238. 'activation mismatch between weights and source tokens')
  239. check.Eq(targets.get_shape().as_list()[2], num_target_activations,
  240. 'activation mismatch between weights and target tokens')
  241. check.Same([sources.dtype.base_dtype,
  242. targets.dtype.base_dtype,
  243. weights.dtype.base_dtype],
  244. 'dtype mismatch')
  245. sources_shape = tf.shape(sources)
  246. targets_shape = tf.shape(targets)
  247. batch_size = sources_shape[0]
  248. num_tokens = sources_shape[1]
  249. with tf.control_dependencies([tf.assert_equal(batch_size, targets_shape[0]),
  250. tf.assert_equal(num_tokens, targets_shape[1])]):
  251. # For each token, we must compute a vector-3tensor-vector product. There is
  252. # no op for this, but we can use reshape() and matmul() to compute it.
  253. # Reshape |weights| and |targets| so we can use a single matmul().
  254. weights_lsxt = tf.reshape(weights, [num_labels * num_source_activations,
  255. num_target_activations])
  256. targets_bnxt = tf.reshape(targets, [-1, num_target_activations])
  257. weights_targets_bnxls = tf.matmul(targets_bnxt, weights_lsxt,
  258. transpose_b=True)
  259. # Restore all dimensions.
  260. weights_targets_bxnxlxs = tf.reshape(
  261. weights_targets_bnxls,
  262. [batch_size, num_tokens, num_labels, num_source_activations])
  263. # Incorporate the source activations. In this case, we perform a batched
  264. # matmul() between the trailing [L,S] matrices of the current result and the
  265. # trailing [S] vectors of the tokens.
  266. sources_bxnx1xs = tf.expand_dims(sources, 2)
  267. labels_bxnxlx1 = tf.matmul(weights_targets_bxnxlxs, sources_bxnx1xs,
  268. transpose_b=True)
  269. labels_bxnxl = tf.squeeze(labels_bxnxlx1, [3])
  270. return labels_bxnxl