12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ==============================================================================
- import tensorflow as tf
- from spatial_transformer import transformer
- from scipy import ndimage
- import numpy as np
- import matplotlib.pyplot as plt
- from tf_utils import conv2d, linear, weight_variable, bias_variable
- # %% Create a batch of three images (1600 x 1200)
- # %% Image retrieved from https://raw.githubusercontent.com/skaae/transformer_network/master/cat.jpg
- im = ndimage.imread('cat.jpg')
- im = im / 255.
- im = im.reshape(1, 1200, 1600, 3)
- im = im.astype('float32')
- # %% Simulate batch
- batch = np.append(im, im, axis=0)
- batch = np.append(batch, im, axis=0)
- num_batch = 3
- x = tf.placeholder(tf.float32, [None, 1200, 1600, 3])
- x = tf.cast(batch,'float32')
- # %% Create localisation network and convolutional layer
- with tf.variable_scope('spatial_transformer_0'):
- # %% Create a fully-connected layer with 6 output nodes
- n_fc = 6
- W_fc1 = tf.Variable(tf.zeros([1200 * 1600 * 3, n_fc]), name='W_fc1')
- # %% Zoom into the image
- initial = np.array([[0.5,0, 0],[0,0.5,0]])
- initial = initial.astype('float32')
- initial = initial.flatten()
- b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
- h_fc1 = tf.matmul(tf.zeros([num_batch ,1200 * 1600 * 3]), W_fc1) + b_fc1
- h_trans = transformer(x, h_fc1, downsample_factor=2)
- # %% Run session
- sess = tf.Session()
- sess.run(tf.initialize_all_variables())
- y = sess.run(h_trans, feed_dict={x: batch})
- # plt.imshow(y[0])
|