Forráskód Böngészése

Simplify code in dynamic_rnn (#178)

* Simplify code

* Simplify code in dynamic_rnn.ipynb
张志豪 7 éve
szülő
commit
07cb125ce8

+ 2 - 5
examples/3_NeuralNetworks/dynamic_rnn.py

@@ -175,12 +175,9 @@ with tf.Session() as sess:
         sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,
                                        seqlen: batch_seqlen})
         if step % display_step == 0 or step == 1:
-            # Calculate batch accuracy
-            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y,
+            # Calculate batch accuracy & loss
+            acc, loss = sess.run([accuracy, cost], feed_dict={x: batch_x, y: batch_y,
                                                 seqlen: batch_seqlen})
-            # Calculate batch loss
-            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y,
-                                             seqlen: batch_seqlen})
             print("Step " + str(step*batch_size) + ", Minibatch Loss= " + \
                   "{:.6f}".format(loss) + ", Training Accuracy= " + \
                   "{:.5f}".format(acc))

+ 2 - 5
notebooks/3_NeuralNetworks/dynamic_rnn.ipynb

@@ -308,12 +308,9 @@
     "        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,\n",
     "                                       seqlen: batch_seqlen})\n",
     "        if step % display_step == 0 or step == 1:\n",
-    "            # Calculate batch accuracy\n",
-    "            acc = sess.run(accuracy, feed_dict={x: batch_x, y: batch_y,\n",
+    "            # Calculate batch accuracy & loss\n",
+    "            acc, loss = sess.run([accuracy, cost], feed_dict={x: batch_x, y: batch_y,\n",
     "                                                seqlen: batch_seqlen})\n",
-    "            # Calculate batch loss\n",
-    "            loss = sess.run(cost, feed_dict={x: batch_x, y: batch_y,\n",
-    "                                             seqlen: batch_seqlen})\n",
     "            print(\"Step \" + str(step) + \", Minibatch Loss= \" + \\\n",
     "                  \"{:.6f}\".format(loss) + \", Training Accuracy= \" + \\\n",
     "                  \"{:.5f}\".format(acc))\n",