Browse Source

Refactor cross validation metrics for rolling window, add visualization, put example in notebook (R)

Ben Letham 7 năm trước cách đây
mục cha
commit
b052b56d33

+ 0 - 2
R/DESCRIPTION

@@ -25,8 +25,6 @@ Imports:
     stats,
     tidyr (>= 0.6.1),
     utils,
-    purrr,
-    rlang,
     xts
 Suggests:
     knitr,

+ 2 - 7
R/NAMESPACE

@@ -5,21 +5,16 @@ S3method(predict,prophet)
 export(add_changepoints_to_plot)
 export(add_regressor)
 export(add_seasonality)
-export(all_metrics)
 export(cross_validation)
 export(dyplot.prophet)
 export(fit.prophet)
-export(mae)
 export(make_future_dataframe)
-export(mape)
-export(me)
-export(mpe)
-export(mse)
+export(performance_metrics)
+export(plot_cross_validation_metric)
 export(plot_forecast_component)
 export(predictive_samples)
 export(prophet)
 export(prophet_plot_components)
-export(rmse)
 export(simulated_historical_forecasts)
 import(Rcpp)
 importFrom(dplyr,"%>%")

+ 149 - 0
R/R/diagnostics.R

@@ -191,3 +191,152 @@ prophet_copy <- function(m, cutoff = NULL) {
   m2$seasonalities <- m$seasonalities
   return(m2)
 }
+
+#' Compute performance metrics from cross-validation results.
+#'
+#' Computes a suite of performance metrics on the output of cross-validation.
+#' By default the following metrics are included:
+#' 'mse': mean squared error
+#' 'rmse': root mean squared error
+#' 'mae': mean absolute error
+#' 'mape': mean percent error
+#' 'coverage': coverage of the upper and lower intervals
+#'
+#' A subset of these can be specified by passing a list of names as the
+#' `metrics` argument.
+#'
+#' Metrics are calculated over a rolling window of cross validation
+#' predictions, after sorting by horizon. The size of that window (number of
+#' simulated forecast points) is determined by the rolling_window argument,
+#' which specifies a proportion of simulated forecast points to include in
+#' each window. rolling_window=0 will compute it separately for each simulated
+#' forecast point (i.e., 'mse' will actually be squared error with no mean).
+#' The default of rolling_window=0.1 will use 10% of the rows in df in each
+#' window. rolling_window=1 will compute the metric across all simulated
+#' forecast points. The results are set to the right edge of the window.
+#'
+#' The output is a dataframe containing column 'horizon' along with columns
+#' for each of the metrics computed.
+#'
+#' @param df The dataframe returned by cross_validation.
+#' @param metrics An array of performance metrics to compute. If not provided,
+#'  will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
+#' @param rolling_window Proportion of data to use in each rolling window for
+#'  computing the metrics. Should be in [0, 1].
+#'
+#' @return A dataframe with a column for each metric, and column 'horizon'.
+#'
+#' @export
+performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
+  valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'coverage')
+  if (is.null(metrics)) {
+    metrics <- valid_metrics
+  }
+  if (length(metrics) != length(unique(metrics))) {
+    stop('Input metrics must be an array of unique values.')
+  }
+  if (!all(metrics %in% valid_metrics)) {
+    stop(
+      paste('Valid values for metrics are:', paste(metrics, collapse = ", "))
+    )
+  }
+  df_m <- df
+  df_m$horizon <- df_m$ds - df_m$cutoff
+  df_m <- df_m[order(df_m$horizon),]
+  # Window size
+  w <- as.integer(rolling_window * nrow(df_m))
+  w <- max(w, 1)
+  w <- min(w, nrow(df_m))
+  cols <- c('horizon')
+  for (metric in metrics) {
+    df_m[[metric]] <- get(metric)(df_m, w)
+    cols <- c(cols, metric)
+  }
+  df_m <- df_m[cols]
+  return(na.omit(df_m))
+}
+
+#' Compute a rolling mean of x
+#'
+#' Right-aligned. Padded with NAs on the front so the output is the same
+#' size as x.
+#'
+#' @param x Array.
+#' @param w Integer window size (number of elements).
+#'
+#' @return Rolling mean of x with window size w.
+#'
+#' @keywords internal
+rolling_mean <- function(x, w) {
+  s <- cumsum(c(0, x))
+  prefix <- rep(NA, w - 1)
+  return(c(prefix, (s[(w + 1):length(s)] - s[1:(length(s) - w)]) / w))
+}
+
+# The functions below specify performance metrics for cross-validation results.
+# Each takes as input the output of cross_validation, and returns the statistic
+# as an array, given a window size for rolling aggregation.
+
+#' Mean squared error
+#'
+#' @param df Cross-validation results dataframe.
+#' @param w Aggregation window size.
+#'
+#' @return Array of mean squared errors.
+#'
+#' @keywords internal
+mse <- function(df, w) {
+  se <- (df$y - df$yhat) ** 2
+  return(rolling_mean(se, w))
+}
+
+#' Root mean squared error
+#'
+#' @param df Cross-validation results dataframe.
+#' @param w Aggregation window size.
+#'
+#' @return Array of root mean squared errors.
+#'
+#' @keywords internal
+rmse <- function(df, w) {
+  return(sqrt(mse(df, w)))
+}
+
+#' Mean absolute error
+#'
+#' @param df Cross-validation results dataframe.
+#' @param w Aggregation window size.
+#'
+#' @return Array of mean absolute errors.
+#'
+#' @keywords internal
+mae <- function(df, w) {
+  ae <- abs(df$y - df$yhat)
+  return(rolling_mean(ae, w))
+}
+
+#' Mean absolute percent error
+#'
+#' @param df Cross-validation results dataframe.
+#' @param w Aggregation window size.
+#'
+#' @return Array of mean absolute percent errors.
+#'
+#' @keywords internal
+mape <- function(df, w) {
+  ape <- abs((df$y - df$yhat) / df$y)
+  return(rolling_mean(ape, w))
+}
+
+#' Coverage
+#'
+#' @param df Cross-validation results dataframe.
+#' @param w Aggregation window size.
+#'
+#' @return Array of coverages
+#'
+#' @keywords internal
+coverage <- function(df, w) {
+  is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper)
+  return(rolling_mean(is_covered, w))
+}

+ 0 - 150
R/R/metrics.R

@@ -1,150 +0,0 @@
-## Copyright (c) 2017-present, Facebook, Inc.
-## All rights reserved.
-
-## This source code is licensed under the BSD-style license found in the
-## LICENSE file in the root directory of this source tree. An additional grant
-## of patent rights can be found in the PATENTS file in the same directory.
-
-#' @title Metrics for Time Series Forecasts
-#'
-#' @description
-#' A time-series forecast requires making a quantitative prediction of future values.
-#' After forecast, we also have to provide accurracy of forecasts to check wether the forecast serves our need.
-#' Metrics for time series forecasts are so useful in telling you how your model is good and helping you determine which particular forecasting models work best.
-#'
-#' @details
-#' Here, as a notation, we assume that \eqn{y} is the actual value and \eqn{yhat} is the forecast value.
-#'
-#' Mean Error (ME, \code{me})
-#'
-#' The Mean Error (ME)  is defined by the formula:
-#' \deqn{ \frac{1}{n} \sum_{t=1}^{n} y_{t}-yhat_{t} .}
-#'
-#' Mean Squared Error (MSE, \code{mse})
-#'
-#' The Mean Squared Error (MSE)  is defined by the formula:
-#' \deqn{ \frac{1}{n} \sum_{t=1}^{n} (y_{t}-yhat_{t})^2 .}
-#'
-#' Root Mean Square Error (RMSE, \code{rmse})
-#'
-#' Root Mean Square Error (RMSE) is define by the formula:
-#' \deqn{ \sqrt{\frac{1}{n} \sum_{t=1}^{n} (y_{t}-yhat_{t})^2} .}
-#'
-#' Mean Absolute Error (MAE, \code{mae})
-#'
-#' The Mean Absolute Error (MAE) is defined by the formula:
-#' \deqn{ \frac{1}{n} \sum_{t=1}^{n} | y_{t}-yhat_{t} | .}
-#'
-#' Mean Percentage Error (MPE, \code{mpe})
-#'
-#' The Mean Percentage Error (MPE) is usually expressed as a percentage
-#' and is defined by the formula:
-#' \deqn{ \frac{100}{n} \sum_{t=1}^{n} \frac {y_{t}-yhat_{t}}{y_{t}} .}
-#'
-#' Mean Absolute Percentage Error (MAPE, \code{mape})
-#'
-#' The Mean absolute Percentage Error (MAPE), also known as Mean Absolute Percentage Deviation (MAPD), is usually expressed as a percentage,
-#' and is defined by the formula:
-#' \deqn{ \frac{100}{n} \sum_{t=1}^{n} | \frac {y_{t}-yhat_{t}}{y_{t}}| .}
-#'
-#' @param m Prophet object. Default NULL
-#' @param df A dataframe which is output of `simulated_historical_forecasts` or `cross_validation` Default NULL
-#'
-#' @return metrics value (numeric)
-#'
-#'@examples
-#'\dontrun{
-#' # Create example model
-#' library(readr)
-#' library(prophet)
-#' df <- read_csv('../tests/testthat/data.csv')
-#' m <- prophet(df)
-#' future <- make_future_dataframe(m, periods = 365)
-#' forecast <- predict(m, future)
-#' all_metrics(forecast)
-#' df.cv <- cross_validation(m, horizon = 100, units = 'days')
-#' all_metrics(df.cv)
-#' # You can check your models's accuracy using me, mse, rmse ...etc.
-#' print(rmse(m))
-#'}
-#' @name metrics
-NULL
-
-#' Prepare dataframe for metrics calculation.
-#'
-#' @param m Prophet object. Default NULL
-#' @param df A dataframe which is output of `simulated_historical_forecasts` or `cross_validation` Default NULL
-#'
-#' @return A dataframe only with y and yhat as a column.
-#'
-#' @keywords internal
-create_metric_data <- function(m=NULL, df=NULL)
-{
-  if(is.null(m) && is.null(df))
-  {
-    stop("You have to specify one of `m` and `df` at least.")
-  }
-  if(!is.null(m) && !is.null(df))
-  {
-    warning("You specify both of `m` and `df`. `df` is used for metrics calclation.")
-  }
-
-  data <- if(!is.null(df)){
-    df
-  } else if("prophet" %in% class(m)) {
-    dplyr::inner_join(m$history, predict(m, NULL), by="ds")
-  }
-
-  dplyr::select(data, y, yhat) %>% na.omit()
-}
-
-#' Meta function to make the function which evaluate metrics.
-#'
-#' @param metrics metrics function
-#'
-#' @return A function using for metrics evaluation.
-#'
-#' @keywords internal
-make_metrics_function <- function(metrics)
-{
-  function(m=NULL, df=NULL)
-  {
-    data <- create_metric_data(m, df)
-    metrics(data$y, data$yhat)
-  }
-}
-
-#' @rdname metrics
-#' @export
-me <- make_metrics_function(function(y, yhat){mean(y - yhat)})
-
-#' @rdname metrics
-#' @export
-mse <- make_metrics_function(function(y, yhat){mean((y - yhat)^2)})
-
-#' @rdname metrics
-#' @export
-rmse <- make_metrics_function(function(y, yhat){sqrt(mean((y - yhat)^2))})
-
-#' @rdname metrics
-#' @export
-mae <- make_metrics_function(function(y, yhat){mean(abs(y - yhat))})
-
-#' @rdname metrics
-#' @export
-mpe <- make_metrics_function(function(y, yhat){100*mean((y - yhat)/y)})
-
-#' @rdname metrics
-#' @export
-mape <- make_metrics_function(function(y, yhat){100*mean(abs((y - yhat)/y))})
-
-#' @rdname metrics
-#' @export
-all_metrics <- function(m=NULL, df=NULL)
-{
-  # Define all metrics functions as a character
-  metrics <- rlang::set_names(c("me", "mse", "rmse", "mae", "mpe", "mape"))
-  # Convert character to function and evalate each metrics in invoke_map_df
-  # The result is data.frame with each metrics name
-  purrr::invoke_map_df(metrics, list(list(m, df)))
-}

+ 65 - 0
R/R/plot.R

@@ -406,3 +406,68 @@ dyplot.prophet <- function(x, fcst, uncertainty=TRUE,
   return(dyBase)
 }
 
+#' Plot a performance metric vs. forecast horizon from cross validation.
+
+#' Cross validation produces a collection of out-of-sample model predictions
+#' that can be compared to actual values, at a range of different horizons
+#' (distance from the cutoff). This computes a specified performance metric
+#' for each prediction, and aggregated over a rolling window with horizon.
+#'
+#' This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
+#' Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
+#'
+#' rolling_window is the proportion of data included in the rolling window of
+#' aggregation. The default value of 0.1 means 10% of data are included in the
+#' aggregation for computing the metric.
+#'
+#' As a concrete example, if metric='mse', then this plot will show the
+#' squared error for each cross validation prediction, along with the MSE
+#' averaged over rolling windows of 10% of the data.
+#'
+#' @param df_cv The output from fbprophet.diagnostics.cross_validation.
+#' @param metric Metric name, one of 'mse', 'rmse', 'mae', 'mape', 'coverage'.
+#' @param rolling_window Proportion of data to use for rolling average of
+#'  metric. In [0, 1]. Defaults to 0.1.
+#'
+#' @return A ggplot2 plot.
+#'
+#' @export
+plot_cross_validation_metric <- function(df_cv, metric, rolling_window=0.1) {
+  df_none <- performance_metrics(df_cv, metrics = metric, rolling_window = 0)
+  df_h <- performance_metrics(
+    df_cv, metrics = metric, rolling_window = rolling_window
+  )
+
+  # Better plotting of difftime
+  # Target ~10 ticks
+  tick_w <- max(as.double(df_none$horizon, units = 'secs')) / 10.
+  # Find the largest time resolution that has <1 unit per bin
+  dts <- c('days', 'hours', 'mins', 'secs')
+  dt_conversions <- c(
+    24 * 60 * 60,
+    60 * 60,
+    60,
+    1
+  )
+  for (i in seq_along(dts)) {
+    if (as.difftime(1, units = dts[i]) < as.difftime(tick_w, units = 'secs')) {
+      break
+    }
+  }
+  df_none$x_plt <- (
+    as.double(df_none$horizon, units = 'secs') / dt_conversions[i]
+  )
+  df_h$x_plt <- as.double(df_h$horizon, units = 'secs') / dt_conversions[i]
+
+  gg <- (
+    ggplot2::ggplot(df_none, ggplot2::aes_string(x = 'x_plt', y = metric)) +
+    ggplot2::labs(x = paste0('Horizon (', dts[i], ')'), y = metric) +
+    ggplot2::geom_point(color = 'gray') +
+    ggplot2::geom_line(
+      data = df_h, ggplot2::aes_string(x = 'x_plt', y = metric), color = 'blue'
+    ) +
+    ggplot2::theme(aspect.ratio = 3 / 5)
+  )
+
+  return(gg)
+}

+ 20 - 0
R/man/coverage.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{coverage}
+\alias{coverage}
+\title{Coverage}
+\usage{
+coverage(df, w)
+}
+\arguments{
+\item{df}{Cross-validation results dataframe.}
+
+\item{w}{Aggregation window size.}
+}
+\value{
+Array of coverages
+}
+\description{
+Coverage
+}
+\keyword{internal}

+ 0 - 20
R/man/create_metric_data.Rd

@@ -1,20 +0,0 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/metrics.R
-\name{create_metric_data}
-\alias{create_metric_data}
-\title{Prepare dataframe for metrics calculation.}
-\usage{
-create_metric_data(m = NULL, df = NULL)
-}
-\arguments{
-\item{m}{Prophet object. Default NULL}
-
-\item{df}{A dataframe which is output of `simulated_historical_forecasts` or `cross_validation` Default NULL}
-}
-\value{
-A dataframe only with y and yhat as a column.
-}
-\description{
-Prepare dataframe for metrics calculation.
-}
-\keyword{internal}

+ 20 - 0
R/man/mae.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{mae}
+\alias{mae}
+\title{Mean absolute error}
+\usage{
+mae(df, w)
+}
+\arguments{
+\item{df}{Cross-validation results dataframe.}
+
+\item{w}{Aggregation window size.}
+}
+\value{
+Array of mean absolute errors.
+}
+\description{
+Mean absolute error
+}
+\keyword{internal}

+ 0 - 18
R/man/make_metrics_function.Rd

@@ -1,18 +0,0 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/metrics.R
-\name{make_metrics_function}
-\alias{make_metrics_function}
-\title{Meta function to make the function which evaluate metrics.}
-\usage{
-make_metrics_function(metrics)
-}
-\arguments{
-\item{metrics}{metrics function}
-}
-\value{
-A function using for metrics evaluation.
-}
-\description{
-Meta function to make the function which evaluate metrics.
-}
-\keyword{internal}

+ 20 - 0
R/man/mape.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{mape}
+\alias{mape}
+\title{Mean absolute percent error}
+\usage{
+mape(df, w)
+}
+\arguments{
+\item{df}{Cross-validation results dataframe.}
+
+\item{w}{Aggregation window size.}
+}
+\value{
+Array of mean absolute percent errors.
+}
+\description{
+Mean absolute percent error
+}
+\keyword{internal}

+ 0 - 91
R/man/metrics.Rd

@@ -1,91 +0,0 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/metrics.R
-\name{metrics}
-\alias{metrics}
-\alias{me}
-\alias{mse}
-\alias{rmse}
-\alias{mae}
-\alias{mpe}
-\alias{mape}
-\alias{all_metrics}
-\title{Metrics for Time Series Forecasts}
-\usage{
-me(m = NULL, df = NULL)
-
-mse(m = NULL, df = NULL)
-
-rmse(m = NULL, df = NULL)
-
-mae(m = NULL, df = NULL)
-
-mpe(m = NULL, df = NULL)
-
-mape(m = NULL, df = NULL)
-
-all_metrics(m = NULL, df = NULL)
-}
-\arguments{
-\item{m}{Prophet object. Default NULL}
-
-\item{df}{A dataframe which is output of `simulated_historical_forecasts` or `cross_validation` Default NULL}
-}
-\value{
-metrics value (numeric)
-}
-\description{
-A time-series forecast requires making a quantitative prediction of future values.
-After forecast, we also have to provide accurracy of forecasts to check wether the forecast serves our need.
-Metrics for time series forecasts are so useful in telling you how your model is good and helping you determine which particular forecasting models work best.
-}
-\details{
-Here, as a notation, we assume that \eqn{y} is the actual value and \eqn{yhat} is the forecast value.
-
-Mean Error (ME, \code{me})
-
-The Mean Error (ME)  is defined by the formula:
-\deqn{ \frac{1}{n} \sum_{t=1}^{n} y_{t}-yhat_{t} .}
-
-Mean Squared Error (MSE, \code{mse})
-
-The Mean Squared Error (MSE)  is defined by the formula:
-\deqn{ \frac{1}{n} \sum_{t=1}^{n} (y_{t}-yhat_{t})^2 .}
-
-Root Mean Square Error (RMSE, \code{rmse})
-
-Root Mean Square Error (RMSE) is define by the formula:
-\deqn{ \sqrt{\frac{1}{n} \sum_{t=1}^{n} (y_{t}-yhat_{t})^2} .}
-
-Mean Absolute Error (MAE, \code{mae})
-
-The Mean Absolute Error (MAE) is defined by the formula:
-\deqn{ \frac{1}{n} \sum_{t=1}^{n} | y_{t}-yhat_{t} | .}
-
-Mean Percentage Error (MPE, \code{mpe})
-
-The Mean Percentage Error (MPE) is usually expressed as a percentage
-and is defined by the formula:
-\deqn{ \frac{100}{n} \sum_{t=1}^{n} \frac {y_{t}-yhat_{t}}{y_{t}} .}
-
-Mean Absolute Percentage Error (MAPE, \code{mape})
-
-The Mean absolute Percentage Error (MAPE), also known as Mean Absolute Percentage Deviation (MAPD), is usually expressed as a percentage,
-and is defined by the formula:
-\deqn{ \frac{100}{n} \sum_{t=1}^{n} | \frac {y_{t}-yhat_{t}}{y_{t}}| .}
-}
-\examples{
-\dontrun{
-# Create example model
-library(readr)
-library(prophet)
-df <- read_csv('../tests/testthat/data.csv')
-m <- prophet(df)
-future <- make_future_dataframe(m, periods = 365)
-forecast <- predict(m, future)
-all_metrics(forecast)
-df.cv <- cross_validation(m, horizon = 100, units = 'days')
-all_metrics(df.cv)
-# You can check your models's accuracy using me, mse, rmse ...etc.
-print(rmse(m))
-}
-}

+ 20 - 0
R/man/mse.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{mse}
+\alias{mse}
+\title{Mean squared error}
+\usage{
+mse(df, w)
+}
+\arguments{
+\item{df}{Cross-validation results dataframe.}
+
+\item{w}{Aggregation window size.}
+}
+\value{
+Array of mean squared errors.
+}
+\description{
+Mean squared error
+}
+\keyword{internal}

+ 46 - 0
R/man/performance_metrics.Rd

@@ -0,0 +1,46 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{performance_metrics}
+\alias{performance_metrics}
+\title{Compute performance metrics from cross-validation results.}
+\usage{
+performance_metrics(df, metrics = NULL, rolling_window = 0.1)
+}
+\arguments{
+\item{df}{The dataframe returned by cross_validation.}
+
+\item{metrics}{An array of performance metrics to compute. If not provided,
+will use c('mse', 'rmse', 'mae', 'mape', 'coverage').}
+
+\item{rolling_window}{Proportion of data to use in each rolling window for
+computing the metrics. Should be in [0, 1].}
+}
+\value{
+A dataframe with a column for each metric, and column 'horizon'.
+}
+\description{
+Computes a suite of performance metrics on the output of cross-validation.
+By default the following metrics are included:
+'mse': mean squared error
+'rmse': root mean squared error
+'mae': mean absolute error
+'mape': mean percent error
+'coverage': coverage of the upper and lower intervals
+}
+\details{
+A subset of these can be specified by passing a list of names as the
+`metrics` argument.
+
+Metrics are calculated over a rolling window of cross validation
+predictions, after sorting by horizon. The size of that window (number of
+simulated forecast points) is determined by the rolling_window argument,
+which specifies a proportion of simulated forecast points to include in
+each window. rolling_window=0 will compute it separately for each simulated
+forecast point (i.e., 'mse' will actually be squared error with no mean).
+The default of rolling_window=0.1 will use 10% of the rows in df in each
+window. rolling_window=1 will compute the metric across all simulated
+forecast points. The results are set to the right edge of the window.
+
+The output is a dataframe containing column 'horizon' along with columns
+for each of the metrics computed.
+}

+ 36 - 0
R/man/plot_cross_validation_metric.Rd

@@ -0,0 +1,36 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/plot.R
+\name{plot_cross_validation_metric}
+\alias{plot_cross_validation_metric}
+\title{Plot a performance metric vs. forecast horizon from cross validation.
+Cross validation produces a collection of out-of-sample model predictions
+that can be compared to actual values, at a range of different horizons
+(distance from the cutoff). This computes a specified performance metric
+for each prediction, and aggregated over a rolling window with horizon.}
+\usage{
+plot_cross_validation_metric(df_cv, metric, rolling_window = 0.1)
+}
+\arguments{
+\item{df_cv}{The output from fbprophet.diagnostics.cross_validation.}
+
+\item{metric}{Metric name, one of 'mse', 'rmse', 'mae', 'mape', 'coverage'.}
+
+\item{rolling_window}{Proportion of data to use for rolling average of
+metric. In [0, 1]. Defaults to 0.1.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
+Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
+}
+\details{
+rolling_window is the proportion of data included in the rolling window of
+aggregation. The default value of 0.1 means 10% of data are included in the
+aggregation for computing the metric.
+
+As a concrete example, if metric='mse', then this plot will show the
+squared error for each cross validation prediction, along with the MSE
+averaged over rolling windows of 10% of the data.
+}

+ 20 - 0
R/man/rmse.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{rmse}
+\alias{rmse}
+\title{Root mean squared error}
+\usage{
+rmse(df, w)
+}
+\arguments{
+\item{df}{Cross-validation results dataframe.}
+
+\item{w}{Aggregation window size.}
+}
+\value{
+Array of root mean squared errors.
+}
+\description{
+Root mean squared error
+}
+\keyword{internal}

+ 21 - 0
R/man/rolling_mean.Rd

@@ -0,0 +1,21 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{rolling_mean}
+\alias{rolling_mean}
+\title{Compute a rolling mean of x}
+\usage{
+rolling_mean(x, w)
+}
+\arguments{
+\item{x}{Array.}
+
+\item{w}{Integer window size (number of elements).}
+}
+\value{
+Rolling mean of x with window size w.
+}
+\description{
+Right-aligned. Padded with NAs on the front so the output is the same
+size as x.
+}
+\keyword{internal}

+ 29 - 0
R/tests/testthat/test_diagnostics.R

@@ -103,3 +103,32 @@ test_that("cross_validation_default_value_check", {
     m, horizon = 32, units = 'days', period = 10, initial = 96)
   expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
 })
+
+test_that("performance_metrics", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(DATA)
+  df_cv <- cross_validation(
+    m, horizon = 4, units = "days", period = 10, initial = 90)
+  # Aggregation level none
+  df_none <- performance_metrics(df_cv, rolling_window = 0)
+  expect_true(all(
+    sort(colnames(df_none))
+    == sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
+  ))
+  expect_equal(nrow(df_none), 14)
+  # Aggregation level 0.2
+  df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
+  expect_equal(length(unique(df_horizon$horizon)), 4)
+  expect_equal(nrow(df_horizon), 13)
+  # Aggregation level all
+  df_all <- performance_metrics(df_cv, rolling_window = 1)
+  expect_equal(nrow(df_all), 1)
+  for (metric in c('mse', 'mape', 'mae', 'coverage')) {
+    expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
+  }
+  # Custom list of metrics
+  df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
+  expect_true(all(
+    sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
+  ))
+})

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 54 - 8
notebooks/diagnostics.ipynb


+ 59 - 0
python/fbprophet/diagnostics.py

@@ -260,6 +260,20 @@ def performance_metrics(df, metrics=None, rolling_window=0.1):
 
 
 def rolling_mean(x, w):
+    """Compute a rolling mean of x
+
+    Right-aligned. Padded with NaNs on the front so the output is the same
+    size as x.
+
+    Parameters
+    ----------
+    x: Array.
+    w: Integer window size (number of elements).
+
+    Returns
+    -------
+    Rolling mean of x with window size w.
+    """
     s = np.cumsum(np.insert(x, 0, 0))
     prefix = np.empty(w - 1)
     prefix.fill(np.nan)
@@ -273,6 +287,15 @@ def rolling_mean(x, w):
 
 def mse(df, w):
     """Mean squared error
+
+    Parameters
+    ----------
+    df: Cross-validation results dataframe.
+    w: Aggregation window size.
+
+    Returns
+    -------
+    Array of mean squared errors.
     """
     se = (df['y'] - df['yhat']) ** 2
     return rolling_mean(se.values, w)
@@ -280,12 +303,30 @@ def mse(df, w):
 
 def rmse(df, w):
     """Root mean squared error
+
+    Parameters
+    ----------
+    df: Cross-validation results dataframe.
+    w: Aggregation window size.
+
+    Returns
+    -------
+    Array of root mean squared errors.
     """
     return np.sqrt(mse(df, w))
 
 
 def mae(df, w):
     """Mean absolute error
+
+    Parameters
+    ----------
+    df: Cross-validation results dataframe.
+    w: Aggregation window size.
+
+    Returns
+    -------
+    Array of mean absolute errors.
     """
     ae = np.abs(df['y'] - df['yhat'])
     return rolling_mean(ae.values, w)
@@ -293,6 +334,15 @@ def mae(df, w):
 
 def mape(df, w):
     """Mean absolute percent error
+
+    Parameters
+    ----------
+    df: Cross-validation results dataframe.
+    w: Aggregation window size.
+
+    Returns
+    -------
+    Array of mean absolute percent errors.
     """
     ape = np.abs((df['y'] - df['yhat']) / df['y'])
     return rolling_mean(ape.values, w)
@@ -300,6 +350,15 @@ def mape(df, w):
 
 def coverage(df, w):
     """Coverage
+
+    Parameters
+    ----------
+    df: Cross-validation results dataframe.
+    w: Aggregation window size.
+
+    Returns
+    -------
+    Array of coverages.
     """
     is_covered = (df['y'] >= df['yhat_lower']) & (df['y'] <= df['yhat_upper'])
     return rolling_mean(is_covered.values, w)