浏览代码

Add a visualiztion of cross validation prediction performance vs. horizon

Ben Letham 7 年之前
父节点
当前提交
8198afe17a
共有 3 个文件被更改,包括 333 次插入62 次删除
  1. 251 58
      notebooks/diagnostics.ipynb
  2. 4 4
      python/fbprophet/diagnostics.py
  3. 78 0
      python/fbprophet/plot.py

文件差异内容过多而无法显示
+ 251 - 58
notebooks/diagnostics.ipynb


+ 4 - 4
python/fbprophet/diagnostics.py

@@ -196,7 +196,7 @@ def prophet_copy(m, cutoff=None):
     return m2
     return m2
 
 
 
 
-def performance_metrics(df, metrics=None, rolling_window=0.05):
+def performance_metrics(df, metrics=None, rolling_window=0.1):
     """Compute performance metrics from cross-validation results.
     """Compute performance metrics from cross-validation results.
 
 
     Computes a suite of performance metrics on the output of cross-validation.
     Computes a suite of performance metrics on the output of cross-validation.
@@ -216,7 +216,7 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     which specifies a proportion of simulated forecast points to include in
     which specifies a proportion of simulated forecast points to include in
     each window. rolling_window=0 will compute it separately for each simulated
     each window. rolling_window=0 will compute it separately for each simulated
     forecast point (i.e., 'mse' will actually be squared error with no mean).
     forecast point (i.e., 'mse' will actually be squared error with no mean).
-    The default of rolling_window=0.05 will use 5% of the rows in df in each
+    The default of rolling_window=0.1 will use 10% of the rows in df in each
     window. rolling_window=1 will compute the metric across all simulated forecast
     window. rolling_window=1 will compute the metric across all simulated forecast
     points. The results are set to the right edge of the window.
     points. The results are set to the right edge of the window.
 
 
@@ -227,9 +227,9 @@ def performance_metrics(df, metrics=None, rolling_window=0.05):
     ----------
     ----------
     df: The dataframe returned by cross_validation.
     df: The dataframe returned by cross_validation.
     metrics: A list of performance metrics to compute. If not provided, will
     metrics: A list of performance metrics to compute. If not provided, will
-        use ['mse', 'mae', 'mape', 'coverage', 'rmse'].
+        use ['mse', 'rmse', 'mae', 'mape', 'coverage'].
     rolling_window: Proportion of data to use in each rolling window for
     rolling_window: Proportion of data to use in each rolling window for
-        computing the metrics.
+        computing the metrics. Should be in [0, 1].
 
 
     Returns
     Returns
     -------
     -------

+ 78 - 0
python/fbprophet/plot.py

@@ -15,6 +15,9 @@ import logging
 import numpy as np
 import numpy as np
 import pandas as pd
 import pandas as pd
 
 
+from fbprophet.diagnostics import performance_metrics
+
+
 logging.basicConfig()
 logging.basicConfig()
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
@@ -367,3 +370,78 @@ def add_changepoints_to_plot(
     for cp in signif_changepoints:
     for cp in signif_changepoints:
         artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
         artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
     return artists
     return artists
+
+
+def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
+    """Plot a performance metric vs. forecast horizon from cross validation.
+
+    Cross validation produces a collection of out-of-sample model predictions
+    that can be compared to actual values, at a range of different horizons
+    (distance from the cutoff). This computes a specified performance metric
+    for each prediction, and aggregated over a rolling window with horizon.
+
+    This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
+    Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
+
+    rolling_window is the proportion of data included in the rolling window of
+    aggregation. The default value of 0.1 means 10% of data are included in the
+    aggregation for computing the metric.
+
+    As a concrete example, if metric='mse', then this plot will show the
+    squared error for each cross validation prediction, along with the MSE
+    averaged over rolling windows of 10% of the data.
+
+    Parameters
+    ----------
+    df_cv: The output from fbprophet.diagnostics.cross_validation.
+    metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
+    rolling_window: Proportion of data to use for rolling average of metric.
+        In [0, 1]. Defaults to 0.1.
+    ax: Optional matplotlib axis on which to plot. If not given, a new figure
+        will be created.
+
+    Returns
+    -------
+    a matplotlib figure.
+    """
+    if ax is None:
+        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        ax = fig.add_subplot(111)
+    else:
+        fig = ax.get_figure()
+    # Get the metric at the level of individual predictions, and with the rolling window.
+    df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
+    df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
+
+    # Some work because matplotlib does not handle timedelta
+    # Target ~10 ticks.
+    tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
+    # Find the largest time resolution that has <1 unit per bin.
+    dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
+    dt_names = [
+        'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
+        'nanoseconds'
+    ]
+    dt_conversions = [
+        24 * 60 * 60 * 10 ** 9,
+        60 * 60 * 10 ** 9,
+        60 * 10 ** 9,
+        10 ** 9,
+        10 ** 6,
+        10 ** 3,
+        1.,
+    ]
+    for i, dt in enumerate(dts):
+        if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
+            break
+
+    x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
+    x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(int) / float(dt_conversions[i])
+
+    ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
+    ax.plot(x_plt_h, df_h[metric], '-', c='b')
+    ax.grid(True)
+
+    ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
+    ax.set_ylabel(metric)
+    return fig