Procházet zdrojové kódy

Spatial transformer: (#57)

* Modified the way the output size is specified.
* Added support for batches of inputs.
Timur před 9 roky
rodič
revize
bf60abf833
1 změnil soubory, kde provedl 36 přidání a 18 odebrání
  1. 36 18
      transformer/spatial_transformer.py

+ 36 - 18
transformer/spatial_transformer.py

@@ -14,7 +14,7 @@
 # ==============================================================================
 import tensorflow as tf
 
-def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
+def transformer(U, theta, out_size, name='SpatialTransformer', **kwargs):
     """Spatial Transformer Layer
     
     Implements a spatial transformer layer as described in [1]_.
@@ -28,14 +28,9 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
     theta: float   
         The output of the
         localisation network should be [num_batch, 6].
-    downsample_factor : float
-        A value of 1 will keep the original size of the image
-        Values larger than 1 will downsample the image. 
-        Values below 1 will upsample the image
-        example image: height = 100, width = 200
-        downsample_factor = 2
-        output image will then be 50, 100
-        
+    out_size: tuple of two floats
+        The size of the output of the network
+
     References
     ----------
     .. [1]  Spatial Transformer Networks
@@ -61,7 +56,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
             x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
             return tf.reshape(x,[-1])
 
-    def _interpolate(im, x, y, downsample_factor):
+    def _interpolate(im, x, y, out_size):
         with tf.variable_scope('_interpolate'):
             # constants
             num_batch = tf.shape(im)[0]
@@ -73,8 +68,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
             y = tf.cast(y, 'float32')
             height_f = tf.cast(height, 'float32')
             width_f = tf.cast(width, 'float32')
-            out_height = tf.cast(height_f // downsample_factor, 'int32')
-            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            out_height = out_size[0]
+            out_width = out_size[1] 
             zero = tf.zeros([], dtype='int32')
             max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
             max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
@@ -142,7 +137,7 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
             grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
             return grid
 
-    def _transform(theta, input_dim, downsample_factor):
+    def _transform(theta, input_dim, out_size):
         with tf.variable_scope('_transform'):
             num_batch = tf.shape(input_dim)[0]
             height = tf.shape(input_dim)[1]
@@ -154,8 +149,8 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
             # grid of (x_t, y_t, 1), eq (1) in ref [1]
             height_f = tf.cast(height, 'float32')
             width_f = tf.cast(width, 'float32')
-            out_height = tf.cast(height_f // downsample_factor, 'int32')
-            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            out_height = out_size[0]
+            out_width = out_size[1] 
             grid = _meshgrid(out_height, out_width)
             grid = tf.expand_dims(grid,0)
             grid = tf.reshape(grid,[-1])
@@ -171,11 +166,34 @@ def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwar
 
             input_transformed = _interpolate(
                   input_dim, x_s_flat, y_s_flat,
-                  downsample_factor)
+                  out_size)
 
             output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
             return output
     
     with tf.variable_scope(name):
-        output = _transform(theta, U, downsample_factor)
-        return output
+        output = _transform(theta, U, out_size)
+        return output
+
+def batch_transformer(U, thetas, out_size, name='BatchSpatialTransformer'):
+    """Batch Spatial Transformer Layer
+
+    Parameters
+    ----------
+    
+    U : float
+        tensor of inputs [num_batch,height,width,num_channels]
+    thetas : float
+        a set of transformations for each input [num_batch,num_transforms,6]
+    out_size : int
+        the size of the output [out_height,out_width]
+
+    Returns: float
+        Tensor of size [num_batch*num_transforms,out_height,out_width,num_channels]
+    """
+    with tf.variable_scope(name):
+        num_batch, num_transforms = map(int, thetas.get_shape().as_list()[:2])
+        indices = [[i]*num_transforms for i in xrange(num_batch)]
+        input_repeated = tf.gather(U, tf.reshape(indices, [-1]))
+        return transformer(input_repeated, thetas, out_size)
+