|
@@ -0,0 +1,181 @@
|
|
|
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
|
+#
|
|
|
+# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
+# you may not use this file except in compliance with the License.
|
|
|
+# You may obtain a copy of the License at
|
|
|
+#
|
|
|
+# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
+#
|
|
|
+# Unless required by applicable law or agreed to in writing, software
|
|
|
+# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
+# See the License for the specific language governing permissions and
|
|
|
+# limitations under the License.
|
|
|
+# ==============================================================================
|
|
|
+import tensorflow as tf
|
|
|
+
|
|
|
+def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
|
|
|
+ """Spatial Transformer Layer
|
|
|
+
|
|
|
+ Implements a spatial transformer layer as described in [1]_.
|
|
|
+ Based on [2]_ and edited by David Dao for Tensorflow.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ U : float
|
|
|
+ The output of a convolutional net should have the
|
|
|
+ shape [num_batch, height, width, num_channels].
|
|
|
+ theta: float
|
|
|
+ The output of the
|
|
|
+ localisation network should be [num_batch, 6].
|
|
|
+ downsample_factor : float
|
|
|
+ A value of 1 will keep the original size of the image
|
|
|
+ Values larger than 1 will downsample the image.
|
|
|
+ Values below 1 will upsample the image
|
|
|
+ example image: height = 100, width = 200
|
|
|
+ downsample_factor = 2
|
|
|
+ output image will then be 50, 100
|
|
|
+
|
|
|
+ References
|
|
|
+ ----------
|
|
|
+ .. [1] Spatial Transformer Networks
|
|
|
+ Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
|
|
|
+ Submitted on 5 Jun 2015
|
|
|
+ .. [2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
|
|
|
+
|
|
|
+ Notes
|
|
|
+ -----
|
|
|
+ To initialize the network to the identity transform init
|
|
|
+ ``theta`` to :
|
|
|
+ identity = np.array([[1., 0., 0.],
|
|
|
+ [0., 1., 0.]])
|
|
|
+ identity = identity.flatten()
|
|
|
+ theta = tf.Variable(initial_value=identity)
|
|
|
+
|
|
|
+ """
|
|
|
+
|
|
|
+ def _repeat(x, n_repeats):
|
|
|
+ with tf.variable_scope('_repeat'):
|
|
|
+ rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
|
|
|
+ rep = tf.cast(rep, 'int32')
|
|
|
+ x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
|
|
|
+ return tf.reshape(x,[-1])
|
|
|
+
|
|
|
+ def _interpolate(im, x, y, downsample_factor):
|
|
|
+ with tf.variable_scope('_interpolate'):
|
|
|
+ # constants
|
|
|
+ num_batch = tf.shape(im)[0]
|
|
|
+ height = tf.shape(im)[1]
|
|
|
+ width = tf.shape(im)[2]
|
|
|
+ channels = tf.shape(im)[3]
|
|
|
+
|
|
|
+ x = tf.cast(x, 'float32')
|
|
|
+ y = tf.cast(y, 'float32')
|
|
|
+ height_f = tf.cast(height, 'float32')
|
|
|
+ width_f = tf.cast(width, 'float32')
|
|
|
+ out_height = tf.cast(height_f // downsample_factor, 'int32')
|
|
|
+ out_width = tf.cast(width_f // downsample_factor, 'int32')
|
|
|
+ zero = tf.zeros([], dtype='int32')
|
|
|
+ max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
|
|
|
+ max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
|
|
|
+
|
|
|
+ # scale indices from [-1, 1] to [0, width/height]
|
|
|
+ x = (x + 1.0)*(width_f) / 2.0
|
|
|
+ y = (y + 1.0)*(height_f) / 2.0
|
|
|
+
|
|
|
+ # do sampling
|
|
|
+ x0 = tf.cast(tf.floor(x), 'int32')
|
|
|
+ x1 = x0 + 1
|
|
|
+ y0 = tf.cast(tf.floor(y), 'int32')
|
|
|
+ y1 = y0 + 1
|
|
|
+
|
|
|
+ x0 = tf.clip_by_value(x0, zero, max_x)
|
|
|
+ x1 = tf.clip_by_value(x1, zero, max_x)
|
|
|
+ y0 = tf.clip_by_value(y0, zero, max_y)
|
|
|
+ y1 = tf.clip_by_value(y1, zero, max_y)
|
|
|
+ dim2 = width
|
|
|
+ dim1 = width*height
|
|
|
+ base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
|
|
|
+ base_y0 = base + y0*dim2
|
|
|
+ base_y1 = base + y1*dim2
|
|
|
+ idx_a = base_y0 + x0
|
|
|
+ idx_b = base_y1 + x0
|
|
|
+ idx_c = base_y0 + x1
|
|
|
+ idx_d = base_y1 + x1
|
|
|
+
|
|
|
+ # use indices to lookup pixels in the flat image and restore channels dim
|
|
|
+ im_flat = tf.reshape(im,tf.pack([-1, channels]))
|
|
|
+ im_flat = tf.cast(im_flat, 'float32')
|
|
|
+ Ia = tf.gather(im_flat, idx_a)
|
|
|
+ Ib = tf.gather(im_flat, idx_b)
|
|
|
+ Ic = tf.gather(im_flat, idx_c)
|
|
|
+ Id = tf.gather(im_flat, idx_d)
|
|
|
+
|
|
|
+ # and finally calculate interpolated values
|
|
|
+ x0_f = tf.cast(x0, 'float32')
|
|
|
+ x1_f = tf.cast(x1, 'float32')
|
|
|
+ y0_f = tf.cast(y0, 'float32')
|
|
|
+ y1_f = tf.cast(y1, 'float32')
|
|
|
+ wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
|
|
|
+ wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
|
|
|
+ wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
|
|
|
+ wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
|
|
|
+ output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
|
|
|
+ return output
|
|
|
+
|
|
|
+ def _meshgrid(height, width):
|
|
|
+ with tf.variable_scope('_meshgrid'):
|
|
|
+ # This should be equivalent to:
|
|
|
+ # x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
|
|
|
+ # np.linspace(-1, 1, height))
|
|
|
+ # ones = np.ones(np.prod(x_t.shape))
|
|
|
+ # grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
|
|
|
+ x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
|
|
|
+ tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0]))
|
|
|
+ y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
|
|
|
+ tf.ones(shape=tf.pack([1, width])))
|
|
|
+
|
|
|
+ x_t_flat = tf.reshape(x_t,(1, -1))
|
|
|
+ y_t_flat = tf.reshape(y_t,(1, -1))
|
|
|
+
|
|
|
+ ones = tf.ones_like(x_t_flat)
|
|
|
+ grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
|
|
|
+ return grid
|
|
|
+
|
|
|
+ def _transform(theta, input_dim, downsample_factor):
|
|
|
+ with tf.variable_scope('_transform'):
|
|
|
+ num_batch = tf.shape(input_dim)[0]
|
|
|
+ height = tf.shape(input_dim)[1]
|
|
|
+ width = tf.shape(input_dim)[2]
|
|
|
+ num_channels = tf.shape(input_dim)[3]
|
|
|
+ theta = tf.reshape(theta, (-1, 2, 3))
|
|
|
+ theta = tf.cast(theta, 'float32')
|
|
|
+
|
|
|
+ # grid of (x_t, y_t, 1), eq (1) in ref [1]
|
|
|
+ height_f = tf.cast(height, 'float32')
|
|
|
+ width_f = tf.cast(width, 'float32')
|
|
|
+ out_height = tf.cast(height_f // downsample_factor, 'int32')
|
|
|
+ out_width = tf.cast(width_f // downsample_factor, 'int32')
|
|
|
+ grid = _meshgrid(out_height, out_width)
|
|
|
+ grid = tf.expand_dims(grid,0)
|
|
|
+ grid = tf.reshape(grid,[-1])
|
|
|
+ grid = tf.tile(grid,tf.pack([num_batch]))
|
|
|
+ grid = tf.reshape(grid,tf.pack([num_batch, 3, -1]))
|
|
|
+
|
|
|
+ # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
|
|
|
+ T_g = tf.batch_matmul(theta, grid)
|
|
|
+ x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
|
|
|
+ y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
|
|
|
+ x_s_flat = tf.reshape(x_s,[-1])
|
|
|
+ y_s_flat = tf.reshape(y_s,[-1])
|
|
|
+
|
|
|
+ input_transformed = _interpolate(
|
|
|
+ input_dim, x_s_flat, y_s_flat,
|
|
|
+ downsample_factor)
|
|
|
+
|
|
|
+ output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
|
|
|
+ return output
|
|
|
+
|
|
|
+ with tf.variable_scope(name):
|
|
|
+ output = _transform(theta, U, downsample_factor)
|
|
|
+ return output
|