diagnostics.R 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", "yhat", "yhat_lower", "yhat_upper"))
  9. #' Generate cutoff dates
  10. #'
  11. #' @param df Dataframe with historical data
  12. #' @param horizon timediff forecast horizon
  13. #' @param k integer number of forecast points
  14. #' @param period timediff Simulated forecasts are done with this period.
  15. #'
  16. #' @return Array of datetimes
  17. #'
  18. #' @keywords internal
  19. generate_cutoffs <- function(df, horizon, k, period) {
  20. # Last cutoff is (latest date in data) - (horizon).
  21. cutoff <- max(df$ds) - horizon
  22. if (cutoff < min(df$ds)) {
  23. stop('Less data than horizon.')
  24. }
  25. tzone <- attr(cutoff, "tzone") # Timezone is wiped by putting in array
  26. result <- cutoff
  27. if (k > 1) {
  28. for (i in 2:k) {
  29. cutoff <- cutoff - period
  30. # If data does not exist in data range (cutoff, cutoff + horizon]
  31. if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
  32. # Next cutoff point is 'closest date before cutoff in data - horizon'
  33. closest.date <- max(df$ds[df$ds <= cutoff])
  34. cutoff <- closest.date - horizon
  35. }
  36. if (cutoff < min(df$ds)) {
  37. warning('Not enough data for requested number of cutoffs! Using ', i)
  38. break
  39. }
  40. result <- c(result, cutoff)
  41. }
  42. }
  43. # Reset timezones
  44. attr(result, "tzone") <- tzone
  45. return(rev(result))
  46. }
  47. #' Simulated historical forecasts.
  48. #'
  49. #' Make forecasts from k historical cutoff points, working backwards from
  50. #' (end - horizon) with a spacing of period between each cutoff.
  51. #'
  52. #' @param model Fitted Prophet model.
  53. #' @param horizon Integer size of the horizon
  54. #' @param units String unit of the horizon, e.g., "days", "secs".
  55. #' @param k integer number of forecast points
  56. #' @param period Integer amount of time between cutoff dates. Same units as
  57. #' horizon. If not provided, will use 0.5 * horizon.
  58. #'
  59. #' @return A dataframe with the forecast, actual value, and cutoff date.
  60. #'
  61. #' @export
  62. simulated_historical_forecasts <- function(model, horizon, units, k,
  63. period = NULL) {
  64. df <- model$history
  65. horizon <- as.difftime(horizon, units = units)
  66. if (is.null(period)) {
  67. period <- horizon / 2
  68. } else {
  69. period <- as.difftime(period, units = units)
  70. }
  71. cutoffs <- generate_cutoffs(df, horizon, k, period)
  72. predicts <- data.frame()
  73. for (i in seq_along(cutoffs)) {
  74. cutoff <- cutoffs[i]
  75. # Copy the model
  76. m <- prophet_copy(model, cutoff)
  77. # Train model
  78. history.c <- dplyr::filter(df, ds <= cutoff)
  79. m <- fit.prophet(m, history.c)
  80. # Calculate yhat
  81. df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon)
  82. # Get the columns for the future dataframe
  83. columns <- 'ds'
  84. if (m$growth == 'logistic') {
  85. columns <- c(columns, 'cap')
  86. if (m$logistic.floor) {
  87. columns <- c(columns, 'floor')
  88. }
  89. }
  90. columns <- c(columns, names(m$extra_regressors))
  91. future <- df[columns]
  92. yhat <- stats::predict(m, future)
  93. # Merge yhat, y, and cutoff.
  94. df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
  95. df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
  96. df.c$cutoff <- cutoff
  97. predicts <- rbind(predicts, df.c)
  98. }
  99. return(predicts)
  100. }
  101. #' Cross-validation for time series.
  102. #'
  103. #' Computes forecasts from historical cutoff points. Beginning from initial,
  104. #' makes cutoffs with a spacing of period up to (end - horizon).
  105. #'
  106. #' When period is equal to the time interval of the data, this is the
  107. #' technique described in https://robjhyndman.com/hyndsight/tscv/ .
  108. #'
  109. #' @param model Fitted Prophet model.
  110. #' @param horizon Integer size of the horizon
  111. #' @param units String unit of the horizon, e.g., "days", "secs".
  112. #' @param period Integer amount of time between cutoff dates. Same units as
  113. #' horizon. If not provided, 0.5 * horizon is used.
  114. #' @param initial Integer size of the first training period. If not provided,
  115. #' 3 * horizon is used. Same units as horizon.
  116. #'
  117. #' @return A dataframe with the forecast, actual value, and cutoff date.
  118. #'
  119. #' @export
  120. cross_validation <- function(
  121. model, horizon, units, period = NULL, initial = NULL) {
  122. te <- max(model$history$ds)
  123. ts <- min(model$history$ds)
  124. if (is.null(period)) {
  125. period <- 0.5 * horizon
  126. }
  127. if (is.null(initial)) {
  128. initial <- 3 * horizon
  129. }
  130. horizon.dt <- as.difftime(horizon, units = units)
  131. initial.dt <- as.difftime(initial, units = units)
  132. period.dt <- as.difftime(period, units = units)
  133. k <- ceiling(
  134. as.double((te - horizon.dt) - (ts + initial.dt), units='secs') /
  135. as.double(period.dt, units = 'secs')
  136. )
  137. if (k < 1) {
  138. stop('Not enough data for specified horizon, period, and initial.')
  139. }
  140. return(simulated_historical_forecasts(model, horizon, units, k, period))
  141. }
  142. #' Copy Prophet object.
  143. #'
  144. #' @param m Prophet model object.
  145. #' @param cutoff Date, possibly as string. Changepoints are only retained if
  146. #' changepoints <= cutoff.
  147. #'
  148. #' @return An unfitted Prophet model object with the same parameters as the
  149. #' input model.
  150. #'
  151. #' @keywords internal
  152. prophet_copy <- function(m, cutoff = NULL) {
  153. if (is.null(m$history)) {
  154. stop("This is for copying a fitted Prophet object.")
  155. }
  156. if (m$specified.changepoints) {
  157. changepoints <- m$changepoints
  158. if (!is.null(cutoff)) {
  159. cutoff <- set_date(cutoff)
  160. changepoints <- changepoints[changepoints <= cutoff]
  161. }
  162. } else {
  163. changepoints <- NULL
  164. }
  165. # Auto seasonalities are set to FALSE because they are already set in
  166. # m$seasonalities.
  167. m2 <- prophet(
  168. growth = m$growth,
  169. changepoints = changepoints,
  170. n.changepoints = m$n.changepoints,
  171. yearly.seasonality = FALSE,
  172. weekly.seasonality = FALSE,
  173. daily.seasonality = FALSE,
  174. holidays = m$holidays,
  175. seasonality.prior.scale = m$seasonality.prior.scale,
  176. changepoint.prior.scale = m$changepoint.prior.scale,
  177. holidays.prior.scale = m$holidays.prior.scale,
  178. mcmc.samples = m$mcmc.samples,
  179. interval.width = m$interval.width,
  180. uncertainty.samples = m$uncertainty.samples,
  181. fit = FALSE
  182. )
  183. m2$extra_regressors <- m$extra_regressors
  184. m2$seasonalities <- m$seasonalities
  185. return(m2)
  186. }
  187. #' Compute performance metrics from cross-validation results.
  188. #'
  189. #' Computes a suite of performance metrics on the output of cross-validation.
  190. #' By default the following metrics are included:
  191. #' 'mse': mean squared error
  192. #' 'rmse': root mean squared error
  193. #' 'mae': mean absolute error
  194. #' 'mape': mean percent error
  195. #' 'coverage': coverage of the upper and lower intervals
  196. #'
  197. #' A subset of these can be specified by passing a list of names as the
  198. #' `metrics` argument.
  199. #'
  200. #' Metrics are calculated over a rolling window of cross validation
  201. #' predictions, after sorting by horizon. The size of that window (number of
  202. #' simulated forecast points) is determined by the rolling_window argument,
  203. #' which specifies a proportion of simulated forecast points to include in
  204. #' each window. rolling_window=0 will compute it separately for each simulated
  205. #' forecast point (i.e., 'mse' will actually be squared error with no mean).
  206. #' The default of rolling_window=0.1 will use 10% of the rows in df in each
  207. #' window. rolling_window=1 will compute the metric across all simulated
  208. #' forecast points. The results are set to the right edge of the window.
  209. #'
  210. #' The output is a dataframe containing column 'horizon' along with columns
  211. #' for each of the metrics computed.
  212. #'
  213. #' @param df The dataframe returned by cross_validation.
  214. #' @param metrics An array of performance metrics to compute. If not provided,
  215. #' will use c('mse', 'rmse', 'mae', 'mape', 'coverage').
  216. #' @param rolling_window Proportion of data to use in each rolling window for
  217. #' computing the metrics. Should be in [0, 1].
  218. #'
  219. #' @return A dataframe with a column for each metric, and column 'horizon'.
  220. #'
  221. #' @export
  222. performance_metrics <- function(df, metrics = NULL, rolling_window = 0.1) {
  223. valid_metrics <- c('mse', 'rmse', 'mae', 'mape', 'coverage')
  224. if (is.null(metrics)) {
  225. metrics <- valid_metrics
  226. }
  227. if (length(metrics) != length(unique(metrics))) {
  228. stop('Input metrics must be an array of unique values.')
  229. }
  230. if (!all(metrics %in% valid_metrics)) {
  231. stop(
  232. paste('Valid values for metrics are:', paste(metrics, collapse = ", "))
  233. )
  234. }
  235. df_m <- df
  236. df_m$horizon <- df_m$ds - df_m$cutoff
  237. df_m <- df_m[order(df_m$horizon),]
  238. # Window size
  239. w <- as.integer(rolling_window * nrow(df_m))
  240. w <- max(w, 1)
  241. w <- min(w, nrow(df_m))
  242. cols <- c('horizon')
  243. for (metric in metrics) {
  244. df_m[[metric]] <- get(metric)(df_m, w)
  245. cols <- c(cols, metric)
  246. }
  247. df_m <- df_m[cols]
  248. return(na.omit(df_m))
  249. }
  250. #' Compute a rolling mean of x
  251. #'
  252. #' Right-aligned. Padded with NAs on the front so the output is the same
  253. #' size as x.
  254. #'
  255. #' @param x Array.
  256. #' @param w Integer window size (number of elements).
  257. #'
  258. #' @return Rolling mean of x with window size w.
  259. #'
  260. #' @keywords internal
  261. rolling_mean <- function(x, w) {
  262. s <- cumsum(c(0, x))
  263. prefix <- rep(NA, w - 1)
  264. return(c(prefix, (s[(w + 1):length(s)] - s[1:(length(s) - w)]) / w))
  265. }
  266. # The functions below specify performance metrics for cross-validation results.
  267. # Each takes as input the output of cross_validation, and returns the statistic
  268. # as an array, given a window size for rolling aggregation.
  269. #' Mean squared error
  270. #'
  271. #' @param df Cross-validation results dataframe.
  272. #' @param w Aggregation window size.
  273. #'
  274. #' @return Array of mean squared errors.
  275. #'
  276. #' @keywords internal
  277. mse <- function(df, w) {
  278. se <- (df$y - df$yhat) ** 2
  279. return(rolling_mean(se, w))
  280. }
  281. #' Root mean squared error
  282. #'
  283. #' @param df Cross-validation results dataframe.
  284. #' @param w Aggregation window size.
  285. #'
  286. #' @return Array of root mean squared errors.
  287. #'
  288. #' @keywords internal
  289. rmse <- function(df, w) {
  290. return(sqrt(mse(df, w)))
  291. }
  292. #' Mean absolute error
  293. #'
  294. #' @param df Cross-validation results dataframe.
  295. #' @param w Aggregation window size.
  296. #'
  297. #' @return Array of mean absolute errors.
  298. #'
  299. #' @keywords internal
  300. mae <- function(df, w) {
  301. ae <- abs(df$y - df$yhat)
  302. return(rolling_mean(ae, w))
  303. }
  304. #' Mean absolute percent error
  305. #'
  306. #' @param df Cross-validation results dataframe.
  307. #' @param w Aggregation window size.
  308. #'
  309. #' @return Array of mean absolute percent errors.
  310. #'
  311. #' @keywords internal
  312. mape <- function(df, w) {
  313. ape <- abs((df$y - df$yhat) / df$y)
  314. return(rolling_mean(ape, w))
  315. }
  316. #' Coverage
  317. #'
  318. #' @param df Cross-validation results dataframe.
  319. #' @param w Aggregation window size.
  320. #'
  321. #' @return Array of coverages
  322. #'
  323. #' @keywords internal
  324. coverage <- function(df, w) {
  325. is_covered <- (df$y >= df$yhat_lower) & (df$y <= df$yhat_upper)
  326. return(rolling_mean(is_covered, w))
  327. }