Explorar el Código

Simpilify RNN examples data transformation (#136)

* gittest

* Simpilify RNN examples data transform.
zxiaomzxm hace 8 años
padre
commit
373d9810c4

+ 2 - 6
examples/3_NeuralNetworks/bidirectional_rnn.py

@@ -55,12 +55,8 @@ def BiRNN(x, weights, biases):
     # Current data input shape: (batch_size, n_steps, n_input)
     # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
 
-    # Permuting batch_size and n_steps
-    x = tf.transpose(x, [1, 0, 2])
-    # Reshape to (n_steps*batch_size, n_input)
-    x = tf.reshape(x, [-1, n_input])
-    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
-    x = tf.split(x, n_steps, 0)
+    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
+    x = tf.unstack(x, n_steps, 1)
 
     # Define lstm cells with tensorflow
     # Forward direction cell

+ 3 - 7
examples/3_NeuralNetworks/dynamic_rnn.py

@@ -113,13 +113,9 @@ def dynamicRNN(x, seqlen, weights, biases):
     # Prepare data shape to match `rnn` function requirements
     # Current data input shape: (batch_size, n_steps, n_input)
     # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
-
-    # Permuting batch_size and n_steps
-    x = tf.transpose(x, [1, 0, 2])
-    # Reshaping to (n_steps*batch_size, n_input)
-    x = tf.reshape(x, [-1, 1])
-    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
-    x = tf.split(axis=0, num_or_size_splits=seq_max_len, value=x)
+    
+    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
+    x = tf.unstack(x, seq_max_len, 1)
 
     # Define a lstm cell with tensorflow
     lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden)

+ 2 - 6
examples/3_NeuralNetworks/recurrent_network.py

@@ -53,12 +53,8 @@ def RNN(x, weights, biases):
     # Current data input shape: (batch_size, n_steps, n_input)
     # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)
 
-    # Permuting batch_size and n_steps
-    x = tf.transpose(x, [1, 0, 2])
-    # Reshaping to (n_steps*batch_size, n_input)
-    x = tf.reshape(x, [-1, n_input])
-    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)
-    x = tf.split(x, n_steps, 0)
+    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)
+    x = tf.unstack(x, n_steps, 1)
 
     # Define a lstm cell with tensorflow
     lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)

+ 2 - 6
notebooks/3_NeuralNetworks/bidirectional_rnn.ipynb

@@ -94,12 +94,8 @@
     "    # Current data input shape: (batch_size, n_steps, n_input)\n",
     "    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
     "    \n",
-    "    # Permuting batch_size and n_steps\n",
-    "    x = tf.transpose(x, [1, 0, 2])\n",
-    "    # Reshape to (n_steps*batch_size, n_input)\n",
-    "    x = tf.reshape(x, [-1, n_input])\n",
-    "    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
-    "    x = tf.split(x, n_steps, 0)\n",
+    "    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
+    "    x = tf.unstack(x, n_steps, 1)\n",
     "\n",
     "    # Define lstm cells with tensorflow\n",
     "    # Forward direction cell\n",

+ 2 - 6
notebooks/3_NeuralNetworks/recurrent_network.ipynb

@@ -93,12 +93,8 @@
     "    # Current data input shape: (batch_size, n_steps, n_input)\n",
     "    # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
     "    \n",
-    "    # Permuting batch_size and n_steps\n",
-    "    x = tf.transpose(x, [1, 0, 2])\n",
-    "    # Reshaping to (n_steps*batch_size, n_input)\n",
-    "    x = tf.reshape(x, [-1, n_input])\n",
-    "    # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
-    "    x = tf.split(x, n_steps, 0)\n",
+    "    # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
+    "    x = tf.unstack(x, n_steps, 1)\n",
     "\n",
     "    # Define a lstm cell with tensorflow\n",
     "    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",