per_example_gradients.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. """Per-example gradients for selected ops."""
  16. import collections
  17. import tensorflow as tf
  18. OrderedDict = collections.OrderedDict
  19. def _ListUnion(list_1, list_2):
  20. """Returns the union of two lists.
  21. Python sets can have a non-deterministic iteration order. In some
  22. contexts, this could lead to TensorFlow producing two different
  23. programs when the same Python script is run twice. In these contexts
  24. we use lists instead of sets.
  25. This function is not designed to be especially fast and should only
  26. be used with small lists.
  27. Args:
  28. list_1: A list
  29. list_2: Another list
  30. Returns:
  31. A new list containing one copy of each unique element of list_1 and
  32. list_2. Uniqueness is determined by "x in union" logic; e.g. two
  33. string of that value appearing in the union.
  34. Raises:
  35. TypeError: The arguments are not lists.
  36. """
  37. if not (isinstance(list_1, list) and isinstance(list_2, list)):
  38. raise TypeError("Arguments must be lists.")
  39. union = []
  40. for x in list_1 + list_2:
  41. if x not in union:
  42. union.append(x)
  43. return union
  44. def Interface(ys, xs):
  45. """
  46. Returns a dict mapping each element of xs to any of its consumers that are
  47. indirectly consumed by ys.
  48. Args:
  49. ys: The outputs
  50. xs: The inputs
  51. Returns:
  52. out: Dict mapping each member x of `xs` to a list of all Tensors that are
  53. direct consumers of x and are eventually consumed by a member of
  54. `ys`.
  55. """
  56. if isinstance(ys, (list, tuple)):
  57. queue = list(ys)
  58. else:
  59. queue = [ys]
  60. out = OrderedDict()
  61. if isinstance(xs, (list, tuple)):
  62. for x in xs:
  63. out[x] = []
  64. else:
  65. out[xs] = []
  66. done = set()
  67. while queue:
  68. y = queue.pop()
  69. if y in done:
  70. continue
  71. done = done.union(set([y]))
  72. for x in y.op.inputs:
  73. if x in out:
  74. out[x].append(y)
  75. else:
  76. assert id(x) not in [id(foo) for foo in out]
  77. queue.extend(y.op.inputs)
  78. return out
  79. class PXGRegistry(object):
  80. """Per-Example Gradient registry.
  81. Maps names of ops to per-example gradient rules for those ops.
  82. These rules are only needed for ops that directly touch values that
  83. are shared between examples. For most machine learning applications,
  84. this means only ops that directly operate on the parameters.
  85. See http://arxiv.org/abs/1510.01799 for more information, and please
  86. consider citing that tech report if you use this function in published
  87. research.
  88. """
  89. def __init__(self):
  90. self.d = OrderedDict()
  91. def __call__(self, op,
  92. colocate_gradients_with_ops=False,
  93. gate_gradients=False):
  94. if op.node_def.op not in self.d:
  95. raise NotImplementedError("No per-example gradient rule registered "
  96. "for " + op.node_def.op + " in pxg_registry.")
  97. return self.d[op.node_def.op](op,
  98. colocate_gradients_with_ops,
  99. gate_gradients)
  100. def Register(self, op_name, pxg_class):
  101. """Associates `op_name` key with `pxg_class` value.
  102. Registers `pxg_class` as the class that will be called to perform
  103. per-example differentiation through ops with `op_name`.
  104. Args:
  105. op_name: String op name.
  106. pxg_class: An instance of any class with the same signature as MatMulPXG.
  107. """
  108. self.d[op_name] = pxg_class
  109. pxg_registry = PXGRegistry()
  110. class MatMulPXG(object):
  111. """Per-example gradient rule for MatMul op.
  112. """
  113. def __init__(self, op,
  114. colocate_gradients_with_ops=False,
  115. gate_gradients=False):
  116. """Construct an instance of the rule for `op`.
  117. Args:
  118. op: The Operation to differentiate through.
  119. colocate_gradients_with_ops: currently unsupported
  120. gate_gradients: currently unsupported
  121. """
  122. assert op.node_def.op == "MatMul"
  123. self.op = op
  124. self.colocate_gradients_with_ops = colocate_gradients_with_ops
  125. self.gate_gradients = gate_gradients
  126. def __call__(self, x, z_grads):
  127. """Build the graph for the per-example gradient through the op.
  128. Assumes that the MatMul was called with a design matrix with examples
  129. in rows as the first argument and parameters as the second argument.
  130. Args:
  131. x: The Tensor to differentiate with respect to. This tensor must
  132. represent the weights.
  133. z_grads: The list of gradients on the output of the op.
  134. Returns:
  135. x_grads: A Tensor containing the gradient with respect to `x` for
  136. each example. This is a 3-D tensor, with the first axis corresponding
  137. to examples and the remaining axes matching the shape of x.
  138. """
  139. idx = list(self.op.inputs).index(x)
  140. assert idx != -1
  141. assert len(z_grads) == len(self.op.outputs)
  142. assert idx == 1 # We expect weights to be arg 1
  143. # We don't expect anyone to per-example differentiate with repsect
  144. # to anything other than the weights.
  145. x, w = self.op.inputs
  146. z_grads, = z_grads
  147. x_expanded = tf.expand_dims(x, 2)
  148. z_grads_expanded = tf.expand_dims(z_grads, 1)
  149. return tf.mul(x_expanded, z_grads_expanded)
  150. pxg_registry.Register("MatMul", MatMulPXG)
  151. class Conv2DPXG(object):
  152. """Per-example gradient rule of Conv2d op.
  153. Same interface as MatMulPXG.
  154. """
  155. def __init__(self, op,
  156. colocate_gradients_with_ops=False,
  157. gate_gradients=False):
  158. assert op.node_def.op == "Conv2D"
  159. self.op = op
  160. self.colocate_gradients_with_ops = colocate_gradients_with_ops
  161. self.gate_gradients = gate_gradients
  162. def _PxConv2DBuilder(self, input_, w, strides, padding):
  163. """conv2d run separately per example, to help compute per-example gradients.
  164. Args:
  165. input_: tensor containing a minibatch of images / feature maps.
  166. Shape [batch_size, rows, columns, channels]
  167. w: convolution kernels. Shape
  168. [kernel rows, kernel columns, input channels, output channels]
  169. strides: passed through to regular conv_2d
  170. padding: passed through to regular conv_2d
  171. Returns:
  172. conv: the output of the convolution.
  173. single tensor, same as what regular conv_2d does
  174. w_px: a list of batch_size copies of w. each copy was used
  175. for the corresponding example in the minibatch.
  176. calling tf.gradients on the copy gives the gradient for just
  177. that example.
  178. """
  179. input_shape = [int(e) for e in input_.get_shape()]
  180. batch_size = input_shape[0]
  181. input_px = [tf.slice(
  182. input_, [example] + [0] * 3, [1] + input_shape[1:]) for example
  183. in xrange(batch_size)]
  184. for input_x in input_px:
  185. assert int(input_x.get_shape()[0]) == 1
  186. w_px = [tf.identity(w) for example in xrange(batch_size)]
  187. conv_px = [tf.nn.conv2d(input_x, w_x,
  188. strides=strides,
  189. padding=padding)
  190. for input_x, w_x in zip(input_px, w_px)]
  191. for conv_x in conv_px:
  192. num_x = int(conv_x.get_shape()[0])
  193. assert num_x == 1, num_x
  194. assert len(conv_px) == batch_size
  195. conv = tf.concat(0, conv_px)
  196. assert int(conv.get_shape()[0]) == batch_size
  197. return conv, w_px
  198. def __call__(self, w, z_grads):
  199. idx = list(self.op.inputs).index(w)
  200. # Make sure that `op` was actually applied to `w`
  201. assert idx != -1
  202. assert len(z_grads) == len(self.op.outputs)
  203. # The following assert may be removed when we are ready to use this
  204. # for general purpose code.
  205. # This assert is only expected to hold in the contex of our preliminary
  206. # MNIST experiments.
  207. assert idx == 1 # We expect convolution weights to be arg 1
  208. images, filters = self.op.inputs
  209. strides = self.op.get_attr("strides")
  210. padding = self.op.get_attr("padding")
  211. # Currently assuming that one specifies at most these four arguments and
  212. # that all other arguments to conv2d are set to default.
  213. conv, w_px = self._PxConv2DBuilder(images, filters, strides, padding)
  214. z_grads, = z_grads
  215. gradients_list = tf.gradients(conv, w_px, z_grads,
  216. colocate_gradients_with_ops=
  217. self.colocate_gradients_with_ops,
  218. gate_gradients=self.gate_gradients)
  219. return tf.pack(gradients_list)
  220. pxg_registry.Register("Conv2D", Conv2DPXG)
  221. class AddPXG(object):
  222. """Per-example gradient rule for Add op.
  223. Same interface as MatMulPXG.
  224. """
  225. def __init__(self, op,
  226. colocate_gradients_with_ops=False,
  227. gate_gradients=False):
  228. assert op.node_def.op == "Add"
  229. self.op = op
  230. self.colocate_gradients_with_ops = colocate_gradients_with_ops
  231. self.gate_gradients = gate_gradients
  232. def __call__(self, x, z_grads):
  233. idx = list(self.op.inputs).index(x)
  234. # Make sure that `op` was actually applied to `x`
  235. assert idx != -1
  236. assert len(z_grads) == len(self.op.outputs)
  237. # The following assert may be removed when we are ready to use this
  238. # for general purpose code.
  239. # This assert is only expected to hold in the contex of our preliminary
  240. # MNIST experiments.
  241. assert idx == 1 # We expect biases to be arg 1
  242. # We don't expect anyone to per-example differentiate with respect
  243. # to anything other than the biases.
  244. x, b = self.op.inputs
  245. z_grads, = z_grads
  246. return z_grads
  247. pxg_registry.Register("Add", AddPXG)
  248. def PerExampleGradients(ys, xs, grad_ys=None, name="gradients",
  249. colocate_gradients_with_ops=False,
  250. gate_gradients=False):
  251. """Symbolic differentiation, separately for each example.
  252. Matches the interface of tf.gradients, but the return values each have an
  253. additional axis corresponding to the examples.
  254. Assumes that the cost in `ys` is additive across examples.
  255. e.g., no batch normalization.
  256. Individual rules for each op specify their own assumptions about how
  257. examples are put into tensors.
  258. """
  259. # Find the interface between the xs and the cost
  260. for x in xs:
  261. assert isinstance(x, tf.Tensor), type(x)
  262. interface = Interface(ys, xs)
  263. merged_interface = []
  264. for x in xs:
  265. merged_interface = _ListUnion(merged_interface, interface[x])
  266. # Differentiate with respect to the interface
  267. interface_gradients = tf.gradients(ys, merged_interface, grad_ys=grad_ys,
  268. name=name,
  269. colocate_gradients_with_ops=
  270. colocate_gradients_with_ops,
  271. gate_gradients=gate_gradients)
  272. grad_dict = OrderedDict(zip(merged_interface, interface_gradients))
  273. # Build the per-example gradients with respect to the xs
  274. if colocate_gradients_with_ops:
  275. raise NotImplementedError("The per-example gradients are not yet "
  276. "colocated with ops.")
  277. if gate_gradients:
  278. raise NotImplementedError("The per-example gradients are not yet "
  279. "gated.")
  280. out = []
  281. for x in xs:
  282. zs = interface[x]
  283. ops = []
  284. for z in zs:
  285. ops = _ListUnion(ops, [z.op])
  286. if len(ops) != 1:
  287. raise NotImplementedError("Currently we only support the case "
  288. "where each x is consumed by exactly "
  289. "one op. but %s is consumed by %d ops."
  290. % (x.name, len(ops)))
  291. op = ops[0]
  292. pxg_rule = pxg_registry(op, colocate_gradients_with_ops, gate_gradients)
  293. x_grad = pxg_rule(x, [grad_dict[z] for z in zs])
  294. out.append(x_grad)
  295. return out