Browse Source

fix random forest TF 1.4 compatibility

aymericdamien 7 years ago
parent
commit
0c4e6661de

+ 2 - 2
examples/2_BasicModels/random_forest.py

@@ -47,7 +47,7 @@ train_op = forest_graph.training_graph(X, Y)
 loss_op = forest_graph.training_loss(X, Y)
 
 # Measure the accuracy
-infer_op = forest_graph.inference_graph(X)
+infer_op, _, _ = forest_graph.inference_graph(X)
 correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
 accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
@@ -55,7 +55,7 @@ accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 init_vars = tf.global_variables_initializer()
 
 # Start TensorFlow session
-sess = tf.Session()
+sess = tf.train.MonitoredSession()
 
 # Run the initializer
 sess.run(init_vars)

+ 2 - 2
notebooks/2_BasicModels/random_forest.ipynb

@@ -122,7 +122,7 @@
     "loss_op = forest_graph.training_loss(X, Y)\n",
     "\n",
     "# Measure the accuracy\n",
-    "infer_op = forest_graph.inference_graph(X)\n",
+    "infer_op, _, _ = forest_graph.inference_graph(X)\n",
     "correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
     "accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
     "\n",
@@ -158,7 +158,7 @@
    ],
    "source": [
     "# Start TensorFlow session\n",
-    "sess = tf.Session()\n",
+    "sess = tf.train.MonitoredSession()\n",
     "\n",
     "# Run the initializer\n",
     "sess.run(init_vars)\n",