|
@@ -15,6 +15,9 @@ import logging
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
|
|
|
+from fbprophet.diagnostics import performance_metrics
|
|
|
+
|
|
|
+
|
|
|
logging.basicConfig()
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
@@ -367,3 +370,78 @@ def add_changepoints_to_plot(
|
|
|
for cp in signif_changepoints:
|
|
|
artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
|
|
|
return artists
|
|
|
+
|
|
|
+
|
|
|
+def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
|
|
|
+ """Plot a performance metric vs. forecast horizon from cross validation.
|
|
|
+
|
|
|
+ Cross validation produces a collection of out-of-sample model predictions
|
|
|
+ that can be compared to actual values, at a range of different horizons
|
|
|
+ (distance from the cutoff). This computes a specified performance metric
|
|
|
+ for each prediction, and aggregated over a rolling window with horizon.
|
|
|
+
|
|
|
+ This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
|
|
|
+ Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
|
|
|
+
|
|
|
+ rolling_window is the proportion of data included in the rolling window of
|
|
|
+ aggregation. The default value of 0.1 means 10% of data are included in the
|
|
|
+ aggregation for computing the metric.
|
|
|
+
|
|
|
+ As a concrete example, if metric='mse', then this plot will show the
|
|
|
+ squared error for each cross validation prediction, along with the MSE
|
|
|
+ averaged over rolling windows of 10% of the data.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ df_cv: The output from fbprophet.diagnostics.cross_validation.
|
|
|
+ metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
|
|
|
+ rolling_window: Proportion of data to use for rolling average of metric.
|
|
|
+ In [0, 1]. Defaults to 0.1.
|
|
|
+ ax: Optional matplotlib axis on which to plot. If not given, a new figure
|
|
|
+ will be created.
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ a matplotlib figure.
|
|
|
+ """
|
|
|
+ if ax is None:
|
|
|
+ fig = plt.figure(facecolor='w', figsize=(10, 6))
|
|
|
+ ax = fig.add_subplot(111)
|
|
|
+ else:
|
|
|
+ fig = ax.get_figure()
|
|
|
+ # Get the metric at the level of individual predictions, and with the rolling window.
|
|
|
+ df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
|
|
|
+ df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
|
|
|
+
|
|
|
+ # Some work because matplotlib does not handle timedelta
|
|
|
+ # Target ~10 ticks.
|
|
|
+ tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
|
|
|
+ # Find the largest time resolution that has <1 unit per bin.
|
|
|
+ dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
|
|
|
+ dt_names = [
|
|
|
+ 'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
|
|
|
+ 'nanoseconds'
|
|
|
+ ]
|
|
|
+ dt_conversions = [
|
|
|
+ 24 * 60 * 60 * 10 ** 9,
|
|
|
+ 60 * 60 * 10 ** 9,
|
|
|
+ 60 * 10 ** 9,
|
|
|
+ 10 ** 9,
|
|
|
+ 10 ** 6,
|
|
|
+ 10 ** 3,
|
|
|
+ 1.,
|
|
|
+ ]
|
|
|
+ for i, dt in enumerate(dts):
|
|
|
+ if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
|
|
|
+ break
|
|
|
+
|
|
|
+ x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
|
|
|
+ x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
|
|
|
+
|
|
|
+ ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
|
|
|
+ ax.plot(x_plt_h, df_h[metric], '-', c='b')
|
|
|
+ ax.grid(True)
|
|
|
+
|
|
|
+ ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
|
|
|
+ ax.set_ylabel(metric)
|
|
|
+ return fig
|