random_forest.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. """ Random Forest.
  2. Implement Random Forest algorithm with TensorFlow, and apply it to classify
  3. handwritten digit images. This example is using the MNIST database of
  4. handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
  5. Author: Aymeric Damien
  6. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  7. """
  8. from __future__ import print_function
  9. import tensorflow as tf
  10. from tensorflow.contrib.tensor_forest.python import tensor_forest
  11. from tensorflow.python.ops import resources
  12. # Ignore all GPUs, tf random forest does not benefit from it.
  13. import os
  14. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  15. # Import MNIST data
  16. from tensorflow.examples.tutorials.mnist import input_data
  17. mnist = input_data.read_data_sets("/tmp/data/", one_hot=False)
  18. # Parameters
  19. num_steps = 500 # Total steps to train
  20. batch_size = 1024 # The number of samples per batch
  21. num_classes = 10 # The 10 digits
  22. num_features = 784 # Each image is 28x28 pixels
  23. num_trees = 10
  24. max_nodes = 1000
  25. # Input and Target data
  26. X = tf.placeholder(tf.float32, shape=[None, num_features])
  27. # For random forest, labels must be integers (the class id)
  28. Y = tf.placeholder(tf.int32, shape=[None])
  29. # Random Forest Parameters
  30. hparams = tensor_forest.ForestHParams(num_classes=num_classes,
  31. num_features=num_features,
  32. num_trees=num_trees,
  33. max_nodes=max_nodes).fill()
  34. # Build the Random Forest
  35. forest_graph = tensor_forest.RandomForestGraphs(hparams)
  36. # Get training graph and loss
  37. train_op = forest_graph.training_graph(X, Y)
  38. loss_op = forest_graph.training_loss(X, Y)
  39. # Measure the accuracy
  40. infer_op, _, _ = forest_graph.inference_graph(X)
  41. correct_prediction = tf.equal(tf.argmax(infer_op, 1), tf.cast(Y, tf.int64))
  42. accuracy_op = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  43. # Initialize the variables (i.e. assign their default value) and forest resources
  44. init_vars = tf.group(tf.global_variables_initializer(),
  45. resources.initialize_resources(resources.shared_resources()))
  46. # Start TensorFlow session
  47. sess = tf.Session()
  48. # Run the initializer
  49. sess.run(init_vars)
  50. # Training
  51. for i in range(1, num_steps + 1):
  52. # Prepare Data
  53. # Get the next batch of MNIST data (only images are needed, not labels)
  54. batch_x, batch_y = mnist.train.next_batch(batch_size)
  55. _, l = sess.run([train_op, loss_op], feed_dict={X: batch_x, Y: batch_y})
  56. if i % 50 == 0 or i == 1:
  57. acc = sess.run(accuracy_op, feed_dict={X: batch_x, Y: batch_y})
  58. print('Step %i, Loss: %f, Acc: %f' % (i, l, acc))
  59. # Test Model
  60. test_x, test_y = mnist.test.images, mnist.test.labels
  61. print("Test Accuracy:", sess.run(accuracy_op, feed_dict={X: test_x, Y: test_y}))