gradient_boosted_decision_tree.py 2.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. """ Gradient Boosted Decision Tree (GBDT).
  2. Implement a Gradient Boosted Decision tree with TensorFlow to classify
  3. handwritten digit images. This example is using the MNIST database of
  4. handwritten digits as training samples (http://yann.lecun.com/exdb/mnist/).
  5. Links:
  6. [MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
  7. Author: Aymeric Damien
  8. Project: https://github.com/aymericdamien/TensorFlow-Examples/
  9. """
  10. from __future__ import print_function
  11. import tensorflow as tf
  12. from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier
  13. from tensorflow.contrib.boosted_trees.proto import learner_pb2 as gbdt_learner
  14. # Ignore all GPUs (current TF GBDT does not support GPU).
  15. import os
  16. os.environ["CUDA_VISIBLE_DEVICES"] = ""
  17. # Import MNIST data
  18. # Set verbosity to display errors only (Remove this line for showing warnings)
  19. tf.logging.set_verbosity(tf.logging.ERROR)
  20. from tensorflow.examples.tutorials.mnist import input_data
  21. mnist = input_data.read_data_sets("/tmp/data/", one_hot=False,
  22. source_url='http://yann.lecun.com/exdb/mnist/')
  23. # Parameters
  24. batch_size = 4096 # The number of samples per batch
  25. num_classes = 10 # The 10 digits
  26. num_features = 784 # Each image is 28x28 pixels
  27. max_steps = 10000
  28. # GBDT Parameters
  29. learning_rate = 0.1
  30. l1_regul = 0.
  31. l2_regul = 1.
  32. examples_per_layer = 1000
  33. num_trees = 10
  34. max_depth = 16
  35. # Fill GBDT parameters into the config proto
  36. learner_config = gbdt_learner.LearnerConfig()
  37. learner_config.learning_rate_tuner.fixed.learning_rate = learning_rate
  38. learner_config.regularization.l1 = l1_regul
  39. learner_config.regularization.l2 = l2_regul / examples_per_layer
  40. learner_config.constraints.max_tree_depth = max_depth
  41. growing_mode = gbdt_learner.LearnerConfig.LAYER_BY_LAYER
  42. learner_config.growing_mode = growing_mode
  43. run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300)
  44. learner_config.multi_class_strategy = (
  45. gbdt_learner.LearnerConfig.DIAGONAL_HESSIAN)\
  46. # Create a TensorFlor GBDT Estimator
  47. gbdt_model = GradientBoostedDecisionTreeClassifier(
  48. model_dir=None, # No save directory specified
  49. learner_config=learner_config,
  50. n_classes=num_classes,
  51. examples_per_layer=examples_per_layer,
  52. num_trees=num_trees,
  53. center_bias=False,
  54. config=run_config)
  55. # Display TF info logs
  56. tf.logging.set_verbosity(tf.logging.INFO)
  57. # Define the input function for training
  58. input_fn = tf.estimator.inputs.numpy_input_fn(
  59. x={'images': mnist.train.images}, y=mnist.train.labels,
  60. batch_size=batch_size, num_epochs=None, shuffle=True)
  61. # Train the Model
  62. gbdt_model.fit(input_fn=input_fn, max_steps=max_steps)
  63. # Evaluate the Model
  64. # Define the input function for evaluating
  65. input_fn = tf.estimator.inputs.numpy_input_fn(
  66. x={'images': mnist.test.images}, y=mnist.test.labels,
  67. batch_size=batch_size, shuffle=False)
  68. # Use the Estimator 'evaluate' method
  69. e = gbdt_model.evaluate(input_fn=input_fn)
  70. print("Testing Accuracy:", e['accuracy'])