example.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # ==============================================================================
  15. import tensorflow as tf
  16. from spatial_transformer import transformer
  17. from scipy import ndimage
  18. import numpy as np
  19. import matplotlib.pyplot as plt
  20. from tf_utils import conv2d, linear, weight_variable, bias_variable
  21. # %% Create a batch of three images (1600 x 1200)
  22. # %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
  23. im = ndimage.imread('cat.jpg')
  24. im = im / 255.
  25. im = im.reshape(1, 1200, 1600, 3)
  26. im = im.astype('float32')
  27. # %% Simulate batch
  28. batch = np.append(im, im, axis=0)
  29. batch = np.append(batch, im, axis=0)
  30. num_batch = 3
  31. x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
  32. x = tf.cast(batch,'float32')
  33. # %% Create localisation network and convolutional layer
  34. with tf.variable_scope('spatial_transformer_0'):
  35. # %% Create a fully-connected layer with 6 output nodes
  36. n_fc = 6
  37. W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
  38. # %% Zoom into the image
  39. initial = np.array([[0.5,0, 0],[0,0.5,0]])
  40. initial = initial.astype('float32')
  41. initial = initial.flatten()
  42. b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
  43. h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
  44. h_trans = transformer(x, h_fc1, downsample_factor=2)
  45. # %% Run session
  46. sess = tf.Session()
  47. sess.run(tf.initialize_all_variables())
  48. y = sess.run(h_trans, feed_dict={x: batch})
  49. # plt.imshow(y[0])