|
@@ -93,12 +93,8 @@
|
|
|
" # Current data input shape: (batch_size, n_steps, n_input)\n",
|
|
|
" # Required shape: 'n_steps' tensors list of shape (batch_size, n_input)\n",
|
|
|
" \n",
|
|
|
- " # Permuting batch_size and n_steps\n",
|
|
|
- " x = tf.transpose(x, [1, 0, 2])\n",
|
|
|
- " # Reshaping to (n_steps*batch_size, n_input)\n",
|
|
|
- " x = tf.reshape(x, [-1, n_input])\n",
|
|
|
- " # Split to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
|
|
|
- " x = tf.split(x, n_steps, 0)\n",
|
|
|
+ " # Unstack to get a list of 'n_steps' tensors of shape (batch_size, n_input)\n",
|
|
|
+ " x = tf.unstack(x, n_steps, 1)\n",
|
|
|
"\n",
|
|
|
" # Define a lstm cell with tensorflow\n",
|
|
|
" lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)\n",
|