spatial_transformer.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import tensorflow as tf
  16. def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
  17. """Spatial Transformer Layer
  18. Implements a spatial transformer layer as described in [1]_.
  19. Based on [2]_ and edited by David Dao for Tensorflow.
  20. Parameters
  21. ----------
  22. U : float
  23. The output of a convolutional net should have the
  24. shape [num_batch, height, width, num_channels].
  25. theta: float
  26. The output of the
  27. localisation network should be [num_batch, 6].
  28. out_size: tuple of two floats
  29. The size of the output of the network
  30. References
  31. ----------
  32. .. [1] Spatial Transformer Networks
  33. Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
  34. Submitted on 5 Jun 2015
  35. .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
  36. Notes
  37. -----
  38. To initialize the network to the identity transform init
  39. ``theta`` to :
  40. identity = np.array([[1., 0., 0.],
  41. [0., 1., 0.]])
  42. identity = identity.flatten()
  43. theta = tf.Variable(initial_value=identity)
  44. """
  45. def _repeat(x, n_repeats):
  46. with tf.variable_scope('_repeat'):
  47. rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
  48. rep = tf.cast(rep, 'int32')
  49. x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
  50. return tf.reshape(x,[-1])
  51. def _interpolate(im, x, y, out_size):
  52. with tf.variable_scope('_interpolate'):
  53. # constants
  54. num_batch = tf.shape(im)[0]
  55. height = tf.shape(im)[1]
  56. width = tf.shape(im)[2]
  57. channels = tf.shape(im)[3]
  58. x = tf.cast(x, 'float32')
  59. y = tf.cast(y, 'float32')
  60. height_f = tf.cast(height, 'float32')
  61. width_f = tf.cast(width, 'float32')
  62. out_height = out_size[0]
  63. out_width = out_size[1]
  64. zero = tf.zeros([], dtype='int32')
  65. max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
  66. max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
  67. # scale indices from [-1, 1] to [0, width/height]
  68. x = (x + 1.0)*(width_f) / 2.0
  69. y = (y + 1.0)*(height_f) / 2.0
  70. # do sampling
  71. x0 = tf.cast(tf.floor(x), 'int32')
  72. x1 = x0 + 1
  73. y0 = tf.cast(tf.floor(y), 'int32')
  74. y1 = y0 + 1
  75. x0 = tf.clip_by_value(x0, zero, max_x)
  76. x1 = tf.clip_by_value(x1, zero, max_x)
  77. y0 = tf.clip_by_value(y0, zero, max_y)
  78. y1 = tf.clip_by_value(y1, zero, max_y)
  79. dim2 = width
  80. dim1 = width*height
  81. base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
  82. base_y0 = base + y0*dim2
  83. base_y1 = base + y1*dim2
  84. idx_a = base_y0 + x0
  85. idx_b = base_y1 + x0
  86. idx_c = base_y0 + x1
  87. idx_d = base_y1 + x1
  88. # use indices to lookup pixels in the flat image and restore channels dim
  89. im_flat = tf.reshape(im,tf.pack([-1, channels]))
  90. im_flat = tf.cast(im_flat, 'float32')
  91. Ia = tf.gather(im_flat, idx_a)
  92. Ib = tf.gather(im_flat, idx_b)
  93. Ic = tf.gather(im_flat, idx_c)
  94. Id = tf.gather(im_flat, idx_d)
  95. # and finally calculate interpolated values
  96. x0_f = tf.cast(x0, 'float32')
  97. x1_f = tf.cast(x1, 'float32')
  98. y0_f = tf.cast(y0, 'float32')
  99. y1_f = tf.cast(y1, 'float32')
  100. wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
  101. wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
  102. wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
  103. wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
  104. output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
  105. return output
  106. def _meshgrid(height, width):
  107. with tf.variable_scope('_meshgrid'):
  108. # This should be equivalent to:
  109. # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
  110. # np.linspace(-1, 1, height))
  111. # ones = np.ones(np.prod(x_t.shape))
  112. # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
  113. x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
  114. tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
  115. y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
  116. tf.ones(shape=tf.pack([1, width])))
  117. x_t_flat = tf.reshape(x_t,(1, -1))
  118. y_t_flat = tf.reshape(y_t,(1, -1))
  119. ones = tf.ones_like(x_t_flat)
  120. grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
  121. return grid
  122. def _transform(theta, input_dim, out_size):
  123. with tf.variable_scope('_transform'):
  124. num_batch = tf.shape(input_dim)[0]
  125. height = tf.shape(input_dim)[1]
  126. width = tf.shape(input_dim)[2]
  127. num_channels = tf.shape(input_dim)[3]
  128. theta = tf.reshape(theta, (-1, 2, 3))
  129. theta = tf.cast(theta, 'float32')
  130. # grid of (x_t, y_t, 1), eq (1) in ref [1]
  131. height_f = tf.cast(height, 'float32')
  132. width_f = tf.cast(width, 'float32')
  133. out_height = out_size[0]
  134. out_width = out_size[1]
  135. grid = _meshgrid(out_height, out_width)
  136. grid = tf.expand_dims(grid,0)
  137. grid = tf.reshape(grid,[-1])
  138. grid = tf.tile(grid,tf.pack([num_batch]))
  139. grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
  140. # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
  141. T_g = tf.batch_matmul(theta, grid)
  142. x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
  143. y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
  144. x_s_flat = tf.reshape(x_s,[-1])
  145. y_s_flat = tf.reshape(y_s,[-1])
  146. input_transformed = _interpolate(
  147. input_dim, x_s_flat, y_s_flat,
  148. out_size)
  149. output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
  150. return output
  151. with tf.variable_scope(name):
  152. output = _transform(theta, U, out_size)
  153. return output
  154. def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
  155. """Batch Spatial Transformer Layer
  156. Parameters
  157. ----------
  158. U : float
  159. tensor of inputs [num_batch,height,width,num_channels]
  160. thetas : float
  161. a set of transformations for each input [num_batch,num_transforms,6]
  162. out_size : int
  163. the size of the output [out_height,out_width]
  164. Returns: float
  165. Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
  166. """
  167. with tf.variable_scope(name):
  168. num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
  169. indices = [[i]*num_transforms for i in xrange(num_batch)]
  170. input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
  171. return transformer(input_repeated, thetas, out_size)