Bläddra i källkod

Refactor bidirectional rnn for TF1.0

Signed-off-by: Norman Heckscher <norman.heckscher@gmail.com>
Norman Heckscher 8 år sedan
förälder
incheckning
53a722877a
1 ändrade filer med 8 tillägg och 8 borttagningar
  1. 8 8
      examples/3_NeuralNetworks/bidirectional_rnn.py

+ 8 - 8
examples/3_NeuralNetworks/bidirectional_rnn.py

@@ -10,7 +10,7 @@ Project: https://github.com/aymericdamien/TensorFlow-Examples/
 from __future__ import print_function
 
 import tensorflow as tf
-from tensorflow.python.ops import rnn, rnn_cell
+from tensorflow.contrib import rnn
 import numpy as np
 
 # Import MNIST data
@@ -60,20 +60,20 @@ def BiRNN(x, weights, biases):
     # Reshape to (n_steps*batch_size, n_input)
     x = tf.reshape(x, [-1, n_input])
     # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
-    x = tf.split(0, n_steps, x)
+    x = tf.split(x, n_steps, 0)
 
     # Define lstm cells with tensorflow
     # Forward direction cell
-    lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
+    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
     # Backward direction cell
-    lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
+    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
 
     # Get lstm cell output
     try:
-        outputs, _, _ = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
+        outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
                                               dtype=tf.float32)
     except Exception: # Old TensorFlow version only returns outputs not states
-        outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
+        outputs = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x,
                                         dtype=tf.float32)
 
     # Linear activation, using rnn inner loop last output
@@ -82,7 +82,7 @@ def BiRNN(x, weights, biases):
 pred = BiRNN(x, weights, biases)
 
 # Define loss and optimizer
-cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
+cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
 optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
 
 # Evaluate model
@@ -90,7 +90,7 @@ correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
 accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
 
 # Initializing the variables
-init = tf.initialize_all_variables()
+init = tf.global_variables_initializer()
 
 # Launch the graph
 with tf.Session() as sess: