|
@@ -12,6 +12,7 @@ from __future__ import print_function
|
|
|
|
|
|
import tensorflow as tf
|
|
|
from tensorflow.contrib.tensor_forest.python import tensor_forest
|
|
|
+from tensorflow.python.ops import resources
|
|
|
|
|
|
# Ignore all GPUs, tf random forest does not benefit from it.
|
|
|
import os
|
|
@@ -51,11 +52,12 @@ infer_op, _, _ = forest_graph.inference_graph(X)
|
|
|
correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
|
|
|
accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
|
|
|
|
|
|
-# Initialize the variables (i.e. assign their default value)
|
|
|
-init_vars = tf.global_variables_initializer()
|
|
|
+# Initialize the variables (i.e. assign their default value) and forest resources
|
|
|
+init_vars = tf.group(tf.global_variables_initializer(),
|
|
|
+ resources.initialize_resources(resources.shared_resources()))
|
|
|
|
|
|
# Start TensorFlow session
|
|
|
-sess = tf.train.MonitoredSession()
|
|
|
+sess = tf.Session()
|
|
|
|
|
|
# Run the initializer
|
|
|
sess.run(init_vars)
|