Переглянути джерело

Add a visualiztion of cross validation prediction performance vs. horizon

Ben Letham 7 роки тому
батько
коміт
8198afe17a
3 змінених файлів з 333 додано та 62 видалено
  1. 251 58
      notebooks/diagnostics.ipynb
  2. 4 4
      python/fbprophet/diagnostics.py
  3. 78 0
      python/fbprophet/plot.py

Різницю між файлами не показано, бо вона завелика
+ 251 - 58
notebooks/diagnostics.ipynb


+ 4 - 4
python/fbprophet/diagnostics.py

@@ -196,7 +196,7 @@ def prophet_copy(m, cutoff=None):
     return m2
 
 
-def performance_metrics(df, metrics=None, rolling_window=0.05):
+def performance_metrics(df, metrics=None, rolling_window=0.1):
     """Compute performance metrics from cross-validation results.
 
     Computes a suite of performance metrics on the output of cross-validation.
@@ -216,7 +216,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     which specifies a proportion of simulated forecast points to include in
     each window. rolling_window=0 will compute it separately for each simulated
     forecast point (i.e., 'mse' will actually be squared error with no mean).
-    The default of rolling_window=0.05 will use 5% of the rows in df in each
+    The default of rolling_window=0.1 will use 10% of the rows in df in each
     window. rolling_window=1 will compute the metric across all simulated forecast
     points. The results are set to the right edge of the window.
 
@@ -227,9 +227,9 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     ----------
     df: The dataframe returned by cross_validation.
     metrics: A list of performance metrics to compute. If not provided, will
-        use ['mse', 'mae', 'mape', 'coverage', 'rmse'].
+        use ['mse', 'rmse', 'mae', 'mape', 'coverage'].
     rolling_window: Proportion of data to use in each rolling window for
-        computing the metrics.
+        computing the metrics. Should be in [0, 1].
 
     Returns
     -------

+ 78 - 0
python/fbprophet/plot.py

@@ -15,6 +15,9 @@ import logging
 import numpy as np
 import pandas as pd
 
+from fbprophet.diagnostics import performance_metrics
+
+
 logging.basicConfig()
 logger = logging.getLogger(__name__)
 
@@ -367,3 +370,78 @@ def add_changepoints_to_plot(
     for cp in signif_changepoints:
         artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
     return artists
+
+
+def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
+    """Plot a performance metric vs. forecast horizon from cross validation.
+
+    Cross validation produces a collection of out-of-sample model predictions
+    that can be compared to actual values, at a range of different horizons
+    (distance from the cutoff). This computes a specified performance metric
+    for each prediction, and aggregated over a rolling window with horizon.
+
+    This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
+    Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
+
+    rolling_window is the proportion of data included in the rolling window of
+    aggregation. The default value of 0.1 means 10% of data are included in the
+    aggregation for computing the metric.
+
+    As a concrete example, if metric='mse', then this plot will show the
+    squared error for each cross validation prediction, along with the MSE
+    averaged over rolling windows of 10% of the data.
+
+    Parameters
+    ----------
+    df_cv: The output from fbprophet.diagnostics.cross_validation.
+    metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
+    rolling_window: Proportion of data to use for rolling average of metric.
+        In [0, 1]. Defaults to 0.1.
+    ax: Optional matplotlib axis on which to plot. If not given, a new figure
+        will be created.
+
+    Returns
+    -------
+    a matplotlib figure.
+    """
+    if ax is None:
+        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        ax = fig.add_subplot(111)
+    else:
+        fig = ax.get_figure()
+    # Get the metric at the level of individual predictions, and with the rolling window.
+    df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
+    df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
+
+    # Some work because matplotlib does not handle timedelta
+    # Target ~10 ticks.
+    tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
+    # Find the largest time resolution that has <1 unit per bin.
+    dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
+    dt_names = [
+        'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
+        'nanoseconds'
+    ]
+    dt_conversions = [
+        24 * 60 * 60 * 10 ** 9,
+        60 * 60 * 10 ** 9,
+        60 * 10 ** 9,
+        10 ** 9,
+        10 ** 6,
+        10 ** 3,
+        1.,
+    ]
+    for i, dt in enumerate(dts):
+        if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
+            break
+
+    x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
+    x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
+
+    ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
+    ax.plot(x_plt_h, df_h[metric], '-', c='b')
+    ax.grid(True)
+
+    ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
+    ax.set_ylabel(metric)
+    return fig