prophet.R 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df Dataframe containing the history. Must have columns ds (date type)
  16. #' and y, the time series. If growth is logistic, then df must also have a
  17. #' column cap that specifies the capacity at each ds.
  18. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  19. #' trend.
  20. #' @param changepoints Vector of dates at which to include potential
  21. #' changepoints. If not specified, potential changepoints are selected
  22. #' automatically.
  23. #' @param n.changepoints Number of potential changepoints to include. Not used
  24. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  25. #' then n.changepoints potential changepoints are selected uniformly from the
  26. #' first 80 percent of df$ds.
  27. #' @param yearly.seasonality Boolean, fit yearly seasonality.
  28. #' @param weekly.seasonality Boolean, fit weekly seasonality.
  29. #' @param holidays data frame with columns holiday (character) and ds (date
  30. #' type)and optionally columns lower_window and upper_window which specify a
  31. #' range of days around the date to be included as holidays. lower_window=-2
  32. #' will include 2 days prior to the date as holidays.
  33. #' @param seasonality.prior.scale Parameter modulating the strength of the
  34. #' seasonality model. Larger values allow the model to fit larger seasonal
  35. #' fluctuations, smaller values dampen the seasonality.
  36. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  37. #' components model.
  38. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  39. #' automatic changepoint selection. Large values will allow many changepoints,
  40. #' small values will allow few changepoints.
  41. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  42. #' inference with the specified number of MCMC samples. If 0, will do MAP
  43. #' estimation.
  44. #' @param interval.width Numeric, width of the uncertainty intervals provided
  45. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  46. #' in the trend using the MAP estimate of the extrapolated generative model.
  47. #' If mcmc.samples>0, this will be integrated over all model parameters,
  48. #' which will include uncertainty in seasonality.
  49. #' @param uncertainty.samples Number of simulated draws used to estimate
  50. #' uncertainty intervals.
  51. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  52. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  53. #'
  54. #' @return A prophet model.
  55. #'
  56. #' @examples
  57. #' \dontrun{
  58. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  59. #' y = sin(1:366/200) + rnorm(366)/10)
  60. #' m <- prophet(history)
  61. #' }
  62. #'
  63. #' @export
  64. #' @importFrom dplyr "%>%"
  65. #' @import Rcpp
  66. prophet <- function(df = df,
  67. growth = 'linear',
  68. changepoints = NULL,
  69. n.changepoints = 25,
  70. yearly.seasonality = TRUE,
  71. weekly.seasonality = TRUE,
  72. holidays = NULL,
  73. seasonality.prior.scale = 10,
  74. holidays.prior.scale = 10,
  75. changepoint.prior.scale = 0.05,
  76. mcmc.samples = 0,
  77. interval.width = 0.80,
  78. uncertainty.samples = 1000,
  79. fit = TRUE,
  80. ...
  81. ) {
  82. # fb-block 1
  83. if (!is.null(changepoints)) {
  84. n.changepoints <- length(changepoints)
  85. }
  86. m <- list(
  87. growth = growth,
  88. changepoints = changepoints,
  89. n.changepoints = n.changepoints,
  90. yearly.seasonality = yearly.seasonality,
  91. weekly.seasonality = weekly.seasonality,
  92. holidays = holidays,
  93. seasonality.prior.scale = seasonality.prior.scale,
  94. changepoint.prior.scale = changepoint.prior.scale,
  95. holidays.prior.scale = holidays.prior.scale,
  96. mcmc.samples = mcmc.samples,
  97. interval.width = interval.width,
  98. uncertainty.samples = uncertainty.samples,
  99. start = NULL, # This and following attributes are set during fitting
  100. y.scale = NULL,
  101. t.scale = NULL,
  102. changepoints.t = NULL,
  103. stan.fit = NULL,
  104. params = list(),
  105. history = NULL,
  106. history.dates = NULL
  107. )
  108. validate_inputs(m)
  109. class(m) <- append("prophet", class(m))
  110. if (fit) {
  111. m <- fit.prophet(m, df, ...)
  112. }
  113. # fb-block 2
  114. return(m)
  115. }
  116. #' Validates the inputs to Prophet.
  117. #'
  118. #' @param m Prophet object.
  119. #'
  120. validate_inputs <- function(m) {
  121. if (!(m$growth %in% c('linear', 'logistic'))) {
  122. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  123. }
  124. if (!is.null(m$holidays)) {
  125. if (!(exists('holiday', where = m$holidays))) {
  126. stop('Holidays dataframe must have holiday field.')
  127. }
  128. if (!(exists('ds', where = m$holidays))) {
  129. stop('Holidays dataframe must have ds field.')
  130. }
  131. has.lower <- exists('lower_window', where = m$holidays)
  132. has.upper <- exists('upper_window', where = m$holidays)
  133. if (has.lower + has.upper == 1) {
  134. stop(paste('Holidays must have both lower_window and upper_window,',
  135. 'or neither.'))
  136. }
  137. if (has.lower) {
  138. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  139. stop('Holiday lower_window should be <= 0')
  140. }
  141. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  142. stop('Holiday upper_window should be >= 0')
  143. }
  144. }
  145. for (h in unique(m$holidays$holiday)) {
  146. if (grepl("_delim_", h)) {
  147. stop('Holiday name cannot contain "_delim_"')
  148. }
  149. if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
  150. stop(paste0('Holiday name "', h, '" reserved.'))
  151. }
  152. }
  153. }
  154. }
  155. #' Load compiled Stan model
  156. #'
  157. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  158. #' trend.
  159. #'
  160. #' @return Stan model.
  161. get_prophet_stan_model <- function(model) {
  162. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  163. ## If the cached model doesn't work, just compile a new one.
  164. tryCatch({
  165. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  166. package = 'prophet',
  167. mustWork = TRUE)
  168. load(binary)
  169. obj.name <- paste(model, 'growth.stanm', sep = '.')
  170. stanm <- eval(parse(text = obj.name))
  171. ## Should cause an error if the model doesn't work.
  172. stanm@mk_cppmodule(stanm)
  173. stanm
  174. }, error = function(cond) {
  175. compile_stan_model(model)
  176. })
  177. }
  178. #' Compile Stan model
  179. #'
  180. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  181. #' trend.
  182. #'
  183. #' @return Stan model.
  184. compile_stan_model <- function(model) {
  185. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  186. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  187. stanc <- rstan::stanc(stan.src)
  188. model.name <- paste(model, 'growth', sep = '_')
  189. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  190. }
  191. #' Prepare dataframe for fitting or predicting.
  192. #'
  193. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  194. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  195. #' and predicting.
  196. #'
  197. #' @param m Prophet object.
  198. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  199. #' @param initialize_scales Boolean set scaling factors in m from df.
  200. #'
  201. #' @return list with items 'df' and 'm'.
  202. #'
  203. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  204. if (exists('y', where=df)) {
  205. df$y <- as.numeric(df$y)
  206. }
  207. df$ds <- zoo::as.Date(df$ds)
  208. if (anyNA(df$ds)) {
  209. stop('Unable to parse date format in column ds. Convert to date format.')
  210. }
  211. df <- df %>%
  212. dplyr::arrange(ds)
  213. if (initialize_scales) {
  214. m$y.scale <- max(df$y)
  215. m$start <- min(df$ds)
  216. m$t.scale <- as.numeric(max(df$ds) - m$start)
  217. }
  218. df$t <- as.numeric(df$ds - m$start) / m$t.scale
  219. if (exists('y', where=df)) {
  220. df$y_scaled <- df$y / m$y.scale
  221. }
  222. if (m$growth == 'logistic') {
  223. if (!(exists('cap', where=df))) {
  224. stop('Capacities must be supplied for logistic growth.')
  225. }
  226. df <- df %>%
  227. dplyr::mutate(cap_scaled = cap / m$y.scale)
  228. }
  229. return(list("m" = m, "df" = df))
  230. }
  231. #' Set changepoints
  232. #'
  233. #' Sets m$changepoints to the dates of changepoints. Either:
  234. #' 1) The changepoints were passed in explicitly.
  235. #' A) They are empty.
  236. #' B) They are not empty, and need validation.
  237. #' 2) We are generating a grid of them.
  238. #' 3) The user prefers no changepoints be used.
  239. #'
  240. #' @param m Prophet object.
  241. #'
  242. #' @return m with changepoints set.
  243. #'
  244. set_changepoints <- function(m) {
  245. if (!is.null(m$changepoints)) {
  246. if (length(m$changepoints) > 0) {
  247. if (min(m$changepoints) < min(m$history$ds)
  248. || max(m$changepoints) > max(m$history$ds)) {
  249. stop('Changepoints must fall within training data.')
  250. }
  251. }
  252. } else {
  253. if (m$n.changepoints > 0) {
  254. # Place potential changepoints evenly through the first 80 pcnt of
  255. # the history.
  256. cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
  257. length.out = (m$n.changepoints + 1))) %>%
  258. utils::tail(-1)
  259. m$changepoints <- m$history$ds[cp.indexes]
  260. } else {
  261. m$changepoints <- c()
  262. }
  263. }
  264. if (length(m$changepoints) > 0) {
  265. m$changepoints <- zoo::as.Date(m$changepoints)
  266. m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
  267. } else {
  268. m$changepoints.t <- c(0) # dummy changepoint
  269. }
  270. return(m)
  271. }
  272. #' Gets changepoint matrix for history dataframe.
  273. #'
  274. #' @param m Prophet object.
  275. #'
  276. #' @return array of indexes.
  277. #'
  278. get_changepoint_matrix <- function(m) {
  279. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  280. for (i in 1:length(m$changepoints.t)) {
  281. A[m$history$t >= m$changepoints.t[i], i] <- 1
  282. }
  283. return(A)
  284. }
  285. #' Provides Fourier series components with the specified frequency and order.
  286. #'
  287. #' @param dates Vector of dates.
  288. #' @param period Number of days of the period.
  289. #' @param series.order Number of components.
  290. #'
  291. #' @return Matrix with seasonality features.
  292. #'
  293. fourier_series <- function(dates, period, series.order) {
  294. t <- dates - zoo::as.Date('1970-01-01')
  295. features <- matrix(0, length(t), 2 * series.order)
  296. for (i in 1:series.order) {
  297. x <- as.numeric(2 * i * pi * t / period)
  298. features[, i * 2 - 1] <- sin(x)
  299. features[, i * 2] <- cos(x)
  300. }
  301. return(features)
  302. }
  303. #' Data frame with seasonality features.
  304. #'
  305. #' @param dates Vector of dates.
  306. #' @param period Number of days of the period.
  307. #' @param series.order Number of components.
  308. #' @param prefix Column name prefix.
  309. #'
  310. #' @return Dataframe with seasonality.
  311. #'
  312. make_seasonality_features <- function(dates, period, series.order, prefix) {
  313. features <- fourier_series(dates, period, series.order)
  314. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
  315. return(data.frame(features))
  316. }
  317. #' Construct a matrix of holiday features.
  318. #'
  319. #' @param m Prophet object.
  320. #' @param dates Vector with dates used for computing seasonality.
  321. #'
  322. #' @return A dataframe with a column for each holiday.
  323. #'
  324. #' @importFrom dplyr "%>%"
  325. make_holiday_features <- function(m, dates) {
  326. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  327. wide <- m$holidays %>%
  328. dplyr::mutate(ds = zoo::as.Date(ds)) %>%
  329. dplyr::group_by(holiday, ds) %>%
  330. dplyr::filter(row_number() == 1) %>%
  331. dplyr::do({
  332. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  333. && !is.na(.$upper_window)) {
  334. offsets <- seq(.$lower_window, .$upper_window)
  335. } else {
  336. offsets <- c(0)
  337. }
  338. names <- paste(
  339. .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  340. dplyr::data_frame(ds = .$ds + offsets, holiday = names)
  341. }) %>%
  342. dplyr::mutate(x = scale.ratio) %>%
  343. tidyr::spread(holiday, x, fill = 0)
  344. holiday.mat <- data.frame(ds = dates) %>%
  345. dplyr::left_join(wide, by = 'ds') %>%
  346. dplyr::select(-ds)
  347. holiday.mat[is.na(holiday.mat)] <- 0
  348. return(holiday.mat)
  349. }
  350. #' Dataframe with seasonality features.
  351. #'
  352. #' @param m Prophet object.
  353. #' @param df Dataframe with dates for computing seasonality features.
  354. #'
  355. #' @return Dataframe with seasonality.
  356. #'
  357. make_all_seasonality_features <- function(m, df) {
  358. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  359. if (m$yearly.seasonality) {
  360. seasonal.features <- cbind(
  361. seasonal.features,
  362. make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
  363. }
  364. if (m$weekly.seasonality) {
  365. seasonal.features <- cbind(
  366. seasonal.features,
  367. make_seasonality_features(df$ds, 7, 3, 'weekly'))
  368. }
  369. if(!is.null(m$holidays)) {
  370. # A smaller prior scale will shrink holiday estimates more than seasonality
  371. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  372. seasonal.features <- cbind(
  373. seasonal.features,
  374. make_holiday_features(m, df$ds))
  375. }
  376. return(seasonal.features)
  377. }
  378. #' Initialize linear growth.
  379. #'
  380. #' Provides a strong initialization for linear growth by calculating the
  381. #' growth and offset parameters that pass the function through the first and
  382. #' last points in the time series.
  383. #'
  384. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  385. #' and t (scaled time).
  386. #'
  387. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  388. #' growth function.
  389. #'
  390. linear_growth_init <- function(df) {
  391. i0 <- which.min(df$ds)
  392. i1 <- which.max(df$ds)
  393. T <- df$t[i1] - df$t[i0]
  394. # Initialize the rate
  395. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  396. # And the offset
  397. m <- df$y_scaled[i0] - k * df$t[i0]
  398. return(c(k, m))
  399. }
  400. #' Initialize logistic growth.
  401. #'
  402. #' Provides a strong initialization for logistic growth by calculating the
  403. #' growth and offset parameters that pass the function through the first and
  404. #' last points in the time series.
  405. #'
  406. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  407. #' y_scaled (scaled time series), and t (scaled time).
  408. #'
  409. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  410. #' growth function.
  411. #'
  412. logistic_growth_init <- function(df) {
  413. i0 <- which.min(df$ds)
  414. i1 <- which.max(df$ds)
  415. T <- df$t[i1] - df$t[i0]
  416. # Force valid values, in case y > cap.
  417. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  418. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  419. if (abs(r0 - r1) <= 0.01) {
  420. r0 <- 1.05 * r0
  421. }
  422. L0 <- log(r0 - 1)
  423. L1 <- log(r1 - 1)
  424. # Initialize the offset
  425. m <- L0 * T / (L0 - L1)
  426. # And the rate
  427. k <- L0 / m
  428. return(c(k, m))
  429. }
  430. #' Fit the prophet model.
  431. #'
  432. #' This sets m$params to contain the fitted model parameters. It is a list
  433. #' with the following elements:
  434. #' k (M array): M posterior samples of the initial slope.
  435. #' m (M array): The initial intercept.
  436. #' delta (MxN matrix): The slope change at each of N changepoints.
  437. #' beta (MxK matrix): Coefficients for K seasonality features.
  438. #' sigma_obs (M array): Noise level.
  439. #' Note that M=1 if MAP estimation.
  440. #'
  441. #' @param m Prophet object.
  442. #' @param df Data frame.
  443. #' @param ... Additional arguments passed to the \code{optimizing} or
  444. #' \code{sampling} functions in Stan.
  445. #'
  446. #' @export
  447. fit.prophet <- function(m, df, ...) {
  448. if (!is.null(m$history)) {
  449. stop("Prophet object can only be fit once. Instantiate a new object.")
  450. }
  451. history <- df %>%
  452. dplyr::filter(!is.na(y))
  453. m$history.dates <- sort(zoo::as.Date(df$ds))
  454. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  455. history <- out$df
  456. m <- out$m
  457. m$history <- history
  458. seasonal.features <- make_all_seasonality_features(m, history)
  459. m <- set_changepoints(m)
  460. A <- get_changepoint_matrix(m)
  461. # Construct input to stan
  462. dat <- list(
  463. T = nrow(history),
  464. K = ncol(seasonal.features),
  465. S = length(m$changepoints.t),
  466. y = history$y_scaled,
  467. t = history$t,
  468. A = A,
  469. t_change = array(m$changepoints.t),
  470. X = as.matrix(seasonal.features),
  471. sigma = m$seasonality.prior.scale,
  472. tau = m$changepoint.prior.scale
  473. )
  474. # Run stan
  475. if (m$growth == 'linear') {
  476. kinit <- linear_growth_init(history)
  477. model <- get_prophet_stan_model('linear')
  478. } else {
  479. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  480. kinit <- logistic_growth_init(history)
  481. model <- get_prophet_stan_model('logistic')
  482. }
  483. stan_init <- function() {
  484. list(k = kinit[1],
  485. m = kinit[2],
  486. delta = array(rep(0, length(m$changepoints.t))),
  487. beta = array(rep(0, ncol(seasonal.features))),
  488. sigma_obs = 1
  489. )
  490. }
  491. if (m$mcmc.samples > 0) {
  492. stan.fit <- rstan::sampling(
  493. model,
  494. data = dat,
  495. init = stan_init,
  496. iter = m$mcmc.samples,
  497. ...
  498. )
  499. m$params <- rstan::extract(stan.fit)
  500. n.iteration <- length(m$params$k)
  501. } else {
  502. stan.fit <- rstan::optimizing(
  503. model,
  504. data = dat,
  505. init = stan_init,
  506. iter = 1e4,
  507. as_vector = FALSE,
  508. ...
  509. )
  510. m$params <- stan.fit$par
  511. n.iteration <- 1
  512. }
  513. # Cast the parameters to have consistent form, whether full bayes or MAP
  514. for (name in c('delta', 'beta')){
  515. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  516. }
  517. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  518. for (name in c('k', 'm', 'sigma_obs')){
  519. m$params[[name]] <- c(m$params[[name]])
  520. }
  521. # If no changepoints were requested, replace delta with 0s
  522. if (m$n.changepoints == 0) {
  523. # Fold delta into the base rate k
  524. m$params$k <- m$params$k + m$params$delta[, 1]
  525. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  526. }
  527. return(m)
  528. }
  529. #' Predict using the prophet model.
  530. #'
  531. #' @param object Prophet object.
  532. #' @param df Dataframe with dates for predictions (column ds), and capacity
  533. #' (column cap) if logistic growth. If not provided, predictions are made on
  534. #' the history.
  535. #' @param ... additional arguments.
  536. #'
  537. #' @return A dataframe with the forecast components.
  538. #'
  539. #' @examples
  540. #' \dontrun{
  541. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  542. #' y = sin(1:366/200) + rnorm(366)/10)
  543. #' m <- prophet(history)
  544. #' future <- make_future_dataframe(m, periods = 365)
  545. #' forecast <- predict(m, future)
  546. #' plot(m, forecast)
  547. #' }
  548. #'
  549. #' @export
  550. predict.prophet <- function(object, df = NULL, ...) {
  551. if (is.null(df)) {
  552. df <- object$history
  553. } else {
  554. out <- setup_dataframe(object, df)
  555. df <- out$df
  556. }
  557. df$trend <- predict_trend(object, df)
  558. df <- df %>%
  559. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  560. dplyr::bind_cols(predict_seasonal_components(object, df))
  561. df$yhat <- df$trend + df$seasonal
  562. return(df)
  563. }
  564. #' Evaluate the piecewise linear function.
  565. #'
  566. #' @param t Vector of times on which the function is evaluated.
  567. #' @param deltas Vector of rate changes at each changepoint.
  568. #' @param k Float initial rate.
  569. #' @param m Float initial offset.
  570. #' @param changepoint.ts Vector of changepoint times.
  571. #'
  572. #' @return Vector y(t).
  573. #'
  574. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  575. # Intercept changes
  576. gammas <- -changepoint.ts * deltas
  577. # Get cumulative slope and intercept at each t
  578. k_t <- rep(k, length(t))
  579. m_t <- rep(m, length(t))
  580. for (s in 1:length(changepoint.ts)) {
  581. indx <- t >= changepoint.ts[s]
  582. k_t[indx] <- k_t[indx] + deltas[s]
  583. m_t[indx] <- m_t[indx] + gammas[s]
  584. }
  585. y <- k_t * t + m_t
  586. return(y)
  587. }
  588. #' Evaluate the piecewise logistic function.
  589. #'
  590. #' @param t Vector of times on which the function is evaluated.
  591. #' @param cap Vector of capacities at each t.
  592. #' @param deltas Vector of rate changes at each changepoint.
  593. #' @param k Float initial rate.
  594. #' @param m Float initial offset.
  595. #' @param changepoint.ts Vector of changepoint times.
  596. #'
  597. #' @return Vector y(t).
  598. #'
  599. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  600. # Compute offset changes
  601. k.cum <- c(k, cumsum(deltas) + k)
  602. gammas <- rep(0, length(changepoint.ts))
  603. for (i in 1:length(changepoint.ts)) {
  604. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  605. * (1 - k.cum[i] / k.cum[i + 1]))
  606. }
  607. # Get cumulative rate and offset at each t
  608. k_t <- rep(k, length(t))
  609. m_t <- rep(m, length(t))
  610. for (s in 1:length(changepoint.ts)) {
  611. indx <- t >= changepoint.ts[s]
  612. k_t[indx] <- k_t[indx] + deltas[s]
  613. m_t[indx] <- m_t[indx] + gammas[s]
  614. }
  615. y <- cap / (1 + exp(-k_t * (t - m_t)))
  616. return(y)
  617. }
  618. #' Predict trend using the prophet model.
  619. #'
  620. #' @param model Prophet object.
  621. #' @param df Prediction dataframe.
  622. #'
  623. #' @return Vector with trend on prediction dates.
  624. #'
  625. predict_trend <- function(model, df) {
  626. k <- mean(model$params$k, na.rm = TRUE)
  627. param.m <- mean(model$params$m, na.rm = TRUE)
  628. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  629. t <- df$t
  630. if (model$growth == 'linear') {
  631. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  632. } else {
  633. cap <- df$cap_scaled
  634. trend <- piecewise_logistic(
  635. t, cap, deltas, k, param.m, model$changepoints.t)
  636. }
  637. return(trend * model$y.scale)
  638. }
  639. #' Predict seasonality broken down into components.
  640. #'
  641. #' @param m Prophet object.
  642. #' @param df Prediction dataframe.
  643. #'
  644. #' @return Dataframe with seasonal components.
  645. #'
  646. predict_seasonal_components <- function(m, df) {
  647. seasonal.features <- make_all_seasonality_features(m, df)
  648. lower.p <- (1 - m$interval.width)/2
  649. upper.p <- (1 + m$interval.width)/2
  650. # Broken down into components
  651. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  652. dplyr::mutate(col = 1:n()) %>%
  653. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  654. extra = "merge", fill = "right") %>%
  655. dplyr::filter(component != 'zeros')
  656. if (nrow(components) > 0) {
  657. component.predictions <- components %>%
  658. dplyr::group_by(component) %>% dplyr::do({
  659. comp <- (as.matrix(seasonal.features[, .$col])
  660. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  661. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  662. mean = rowMeans(comp, na.rm = TRUE),
  663. lower = apply(comp, 1, stats::quantile, lower.p,
  664. na.rm = TRUE),
  665. upper = apply(comp, 1, stats::quantile, upper.p,
  666. na.rm = TRUE))
  667. }) %>%
  668. tidyr::gather(stat, value, c(mean, lower, upper)) %>%
  669. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  670. tidyr::unite(component, component, stat, sep="") %>%
  671. tidyr::spread(component, value) %>%
  672. dplyr::select(-ix)
  673. component.predictions$seasonal <- rowSums(
  674. component.predictions[unique(components$component)])
  675. } else {
  676. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  677. }
  678. return(component.predictions)
  679. }
  680. #' Prophet uncertainty intervals.
  681. #'
  682. #' @param m Prophet object.
  683. #' @param df Prediction dataframe.
  684. #'
  685. #' @return Dataframe with uncertainty intervals.
  686. #'
  687. predict_uncertainty <- function(m, df) {
  688. # Sample trend, seasonality, and yhat from the extrapolation model.
  689. n.iterations <- length(m$params$k)
  690. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  691. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  692. seasonal.features <- make_all_seasonality_features(m, df)
  693. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  694. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  695. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  696. for (i in 1:n.iterations) {
  697. # For each set of parameters from MCMC (or just 1 set for MAP),
  698. for (j in 1:samp.per.iter) {
  699. # Do a simulation with this set of parameters,
  700. sim <- sample_model(m, df, seasonal.features, i)
  701. # Store the results
  702. for (key in c("trend", "seasonal", "yhat")) {
  703. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  704. }
  705. }
  706. }
  707. # Add uncertainty estimates
  708. lower.p <- (1 - m$interval.width)/2
  709. upper.p <- (1 + m$interval.width)/2
  710. intervals <- cbind(
  711. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  712. na.rm = TRUE)),
  713. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  714. na.rm = TRUE)),
  715. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  716. na.rm = TRUE))
  717. ) %>% dplyr::as_data_frame()
  718. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  719. c('lower', 'upper'), sep = "_")
  720. return(intervals)
  721. }
  722. #' Simulate observations from the extrapolated generative model.
  723. #'
  724. #' @param m Prophet object.
  725. #' @param df Prediction dataframe.
  726. #' @param seasonal.features Data frame of seasonal features
  727. #' @param iteration Int sampling iteration to use parameters from.
  728. #'
  729. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  730. #'
  731. sample_model <- function(m, df, seasonal.features, iteration) {
  732. trend <- sample_predictive_trend(m, df, iteration)
  733. beta <- m$params$beta[iteration,]
  734. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  735. sigma <- m$params$sigma_obs[iteration]
  736. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  737. return(list("yhat" = trend + seasonal + noise,
  738. "trend" = trend,
  739. "seasonal" = seasonal))
  740. }
  741. #' Simulate the trend using the extrapolated generative model.
  742. #'
  743. #' @param model Prophet object.
  744. #' @param df Prediction dataframe.
  745. #' @param iteration Int sampling iteration to use parameters from.
  746. #'
  747. #' @return Vector of simulated trend over df$t.
  748. #'
  749. sample_predictive_trend <- function(model, df, iteration) {
  750. k <- model$params$k[iteration]
  751. param.m <- model$params$m[iteration]
  752. deltas <- model$params$delta[iteration,]
  753. t <- df$t
  754. T <- max(t)
  755. if (T > 1) {
  756. # Get the time discretization of the history
  757. dt <- diff(model$history$t)
  758. dt <- min(dt[dt > 0])
  759. # Number of time periods in the future
  760. N <- ceiling((T - 1) / dt)
  761. S <- length(model$changepoints.t)
  762. # The history had S split points, over t = [0, 1].
  763. # The forecast is on [1, T], and should have the same average frequency of
  764. # rate changes. Thus for N time periods in the future, we want an average
  765. # of S * (T - 1) changepoints in expectation.
  766. prob.change <- min(1, (S * (T - 1)) / N)
  767. # This calculation works for both history and df not uniformly spaced.
  768. n.changes <- stats::rbinom(1, N, prob.change)
  769. # Sample ts
  770. if (n.changes == 0) {
  771. changepoint.ts.new <- c()
  772. } else {
  773. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  774. }
  775. } else {
  776. changepoint.ts.new <- c()
  777. n.changes <- 0
  778. }
  779. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  780. lambda <- mean(abs(c(deltas))) + 1e-8
  781. # Sample deltas
  782. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  783. # Combine with changepoints from the history
  784. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  785. deltas <- c(deltas, deltas.new)
  786. # Get the corresponding trend
  787. if (model$growth == 'linear') {
  788. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  789. } else {
  790. cap <- df$cap_scaled
  791. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  792. }
  793. return(trend * model$y.scale)
  794. }
  795. #' Make dataframe with future dates for forecasting.
  796. #'
  797. #' @param m Prophet model object.
  798. #' @param periods Int number of periods to forecast forward.
  799. #' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
  800. #' @param include_history Boolean to include the historical dates in the data
  801. #' frame for predictions.
  802. #'
  803. #' @return Dataframe that extends forward from the end of m$history for the
  804. #' requested number of periods.
  805. #'
  806. #' @export
  807. make_future_dataframe <- function(m, periods, freq = 'd',
  808. include_history = TRUE) {
  809. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  810. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  811. if (include_history) {
  812. dates <- c(m$history.dates, dates)
  813. }
  814. return(data.frame(ds = dates))
  815. }
  816. #' Merge history and forecast for plotting.
  817. #'
  818. #' @param m Prophet object.
  819. #' @param fcst Data frame returned by prophet predict.
  820. #'
  821. #' @importFrom dplyr "%>%"
  822. df_for_plotting <- function(m, fcst) {
  823. # Make sure there is no y in fcst
  824. fcst$y <- NULL
  825. df <- m$history %>%
  826. dplyr::select(ds, y) %>%
  827. dplyr::full_join(fcst, by = "ds") %>%
  828. dplyr::arrange(ds)
  829. return(df)
  830. }
  831. #' Plot the prophet forecast.
  832. #'
  833. #' @param x Prophet object.
  834. #' @param fcst Data frame returned by predict(m, df).
  835. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  836. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  837. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  838. #' figure, if available.
  839. #' @param xlabel Optional label for x-axis
  840. #' @param ylabel Optional label for y-axis
  841. #' @param ... additional arguments
  842. #'
  843. #' @return A ggplot2 plot.
  844. #'
  845. #' @examples
  846. #' \dontrun{
  847. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  848. #' y = sin(1:366/200) + rnorm(366)/10)
  849. #' m <- prophet(history)
  850. #' future <- make_future_dataframe(m, periods = 365)
  851. #' forecast <- predict(m, future)
  852. #' plot(m, forecast)
  853. #' }
  854. #'
  855. #' @export
  856. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  857. xlabel = 'ds', ylabel = 'y', ...) {
  858. df <- df_for_plotting(x, fcst)
  859. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  860. ggplot2::labs(x = xlabel, y = ylabel)
  861. if (exists('cap', where = df) && plot_cap) {
  862. gg <- gg + ggplot2::geom_line(
  863. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  864. }
  865. if (uncertainty && exists('yhat_lower', where = df)) {
  866. gg <- gg +
  867. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  868. alpha = 0.2,
  869. fill = "#0072B2",
  870. na.rm = TRUE)
  871. }
  872. gg <- gg +
  873. ggplot2::geom_point(na.rm=TRUE) +
  874. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  875. na.rm = TRUE) +
  876. ggplot2::theme(aspect.ratio = 3 / 5)
  877. return(gg)
  878. }
  879. #' Plot the components of a prophet forecast.
  880. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  881. #' present, and holidays if present.
  882. #'
  883. #' @param m Prophet object.
  884. #' @param fcst Data frame returned by predict(m, df).
  885. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  886. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  887. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  888. #' figure, if available.
  889. #'
  890. #' @return Invisibly return a list containing the plotted ggplot objects
  891. #'
  892. #' @export
  893. #' @importFrom dplyr "%>%"
  894. prophet_plot_components <- function(m, fcst, uncertainty = TRUE,
  895. plot_cap = TRUE) {
  896. df <- df_for_plotting(m, fcst)
  897. # Plot the trend
  898. panels <- list(plot_trend(df, uncertainty, plot_cap))
  899. # Plot holiday components, if present.
  900. if (!is.null(m$holidays)) {
  901. panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
  902. }
  903. # Plot weekly seasonality, if present
  904. if ("weekly" %in% colnames(df)) {
  905. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty)
  906. }
  907. # Plot yearly seasonality, if present
  908. if ("yearly" %in% colnames(df)) {
  909. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty)
  910. }
  911. # Make the plot.
  912. grid::grid.newpage()
  913. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  914. 1)))
  915. for (i in 1:length(panels)) {
  916. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  917. layout.pos.col = 1))
  918. }
  919. return(invisible(panels))
  920. }
  921. #' Plot the prophet trend.
  922. #'
  923. #' @param df Forecast dataframe for plotting.
  924. #' @param uncertainty Boolean to plot uncertainty intervals.
  925. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  926. #' figure, if available.
  927. #'
  928. #' @return A ggplot2 plot.
  929. plot_trend <- function(df, uncertainty = TRUE, plot_cap = TRUE) {
  930. df.t <- df[!is.na(df$trend),]
  931. gg.trend <- ggplot2::ggplot(df.t, ggplot2::aes(x = ds, y = trend)) +
  932. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  933. if (exists('cap', where = df.t) && plot_cap) {
  934. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  935. linetype = 'dashed',
  936. na.rm = TRUE)
  937. }
  938. if (uncertainty) {
  939. gg.trend <- gg.trend +
  940. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  941. ymax = trend_upper),
  942. alpha = 0.2,
  943. fill = "#0072B2",
  944. na.rm = TRUE)
  945. }
  946. return(gg.trend)
  947. }
  948. #' Plot the holidays component of the forecast.
  949. #'
  950. #' @param m Prophet model
  951. #' @param df Forecast dataframe for plotting.
  952. #' @param uncertainty Boolean to plot uncertainty intervals.
  953. #'
  954. #' @return A ggplot2 plot.
  955. plot_holidays <- function(m, df, uncertainty = TRUE) {
  956. holiday.comps <- unique(m$holidays$holiday) %>% as.character()
  957. df.s <- data.frame(ds = df$ds,
  958. holidays = rowSums(df[, holiday.comps, drop = FALSE]),
  959. holidays_lower = rowSums(df[, paste0(holiday.comps,
  960. "_lower"), drop = FALSE]),
  961. holidays_upper = rowSums(df[, paste0(holiday.comps,
  962. "_upper"), drop = FALSE]))
  963. df.s <- df.s[!is.na(df.s$holidays),]
  964. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  965. # Since it is just for the visualization we will not worry about it now.
  966. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  967. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  968. if (uncertainty) {
  969. gg.holidays <- gg.holidays +
  970. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  971. ymax = holidays_upper),
  972. alpha = 0.2,
  973. fill = "#0072B2",
  974. na.rm = TRUE)
  975. }
  976. return(gg.holidays)
  977. }
  978. #' Plot the weekly component of the forecast.
  979. #'
  980. #' @param m Prophet model object
  981. #' @param uncertainty Boolean to plot uncertainty intervals.
  982. #'
  983. #' @return A ggplot2 plot.
  984. plot_weekly <- function(m, uncertainty = TRUE) {
  985. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  986. df.w <- data.frame(
  987. ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7), cap=1.)
  988. df.w <- setup_dataframe(m, df.w)$df
  989. seas <- predict_seasonal_components(m, df.w)
  990. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  991. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  992. group = 1)) +
  993. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  994. ggplot2::labs(x = "Day of week")
  995. if (uncertainty) {
  996. gg.weekly <- gg.weekly +
  997. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  998. ymax = weekly_upper),
  999. alpha = 0.2,
  1000. fill = "#0072B2",
  1001. na.rm = TRUE)
  1002. }
  1003. return(gg.weekly)
  1004. }
  1005. #' Plot the yearly component of the forecast.
  1006. #'
  1007. #' @param m Prophet model object.
  1008. #' @param uncertainty Boolean to plot uncertainty intervals.
  1009. #'
  1010. #' @return A ggplot2 plot.
  1011. plot_yearly <- function(m, uncertainty = TRUE) {
  1012. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  1013. df.y <- data.frame(
  1014. ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365), cap=1.)
  1015. df.y <- setup_dataframe(m, df.y)$df
  1016. seas <- predict_seasonal_components(m, df.y)
  1017. seas$ds <- df.y$ds
  1018. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  1019. group = 1)) +
  1020. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1021. ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
  1022. ggplot2::labs(x = "Day of year")
  1023. if (uncertainty) {
  1024. gg.yearly <- gg.yearly +
  1025. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  1026. ymax = yearly_upper),
  1027. alpha = 0.2,
  1028. fill = "#0072B2",
  1029. na.rm = TRUE)
  1030. }
  1031. return(gg.yearly)
  1032. }
  1033. # fb-block 3