example.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. from scipy import ndimage
  16. import tensorflow as tf
  17. from spatial_transformer import transformer
  18. import numpy as np
  19. import matplotlib.pyplot as plt
  20. # %% Create a batch of three images (1600 x 1200)
  21. # %% Image retrieved from:
  22. # %% https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
  23. im = ndimage.imread('cat.jpg')
  24. im = im / 255.
  25. im = im.reshape(1, 1200, 1600, 3)
  26. im = im.astype('float32')
  27. # %% Let the output size of the transformer be half the image size.
  28. out_size = (600, 800)
  29. # %% Simulate batch
  30. batch = np.append(im, im, axis=0)
  31. batch = np.append(batch, im, axis=0)
  32. num_batch = 3
  33. x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
  34. x = tf.cast(batch, 'float32')
  35. # %% Create localisation network and convolutional layer
  36. with tf.variable_scope('spatial_transformer_0'):
  37. # %% Create a fully-connected layer with 6 output nodes
  38. n_fc = 6
  39. W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
  40. # %% Zoom into the image
  41. initial = np.array([[0.5, 0, 0], [0, 0.5, 0]])
  42. initial = initial.astype('float32')
  43. initial = initial.flatten()
  44. b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
  45. h_fc1 = tf.matmul(tf.zeros([num_batch, 1200 * 1600 * 3]), W_fc1) + b_fc1
  46. h_trans = transformer(x, h_fc1, out_size)
  47. # %% Run session
  48. sess = tf.Session()
  49. sess.run(tf.global_variables_initializer())
  50. y = sess.run(h_trans, feed_dict={x: batch})
  51. # plt.imshow(y[0])