|
@@ -196,67 +196,114 @@ def prophet_copy(m, cutoff=None):
|
|
|
return m2
|
|
|
|
|
|
|
|
|
-def me(df):
|
|
|
- return((df['yhat'] - df['y']).sum()/len(df['yhat']))
|
|
|
-def mse(df):
|
|
|
- return((df['yhat'] - df['y']).pow(2).sum()/len(df))
|
|
|
-def rmse(df):
|
|
|
- return(np.sqrt((df['yhat'] - df['y']).pow(2).sum()/len(df)))
|
|
|
-def mae(df):
|
|
|
- return((df['yhat'] - df['y']).abs().sum()/len(df))
|
|
|
-def mpe(df):
|
|
|
- return((df['yhat'] - df['y']).div(df['y']).sum()*(1/len(df)))
|
|
|
-def mape(df):
|
|
|
- return((df['yhat'] - df['y']).div(df['y']).abs().sum()*(1/len(df)))
|
|
|
-
|
|
|
-def all_metrics(model, df_cv = None):
|
|
|
- """Compute model fit metrics for time series.
|
|
|
-
|
|
|
- Computes the following metrics about each time series that has been through
|
|
|
- Cross Validation;
|
|
|
-
|
|
|
- Mean Error (ME)
|
|
|
- Mean Squared Error (MSE)
|
|
|
- Root Mean Square Error (RMSE,
|
|
|
- Mean Absolute Error (MAE)
|
|
|
- Mean Percentage Error (MPE)
|
|
|
- Mean Absolute Percentage Error (MAPE)
|
|
|
+def performance_metrics(df, metrics=None, aggregation='horizon'):
|
|
|
+ """Compute performance metrics from cross-validation results.
|
|
|
+
|
|
|
+ Computes a suite of performance metrics on the output of cross-validation.
|
|
|
+ By default the following metrics are included:
|
|
|
+ 'mse': mean squared error
|
|
|
+ 'mae': mean absolute error
|
|
|
+ 'mape': mean percent error
|
|
|
+ 'coverage': coverage of the upper and lower intervals
|
|
|
+
|
|
|
+ A subset of these can be specified by passing a list of names as the
|
|
|
+ `metrics` argument.
|
|
|
+
|
|
|
+ By default, metrics will be computed for each horizon (ds - cutoff).
|
|
|
+ Alternatively, metrics can be computed at the level of individual ds/cutoff
|
|
|
+ pairs (aggregation='none'), or aggregated over all ds/cutoffs
|
|
|
+ (aggregation='all').
|
|
|
+
|
|
|
+ The output is a dataframe containing the columns corresponding to the level
|
|
|
+ of aggregation ('horizon', 'ds' and 'cutoff', or none) along with columns
|
|
|
+ for each of the metrics computed.
|
|
|
|
|
|
Parameters
|
|
|
----------
|
|
|
- df: A pandas dataframe. Contains y and yhat produced by cross-validation
|
|
|
+ df: The dataframe returned by cross_validation.
|
|
|
+ metrics: A list of performance metrics to compute. If not provided, will
|
|
|
+ use ['mse', 'mae', 'mape', 'coverage'].
|
|
|
+ aggregation: Level of aggregation for computing performance statistics.
|
|
|
+ Must be 'horizon', 'none', or 'all'.
|
|
|
|
|
|
Returns
|
|
|
-------
|
|
|
- A dictionary where the key = the error type, and value is the value of the error
|
|
|
+ Dataframe with a column for each metric, and a combination of columns 'ds',
|
|
|
+ 'cutoff', and 'horizon', depending on the aggregation level.
|
|
|
"""
|
|
|
+ # Input validation
|
|
|
+ valid_aggregations = ['horizon', 'all', 'none']
|
|
|
+ if aggregation not in valid_aggregations:
|
|
|
+ raise ValueError(
|
|
|
+ 'Aggregation {} is not valid; must be one of {}'.format(
|
|
|
+ aggregation, valid_agggregations
|
|
|
+ )
|
|
|
+ )
|
|
|
+ valid_metrics = ['mse', 'mae', 'mape', 'coverage']
|
|
|
+ if metrics is None:
|
|
|
+ metrics = valid_metrics
|
|
|
+ if len(set(metrics)) != len(metrics):
|
|
|
+ raise ValueError('Input metrics must be a list of unique values')
|
|
|
+ if not set(metrics).issubset(set(valid_metrics)):
|
|
|
+ raise ValueError(
|
|
|
+ 'Valid values for metrics are: {}'.format(valid_metrics)
|
|
|
+ )
|
|
|
+ # Get function for the metrics we want
|
|
|
+ metric_fns = {m: eval(m) for m in metrics}
|
|
|
+ def all_metrics(df_g):
|
|
|
+ return pd.Series({name: fn(df_g) for name, fn in metric_fns.items()})
|
|
|
+ # Apply functions to groupby
|
|
|
+ if aggregation == 'all':
|
|
|
+ return all_metrics(df)
|
|
|
+ # else,
|
|
|
+ df_m = df.copy()
|
|
|
+ df_m['horizon'] = df_m['ds'] - df_m['cutoff']
|
|
|
+ if aggregation == 'horizon':
|
|
|
+ return df_m.groupby('horizon').apply(all_metrics).reset_index()
|
|
|
+ # else,
|
|
|
+ for name, fn in metric_fns.items():
|
|
|
+ df_m[name] = fn(df_m, agg=False)
|
|
|
+ return df_m
|
|
|
+
|
|
|
+
|
|
|
+# The functions below specify performance metrics for cross-validation results.
|
|
|
+# Each takes as input the output of cross_validation, and has two modes of
|
|
|
+# return: if agg=True, returns a float that is the metric aggregated over the
|
|
|
+# input. If agg=False, returns results without aggregation (for
|
|
|
+# aggregation='none' in performance_metrics).
|
|
|
+
|
|
|
+
|
|
|
+def mse(df, agg=True):
|
|
|
+ """Mean squared error
|
|
|
+ """
|
|
|
+ se = (df['y'] - df['yhat']) ** 2
|
|
|
+ if agg:
|
|
|
+ return np.mean(se)
|
|
|
+ return se
|
|
|
|
|
|
-
|
|
|
|
|
|
- df = []
|
|
|
+def mae(df, agg=True):
|
|
|
+ """Mean absolute error
|
|
|
+ """
|
|
|
+ ae = np.abs(df['y'] - df['yhat'])
|
|
|
+ if agg:
|
|
|
+ return np.mean(ae)
|
|
|
+ return ae
|
|
|
|
|
|
- if df_cv is not None:
|
|
|
- df = df_cv
|
|
|
- else:
|
|
|
- # run a forecast on your own data with period = 0 so that it is in-sample data onlyl
|
|
|
- #df = model.predict(model.make_future_dataframe(periods=0))[['y', 'yhat']]
|
|
|
- df = (model
|
|
|
- .history[['ds', 'y']]
|
|
|
- .merge(
|
|
|
- model.predict(model.make_future_dataframe(periods=0))[['ds', 'yhat']],
|
|
|
- how='inner', on='ds'
|
|
|
- )
|
|
|
- )
|
|
|
-
|
|
|
- if 'yhat' not in df.columns:
|
|
|
- raise ValueError(
|
|
|
- 'Please run Cross-Validation first before computing quality metrics.')
|
|
|
-
|
|
|
- return {
|
|
|
- 'ME':me(df),
|
|
|
- 'MSE':mse(df),
|
|
|
- 'RMSE': rmse(df),
|
|
|
- 'MAE': mae(df),
|
|
|
- 'MPE': mpe(df),
|
|
|
- 'MAPE': mape(df)
|
|
|
- }
|
|
|
+
|
|
|
+def mape(df, agg=True):
|
|
|
+ """Mean absolute percent error
|
|
|
+ """
|
|
|
+ ape = np.abs((df['y'] - df['yhat']) / df['y'])
|
|
|
+ if agg:
|
|
|
+ return np.mean(ape)
|
|
|
+ return ape
|
|
|
+
|
|
|
+
|
|
|
+def coverage(df, agg=True):
|
|
|
+ """Coverage
|
|
|
+ """
|
|
|
+ is_covered = (df['y'] >= df['yhat_lower']) & (df['y'] <= df['yhat_upper'])
|
|
|
+ if agg:
|
|
|
+ return np.mean(is_covered)
|
|
|
+ return is_covered
|