|
@@ -191,3 +191,152 @@ prophet_copy <- function(m, cutoff = NULL) {
|
|
|
m2$seasonalities <- m$seasonalities
|
|
|
return(m2)
|
|
|
}
|
|
|
+
|
|
|
+#' Compute performance metrics from cross-validation results.
|
|
|
+#'
|
|
|
+#' Computes a suite of performance metrics on the output of cross-validation.
|
|
|
+#' By default the following metrics are included:
|
|
|
+#' 'mse': mean squared error
|
|
|
+#' 'rmse': root mean squared error
|
|
|
+#' 'mae': mean absolute error
|
|
|
+#' 'mape': mean percent error
|
|
|
+#' 'coverage': coverage of the upper and lower intervals
|
|
|
+#'
|
|
|
+#' A subset of these can be specified by passing a list of names as the
|
|
|
+#' `metrics` argument.
|
|
|
+#'
|
|
|
+#' Metrics are calculated over a rolling window of cross validation
|
|
|
+#' predictions, after sorting by horizon. The size of that window (number of
|
|
|
+#' simulated forecast points) is determined by the rolling_window argument,
|
|
|
+#' which specifies a proportion of simulated forecast points to include in
|
|
|
+#' each window. rolling_window=0 will compute it separately for each simulated
|
|
|
+#' forecast point (i.e., 'mse' will actually be squared error with no mean).
|
|
|
+#' The default of rolling_window=0.1 will use 10% of the rows in df in each
|
|
|
+#' window. rolling_window=1 will compute the metric across all simulated
|
|
|
+#' forecast points. The results are set to the right edge of the window.
|
|
|
+#'
|
|
|
+#' The output is a dataframe containing column 'horizon' along with columns
|
|
|
+#' for each of the metrics computed.
|
|
|
+#'
|
|
|
+#' @param df The dataframe returned by cross_validation.
|
|
|
+#' @param metrics An array of performance metrics to compute. If not provided,
|
|
|
+#' will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
|
|
|
+#' @param rolling_window Proportion of data to use in each rolling window for
|
|
|
+#' computing the metrics. Should be in [0, 1].
|
|
|
+#'
|
|
|
+#' @return A dataframe with a column for each metric, and column 'horizon'.
|
|
|
+#'
|
|
|
+#' @export
|
|
|
+performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
|
|
|
+ valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'coverage')
|
|
|
+ if (is.null(metrics)) {
|
|
|
+ metrics <- valid_metrics
|
|
|
+ }
|
|
|
+ if (length(metrics) != length(unique(metrics))) {
|
|
|
+ stop('Input metrics must be an array of unique values.')
|
|
|
+ }
|
|
|
+ if (!all(metrics %in% valid_metrics)) {
|
|
|
+ stop(
|
|
|
+ paste('Valid values for metrics are:', paste(metrics, collapse = ", "))
|
|
|
+ )
|
|
|
+ }
|
|
|
+ df_m <- df
|
|
|
+ df_m$horizon <- df_m$ds - df_m$cutoff
|
|
|
+ df_m <- df_m[order(df_m$horizon),]
|
|
|
+ # Window size
|
|
|
+ w <- as.integer(rolling_window * nrow(df_m))
|
|
|
+ w <- max(w, 1)
|
|
|
+ w <- min(w, nrow(df_m))
|
|
|
+ cols <- c('horizon')
|
|
|
+ for (metric in metrics) {
|
|
|
+ df_m[[metric]] <- get(metric)(df_m, w)
|
|
|
+ cols <- c(cols, metric)
|
|
|
+ }
|
|
|
+ df_m <- df_m[cols]
|
|
|
+ return(na.omit(df_m))
|
|
|
+}
|
|
|
+
|
|
|
+#' Compute a rolling mean of x
|
|
|
+#'
|
|
|
+#' Right-aligned. Padded with NAs on the front so the output is the same
|
|
|
+#' size as x.
|
|
|
+#'
|
|
|
+#' @param x Array.
|
|
|
+#' @param w Integer window size (number of elements).
|
|
|
+#'
|
|
|
+#' @return Rolling mean of x with window size w.
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+rolling_mean <- function(x, w) {
|
|
|
+ s <- cumsum(c(0, x))
|
|
|
+ prefix <- rep(NA, w - 1)
|
|
|
+ return(c(prefix, (s[(w + 1):length(s)] - s[1:(length(s) - w)]) / w))
|
|
|
+}
|
|
|
+
|
|
|
+# The functions below specify performance metrics for cross-validation results.
|
|
|
+# Each takes as input the output of cross_validation, and returns the statistic
|
|
|
+# as an array, given a window size for rolling aggregation.
|
|
|
+
|
|
|
+#' Mean squared error
|
|
|
+#'
|
|
|
+#' @param df Cross-validation results dataframe.
|
|
|
+#' @param w Aggregation window size.
|
|
|
+#'
|
|
|
+#' @return Array of mean squared errors.
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+mse <- function(df, w) {
|
|
|
+ se <- (df$y - df$yhat) ** 2
|
|
|
+ return(rolling_mean(se, w))
|
|
|
+}
|
|
|
+
|
|
|
+#' Root mean squared error
|
|
|
+#'
|
|
|
+#' @param df Cross-validation results dataframe.
|
|
|
+#' @param w Aggregation window size.
|
|
|
+#'
|
|
|
+#' @return Array of root mean squared errors.
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+rmse <- function(df, w) {
|
|
|
+ return(sqrt(mse(df, w)))
|
|
|
+}
|
|
|
+
|
|
|
+#' Mean absolute error
|
|
|
+#'
|
|
|
+#' @param df Cross-validation results dataframe.
|
|
|
+#' @param w Aggregation window size.
|
|
|
+#'
|
|
|
+#' @return Array of mean absolute errors.
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+mae <- function(df, w) {
|
|
|
+ ae <- abs(df$y - df$yhat)
|
|
|
+ return(rolling_mean(ae, w))
|
|
|
+}
|
|
|
+
|
|
|
+#' Mean absolute percent error
|
|
|
+#'
|
|
|
+#' @param df Cross-validation results dataframe.
|
|
|
+#' @param w Aggregation window size.
|
|
|
+#'
|
|
|
+#' @return Array of mean absolute percent errors.
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+mape <- function(df, w) {
|
|
|
+ ape <- abs((df$y - df$yhat) / df$y)
|
|
|
+ return(rolling_mean(ape, w))
|
|
|
+}
|
|
|
+
|
|
|
+#' Coverage
|
|
|
+#'
|
|
|
+#' @param df Cross-validation results dataframe.
|
|
|
+#' @param w Aggregation window size.
|
|
|
+#'
|
|
|
+#' @return Array of coverages
|
|
|
+#'
|
|
|
+#' @keywords internal
|
|
|
+coverage <- function(df, w) {
|
|
|
+ is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper)
|
|
|
+ return(rolling_mean(is_covered, w))
|
|
|
+}
|