random_forest.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. """ Random Forest.
  2. Implement Random Forest algorithm with TensorFlow, and apply it to classify
  3. handwritten digit images. This example is using the MNIST database of
  4. handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
  5. Author: Aymeric Damien
  6. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  7. """
  8. from __future__ import print_function
  9. import tensorflow as tf
  10. from tensorflow.contrib.tensor_forest.python import tensor_forest
  11. # Ignore all GPUs, tf random forest does not benefit from it.
  12. import os
  13. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  14. # Import MNIST data
  15. from tensorflow.examples.tutorials.mnist import input_data
  16. mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
  17. # Parameters
  18. num_steps = 500 # Total steps to train
  19. batch_size = 1024 # The number of samples per batch
  20. num_classes = 10 # The 10 digits
  21. num_features = 784 # Each image is 28x28 pixels
  22. num_trees = 10
  23. max_nodes = 1000
  24. # Input and Target data
  25. X = tf.placeholder(tf.float32, shape=[None, num_features])
  26. # For random forest, labels must be integers (the class id)
  27. Y = tf.placeholder(tf.int32, shape=[None])
  28. # Random Forest Parameters
  29. hparams = tensor_forest.ForestHParams(num_classes=num_classes,
  30. num_features=num_features,
  31. num_trees=num_trees,
  32. max_nodes=max_nodes).fill()
  33. # Build the Random Forest
  34. forest_graph = tensor_forest.RandomForestGraphs(hparams)
  35. # Get training graph and loss
  36. train_op = forest_graph.training_graph(X, Y)
  37. loss_op = forest_graph.training_loss(X, Y)
  38. # Measure the accuracy
  39. infer_op, _, _ = forest_graph.inference_graph(X)
  40. correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
  41. accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  42. # Initialize the variables (i.e. assign their default value)
  43. init_vars = tf.global_variables_initializer()
  44. # Start TensorFlow session
  45. sess = tf.train.MonitoredSession()
  46. # Run the initializer
  47. sess.run(init_vars)
  48. # Training
  49. for i in range(1, num_steps + 1):
  50. # Prepare Data
  51. # Get the next batch of MNIST data (only images are needed, not labels)
  52. batch_x, batch_y = mnist.train.next_batch(batch_size)
  53. _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
  54. if i % 50 == 0 or i == 1:
  55. acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
  56. print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
  57. # Test Model
  58. test_x, test_y = mnist.test.images, mnist.test.labels
  59. print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))