Преглед на файлове

Spatial Transformer model

Shorten STN summary in README

relinked to data files

adding license header, editing AUTHORS file

adding tensorflow version
David Dao преди 10 години
родител
ревизия
41c52d60fe
променени са 7 файла, в които са добавени 625 реда и са изтрити 0 реда
  1. 1 0
      AUTHORS
  2. 64 0
      transformer/README.md
  3. 172 0
      transformer/cluttered_mnist.py
  4. 20 0
      transformer/data/README.md
  5. 58 0
      transformer/example.py
  6. 181 0
      transformer/spatial_transformer.py
  7. 129 0
      transformer/tf_utils.py

+ 1 - 0
AUTHORS

@@ -7,3 +7,4 @@
 # The email address is not required for organizations.
 
 Google Inc.
+David Dao <daviddao@broad.mit.edu>

+ 64 - 0
transformer/README.md

@@ -0,0 +1,64 @@
+# Spatial Transformer Network
+
+The Spatial Transformer Network [1] allows the spatial manipulation of data within the network.
+
+<div align="center">
+  <img width="600px" src="http://i.imgur.com/ExGDVul.png"><br><br>
+</div>
+
+### API 
+
+A Spatial Transformer Network implemented in Tensorflow 0.7 and based on [2].
+
+#### How to use
+
+<div align="center">
+  <img src="http://i.imgur.com/gfqLV3f.png"><br><br>
+</div>
+
+```python
+transformer(U, theta, downsample_factor=1)
+```
+    
+#### Parameters
+
+    U : float 
+        The output of a convolutional net should have the
+        shape [num_batch, height, width, num_channels]. 
+    theta: float   
+        The output of the
+        localisation network should be [num_batch, 6].
+    downsample_factor : float
+        A value of 1 will keep the original size of the image
+        Values larger than 1 will downsample the image. 
+        Values below 1 will upsample the image
+        example image: height = 100, width = 200
+        downsample_factor = 2
+        output image will then be 50, 100
+        
+    
+#### Notes
+To initialize the network to the identity transform init ``theta`` to :
+
+```python
+identity = np.array([[1., 0., 0.],
+                    [0., 1., 0.]]) 
+identity = identity.flatten()
+theta = tf.Variable(initial_value=identity)
+```        
+
+#### Experiments
+
+<div align="center">
+  <img width="600px" src="http://i.imgur.com/HtCBYk2.png"><br><br>
+</div>
+
+We used cluttered MNIST. Left column are the input images, right are the attended parts of the image by an STN.
+
+All experiments were run in Tensorflow 0.7.
+
+### References
+
+[1] Jaderberg, Max, et al. "Spatial Transformer Networks." arXiv preprint arXiv:1506.02025 (2015)
+
+[2] https://github.com/skaae/transformer_network/blob/master/transformerlayer.py

+ 172 - 0
transformer/cluttered_mnist.py

@@ -0,0 +1,172 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+from spatial_transformer import transformer
+from scipy import ndimage
+import numpy as np
+import matplotlib.pyplot as plt
+from tf_utils import conv2d, linear, weight_variable, bias_variable, dense_to_one_hot
+
+# %% Load data
+mnist_cluttered = np.load('./data/mnist_sequence1_sample_5distortions5x5.npz')
+
+X_train = mnist_cluttered['X_train']
+y_train = mnist_cluttered['y_train']
+X_valid = mnist_cluttered['X_valid']
+y_valid = mnist_cluttered['y_valid']
+X_test = mnist_cluttered['X_test']
+y_test = mnist_cluttered['y_test']
+
+# % turn from dense to one hot representation
+Y_train = dense_to_one_hot(y_train, n_classes=10)
+Y_valid = dense_to_one_hot(y_valid, n_classes=10)
+Y_test = dense_to_one_hot(y_test, n_classes=10)
+
+# %% Graph representation of our network
+
+# %% Placeholders for 40x40 resolution
+x = tf.placeholder(tf.float32, [None, 1600]) 
+y = tf.placeholder(tf.float32, [None, 10])
+
+# %% Since x is currently [batch, height*width], we need to reshape to a
+# 4-D tensor to use it in a convolutional graph.  If one component of
+# `shape` is the special value -1, the size of that dimension is
+# computed so that the total size remains constant.  Since we haven't
+# defined the batch dimension's shape yet, we use -1 to denote this
+# dimension should not change size.
+x_tensor = tf.reshape(x, [-1, 40, 40, 1])
+
+# %% We'll setup the two-layer localisation network to figure out the parameters for an affine transformation of the input
+# %% Create variables for fully connected layer
+W_fc_loc1 = weight_variable([1600, 20])
+b_fc_loc1 = bias_variable([20])
+
+W_fc_loc2 = weight_variable([20, 6])
+initial = np.array([[1.,0, 0],[0,1.,0]]) # Use identity transformation as starting point
+initial = initial.astype('float32')
+initial = initial.flatten()
+b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')
+
+# %% Define the two layer localisation network
+h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
+# %% We can add dropout for regularizing and to reduce overfitting like so:
+keep_prob = tf.placeholder(tf.float32)
+h_fc_loc1_drop = tf.nn.dropout(h_fc_loc1, keep_prob)
+# %% Second layer
+h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)
+
+# %% We'll create a spatial transformer module to identify discriminative patches
+h_trans = transformer(x_tensor, h_fc_loc2, downsample_factor=1)
+
+# %% We'll setup the first convolutional layer
+# Weight matrix is [height x width x input_channels x output_channels]
+filter_size = 3
+n_filters_1 = 16
+W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])
+
+# %% Bias is [output_channels]
+b_conv1 = bias_variable([n_filters_1])
+
+# %% Now we can build a graph which does the first layer of convolution:
+# we define our stride as batch x height x width x channels
+# instead of pooling, we use strides of 2 and more layers
+# with smaller filters.
+
+h_conv1 = tf.nn.relu(
+    tf.nn.conv2d(input=h_trans,
+                 filter=W_conv1,
+                 strides=[1, 2, 2, 1],
+                 padding='SAME') +
+    b_conv1)
+
+# %% And just like the first layer, add additional layers to create
+# a deep net
+n_filters_2 = 16
+W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
+b_conv2 = bias_variable([n_filters_2])
+h_conv2 = tf.nn.relu(
+    tf.nn.conv2d(input=h_conv1,
+                 filter=W_conv2,
+                 strides=[1, 2, 2, 1],
+                 padding='SAME') +
+    b_conv2)
+
+# %% We'll now reshape so we can connect to a fully-connected layer:
+h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])
+
+# %% Create a fully-connected layer:
+n_fc = 1024
+W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
+b_fc1 = bias_variable([n_fc])
+h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)
+
+h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
+
+# %% And finally our softmax layer:
+W_fc2 = weight_variable([n_fc, 10])
+b_fc2 = bias_variable([10])
+y_pred = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
+
+# %% Define loss/eval/training functions
+cross_entropy = -tf.reduce_sum(y * tf.log(y_pred))
+opt = tf.train.AdamOptimizer()
+optimizer = opt.minimize(cross_entropy)
+grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])
+
+# %% Monitor accuracy
+correct_prediction = tf.equal(tf.argmax(y_pred, 1), tf.argmax(y, 1))
+accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))
+
+# %% We now create a new session to actually perform the initialization the
+# variables:
+sess = tf.Session()
+sess.run(tf.initialize_all_variables())
+
+
+# %% We'll now train in minibatches and report accuracy, loss:
+iter_per_epoch = 100
+n_epochs = 500
+train_size = 10000
+
+indices = np.linspace(0,10000 - 1,iter_per_epoch)
+indices = indices.astype('int')
+
+for epoch_i in range(n_epochs):
+    for iter_i in range(iter_per_epoch - 1):
+    	batch_xs = X_train[indices[iter_i]:indices[iter_i+1]]
+        batch_ys = Y_train[indices[iter_i]:indices[iter_i+1]]
+
+        if iter_i % 10 == 0:
+            loss = sess.run(cross_entropy,
+                   feed_dict={
+                       x: batch_xs,
+                       y: batch_ys,
+                       keep_prob: 1.0
+                   })
+            print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))
+
+        sess.run(optimizer, feed_dict={
+            x: batch_xs, y: batch_ys, keep_prob: 0.8})
+        
+        
+    print('Accuracy: ' + str(sess.run(accuracy,
+                   feed_dict={
+                       x: X_valid,
+                       y: Y_valid,
+                       keep_prob: 1.0
+                   })))
+    #theta = sess.run(h_fc_loc2, feed_dict={
+    #        x: batch_xs, keep_prob: 1.0})
+    #print(theta[0])

+ 20 - 0
transformer/data/README.md

@@ -0,0 +1,20 @@
+### How to get the data
+
+#### Cluttered MNIST
+
+The cluttered MNIST dataset can be found here [1] or can be generated via [2].
+
+Settings used for `cluttered_mnist.py` :
+
+```python
+
+ORG_SHP = [28, 28]
+OUT_SHP = [40, 40]
+NUM_DISTORTIONS = 8
+dist_size = (5, 5) 
+
+```
+
+[1] https://github.com/daviddao/spatial-transformer-tensorflow
+
+[2] https://github.com/skaae/recurrent-spatial-transformer-code/blob/master/MNIST_SEQUENCE/create_mnist_sequence.py

+ 58 - 0
transformer/example.py

@@ -0,0 +1,58 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+from spatial_transformer import transformer
+from scipy import ndimage
+import numpy as np
+import matplotlib.pyplot as plt
+from tf_utils import conv2d, linear, weight_variable, bias_variable
+
+# %% Create a batch of three images (1600 x 1200)
+# %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
+im = ndimage.imread('cat.jpg')
+im = im / 255.
+im = im.reshape(1, 1200, 1600, 3)
+im = im.astype('float32')
+
+# %% Simulate batch
+batch = np.append(im, im, axis=0)
+batch = np.append(batch, im, axis=0)
+num_batch = 3
+
+x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
+x = tf.cast(batch,'float32')
+
+# %% Create localisation network and convolutional layer
+with tf.variable_scope('spatial_transformer_0'):
+
+    # %% Create a fully-connected layer with 6 output nodes
+    n_fc = 6 
+    W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
+
+    # %% Zoom into the image
+    initial = np.array([[0.5,0, 0],[0,0.5,0]]) 
+    initial = initial.astype('float32')
+    initial = initial.flatten()
+
+    b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
+    h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
+    h_trans = transformer(x, h_fc1, downsample_factor=2)
+
+# %% Run session
+sess = tf.Session()
+sess.run(tf.initialize_all_variables())
+y = sess.run(h_trans, feed_dict={x: batch})
+
+# plt.imshow(y[0])

+ 181 - 0
transformer/spatial_transformer.py

@@ -0,0 +1,181 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+import tensorflow as tf
+
+def transformer(U, theta, downsample_factor=1, name='SpatialTransformer', **kwargs):
+    """Spatial Transformer Layer
+    
+    Implements a spatial transformer layer as described in [1]_.
+    Based on [2]_ and edited by David Dao for Tensorflow.
+    
+    Parameters
+    ----------
+    U : float 
+        The output of a convolutional net should have the
+        shape [num_batch, height, width, num_channels]. 
+    theta: float   
+        The output of the
+        localisation network should be [num_batch, 6].
+    downsample_factor : float
+        A value of 1 will keep the original size of the image
+        Values larger than 1 will downsample the image. 
+        Values below 1 will upsample the image
+        example image: height = 100, width = 200
+        downsample_factor = 2
+        output image will then be 50, 100
+        
+    References
+    ----------
+    .. [1]  Spatial Transformer Networks
+            Max Jaderberg, Karen Simonyan, Andrew Zisserman, Koray Kavukcuoglu
+            Submitted on 5 Jun 2015
+    .. [2]  https://github.com/skaae/transformer_network/blob/master/transformerlayer.py
+            
+    Notes
+    -----
+    To initialize the network to the identity transform init
+    ``theta`` to :
+        identity = np.array([[1., 0., 0.],
+                             [0., 1., 0.]]) 
+        identity = identity.flatten()
+        theta = tf.Variable(initial_value=identity)
+        
+    """
+    
+    def _repeat(x, n_repeats):
+        with tf.variable_scope('_repeat'):
+            rep = tf.transpose(tf.expand_dims(tf.ones(shape=tf.pack([n_repeats,])),1),[1,0])
+            rep = tf.cast(rep, 'int32')
+            x = tf.matmul(tf.reshape(x,(-1, 1)), rep)
+            return tf.reshape(x,[-1])
+
+    def _interpolate(im, x, y, downsample_factor):
+        with tf.variable_scope('_interpolate'):
+            # constants
+            num_batch = tf.shape(im)[0]
+            height = tf.shape(im)[1]
+            width = tf.shape(im)[2]
+            channels = tf.shape(im)[3]
+
+            x = tf.cast(x, 'float32')
+            y = tf.cast(y, 'float32')
+            height_f = tf.cast(height, 'float32')
+            width_f = tf.cast(width, 'float32')
+            out_height = tf.cast(height_f // downsample_factor, 'int32')
+            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            zero = tf.zeros([], dtype='int32')
+            max_y = tf.cast(tf.shape(im)[1] - 1, 'int32')
+            max_x = tf.cast(tf.shape(im)[2] - 1, 'int32')
+
+            # scale indices from [-1, 1] to [0, width/height]
+            x = (x + 1.0)*(width_f) / 2.0 
+            y = (y + 1.0)*(height_f) / 2.0
+
+            # do sampling
+            x0 = tf.cast(tf.floor(x), 'int32')
+            x1 = x0 + 1
+            y0 = tf.cast(tf.floor(y), 'int32')
+            y1 = y0 + 1
+
+            x0 = tf.clip_by_value(x0, zero, max_x)
+            x1 = tf.clip_by_value(x1, zero, max_x)
+            y0 = tf.clip_by_value(y0, zero, max_y)
+            y1 = tf.clip_by_value(y1, zero, max_y)
+            dim2 = width
+            dim1 = width*height
+            base = _repeat(tf.range(num_batch)*dim1, out_height*out_width)
+            base_y0 = base + y0*dim2
+            base_y1 = base + y1*dim2
+            idx_a = base_y0 + x0
+            idx_b = base_y1 + x0
+            idx_c = base_y0 + x1
+            idx_d = base_y1 + x1
+
+            # use indices to lookup pixels in the flat image and restore channels dim
+            im_flat = tf.reshape(im,tf.pack([-1, channels]))
+            im_flat = tf.cast(im_flat, 'float32')
+            Ia = tf.gather(im_flat, idx_a)
+            Ib = tf.gather(im_flat, idx_b)
+            Ic = tf.gather(im_flat, idx_c)
+            Id = tf.gather(im_flat, idx_d)
+
+            # and finally calculate interpolated values
+            x0_f = tf.cast(x0, 'float32')
+            x1_f = tf.cast(x1, 'float32')
+            y0_f = tf.cast(y0, 'float32')
+            y1_f = tf.cast(y1, 'float32')
+            wa = tf.expand_dims(((x1_f-x) * (y1_f-y)),1)
+            wb = tf.expand_dims(((x1_f-x) * (y-y0_f)),1)
+            wc = tf.expand_dims(((x-x0_f) * (y1_f-y)),1)
+            wd = tf.expand_dims(((x-x0_f) * (y-y0_f)),1)
+            output = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
+            return output
+    
+    def _meshgrid(height, width):
+        with tf.variable_scope('_meshgrid'):
+            # This should be equivalent to:
+            #  x_t, y_t = np.meshgrid(np.linspace(-1, 1, width),
+            #                         np.linspace(-1, 1, height))
+            #  ones = np.ones(np.prod(x_t.shape))
+            #  grid = np.vstack([x_t.flatten(), y_t.flatten(), ones])
+            x_t = tf.matmul(tf.ones(shape=tf.pack([height, 1])),
+                        tf.transpose(tf.expand_dims(tf.linspace(-1.0, 1.0, width),1),[1,0])) 
+            y_t = tf.matmul(tf.expand_dims(tf.linspace(-1.0, 1.0, height),1),
+                        tf.ones(shape=tf.pack([1, width]))) 
+
+            x_t_flat = tf.reshape(x_t,(1, -1))
+            y_t_flat = tf.reshape(y_t,(1, -1))
+
+            ones = tf.ones_like(x_t_flat)
+            grid = tf.concat(0, [x_t_flat, y_t_flat, ones])
+            return grid
+
+    def _transform(theta, input_dim, downsample_factor):
+        with tf.variable_scope('_transform'):
+            num_batch = tf.shape(input_dim)[0]
+            height = tf.shape(input_dim)[1]
+            width = tf.shape(input_dim)[2]            
+            num_channels = tf.shape(input_dim)[3]
+            theta = tf.reshape(theta, (-1, 2, 3))
+            theta = tf.cast(theta, 'float32')
+
+            # grid of (x_t, y_t, 1), eq (1) in ref [1]
+            height_f = tf.cast(height, 'float32')
+            width_f = tf.cast(width, 'float32')
+            out_height = tf.cast(height_f // downsample_factor, 'int32')
+            out_width = tf.cast(width_f // downsample_factor, 'int32')
+            grid = _meshgrid(out_height, out_width)
+            grid = tf.expand_dims(grid,0)
+            grid = tf.reshape(grid,[-1])
+            grid = tf.tile(grid,tf.pack([num_batch]))
+            grid = tf.reshape(grid,tf.pack([num_batch, 3, -1])) 
+            
+            # Transform A x (x_t, y_t, 1)^T -> (x_s, y_s)
+            T_g = tf.batch_matmul(theta, grid)
+            x_s = tf.slice(T_g, [0,0,0], [-1,1,-1])
+            y_s = tf.slice(T_g, [0,1,0], [-1,1,-1])
+            x_s_flat = tf.reshape(x_s,[-1])
+            y_s_flat = tf.reshape(y_s,[-1])
+
+            input_transformed = _interpolate(
+                  input_dim, x_s_flat, y_s_flat,
+                  downsample_factor)
+
+            output = tf.reshape(input_transformed, tf.pack([num_batch, out_height, out_width, num_channels]))
+            return output
+    
+    with tf.variable_scope(name):
+        output = _transform(theta, U, downsample_factor)
+        return output

+ 129 - 0
transformer/tf_utils.py

@@ -0,0 +1,129 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+# %% Borrowed utils from here: https://github.com/pkmital/tensorflow_tutorials/
+import tensorflow as tf
+import numpy as np
+
+def conv2d(x, n_filters,
+           k_h=5, k_w=5,
+           stride_h=2, stride_w=2,
+           stddev=0.02,
+           activation=lambda x: x,
+           bias=True,
+           padding='SAME',
+           name="Conv2D"):
+    """2D Convolution with options for kernel size, stride, and init deviation.
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to convolve.
+    n_filters : int
+        Number of filters to apply.
+    k_h : int, optional
+        Kernel height.
+    k_w : int, optional
+        Kernel width.
+    stride_h : int, optional
+        Stride in rows.
+    stride_w : int, optional
+        Stride in cols.
+    stddev : float, optional
+        Initialization's standard deviation.
+    activation : arguments, optional
+        Function which applies a nonlinearity
+    padding : str, optional
+        'SAME' or 'VALID'
+    name : str, optional
+        Variable scope to use.
+    Returns
+    -------
+    x : Tensor
+        Convolved input.
+    """
+    with tf.variable_scope(name):
+        w = tf.get_variable(
+            'w', [k_h, k_w, x.get_shape()[-1], n_filters],
+            initializer=tf.truncated_normal_initializer(stddev=stddev))
+        conv = tf.nn.conv2d(
+            x, w, strides=[1, stride_h, stride_w, 1], padding=padding)
+        if bias:
+            b = tf.get_variable(
+                'b', [n_filters],
+                initializer=tf.truncated_normal_initializer(stddev=stddev))
+            conv = conv + b
+        return conv
+    
+def linear(x, n_units, scope=None, stddev=0.02,
+           activation=lambda x: x):
+    """Fully-connected network.
+    Parameters
+    ----------
+    x : Tensor
+        Input tensor to the network.
+    n_units : int
+        Number of units to connect to.
+    scope : str, optional
+        Variable scope to use.
+    stddev : float, optional
+        Initialization's standard deviation.
+    activation : arguments, optional
+        Function which applies a nonlinearity
+    Returns
+    -------
+    x : Tensor
+        Fully-connected output.
+    """
+    shape = x.get_shape().as_list()
+
+    with tf.variable_scope(scope or "Linear"):
+        matrix = tf.get_variable("Matrix", [shape[1], n_units], tf.float32,
+                                 tf.random_normal_initializer(stddev=stddev))
+        return activation(tf.matmul(x, matrix))
+    
+# %%
+def weight_variable(shape):
+    '''Helper function to create a weight variable initialized with
+    a normal distribution
+    Parameters
+    ----------
+    shape : list
+        Size of weight variable
+    '''
+    #initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
+    initial = tf.zeros(shape)
+    return tf.Variable(initial)
+
+# %%
+def bias_variable(shape):
+    '''Helper function to create a bias variable initialized with
+    a constant value.
+    Parameters
+    ----------
+    shape : list
+        Size of weight variable
+    '''
+    initial = tf.random_normal(shape, mean=0.0, stddev=0.01)
+    return tf.Variable(initial)
+
+# %% 
+def dense_to_one_hot(labels, n_classes=2):
+    """Convert class labels from scalars to one-hot vectors."""
+    labels = np.array(labels)
+    n_labels = labels.shape[0]
+    index_offset = np.arange(n_labels) * n_classes
+    labels_one_hot = np.zeros((n_labels, n_classes), dtype=np.float32)
+    labels_one_hot.flat[index_offset + labels.ravel()] = 1
+    return labels_one_hot