Browse Source

Refactor recurrent network for TF1.0

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
Norman Heckscher 8 years ago
parent
commit
f3e6051fb9
1 changed files with 6 additions and 6 deletions
  1. 6 6
      examples/3_NeuralNetworks/recurrent_network.py

+ 6 - 6
examples/3_NeuralNetworks/recurrent_network.py

@@ -10,7 +10,7 @@ Project: https://github.com/aymericdamien/TensorFlow-Examples/
 from __future__ import print_function
 
 import tensorflow as tf
-from tensorflow.python.ops import rnn, rnn_cell
+from tensorflow.contrib import rnn
 
 # Import MNIST data
 from tensorflow.examples.tutorials.mnist import input_data
@@ -58,13 +58,13 @@ def RNN(x, weights, biases):
     # Reshaping to (n_steps*batch_size, n_input)
     x = tf.reshape(x, [-1, n_input])
     # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
-    x = tf.split(0, n_steps, x)
+    x = tf.split(x, n_steps, 0)
 
     # Define a lstm cell with tensorflow
-    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
+    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 
     # Get lstm cell output
-    outputs, states = rnn.rnn(lstm_cell, x, dtype=tf.float32)
+    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
 
     # Linear activation, using rnn inner loop last output
     return tf.matmul(outputs[-1], weights['out']) + biases['out']
@@ -72,7 +72,7 @@ def RNN(x, weights, biases):
 pred = RNN(x, weights, biases)
 
 # Define loss and optimizer
-cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
+cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
 # Evaluate model
@@ -80,7 +80,7 @@ correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
 
 # Initializing the variables
-init = tf.initialize_all_variables()
+init = tf.global_variables_initializer()
 
 # Launch the graph
 with tf.Session() as sess: