diagnostics.R 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", "yhat", "yhat_lower", "yhat_upper"))
  9. #' Generate cutoff dates
  10. #'
  11. #' @param df Dataframe with historical data.
  12. #' @param horizon timediff forecast horizon.
  13. #' @param initial timediff initial window.
  14. #' @param period timediff Simulated forecasts are done with this period.
  15. #'
  16. #' @return Array of datetimes.
  17. #'
  18. #' @keywords internal
  19. generate_cutoffs <- function(df, horizon, initial, period) {
  20. # Last cutoff is (latest date in data) - (horizon).
  21. cutoff <- max(df$ds) - horizon
  22. tzone <- attr(cutoff, "tzone") # Timezone is wiped by putting in array
  23. result <- c(cutoff)
  24. while (result[length(result)] >= min(df$ds) + initial) {
  25. cutoff <- cutoff - period
  26. # If data does not exist in data range (cutoff, cutoff + horizon]
  27. if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
  28. # Next cutoff point is 'closest date before cutoff in data - horizon'
  29. closest.date <- max(df$ds[df$ds <= cutoff])
  30. cutoff <- closest.date - horizon
  31. }
  32. result <- c(result, cutoff)
  33. }
  34. result <- utils::head(result, -1)
  35. if (length(result) == 0) {
  36. stop(paste(
  37. 'Less data than horizon after initial window.',
  38. 'Make horizon or initial shorter.'
  39. ))
  40. }
  41. # Reset timezones
  42. attr(result, "tzone") <- tzone
  43. message(paste(
  44. 'Making', length(result), 'forecasts with cutoffs between',
  45. result[length(result)], 'and', result[1]
  46. ))
  47. return(rev(result))
  48. }
  49. #' Cross-validation for time series.
  50. #'
  51. #' Computes forecasts from historical cutoff points. Beginning from
  52. #' (end - horizon), works backwards making cutoffs with a spacing of period
  53. #' until initial is reached.
  54. #'
  55. #' When period is equal to the time interval of the data, this is the
  56. #' technique described in https://robjhyndman.com/hyndsight/tscv/ .
  57. #'
  58. #' @param model Fitted Prophet model.
  59. #' @param horizon Integer size of the horizon
  60. #' @param units String unit of the horizon, e.g., "days", "secs".
  61. #' @param period Integer amount of time between cutoff dates. Same units as
  62. #' horizon. If not provided, 0.5 * horizon is used.
  63. #' @param initial Integer size of the first training period. If not provided,
  64. #' 3 * horizon is used. Same units as horizon.
  65. #'
  66. #' @return A dataframe with the forecast, actual value, and cutoff date.
  67. #'
  68. #' @export
  69. cross_validation <- function(
  70. model, horizon, units, period = NULL, initial = NULL) {
  71. df <- model$history
  72. te <- max(df$ds)
  73. ts <- min(df$ds)
  74. if (is.null(period)) {
  75. period <- 0.5 * horizon
  76. }
  77. if (is.null(initial)) {
  78. initial <- 3 * horizon
  79. }
  80. horizon.dt <- as.difftime(horizon, units = units)
  81. initial.dt <- as.difftime(initial, units = units)
  82. period.dt <- as.difftime(period, units = units)
  83. cutoffs <- generate_cutoffs(df, horizon.dt, initial.dt, period.dt)
  84. predicts <- data.frame()
  85. for (i in 1:length(cutoffs)) {
  86. cutoff <- cutoffs[i]
  87. # Copy the model
  88. m <- prophet_copy(model, cutoff)
  89. # Train model
  90. history.c <- dplyr::filter(df, ds <= cutoff)
  91. if (nrow(history.c) < 2) {
  92. stop('Less than two datapoints before cutoff. Increase initial window.')
  93. }
  94. m <- fit.prophet(m, history.c)
  95. # Calculate yhat
  96. df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon.dt)
  97. # Get the columns for the future dataframe
  98. columns <- 'ds'
  99. if (m$growth == 'logistic') {
  100. columns <- c(columns, 'cap')
  101. if (m$logistic.floor) {
  102. columns <- c(columns, 'floor')
  103. }
  104. }
  105. columns <- c(columns, names(m$extra_regressors))
  106. future <- df.predict[columns]
  107. yhat <- stats::predict(m, future)
  108. # Merge yhat, y, and cutoff.
  109. df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
  110. df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
  111. df.c$cutoff <- cutoff
  112. predicts <- rbind(predicts, df.c)
  113. }
  114. return(predicts)
  115. }
  116. #' Copy Prophet object.
  117. #'
  118. #' @param m Prophet model object.
  119. #' @param cutoff Date, possibly as string. Changepoints are only retained if
  120. #' changepoints <= cutoff.
  121. #'
  122. #' @return An unfitted Prophet model object with the same parameters as the
  123. #' input model.
  124. #'
  125. #' @keywords internal
  126. prophet_copy <- function(m, cutoff = NULL) {
  127. if (is.null(m$history)) {
  128. stop("This is for copying a fitted Prophet object.")
  129. }
  130. if (m$specified.changepoints) {
  131. changepoints <- m$changepoints
  132. if (!is.null(cutoff)) {
  133. cutoff <- set_date(cutoff)
  134. changepoints <- changepoints[changepoints <= cutoff]
  135. }
  136. } else {
  137. changepoints <- NULL
  138. }
  139. # Auto seasonalities are set to FALSE because they are already set in
  140. # m$seasonalities.
  141. m2 <- prophet(
  142. growth = m$growth,
  143. changepoints = changepoints,
  144. n.changepoints = m$n.changepoints,
  145. changepoint.range = m$changepoint.range,
  146. yearly.seasonality = FALSE,
  147. weekly.seasonality = FALSE,
  148. daily.seasonality = FALSE,
  149. holidays = m$holidays,
  150. seasonality.mode = m$seasonality.mode,
  151. seasonality.prior.scale = m$seasonality.prior.scale,
  152. changepoint.prior.scale = m$changepoint.prior.scale,
  153. holidays.prior.scale = m$holidays.prior.scale,
  154. mcmc.samples = m$mcmc.samples,
  155. interval.width = m$interval.width,
  156. uncertainty.samples = m$uncertainty.samples,
  157. fit = FALSE
  158. )
  159. m2$extra_regressors <- m$extra_regressors
  160. m2$seasonalities <- m$seasonalities
  161. return(m2)
  162. }
  163. #' Compute performance metrics from cross-validation results.
  164. #'
  165. #' Computes a suite of performance metrics on the output of cross-validation.
  166. #' By default the following metrics are included:
  167. #' 'mse': mean squared error
  168. #' 'rmse': root mean squared error
  169. #' 'mae': mean absolute error
  170. #' 'mape': mean percent error
  171. #' 'coverage': coverage of the upper and lower intervals
  172. #'
  173. #' A subset of these can be specified by passing a list of names as the
  174. #' `metrics` argument.
  175. #'
  176. #' Metrics are calculated over a rolling window of cross validation
  177. #' predictions, after sorting by horizon. The size of that window (number of
  178. #' simulated forecast points) is determined by the rolling_window argument,
  179. #' which specifies a proportion of simulated forecast points to include in
  180. #' each window. rolling_window=0 will compute it separately for each simulated
  181. #' forecast point (i.e., 'mse' will actually be squared error with no mean).
  182. #' The default of rolling_window=0.1 will use 10% of the rows in df in each
  183. #' window. rolling_window=1 will compute the metric across all simulated
  184. #' forecast points. The results are set to the right edge of the window.
  185. #'
  186. #' The output is a dataframe containing column 'horizon' along with columns
  187. #' for each of the metrics computed.
  188. #'
  189. #' @param df The dataframe returned by cross_validation.
  190. #' @param metrics An array of performance metrics to compute. If not provided,
  191. #' will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
  192. #' @param rolling_window Proportion of data to use in each rolling window for
  193. #' computing the metrics. Should be in [0, 1].
  194. #'
  195. #' @return A dataframe with a column for each metric, and column 'horizon'.
  196. #'
  197. #' @export
  198. performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
  199. valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'coverage')
  200. if (is.null(metrics)) {
  201. metrics <- valid_metrics
  202. }
  203. if (length(metrics) != length(unique(metrics))) {
  204. stop('Input metrics must be an array of unique values.')
  205. }
  206. if (!all(metrics %in% valid_metrics)) {
  207. stop(
  208. paste('Valid values for metrics are:', paste(metrics, collapse = ", "))
  209. )
  210. }
  211. df_m <- df
  212. df_m$horizon <- df_m$ds - df_m$cutoff
  213. df_m <- df_m[order(df_m$horizon),]
  214. # Window size
  215. w <- as.integer(rolling_window * nrow(df_m))
  216. w <- max(w, 1)
  217. w <- min(w, nrow(df_m))
  218. cols <- c('horizon')
  219. for (metric in metrics) {
  220. df_m[[metric]] <- get(metric)(df_m, w)
  221. cols <- c(cols, metric)
  222. }
  223. df_m <- df_m[cols]
  224. return(stats::na.omit(df_m))
  225. }
  226. #' Compute a rolling mean of x
  227. #'
  228. #' Right-aligned. Padded with NAs on the front so the output is the same
  229. #' size as x.
  230. #'
  231. #' @param x Array.
  232. #' @param w Integer window size (number of elements).
  233. #'
  234. #' @return Rolling mean of x with window size w.
  235. #'
  236. #' @keywords internal
  237. rolling_mean <- function(x, w) {
  238. s <- cumsum(c(0, x))
  239. prefix <- rep(NA, w - 1)
  240. return(c(prefix, (s[(w + 1):length(s)] - s[1:(length(s) - w)]) / w))
  241. }
  242. # The functions below specify performance metrics for cross-validation results.
  243. # Each takes as input the output of cross_validation, and returns the statistic
  244. # as an array, given a window size for rolling aggregation.
  245. #' Mean squared error
  246. #'
  247. #' @param df Cross-validation results dataframe.
  248. #' @param w Aggregation window size.
  249. #'
  250. #' @return Array of mean squared errors.
  251. #'
  252. #' @keywords internal
  253. mse <- function(df, w) {
  254. se <- (df$y - df$yhat) ** 2
  255. return(rolling_mean(se, w))
  256. }
  257. #' Root mean squared error
  258. #'
  259. #' @param df Cross-validation results dataframe.
  260. #' @param w Aggregation window size.
  261. #'
  262. #' @return Array of root mean squared errors.
  263. #'
  264. #' @keywords internal
  265. rmse <- function(df, w) {
  266. return(sqrt(mse(df, w)))
  267. }
  268. #' Mean absolute error
  269. #'
  270. #' @param df Cross-validation results dataframe.
  271. #' @param w Aggregation window size.
  272. #'
  273. #' @return Array of mean absolute errors.
  274. #'
  275. #' @keywords internal
  276. mae <- function(df, w) {
  277. ae <- abs(df$y - df$yhat)
  278. return(rolling_mean(ae, w))
  279. }
  280. #' Mean absolute percent error
  281. #'
  282. #' @param df Cross-validation results dataframe.
  283. #' @param w Aggregation window size.
  284. #'
  285. #' @return Array of mean absolute percent errors.
  286. #'
  287. #' @keywords internal
  288. mape <- function(df, w) {
  289. ape <- abs((df$y - df$yhat) / df$y)
  290. return(rolling_mean(ape, w))
  291. }
  292. #' Coverage
  293. #'
  294. #' @param df Cross-validation results dataframe.
  295. #' @param w Aggregation window size.
  296. #'
  297. #' @return Array of coverages
  298. #'
  299. #' @keywords internal
  300. coverage <- function(df, w) {
  301. is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper)
  302. return(rolling_mean(is_covered, w))
  303. }