Explorar o código

update random_forest

aymericdamien %!s(int64=7) %!d(string=hai) anos
pai
achega
d43c58c948

+ 5 - 3
examples/2_BasicModels/random_forest.py

@@ -12,6 +12,7 @@ from __future__ import print_function
 
 import tensorflow as tf
 from tensorflow.contrib.tensor_forest.python import tensor_forest
+from tensorflow.python.ops import resources
 
 # Ignore all GPUs, tf random forest does not benefit from it.
 import os
@@ -51,11 +52,12 @@ infer_op, _, _ = forest_graph.inference_graph(X)
 correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
 accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
 
-# Initialize the variables (i.e. assign their default value)
-init_vars = tf.global_variables_initializer()
+# Initialize the variables (i.e. assign their default value) and forest resources
+init_vars = tf.group(tf.global_variables_initializer(),
+    resources.initialize_resources(resources.shared_resources()))
 
 # Start TensorFlow session
-sess = tf.train.MonitoredSession()
+sess = tf.Session()
 
 # Run the initializer
 sess.run(init_vars)

+ 3 - 2
notebooks/2_BasicModels/random_forest.ipynb

@@ -126,8 +126,9 @@
     "correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))\n",
     "accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))\n",
     "\n",
-    "# Initialize the variables (i.e. assign their default value)\n",
-    "init_vars = tf.global_variables_initializer()"
+    "# Initialize the variables (i.e. assign their default value) and forest resources\n",
+    "init_vars = tf.group(tf.global_variables_initializer(),\n",
+    "    resources.initialize_resources(resources.shared_resources()))"
    ]
   },
   {