Explorar o código

Spatial Transformer model

Shorten STN summary in README

relinked to data files

adding license header, editing AUTHORS file

adding tensorflow version
David Dao %!s(int64=9) %!d(string=hai) anos
pai
achega
41c52d60fe

+ 1 - 0
AUTHORS

@@ -7,3 +7,4 @@
 # The email address is not required for organizations.
 
 Google Inc.
+David Dao <daviddao@broad.mit.edu>

+ 64 - 0
transformer/README.md

@@ -0,0 +1,64 @@
+# Spatial Transformer Network
+
+The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
+
+<div align="center">
+  <img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
+</div>
+
+### API 
+
+A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
+
+#### How to use
+
+<div align="center">
+  <img src="http://i.imgur.com/gfqLV3f.png"><br><br>
+</div>
+
+```python
+transformer(U, theta, downsample_factor=1)
+```
+    
+#### Parameters
+
+    U : float 
+        The output of a convolutional net should have the
+        shape [num_batch, height, width, num_channels]. 
+    theta: float   
+        The output of the
+        localisation network should be [num_batch, 6].
+    downsample_factor : float
+        A value of 1 will keep the original size of the image
+        Values larger than 1 will downsample the image. 
+        Values below 1 will upsample the image
+        example image: height = 100, width = 200
+        downsample_factor = 2
+        output image will then be 50, 100
+        
+    
+#### Notes
+To initialize the network to the identity transform init ``theta`` to :
+
+```python
+identity = np.array([[1., 0., 0.],
+                    [0., 1., 0.]]) 
+identity = identity.flatten()
+theta = tf.Variable(initial_value=identity)
+```        
+
+#### Experiments
+
+<div align="center">
+  <img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
+</div>
+
+We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
+
+All experiments were run in Tensorflow 0.7.
+
+### References
+
+[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
+
+[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py

+ 172 - 0
transformer/cluttered_mnist.py

@@ -0,0 +1,172 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+from spatial_transformer import transformer
+from scipy import ndimage
+import numpy as np
+import matplotlib.pyplot as plt
+from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
+
+# %% Load data
+mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
+
+X_train = mnist_cluttered['X_train']
+y_train = mnist_cluttered['y_train']
+X_valid = mnist_cluttered['X_valid']
+y_valid = mnist_cluttered['y_valid']
+X_test = mnist_cluttered['X_test']
+y_test = mnist_cluttered['y_test']
+
+# % turn from dense to one hot representation
+Y_train = dense_to_one_hot(y_train, n_classes=10)
+Y_valid = dense_to_one_hot(y_valid, n_classes=10)
+Y_test = dense_to_one_hot(y_test, n_classes=10)
+
+# %% Graph representation of our network
+
+# %% Placeholders for 40x40 resolution
+x = tf.placeholder(tf.float32, [None, 1600]) 
+y = tf.placeholder(tf.float32, [None, 10])
+
+# %% Since x is currently [batch, height*width], we need to reshape to a
+# 4-D tensor to use it in a convolutional graph.  If one component of
+# `shape` is the special value -1, the size of that dimension is
+# computed so that the total size remains constant.  Since we haven't
+# defined the batch dimension's shape yet, we use -1 to denote this
+# dimension should not change size.
+x_tensor = tf.reshape(x, [-1, 40, 40, 1])
+
+# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
+# %% Create variables for fully connected layer
+W_fc_loc1 = weight_variable([1600, 20])
+b_fc_loc1 = bias_variable([20])
+
+W_fc_loc2 = weight_variable([20, 6])
+initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
+initial = initial.astype('float32')
+initial = initial.flatten()
+b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
+
+# %% Define the two layer localisation network
+h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
+# %% We can add dropout for regularizing and to reduce overfitting like so:
+keep_prob = tf.placeholder(tf.float32)
+h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
+# %% Second layer
+h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
+
+# %% We'll create a spatial transformer module to identify discriminative patches
+h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
+
+# %% We'll setup the first convolutional layer
+# Weight matrix is [height x width x input_channels x output_channels]
+filter_size = 3
+n_filters_1 = 16
+W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
+
+# %% Bias is [output_channels]
+b_conv1 = bias_variable([n_filters_1])
+
+# %% Now we can build a graph which does the first layer of convolution:
+# we define our stride as batch x height x width x channels
+# instead of pooling, we use strides of 2 and more layers
+# with smaller filters.
+
+h_conv1 = tf.nn.relu(
+    tf.nn.conv2d(input=h_trans,
+                 filter=W_conv1,
+                 strides=[1, 2, 2, 1],
+                 padding='SAME') +
+    b_conv1)
+
+# %% And just like the first layer, add additional layers to create
+# a deep net
+n_filters_2 = 16
+W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
+b_conv2 = bias_variable([n_filters_2])
+h_conv2 = tf.nn.relu(
+    tf.nn.conv2d(input=h_conv1,
+                 filter=W_conv2,
+                 strides=[1, 2, 2, 1],
+                 padding='SAME') +
+    b_conv2)
+
+# %% We'll now reshape so we can connect to a fully-connected layer:
+h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
+
+# %% Create a fully-connected layer:
+n_fc = 1024
+W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
+b_fc1 = bias_variable([n_fc])
+h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
+
+h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
+
+# %% And finally our softmax layer:
+W_fc2 = weight_variable([n_fc, 10])
+b_fc2 = bias_variable([10])
+y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
+
+# %% Define loss/eval/training functions
+cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
+opt = tf.train.AdamOptimizer()
+optimizer = opt.minimize(cross_entropy)
+grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
+
+# %% Monitor accuracy
+correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
+accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
+
+# %% We now create a new session to actually perform the initialization the
+# variables:
+sess = tf.Session()
+sess.run(tf.initialize_all_variables())
+
+
+# %% We'll now train in minibatches and report accuracy, loss:
+iter_per_epoch = 100
+n_epochs = 500
+train_size = 10000
+
+indices = np.linspace(0,10000 - 1,iter_per_epoch)
+indices = indices.astype('int')
+
+for epoch_i in range(n_epochs):
+    for iter_i in range(iter_per_epoch - 1):
+    	batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
+        batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
+
+        if iter_i % 10 == 0:
+            loss = sess.run(cross_entropy,
+                   feed_dict={
+                       x: batch_xs,
+                       y: batch_ys,
+                       keep_prob: 1.0
+                   })
+            print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
+
+        sess.run(optimizer, feed_dict={
+            x: batch_xs, y: batch_ys, keep_prob: 0.8})
+        
+        
+    print('Accuracy: ' + str(sess.run(accuracy,
+                   feed_dict={
+                       x: X_valid,
+                       y: Y_valid,
+                       keep_prob: 1.0
+                   })))
+    #theta = sess.run(h_fc_loc2, feed_dict={
+    #        x: batch_xs, keep_prob: 1.0})
+    #print(theta[0])

+ 20 - 0
transformer/data/README.md

@@ -0,0 +1,20 @@
+### How to get the data
+
+#### Cluttered MNIST
+
+The cluttered MNIST dataset can be found here [1] or can be generated via [2].
+
+Settings used for `cluttered_mnist.py` :
+
+```python
+
+ORG_SHP = [28, 28]
+OUT_SHP = [40, 40]
+NUM_DISTORTIONS = 8
+dist_size = (5, 5) 
+
+```
+
+[1] https://github.com/daviddao/spatial-transformer-tensorflow
+
+[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py

+ 58 - 0
transformer/example.py

@@ -0,0 +1,58 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+from spatial_transformer import transformer
+from scipy import ndimage
+import numpy as np
+import matplotlib.pyplot as plt
+from tf_utils import conv2d, linear, weight_variable, bias_variable
+
+# %% Create a batch of three images (1600 x 1200)
+# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
+im = ndimage.imread('cat.jpg')
+im = im / 255.
+im = im.reshape(1, 1200, 1600, 3)
+im = im.astype('float32')
+
+# %% Simulate batch
+batch = np.append(im, im, axis=0)
+batch = np.append(batch, im, axis=0)
+num_batch = 3
+
+x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
+x = tf.cast(batch,'float32')
+
+# %% Create localisation network and convolutional layer
+with tf.variable_scope('spatial_transformer_0'):
+
+    # %% Create a fully-connected layer with 6 output nodes
+    n_fc = 6 
+    W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
+
+    # %% Zoom into the image
+    initial = np.array([[0.5,0, 0],[0,0.5,0]]) 
+    initial = initial.astype('float32')
+    initial = initial.flatten()
+
+    b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
+    h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
+    h_trans = transformer(x, h_fc1, downsample_factor=2)
+
+# %% Run session
+sess = tf.Session()
+sess.run(tf.initialize_all_variables())
+y = sess.run(h_trans, feed_dict={x: batch})
+
+# plt.imshow(y[0])

+ 181 - 0
transformer/spatial_transformer.py

@@ -0,0 +1,181 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+
+def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
+    """Spatial Transformer Layer
+    
+    Implements a spatial transformer layer as described in [1]_.
+    Based on [2]_ and edited by David Dao for Tensorflow.
+    
+    Parameters
+    ----------
+    U : float 
+        The output of a convolutional net should have the
+        shape [num_batch, height, width, num_channels]. 
+    theta: float   
+        The output of the
+        localisation network should be [num_batch, 6].
+    downsample_factor : float
+        A value of 1 will keep the original size of the image
+        Values larger than 1 will downsample the image. 
+        Values below 1 will upsample the image
+        example image: height = 100, width = 200
+        downsample_factor = 2
+        output image will then be 50, 100
+        
+    References
+    ----------
+    .. [1]  Spatial Transformer Networks
+            Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
+            Submitted on 5 Jun 2015
+    .. [2]  https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
+            
+    Notes
+    -----
+    To initialize the network to the identity transform init
+    ``theta`` to :
+        identity = np.array([[1., 0., 0.],
+                             [0., 1., 0.]]) 
+        identity = identity.flatten()
+        theta = tf.Variable(initial_value=identity)
+        
+    """
+    
+    def _repeat(x, n_repeats):
+        with tf.variable_scope('_repeat'):
+            rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
+            rep = tf.cast(rep, 'int32')
+            x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
+            return tf.reshape(x,[-1])
+
+    def _interpolate(im, x, y, downsample_factor):
+        with tf.variable_scope('_interpolate'):
+            # constants
+            num_batch = tf.shape(im)[0]
+            height = tf.shape(im)[1]
+            width = tf.shape(im)[2]
+            channels = tf.shape(im)[3]
+
+            x = tf.cast(x, 'float32')
+            y = tf.cast(y, 'float32')
+            height_f = tf.cast(height, 'float32')
+            width_f = tf.cast(width, 'float32')
+            out_height = tf.cast(height_f // downsample_factor, 'int32')
+            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            zero = tf.zeros([], dtype='int32')
+            max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
+            max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
+
+            # scale indices from [-1, 1] to [0, width/height]
+            x = (x + 1.0)*(width_f) / 2.0 
+            y = (y + 1.0)*(height_f) / 2.0
+
+            # do sampling
+            x0 = tf.cast(tf.floor(x), 'int32')
+            x1 = x0 + 1
+            y0 = tf.cast(tf.floor(y), 'int32')
+            y1 = y0 + 1
+
+            x0 = tf.clip_by_value(x0, zero, max_x)
+            x1 = tf.clip_by_value(x1, zero, max_x)
+            y0 = tf.clip_by_value(y0, zero, max_y)
+            y1 = tf.clip_by_value(y1, zero, max_y)
+            dim2 = width
+            dim1 = width*height
+            base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
+            base_y0 = base + y0*dim2
+            base_y1 = base + y1*dim2
+            idx_a = base_y0 + x0
+            idx_b = base_y1 + x0
+            idx_c = base_y0 + x1
+            idx_d = base_y1 + x1
+
+            # use indices to lookup pixels in the flat image and restore channels dim
+            im_flat = tf.reshape(im,tf.pack([-1, channels]))
+            im_flat = tf.cast(im_flat, 'float32')
+            Ia = tf.gather(im_flat, idx_a)
+            Ib = tf.gather(im_flat, idx_b)
+            Ic = tf.gather(im_flat, idx_c)
+            Id = tf.gather(im_flat, idx_d)
+
+            # and finally calculate interpolated values
+            x0_f = tf.cast(x0, 'float32')
+            x1_f = tf.cast(x1, 'float32')
+            y0_f = tf.cast(y0, 'float32')
+            y1_f = tf.cast(y1, 'float32')
+            wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
+            wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
+            wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
+            wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
+            output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
+            return output
+    
+    def _meshgrid(height, width):
+        with tf.variable_scope('_meshgrid'):
+            # This should be equivalent to:
+            #  x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
+            #                         np.linspace(-1, 1, height))
+            #  ones = np.ones(np.prod(x_t.shape))
+            #  grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
+            x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
+                        tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0])) 
+            y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
+                        tf.ones(shape=tf.pack([1, width]))) 
+
+            x_t_flat = tf.reshape(x_t,(1, -1))
+            y_t_flat = tf.reshape(y_t,(1, -1))
+
+            ones = tf.ones_like(x_t_flat)
+            grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
+            return grid
+
+    def _transform(theta, input_dim, downsample_factor):
+        with tf.variable_scope('_transform'):
+            num_batch = tf.shape(input_dim)[0]
+            height = tf.shape(input_dim)[1]
+            width = tf.shape(input_dim)[2]            
+            num_channels = tf.shape(input_dim)[3]
+            theta = tf.reshape(theta, (-1, 2, 3))
+            theta = tf.cast(theta, 'float32')
+
+            # grid of (x_t, y_t, 1), eq (1) in ref [1]
+            height_f = tf.cast(height, 'float32')
+            width_f = tf.cast(width, 'float32')
+            out_height = tf.cast(height_f // downsample_factor, 'int32')
+            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            grid = _meshgrid(out_height, out_width)
+            grid = tf.expand_dims(grid,0)
+            grid = tf.reshape(grid,[-1])
+            grid = tf.tile(grid,tf.pack([num_batch]))
+            grid = tf.reshape(grid,tf.pack([num_batch, 3, -1])) 
+            
+            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
+            T_g = tf.batch_matmul(theta, grid)
+            x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
+            y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
+            x_s_flat = tf.reshape(x_s,[-1])
+            y_s_flat = tf.reshape(y_s,[-1])
+
+            input_transformed = _interpolate(
+                  input_dim, x_s_flat, y_s_flat,
+                  downsample_factor)
+
+            output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
+            return output
+    
+    with tf.variable_scope(name):
+        output = _transform(theta, U, downsample_factor)
+        return output

+ 129 - 0
transformer/tf_utils.py

@@ -0,0 +1,129 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
+import tensorflow as tf
+import numpy as np
+
+def conv2d(x, n_filters,
+           k_h=5, k_w=5,
+           stride_h=2, stride_w=2,
+           stddev=0.02,
+           activation=lambda x: x,
+           bias=True,
+           padding='SAME',
+           name="Conv2D"):
+    """2D Convolution with options for kernel size, stride, and init deviation.
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to convolve.
+    n_filters : int
+        Number of filters to apply.
+    k_h : int, optional
+        Kernel height.
+    k_w : int, optional
+        Kernel width.
+    stride_h : int, optional
+        Stride in rows.
+    stride_w : int, optional
+        Stride in cols.
+    stddev : float, optional
+        Initialization's standard deviation.
+    activation : arguments, optional
+        Function which applies a nonlinearity
+    padding : str, optional
+        'SAME' or 'VALID'
+    name : str, optional
+        Variable scope to use.
+    Returns
+    -------
+    x : Tensor
+        Convolved input.
+    """
+    with tf.variable_scope(name):
+        w = tf.get_variable(
+            'w', [k_h, k_w, x.get_shape()[-1], n_filters],
+            initializer=tf.truncated_normal_initializer(stddev=stddev))
+        conv = tf.nn.conv2d(
+            x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
+        if bias:
+            b = tf.get_variable(
+                'b', [n_filters],
+                initializer=tf.truncated_normal_initializer(stddev=stddev))
+            conv = conv + b
+        return conv
+    
+def linear(x, n_units, scope=None, stddev=0.02,
+           activation=lambda x: x):
+    """Fully-connected network.
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to the network.
+    n_units : int
+        Number of units to connect to.
+    scope : str, optional
+        Variable scope to use.
+    stddev : float, optional
+        Initialization's standard deviation.
+    activation : arguments, optional
+        Function which applies a nonlinearity
+    Returns
+    -------
+    x : Tensor
+        Fully-connected output.
+    """
+    shape = x.get_shape().as_list()
+
+    with tf.variable_scope(scope or "Linear"):
+        matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
+                                 tf.random_normal_initializer(stddev=stddev))
+        return activation(tf.matmul(x, matrix))
+    
+# %%
+def weight_variable(shape):
+    '''Helper function to create a weight variable initialized with
+    a normal distribution
+    Parameters
+    ----------
+    shape : list
+        Size of weight variable
+    '''
+    #initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
+    initial = tf.zeros(shape)
+    return tf.Variable(initial)
+
+# %%
+def bias_variable(shape):
+    '''Helper function to create a bias variable initialized with
+    a constant value.
+    Parameters
+    ----------
+    shape : list
+        Size of weight variable
+    '''
+    initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
+    return tf.Variable(initial)
+
+# %% 
+def dense_to_one_hot(labels, n_classes=2):
+    """Convert class labels from scalars to one-hot vectors."""
+    labels = np.array(labels)
+    n_labels = labels.shape[0]
+    index_offset = np.arange(n_labels) * n_classes
+    labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
+    labels_one_hot.flat[index_offset + labels.ravel()] = 1
+    return labels_one_hot