{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Gradient Boosted Decision Tree (GBDT)\n", "Implement a Gradient Boosted Decision Tree (GBDT) with TensorFlow. This example is using the Boston Housing Value dataset as training samples. The example supports both Classification (2 classes: value > $23000 or not) and Regression (raw home value as target).\n", "\n", "- Author: Aymeric Damien\n", "- Project: https://github.com/aymericdamien/TensorFlow-Examples/" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Boston Housing Dataset\n", "\n", "**Link:** https://www.cs.toronto.edu/~delve/data/boston/bostonDetail.html\n", "\n", "**Description:**\n", "\n", "The dataset contains information collected by the U.S Census Service concerning housing in the area of Boston Mass. It was obtained from the StatLib archive (http://lib.stat.cmu.edu/datasets/boston), and has been used extensively throughout the literature to benchmark algorithms. However, these comparisons were primarily done outside of Delve and are thus somewhat suspect. The dataset is small in size with only 506 cases.\n", "\n", "The data was originally published by Harrison, D. and Rubinfeld, D.L. `Hedonic prices and the demand for clean air', J. Environ. Economics & Management, vol.5, 81-102, 1978.`\n", "\n", "*For the full features list, please see the link above*" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "from __future__ import print_function\n", "\n", "# Ignore all GPUs (current TF GBDT does not support GPU).\n", "import os\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n", "os.environ['TF_CPP_MIN_LOG_LEVEL'] = \"1\"\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "import copy" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Dataset parameters.\n", "num_classes = 2 # Total classes: greater or equal to $23,000, or not (See notes below).\n", "num_features = 13 # data features size.\n", "\n", "# Training parameters.\n", "max_steps = 2000\n", "batch_size = 256\n", "learning_rate = 1.0\n", "l1_regul = 0.0\n", "l2_regul = 0.1\n", "\n", "# GBDT parameters.\n", "num_batches_per_layer = 1000\n", "num_trees = 10\n", "max_depth = 4" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# Prepare Boston Housing Dataset.\n", "from tensorflow.keras.datasets import boston_housing\n", "(x_train, y_train), (x_test, y_test) = boston_housing.load_data()\n", "\n", "# For classification purpose, we build 2 classes: price greater or lower than $23,000\n", "def to_binary_class(y):\n", " for i, label in enumerate(y):\n", " if label >= 23.0:\n", " y[i] = 1\n", " else:\n", " y[i] = 0\n", " return y\n", "\n", "y_train_binary = to_binary_class(copy.deepcopy(y_train))\n", "y_test_binary = to_binary_class(copy.deepcopy(y_test))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GBDT Classifier" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Build the input function.\n", "train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(\n", " x={'x': x_train}, y=y_train_binary,\n", " batch_size=batch_size, num_epochs=None, shuffle=True)\n", "test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(\n", " x={'x': x_test}, y=y_test_binary,\n", " batch_size=batch_size, num_epochs=1, shuffle=False)\n", "test_train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(\n", " x={'x': x_train}, y=y_train_binary,\n", " batch_size=batch_size, num_epochs=1, shuffle=False)\n", "# GBDT Models from TF Estimator requires 'feature_column' data format.\n", "feature_columns = [tf.feature_column.numeric_column(key='x', shape=(num_features,))]" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n", "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmp5h6BoR\n", "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': ClusterSpec({}), '_model_dir': '/tmp/tmp5h6BoR', '_protocol': None, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_tf_random_seed': None, '_save_summary_steps': 100, '_device_fn': None, '_session_creation_timeout_secs': 7200, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_experimental_max_worker_delay_secs': None, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': None, '_master': ''}\n" ] } ], "source": [ "gbdt_classifier = tf.estimator.BoostedTreesClassifier(\n", " n_batches_per_layer=num_batches_per_layer,\n", " feature_columns=feature_columns, \n", " n_classes=num_classes,\n", " learning_rate=learning_rate, \n", " n_trees=num_trees,\n", " max_depth=max_depth,\n", " l1_regularization=l1_regul, \n", " l2_regularization=l2_regul\n", ")" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1635: calling __init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "If using Keras pass *_constraint arguments to layers.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_core/python/training/training_util.py:236: initialized_value (from tensorflow.python.ops.variables) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Use Variable.read_value. Variables in 2.X are initialized automatically both in eager and graph (inside tf.defun) contexts.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_queue_runner.py:62: __init__ (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "To construct input pipelines, use the `tf.data` module.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/inputs/queues/feeding_functions.py:500: add_queue_runner (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "To construct input pipelines, use the `tf.data` module.\n", "INFO:tensorflow:Calling model_fn.\n", "INFO:tensorflow:Done calling model_fn.\n", "INFO:tensorflow:Create CheckpointSaverHook.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Graph was finalized.\n", "INFO:tensorflow:Running local_init_op.\n", "INFO:tensorflow:Done running local_init_op.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_core/python/training/monitored_session.py:906: start_queue_runners (from tensorflow.python.training.queue_runner_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "To construct input pipelines, use the `tf.data` module.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmp5h6BoR/model.ckpt.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:loss = 0.6931475, step = 0\n", "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n", "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n", "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n", "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n", "WARNING:tensorflow:It seems that global step (tf.train.get_global_step) has not been increased. Current value (could be stable): 0 vs previous value: 0. You could increase the global step by passing tf.train.get_global_step() to Optimizer.apply_gradients or Optimizer.minimize.\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.406 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.156 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.167 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.156 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.161 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.156 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.154 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.155 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.158 sec)\n", "INFO:tensorflow:loss = 0.6931475, step = 0 (0.150 sec)\n", "INFO:tensorflow:global_step/sec: 47.2392\n", "INFO:tensorflow:loss = 0.6931475, step = 100 (0.301 sec)\n", "INFO:tensorflow:global_step/sec: 605.484\n", "INFO:tensorflow:loss = 0.6931475, step = 200 (0.165 sec)\n", "INFO:tensorflow:global_step/sec: 616.234\n", "INFO:tensorflow:loss = 0.6931475, step = 300 (0.162 sec)\n", "INFO:tensorflow:global_step/sec: 607.741\n", "INFO:tensorflow:loss = 0.6931475, step = 400 (0.165 sec)\n", "INFO:tensorflow:global_step/sec: 591.803\n", "INFO:tensorflow:loss = 0.6931475, step = 500 (0.170 sec)\n", "INFO:tensorflow:global_step/sec: 627.369\n", "INFO:tensorflow:loss = 0.6931475, step = 600 (0.159 sec)\n", "INFO:tensorflow:global_step/sec: 617.083\n", "INFO:tensorflow:loss = 0.6931475, step = 700 (0.162 sec)\n", "INFO:tensorflow:global_step/sec: 608.765\n", "INFO:tensorflow:loss = 0.6931475, step = 800 (0.164 sec)\n", "INFO:tensorflow:global_step/sec: 619.62\n", "INFO:tensorflow:loss = 0.6931475, step = 900 (0.161 sec)\n", "INFO:tensorflow:global_step/sec: 582.581\n", "INFO:tensorflow:loss = 0.44474202, step = 1000 (0.172 sec)\n", "INFO:tensorflow:global_step/sec: 587.127\n", "INFO:tensorflow:loss = 0.46633375, step = 1100 (0.170 sec)\n", "INFO:tensorflow:global_step/sec: 583.294\n", "INFO:tensorflow:loss = 0.45393157, step = 1200 (0.171 sec)\n", "INFO:tensorflow:global_step/sec: 590.375\n", "INFO:tensorflow:loss = 0.44438446, step = 1300 (0.170 sec)\n", "INFO:tensorflow:global_step/sec: 572.479\n", "INFO:tensorflow:loss = 0.4523462, step = 1400 (0.175 sec)\n", "INFO:tensorflow:global_step/sec: 580.282\n", "INFO:tensorflow:loss = 0.4581305, step = 1500 (0.172 sec)\n", "INFO:tensorflow:global_step/sec: 570.032\n", "INFO:tensorflow:loss = 0.45298833, step = 1600 (0.175 sec)\n", "INFO:tensorflow:global_step/sec: 615.6\n", "INFO:tensorflow:loss = 0.4474975, step = 1700 (0.162 sec)\n", "INFO:tensorflow:global_step/sec: 603.042\n", "INFO:tensorflow:loss = 0.47046587, step = 1800 (0.166 sec)\n", "INFO:tensorflow:global_step/sec: 598.262\n", "INFO:tensorflow:loss = 0.46371317, step = 1900 (0.167 sec)\n", "INFO:tensorflow:global_step/sec: 591.323\n", "INFO:tensorflow:Saving checkpoints for 2000 into /tmp/tmp5h6BoR/model.ckpt.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Loss for final step: 0.46488184.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gbdt_classifier.train(train_input_fn, max_steps=max_steps)" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_core/python/ops/metrics_impl.py:2029: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "Deprecated in favor of operator or tf.math.divide.\n", "WARNING:tensorflow:From /home/user/anaconda3/envs/py2/lib/python2.7/site-packages/tensorflow_estimator/python/estimator/canned/head.py:619: auc (from tensorflow.python.ops.metrics_impl) is deprecated and will be removed in a future version.\n", "Instructions for updating:\n", "The value of AUC returned by this may race with the update so this is deprected. Please use tf.keras.metrics.AUC instead.\n", "WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n", "WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n", "INFO:tensorflow:Done calling model_fn.\n", "INFO:tensorflow:Starting evaluation at 2020-07-15T00:50:36Z\n", "INFO:tensorflow:Graph was finalized.\n", "INFO:tensorflow:Restoring parameters from /tmp/tmp5h6BoR/model.ckpt-2000\n", "INFO:tensorflow:Running local_init_op.\n", "INFO:tensorflow:Done running local_init_op.\n", "INFO:tensorflow:Inference Time : 0.56490s\n", "INFO:tensorflow:Finished evaluation at 2020-07-15-00:50:37\n", "INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.87376237, accuracy_baseline = 0.63118815, auc = 0.92280567, auc_precision_recall = 0.9104949, average_loss = 0.38236493, global_step = 2000, label/mean = 0.36881188, loss = 0.38619137, precision = 0.8888889, prediction/mean = 0.378958, recall = 0.7516779\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: /tmp/tmp5h6BoR/model.ckpt-2000\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.87376237,\n", " 'accuracy_baseline': 0.63118815,\n", " 'auc': 0.92280567,\n", " 'auc_precision_recall': 0.9104949,\n", " 'average_loss': 0.38236493,\n", " 'global_step': 2000,\n", " 'label/mean': 0.36881188,\n", " 'loss': 0.38619137,\n", " 'precision': 0.8888889,\n", " 'prediction/mean': 0.378958,\n", " 'recall': 0.7516779}" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gbdt_classifier.evaluate(test_train_input_fn)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n", "WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n", "WARNING:tensorflow:Trapezoidal rule is known to produce incorrect PR-AUCs; please switch to \"careful_interpolation\" instead.\n", "INFO:tensorflow:Done calling model_fn.\n", "INFO:tensorflow:Starting evaluation at 2020-07-15T00:50:38Z\n", "INFO:tensorflow:Graph was finalized.\n", "INFO:tensorflow:Restoring parameters from /tmp/tmp5h6BoR/model.ckpt-2000\n", "INFO:tensorflow:Running local_init_op.\n", "INFO:tensorflow:Done running local_init_op.\n", "INFO:tensorflow:Inference Time : 0.56883s\n", "INFO:tensorflow:Finished evaluation at 2020-07-15-00:50:38\n", "INFO:tensorflow:Saving dict for global step 2000: accuracy = 0.78431374, accuracy_baseline = 0.5588235, auc = 0.8458089, auc_precision_recall = 0.86285317, average_loss = 0.49404, global_step = 2000, label/mean = 0.44117647, loss = 0.49404, precision = 0.87096775, prediction/mean = 0.37467176, recall = 0.6\n", "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: /tmp/tmp5h6BoR/model.ckpt-2000\n" ] }, { "data": { "text/plain": [ "{'accuracy': 0.78431374,\n", " 'accuracy_baseline': 0.5588235,\n", " 'auc': 0.8458089,\n", " 'auc_precision_recall': 0.86285317,\n", " 'average_loss': 0.49404,\n", " 'global_step': 2000,\n", " 'label/mean': 0.44117647,\n", " 'loss': 0.49404,\n", " 'precision': 0.87096775,\n", " 'prediction/mean': 0.37467176,\n", " 'recall': 0.6}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gbdt_classifier.evaluate(test_input_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### GBDT Regressor" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Build the input function.\n", "train_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(\n", " x={'x': x_train}, y=y_train,\n", " batch_size=batch_size, num_epochs=None, shuffle=True)\n", "test_input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(\n", " x={'x': x_test}, y=y_test,\n", " batch_size=batch_size, num_epochs=1, shuffle=False)\n", "# GBDT Models from TF Estimator requires 'feature_column' data format.\n", "feature_columns = [tf.feature_column.numeric_column(key='x', shape=(num_features,))]" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Using default config.\n", "WARNING:tensorflow:Using temporary folder as model directory: /tmp/tmpts3Kmu\n", "INFO:tensorflow:Using config: {'_save_checkpoints_secs': 600, '_num_ps_replicas': 0, '_keep_checkpoint_max': 5, '_task_type': 'worker', '_global_id_in_cluster': 0, '_is_chief': True, '_cluster_spec': ClusterSpec({}), '_model_dir': '/tmp/tmpts3Kmu', '_protocol': None, '_save_checkpoints_steps': None, '_keep_checkpoint_every_n_hours': 10000, '_service': None, '_session_config': allow_soft_placement: true\n", "graph_options {\n", " rewrite_options {\n", " meta_optimizer_iterations: ONE\n", " }\n", "}\n", ", '_tf_random_seed': None, '_save_summary_steps': 100, '_device_fn': None, '_session_creation_timeout_secs': 7200, '_experimental_distribute': None, '_num_worker_replicas': 1, '_task_id': 0, '_log_step_count_steps': 100, '_experimental_max_worker_delay_secs': None, '_evaluation_master': '', '_eval_distribute': None, '_train_distribute': None, '_master': ''}\n" ] } ], "source": [ "gbdt_regressor = tf.estimator.BoostedTreesRegressor(\n", " n_batches_per_layer=num_batches_per_layer,\n", " feature_columns=feature_columns, \n", " learning_rate=learning_rate, \n", " n_trees=num_trees,\n", " max_depth=max_depth,\n", " l1_regularization=l1_regul, \n", " l2_regularization=l2_regul\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "scrolled": false }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n", "INFO:tensorflow:Done calling model_fn.\n", "INFO:tensorflow:Create CheckpointSaverHook.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Graph was finalized.\n", "INFO:tensorflow:Running local_init_op.\n", "INFO:tensorflow:Done running local_init_op.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Saving checkpoints for 0 into /tmp/tmpts3Kmu/model.ckpt.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:loss = 584.82294, step = 0\n", "INFO:tensorflow:loss = 560.2794, step = 0 (0.369 sec)\n", "INFO:tensorflow:loss = 606.68115, step = 0 (0.156 sec)\n", "INFO:tensorflow:loss = 583.2771, step = 0 (0.155 sec)\n", "INFO:tensorflow:loss = 603.4647, step = 0 (0.160 sec)\n", "INFO:tensorflow:loss = 605.8213, step = 0 (0.153 sec)\n", "INFO:tensorflow:loss = 577.5599, step = 0 (0.157 sec)\n", "INFO:tensorflow:loss = 585.297, step = 0 (0.157 sec)\n", "INFO:tensorflow:loss = 545.26074, step = 0 (0.156 sec)\n", "INFO:tensorflow:loss = 597.91046, step = 0 (0.190 sec)\n", "INFO:tensorflow:loss = 600.55396, step = 0 (0.174 sec)\n", "INFO:tensorflow:global_step/sec: 47.5449\n", "INFO:tensorflow:loss = 539.62646, step = 100 (0.280 sec)\n", "INFO:tensorflow:global_step/sec: 592.267\n", "INFO:tensorflow:loss = 573.9592, step = 200 (0.169 sec)\n", "INFO:tensorflow:global_step/sec: 573.943\n", "INFO:tensorflow:loss = 617.79407, step = 300 (0.175 sec)\n", "INFO:tensorflow:global_step/sec: 583.88\n", "INFO:tensorflow:loss = 593.62915, step = 400 (0.171 sec)\n", "INFO:tensorflow:global_step/sec: 595.888\n", "INFO:tensorflow:loss = 594.5435, step = 500 (0.168 sec)\n", "INFO:tensorflow:global_step/sec: 610.997\n", "INFO:tensorflow:loss = 579.5427, step = 600 (0.163 sec)\n", "INFO:tensorflow:global_step/sec: 625.07\n", "INFO:tensorflow:loss = 555.19604, step = 700 (0.160 sec)\n", "INFO:tensorflow:global_step/sec: 674.427\n", "INFO:tensorflow:loss = 585.61127, step = 800 (0.149 sec)\n", "INFO:tensorflow:global_step/sec: 652.597\n", "INFO:tensorflow:loss = 645.147, step = 900 (0.153 sec)\n", "INFO:tensorflow:global_step/sec: 656.608\n", "INFO:tensorflow:loss = 65.438034, step = 1000 (0.152 sec)\n", "INFO:tensorflow:global_step/sec: 660.171\n", "INFO:tensorflow:loss = 57.25811, step = 1100 (0.151 sec)\n", "INFO:tensorflow:global_step/sec: 676.676\n", "INFO:tensorflow:loss = 70.39737, step = 1200 (0.148 sec)\n", "INFO:tensorflow:global_step/sec: 664.916\n", "INFO:tensorflow:loss = 63.969463, step = 1300 (0.150 sec)\n", "INFO:tensorflow:global_step/sec: 679.204\n", "INFO:tensorflow:loss = 55.910896, step = 1400 (0.147 sec)\n", "INFO:tensorflow:global_step/sec: 680.936\n", "INFO:tensorflow:loss = 58.16027, step = 1500 (0.147 sec)\n", "INFO:tensorflow:global_step/sec: 670.412\n", "INFO:tensorflow:loss = 66.20054, step = 1600 (0.149 sec)\n", "INFO:tensorflow:global_step/sec: 673.441\n", "INFO:tensorflow:loss = 52.643417, step = 1700 (0.149 sec)\n", "INFO:tensorflow:global_step/sec: 684.782\n", "INFO:tensorflow:loss = 59.981026, step = 1800 (0.145 sec)\n", "INFO:tensorflow:global_step/sec: 684.191\n", "INFO:tensorflow:loss = 65.427055, step = 1900 (0.146 sec)\n", "INFO:tensorflow:global_step/sec: 683.812\n", "INFO:tensorflow:Saving checkpoints for 2000 into /tmp/tmpts3Kmu/model.ckpt.\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Loss for final step: 42.740192.\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gbdt_regressor.train(train_input_fn, max_steps=max_steps)" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "INFO:tensorflow:Calling model_fn.\n", "INFO:tensorflow:Done calling model_fn.\n", "INFO:tensorflow:Starting evaluation at 2020-07-15T00:50:45Z\n", "INFO:tensorflow:Graph was finalized.\n", "INFO:tensorflow:Restoring parameters from /tmp/tmpts3Kmu/model.ckpt-2000\n", "INFO:tensorflow:Running local_init_op.\n", "INFO:tensorflow:Done running local_init_op.\n", "INFO:tensorflow:Inference Time : 0.24467s\n", "INFO:tensorflow:Finished evaluation at 2020-07-15-00:50:45\n", "INFO:tensorflow:Saving dict for global step 2000: average_loss = 30.202602, global_step = 2000, label/mean = 23.078432, loss = 30.202602, prediction/mean = 22.536291\n", "WARNING:tensorflow:Issue encountered when serializing resources.\n", "Type is unsupported, or the types of the items don't match field type in CollectionDef. Note this is a warning and probably safe to ignore.\n", "'_Resource' object has no attribute 'name'\n", "INFO:tensorflow:Saving 'checkpoint_path' summary for global step 2000: /tmp/tmpts3Kmu/model.ckpt-2000\n" ] }, { "data": { "text/plain": [ "{'average_loss': 30.202602,\n", " 'global_step': 2000,\n", " 'label/mean': 23.078432,\n", " 'loss': 30.202602,\n", " 'prediction/mean': 22.536291}" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gbdt_regressor.evaluate(test_input_fn)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.18" } }, "nbformat": 4, "nbformat_minor": 2 }