|
|
@@ -14,7 +14,7 @@
|
|
|
# ==============================================================================
|
|
|
import tensorflow as tf
|
|
|
|
|
|
-def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
|
|
|
+def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
|
|
|
"""Spatial Transformer Layer
|
|
|
|
|
|
Implements a spatial transformer layer as described in [1]_.
|
|
|
@@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
theta: float
|
|
|
The output of the
|
|
|
localisation network should be [num_batch, 6].
|
|
|
- downsample_factor : float
|
|
|
- A value of 1 will keep the original size of the image
|
|
|
- Values larger than 1 will downsample the image.
|
|
|
- Values below 1 will upsample the image
|
|
|
- example image: height = 100, width = 200
|
|
|
- downsample_factor = 2
|
|
|
- output image will then be 50, 100
|
|
|
-
|
|
|
+ out_size: tuple of two floats
|
|
|
+ The size of the output of the network
|
|
|
+
|
|
|
References
|
|
|
----------
|
|
|
.. [1] Spatial Transformer Networks
|
|
|
@@ -61,7 +56,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
|
|
|
return tf.reshape(x,[-1])
|
|
|
|
|
|
- def _interpolate(im, x, y, downsample_factor):
|
|
|
+ def _interpolate(im, x, y, out_size):
|
|
|
with tf.variable_scope('_interpolate'):
|
|
|
# constants
|
|
|
num_batch = tf.shape(im)[0]
|
|
|
@@ -73,8 +68,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
y = tf.cast(y, 'float32')
|
|
|
height_f = tf.cast(height, 'float32')
|
|
|
width_f = tf.cast(width, 'float32')
|
|
|
- out_height = tf.cast(height_f // downsample_factor, 'int32')
|
|
|
- out_width = tf.cast(width_f // downsample_factor, 'int32')
|
|
|
+ out_height = out_size[0]
|
|
|
+ out_width = out_size[1]
|
|
|
zero = tf.zeros([], dtype='int32')
|
|
|
max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
|
|
|
max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
|
|
|
@@ -142,7 +137,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
|
|
|
return grid
|
|
|
|
|
|
- def _transform(theta, input_dim, downsample_factor):
|
|
|
+ def _transform(theta, input_dim, out_size):
|
|
|
with tf.variable_scope('_transform'):
|
|
|
num_batch = tf.shape(input_dim)[0]
|
|
|
height = tf.shape(input_dim)[1]
|
|
|
@@ -154,8 +149,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
# grid of (x_t, y_t, 1), eq (1) in ref [1]
|
|
|
height_f = tf.cast(height, 'float32')
|
|
|
width_f = tf.cast(width, 'float32')
|
|
|
- out_height = tf.cast(height_f // downsample_factor, 'int32')
|
|
|
- out_width = tf.cast(width_f // downsample_factor, 'int32')
|
|
|
+ out_height = out_size[0]
|
|
|
+ out_width = out_size[1]
|
|
|
grid = _meshgrid(out_height, out_width)
|
|
|
grid = tf.expand_dims(grid,0)
|
|
|
grid = tf.reshape(grid,[-1])
|
|
|
@@ -171,11 +166,34 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
|
|
|
|
|
|
input_transformed = _interpolate(
|
|
|
input_dim, x_s_flat, y_s_flat,
|
|
|
- downsample_factor)
|
|
|
+ out_size)
|
|
|
|
|
|
output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
|
|
|
return output
|
|
|
|
|
|
with tf.variable_scope(name):
|
|
|
- output = _transform(theta, U, downsample_factor)
|
|
|
- return output
|
|
|
+ output = _transform(theta, U, out_size)
|
|
|
+ return output
|
|
|
+
|
|
|
+def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
|
|
|
+ """Batch Spatial Transformer Layer
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+
|
|
|
+ U : float
|
|
|
+ tensor of inputs [num_batch,height,width,num_channels]
|
|
|
+ thetas : float
|
|
|
+ a set of transformations for each input [num_batch,num_transforms,6]
|
|
|
+ out_size : int
|
|
|
+ the size of the output [out_height,out_width]
|
|
|
+
|
|
|
+ Returns: float
|
|
|
+ Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
|
|
|
+ """
|
|
|
+ with tf.variable_scope(name):
|
|
|
+ num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
|
|
|
+ indices = [[i]*num_transforms for i in xrange(num_batch)]
|
|
|
+ input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
|
|
|
+ return transformer(input_repeated, thetas, out_size)
|
|
|
+
|