|
@@ -122,7 +122,7 @@
|
|
|
"loss_op = forest_graph.training_loss(X, Y)\n",
|
|
|
"\n",
|
|
|
"# Measure the accuracy\n",
|
|
|
- "infer_op = forest_graph.inference_graph(X)\n",
|
|
|
+ "infer_op, _, _ = forest_graph.inference_graph(X)\n",
|
|
|
"correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
|
|
|
"accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
|
|
|
"\n",
|
|
@@ -158,7 +158,7 @@
|
|
|
],
|
|
|
"source": [
|
|
|
"# Start TensorFlow session\n",
|
|
|
- "sess = tf.Session()\n",
|
|
|
+ "sess = tf.train.MonitoredSession()\n",
|
|
|
"\n",
|
|
|
"# Run the initializer\n",
|
|
|
"sess.run(init_vars)\n",
|