prophet.R 33 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecast.
  14. #'
  15. #' @param df Data frame with columns ds (date type) and y, the time series.
  16. #' If growth is logistic, then df must also have a column cap that specifies
  17. #' the capacity at each ds.
  18. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  19. #' trend.
  20. #' @param changepoints Vector of dates at which to include potential
  21. #' changepoints. Each date must be present in df$ds. If not specified,
  22. #' potential changepoints are selected automatically.
  23. #' @param n.changepoints Number of potential changepoints to include. Not used
  24. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  25. #' then n.changepoints potential changepoints are selected uniformly from the
  26. #' first 80 percent of df$ds.
  27. #' @param yearly.seasonality Boolean, fit yearly seasonality.
  28. #' @param weekly.seasonality Boolean, fit weekly seasonality.
  29. #' @param holidays data frame with columns holiday (character) and ds (date
  30. #' type)and optionally columns lower_window and upper_window which specify a
  31. #' range of days around the date to be included as holidays.
  32. #' @param seasonality.prior.scale Parameter modulating the strength of the
  33. #' seasonality model. Larger values allow the model to fit larger seasonal
  34. #' fluctuations, smaller values dampen the seasonality.
  35. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  36. #' automatic changepoint selection. Large values will allow many changepoints,
  37. #' small values will allow few changepoints.
  38. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  39. #' components model.
  40. #' @param mcmc.samples Integer, if great than 0, will do full Bayesian
  41. #' inference with the specified number of MCMC samples. If 0, will do MAP
  42. #' estimation.
  43. #' @param interval.width Numeric, width of the uncertainty intervals provided
  44. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  45. #' in the trend using the MAP estimate of the extrapolated generative model.
  46. #' If mcmc.samples>0, this will be integrated over all model parameters,
  47. #' which will include uncertainty in seasonality.
  48. #' @param uncertainty.samples Number of simulated draws used to estimate
  49. #' uncertainty intervals.
  50. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  51. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  52. #'
  53. #' @return A prophet model.
  54. #'
  55. #' @examples
  56. #' \dontrun{
  57. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  58. #' y = sin(1:366/200) + rnorm(366)/10)
  59. #' m <- prophet(history)
  60. #' }
  61. #'
  62. #' @export
  63. #' @importFrom dplyr "%>%"
  64. #' @import Rcpp
  65. prophet <- function(df = df,
  66. growth = 'linear',
  67. changepoints = NULL,
  68. n.changepoints = 25,
  69. yearly.seasonality = TRUE,
  70. weekly.seasonality = TRUE,
  71. holidays = NULL,
  72. seasonality.prior.scale = 10,
  73. changepoint.prior.scale = 0.05,
  74. holidays.prior.scale = 10,
  75. mcmc.samples = 0,
  76. interval.width = 0.80,
  77. uncertainty.samples = 1000,
  78. fit = TRUE,
  79. ...
  80. ) {
  81. # fb-block 1
  82. if (!is.null(changepoints)) {
  83. n.changepoints <- length(changepoints)
  84. }
  85. m <- list(
  86. growth = growth,
  87. changepoints = changepoints,
  88. n.changepoints = n.changepoints,
  89. yearly.seasonality = yearly.seasonality,
  90. weekly.seasonality = weekly.seasonality,
  91. holidays = holidays,
  92. seasonality.prior.scale = seasonality.prior.scale,
  93. changepoint.prior.scale = changepoint.prior.scale,
  94. holidays.prior.scale = holidays.prior.scale,
  95. mcmc.samples = mcmc.samples,
  96. interval.width = interval.width,
  97. uncertainty.samples = uncertainty.samples,
  98. start = NULL, # This and following attributes are set during fitting
  99. y.scale = NULL,
  100. t.scale = NULL,
  101. changepoints.t = NULL,
  102. stan.fit = NULL,
  103. params = list(),
  104. history = NULL
  105. )
  106. validate_inputs(m)
  107. class(m) <- append(class(m), "prophet")
  108. if (fit) {
  109. m <- fit.prophet(m, df, ...)
  110. }
  111. # fb-block 2
  112. return(m)
  113. }
  114. #' Validates the inputs to Prophet.
  115. #'
  116. #' @param m Prophet object.
  117. #'
  118. validate_inputs <- function(m) {
  119. if (!(m$growth %in% c('linear', 'logistic'))) {
  120. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  121. }
  122. if (!is.null(m$holidays)) {
  123. if (!(exists('holiday', where = m$holidays))) {
  124. stop('Holidays dataframe must have holiday field.')
  125. }
  126. if (!(exists('ds', where = m$holidays))) {
  127. stop('Holidays dataframe must have ds field.')
  128. }
  129. for (h in unique(m$holidays$holiday)) {
  130. if (grepl("_", h)) {
  131. stop('Holiday name cannot contain "_"')
  132. }
  133. if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
  134. stop(paste0('Holiday name "', h, '" reserved.'))
  135. }
  136. }
  137. }
  138. }
  139. #' Load compiled Stan model
  140. #'
  141. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  142. #' trend.
  143. #'
  144. #' @return Stan model.
  145. get_prophet_stan_model <- function(model) {
  146. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  147. ## If the cached model doesn't work, just compile a new one.
  148. tryCatch({
  149. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  150. package = 'prophet',
  151. mustWork = TRUE)
  152. load(binary)
  153. obj.name <- paste(model, 'growth.stanm', sep = '.')
  154. stanm <- eval(parse(text = obj.name))
  155. ## Should cause an error if the model doesn't work.
  156. stanm@mk_cppmodule(stanm)
  157. stanm
  158. }, error = function(cond) {
  159. compile_stan_model(model)
  160. })
  161. }
  162. #' Compile Stan model
  163. #'
  164. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  165. #' trend.
  166. #'
  167. #' @return Stan model.
  168. compile_stan_model <- function(model) {
  169. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  170. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  171. stanc <- rstan::stanc(stan.src)
  172. model.name <- paste(model, 'growth', sep = '_')
  173. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  174. }
  175. #' Prepare dataframe for fitting or predicting.
  176. #'
  177. #' Adds a time index and scales y.
  178. #'
  179. #' @param m Prophet object.
  180. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  181. #' @param initialize_scales Boolean set scaling factors in m from df.
  182. #'
  183. #' @return list with items 'df' and 'm'.
  184. #'
  185. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  186. if (exists('y', where=df)) {
  187. df$y <- as.numeric(df$y)
  188. }
  189. df$ds <- zoo::as.Date(df$ds)
  190. df <- df %>%
  191. dplyr::arrange(ds)
  192. if (initialize_scales) {
  193. m$y.scale <- max(df$y)
  194. m$start <- min(df$ds)
  195. m$t.scale <- as.numeric(max(df$ds) - m$start)
  196. }
  197. df$t <- as.numeric(df$ds - m$start) / m$t.scale
  198. if (exists('y', where=df)) {
  199. df$y_scaled <- df$y / m$y.scale
  200. }
  201. if (m$growth == 'logistic') {
  202. if (!(exists('cap', where=df))) {
  203. stop('Capacities must be supplied for logistic growth.')
  204. }
  205. df <- df %>%
  206. dplyr::mutate(cap_scaled = cap / m$y.scale)
  207. }
  208. return(list("m" = m, "df" = df))
  209. }
  210. #' Set changepoints
  211. #'
  212. #' Sets m$changepoints to the dates of changepoints.
  213. #'
  214. #' @param m Prophet object.
  215. #'
  216. #' @return m with changepoints set.
  217. #'
  218. set_changepoints <- function(m) {
  219. if (!is.null(m$changepoints)) {
  220. if (length(m$changepoints) > 0) {
  221. if (min(m$changepoints) < min(m$history$ds)
  222. || max(m$changepoints) > max(m$history$ds)) {
  223. stop('Changepoints must fall within training data.')
  224. }
  225. }
  226. } else {
  227. if (m$n.changepoints > 0) {
  228. # Place potential changepoints evenly through the first 80 pcnt of
  229. # the history.
  230. cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
  231. length.out = (m$n.changepoints + 1))) %>%
  232. utils::tail(-1)
  233. m$changepoints <- m$history$ds[cp.indexes]
  234. } else {
  235. m$changepoints <- c()
  236. }
  237. }
  238. if (length(m$changepoints) > 0) {
  239. m$changepoints <- zoo::as.Date(m$changepoints)
  240. m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
  241. } else {
  242. m$changepoints.t <- c(0) # dummy changepoint
  243. }
  244. return(m)
  245. }
  246. #' Gets changepoint matrix for history dataframe.
  247. #'
  248. #' @param m Prophet object.
  249. #'
  250. #' @return array of indexes.
  251. #'
  252. get_changepoint_matrix <- function(m) {
  253. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  254. for (i in 1:length(m$changepoints.t)) {
  255. A[m$history$t >= m$changepoints.t[i], i] <- 1
  256. }
  257. return(A)
  258. }
  259. #' Provides fourier series components with the specified frequency.
  260. #'
  261. #' @param dates Vector of dates.
  262. #' @param period Number of days of the period.
  263. #' @param series.order Number of components.
  264. #'
  265. #' @return Matrix with seasonality features.
  266. #'
  267. fourier_series <- function(dates, period, series.order) {
  268. t <- dates - zoo::as.Date('1970-01-01')
  269. features <- matrix(0, length(t), 2 * series.order)
  270. for (i in 1:series.order) {
  271. x <- as.numeric(2 * i * pi * t / period)
  272. features[, i * 2 - 1] <- sin(x)
  273. features[, i * 2] <- cos(x)
  274. }
  275. return(features)
  276. }
  277. #' Data frame with seasonality features.
  278. #'
  279. #' @param dates Vector of dates.
  280. #' @param period Number of days of the period.
  281. #' @param series.order Number of components.
  282. #' @param prefix Column name prefix
  283. #'
  284. #' @return Dataframe with seasonality.
  285. #'
  286. make_seasonality_features <- function(dates, period, series.order, prefix) {
  287. features <- fourier_series(dates, period, series.order)
  288. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_')
  289. return(data.frame(features))
  290. }
  291. #' Construct a matrix of holiday features.
  292. #'
  293. #' @param m Prophet object.
  294. #' @param dates Vector with dates used for computing seasonality.
  295. #'
  296. #' @return A dataframe with a column for each holiday
  297. #'
  298. #' @importFrom dplyr "%>%"
  299. make_holiday_features <- function(m, dates) {
  300. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  301. wide <- m$holidays %>%
  302. dplyr::mutate(ds = zoo::as.Date(ds)) %>%
  303. dplyr::group_by(holiday, ds) %>%
  304. dplyr::do({
  305. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  306. && !is.na(.$upper_window)) {
  307. offsets <- seq(.$lower_window, .$upper_window)
  308. } else {
  309. offsets <- c(0)
  310. }
  311. names <- paste(
  312. .$holiday, '_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  313. dplyr::data_frame(ds = .$ds + offsets, holiday = names)
  314. }) %>%
  315. dplyr::mutate(x = scale.ratio) %>%
  316. tidyr::spread(holiday, x, fill = 0)
  317. holiday.mat <- data.frame(ds = dates) %>%
  318. dplyr::left_join(wide, by = 'ds') %>%
  319. dplyr::select(-ds)
  320. holiday.mat[is.na(holiday.mat)] <- 0
  321. return(holiday.mat)
  322. }
  323. #' Data frame seasonality features.
  324. #'
  325. #' @param m Prophet object.
  326. #' @param df Dataframe with dates for computing seasonality features.
  327. #'
  328. #' @return Dataframe with seasonality.
  329. #'
  330. make_all_seasonality_features <- function(m, df) {
  331. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  332. if (m$yearly.seasonality) {
  333. seasonal.features <- cbind(
  334. seasonal.features,
  335. make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
  336. }
  337. if (m$weekly.seasonality) {
  338. seasonal.features <- cbind(
  339. seasonal.features,
  340. make_seasonality_features(df$ds, 7, 3, 'weekly'))
  341. }
  342. if(!is.null(m$holidays)) {
  343. # A smaller prior scale will shrink holiday estimates more than seasonality
  344. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  345. seasonal.features <- cbind(
  346. seasonal.features,
  347. make_holiday_features(m, df$ds))
  348. }
  349. return(seasonal.features)
  350. }
  351. #' Initialize linear growth
  352. #'
  353. #' Provides a strong initialization for linear growth by calculating the
  354. #' growth and offset parameters that pass the function through the first and
  355. #' last points in the time series.
  356. #'
  357. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  358. #' y_scaled (scaled time series), and t (scaled time).
  359. #'
  360. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  361. #' growth function.
  362. #'
  363. linear_growth_init <- function(df) {
  364. i0 <- which.min(df$ds)
  365. i1 <- which.max(df$ds)
  366. T <- df$t[i1] - df$t[i0]
  367. # Initialize the rate
  368. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  369. # And the offset
  370. m <- df$y_scaled[i0] - k * df$t[i0]
  371. return(c(k, m))
  372. }
  373. #' Initialize logistic growth
  374. #'
  375. #' Provides a strong initialization for logistic growth by calculating the
  376. #' growth and offset parameters that pass the function through the first and
  377. #' last points in the time series.
  378. #'
  379. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  380. #' y_scaled (scaled time series), and t (scaled time).
  381. #'
  382. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  383. #' growth function.
  384. #'
  385. logistic_growth_init <- function(df) {
  386. i0 <- which.min(df$ds)
  387. i1 <- which.max(df$ds)
  388. T <- df$t[i1] - df$t[i0]
  389. # Force valid values, in case y > cap.
  390. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  391. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  392. if (abs(r0 - r1) <= 0.01) {
  393. r0 <- 1.05 * r0
  394. }
  395. L0 <- log(r0 - 1)
  396. L1 <- log(r1 - 1)
  397. # Initialize the offset
  398. m <- L0 * T / (L0 - L1)
  399. # And the rate
  400. k <- L0 / m
  401. return(c(k, m))
  402. }
  403. #' Fit the prophet model.
  404. #'
  405. #' @param m Prophet object.
  406. #' @param df Data frame.
  407. #' @param ... Additional arguments passed to the \code{optimizing} or
  408. #' \code{sampling} functions in Stan.
  409. #'
  410. #' @export
  411. fit.prophet <- function(m, df, ...) {
  412. history <- df %>%
  413. dplyr::filter(!is.na(y))
  414. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  415. history <- out$df
  416. m <- out$m
  417. m$history <- history
  418. seasonal.features <- make_all_seasonality_features(m, history)
  419. m <- set_changepoints(m)
  420. A <- get_changepoint_matrix(m)
  421. # Construct input to stan
  422. dat <- list(
  423. T = nrow(history),
  424. K = ncol(seasonal.features),
  425. S = length(m$changepoints.t),
  426. y = history$y_scaled,
  427. t = history$t,
  428. A = A,
  429. t_change = array(m$changepoints.t),
  430. X = as.matrix(seasonal.features),
  431. sigma = m$seasonality.prior.scale,
  432. tau = m$changepoint.prior.scale
  433. )
  434. # Run stan
  435. if (m$growth == 'linear') {
  436. kinit <- linear_growth_init(history)
  437. model <- get_prophet_stan_model('linear')
  438. } else {
  439. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  440. kinit <- logistic_growth_init(history)
  441. model <- get_prophet_stan_model('logistic')
  442. }
  443. stan_init <- function() {
  444. list(k = kinit[1],
  445. m = kinit[2],
  446. delta = array(rep(0, length(m$changepoints.t))),
  447. beta = array(rep(0, ncol(seasonal.features))),
  448. sigma_obs = 1
  449. )
  450. }
  451. if (m$mcmc.samples > 0) {
  452. stan.fit <- rstan::sampling(
  453. model,
  454. data = dat,
  455. init = stan_init,
  456. iter = m$mcmc.samples,
  457. ...
  458. )
  459. m$params <- rstan::extract(stan.fit)
  460. n.iteration <- length(m$params$k)
  461. } else {
  462. stan.fit <- rstan::optimizing(
  463. model,
  464. data = dat,
  465. init = stan_init,
  466. iter = 1e4,
  467. as_vector = FALSE,
  468. ...
  469. )
  470. m$params <- stan.fit$par
  471. n.iteration <- 1
  472. }
  473. # Cast the parameters to have consistent form, whether full bayes or MAP
  474. for (name in c('delta', 'beta')){
  475. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  476. }
  477. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  478. for (name in c('k', 'm', 'sigma_obs')){
  479. m$params[[name]] <- c(m$params[[name]])
  480. }
  481. # If no changepoints were requested, replace delta with 0s
  482. if (m$n.changepoints == 0) {
  483. # Fold delta into the base rate k
  484. m$params$k <- m$params$k + m$params$delta[, 1]
  485. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  486. }
  487. return(m)
  488. }
  489. #' Predict using the prophet model.
  490. #'
  491. #' @param object Prophet object.
  492. #' @param df Dataframe with dates for predictions, and capacity if logistic
  493. #' growth. If not provided, predictions are made on the history.
  494. #' @param ... additional arguments
  495. #'
  496. #' @return A data_frame with a forecast
  497. #'
  498. #' @examples
  499. #' \dontrun{
  500. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  501. #' y = sin(1:366/200) + rnorm(366)/10)
  502. #' m <- prophet(history)
  503. #' future <- make_future_dataframe(m, periods = 365)
  504. #' forecast <- predict(m, future)
  505. #' plot(m, forecast)
  506. #' }
  507. #'
  508. #' @export
  509. predict.prophet <- function(object, df = NULL, ...) {
  510. if (is.null(df)) {
  511. df <- object$history
  512. } else {
  513. out <- setup_dataframe(object, df)
  514. df <- out$df
  515. }
  516. df$trend <- predict_trend(object, df)
  517. df <- df %>%
  518. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  519. dplyr::bind_cols(predict_seasonal_components(object, df))
  520. df$yhat <- df$trend + df$seasonal
  521. return(df)
  522. }
  523. #' Evaluate the piecewise linear function.
  524. #'
  525. #' @param t Vector of times on which the function is evaluated.
  526. #' @param deltas Vector of rate changes at each changepoint.
  527. #' @param k Float initial rate.
  528. #' @param m Float initial offset.
  529. #' @param changepoint.ts Vector of changepoint times.
  530. #'
  531. #' @return Vector y(t).
  532. #'
  533. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  534. # Intercept changes
  535. gammas <- -changepoint.ts * deltas
  536. # Get cumulative slope and intercept at each t
  537. k_t <- rep(k, length(t))
  538. m_t <- rep(m, length(t))
  539. for (s in 1:length(changepoint.ts)) {
  540. indx <- t >= changepoint.ts[s]
  541. k_t[indx] <- k_t[indx] + deltas[s]
  542. m_t[indx] <- m_t[indx] + gammas[s]
  543. }
  544. y <- k_t * t + m_t
  545. return(y)
  546. }
  547. #' Evaluate the piecewise logistic function.
  548. #'
  549. #' @param t Vector of times on which the function is evaluated.
  550. #' @param cap Vector of capacities at each t.
  551. #' @param deltas Vector of rate changes at each changepoint.
  552. #' @param k Float initial rate.
  553. #' @param m Float initial offset.
  554. #' @param changepoint.ts Vector of changepoint times.
  555. #'
  556. #' @return Vector y(t).
  557. #'
  558. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  559. # Compute offset changes
  560. k.cum <- c(k, cumsum(deltas) + k)
  561. gammas <- rep(0, length(changepoint.ts))
  562. for (i in 1:length(changepoint.ts)) {
  563. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  564. * (1 - k.cum[i] / k.cum[i + 1]))
  565. }
  566. # Get cumulative rate and offset at each t
  567. k_t <- rep(k, length(t))
  568. m_t <- rep(m, length(t))
  569. for (s in 1:length(changepoint.ts)) {
  570. indx <- t >= changepoint.ts[s]
  571. k_t[indx] <- k_t[indx] + deltas[s]
  572. m_t[indx] <- m_t[indx] + gammas[s]
  573. }
  574. y <- cap / (1 + exp(-k_t * (t - m_t)))
  575. return(y)
  576. }
  577. #' Predict trend using the prophet model.
  578. #'
  579. #' @param model Prophet object.
  580. #' @param df Data frame.
  581. #'
  582. predict_trend <- function(model, df) {
  583. k <- mean(model$params$k, na.rm = TRUE)
  584. param.m <- mean(model$params$m, na.rm = TRUE)
  585. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  586. t <- df$t
  587. if (model$growth == 'linear') {
  588. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  589. } else {
  590. cap <- df$cap_scaled
  591. trend <- piecewise_logistic(
  592. t, cap, deltas, k, param.m, model$changepoints.t)
  593. }
  594. return(trend * model$y.scale)
  595. }
  596. #' Seasonality broken down into components
  597. #'
  598. #' @param m Prophet object.
  599. #' @param df Data frame.
  600. #'
  601. predict_seasonal_components <- function(m, df) {
  602. seasonal.features <- make_all_seasonality_features(m, df)
  603. lower.p <- (1 - m$interval.width)/2
  604. upper.p <- (1 + m$interval.width)/2
  605. # Broken down into components
  606. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  607. dplyr::mutate(col = 1:n()) %>%
  608. tidyr::separate(component, c('component', 'part'), sep = "_",
  609. extra = "merge", fill = "right") %>%
  610. dplyr::filter(component != 'zeros')
  611. if (nrow(components) > 0) {
  612. component.predictions <- components %>%
  613. dplyr::group_by(component) %>% dplyr::do({
  614. comp <- (as.matrix(seasonal.features[, .$col])
  615. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  616. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  617. mean = rowMeans(comp, na.rm = TRUE),
  618. lower = apply(comp, 1, stats::quantile, lower.p,
  619. na.rm = TRUE),
  620. upper = apply(comp, 1, stats::quantile, upper.p,
  621. na.rm = TRUE))
  622. }) %>%
  623. tidyr::gather(stat, value, c(mean, lower, upper)) %>%
  624. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  625. tidyr::unite(component, component, stat, sep="") %>%
  626. tidyr::spread(component, value) %>%
  627. dplyr::select(-ix)
  628. component.predictions$seasonal <- rowSums(
  629. component.predictions[unique(components$component)])
  630. } else {
  631. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  632. }
  633. return(component.predictions)
  634. }
  635. #' Prophet uncertainty intervals.
  636. #'
  637. #' @param m Prophet object.
  638. #' @param df Data frame.
  639. #'
  640. predict_uncertainty <- function(m, df) {
  641. # Sample trend, seasonality, and yhat from the extrapolation model.
  642. n.iterations <- length(m$params$k)
  643. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  644. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  645. seasonal.features <- make_all_seasonality_features(m, df)
  646. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  647. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  648. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  649. for (i in 1:n.iterations) {
  650. # For each set of parameters from MCMC (or just 1 set for MAP),
  651. for (j in 1:samp.per.iter) {
  652. # Do a simulation with this set of parameters,
  653. sim <- sample_model(m, df, seasonal.features, i)
  654. # Store the results
  655. for (key in c("trend", "seasonal", "yhat")) {
  656. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  657. }
  658. }
  659. }
  660. # Add uncertainty estimates
  661. lower.p <- (1 - m$interval.width)/2
  662. upper.p <- (1 + m$interval.width)/2
  663. intervals <- cbind(
  664. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  665. na.rm = TRUE)),
  666. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  667. na.rm = TRUE)),
  668. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  669. na.rm = TRUE))
  670. ) %>% dplyr::as_data_frame()
  671. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  672. c('lower', 'upper'), sep = "_")
  673. return(intervals)
  674. }
  675. #' Simulate observations from the extrapolated generative model.
  676. #'
  677. #' @param m Prophet object.
  678. #' @param df Dataframe that was fit by Prophet.
  679. #' @param seasonal.features Data frame of seasonal features
  680. #' @param iteration Int sampling iteration ot use parameters from.
  681. #'
  682. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  683. #'
  684. sample_model <- function(m, df, seasonal.features, iteration) {
  685. trend <- sample_predictive_trend(m, df, iteration)
  686. beta <- m$params$beta[iteration,]
  687. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  688. sigma <- m$params$sigma_obs[iteration]
  689. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  690. return(list("yhat" = trend + seasonal + noise,
  691. "trend" = trend,
  692. "seasonal" = seasonal))
  693. }
  694. #' Simulate the trend using the extrapolated generative model.
  695. #'
  696. #' @param model Prophet object.
  697. #' @param df Dataframe that was fit by Prophet.
  698. #' @param iteration Int sampling iteration ot use parameters from.
  699. #'
  700. #' @return Vector of simulated trend over df$t.
  701. #'
  702. sample_predictive_trend <- function(model, df, iteration) {
  703. k <- model$params$k[iteration]
  704. param.m <- model$params$m[iteration]
  705. deltas <- model$params$delta[iteration,]
  706. t <- df$t
  707. T <- max(t)
  708. if (T > 1) {
  709. # Get the time discretization of the history
  710. dt <- diff(model$history$t)
  711. dt <- min(dt[dt > 0])
  712. # Number of time periods in the future
  713. N <- ceiling((T - 1) / dt)
  714. S <- length(model$changepoints.t)
  715. # The history had S split points, over t = [0, 1].
  716. # The forecast is on [1, T], and should have the same average frequency of
  717. # rate changes. Thus for N time periods in the future, we want an average
  718. # of S * (T - 1) changepoints in expectation.
  719. prob.change <- min(1, (S * (T - 1)) / N)
  720. # This calculation works for both history and df not uniformly spaced.
  721. n.changes <- stats::rbinom(1, N, prob.change)
  722. # Sample ts
  723. if (n.changes == 0) {
  724. changepoint.ts.new <- c()
  725. } else {
  726. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  727. }
  728. } else {
  729. changepoint.ts.new <- c()
  730. n.changes <- 0
  731. }
  732. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  733. lambda <- mean(abs(c(deltas))) + 1e-8
  734. # Sample deltas
  735. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  736. # Combine with changepoints from the history
  737. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  738. deltas <- c(deltas, deltas.new)
  739. # Get the corresponding trend
  740. if (model$growth == 'linear') {
  741. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  742. } else {
  743. cap <- df$cap_scaled
  744. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  745. }
  746. return(trend * model$y.scale)
  747. }
  748. #' Make dataframe with future dates for forecasting.
  749. #'
  750. #' @param m Prophet model object.
  751. #' @param periods Int number of periods to forecast forward.
  752. #' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
  753. #' @param include_history Boolean to include the historical dates in the data
  754. #' frame for predictions.
  755. #'
  756. #' @return Dataframe that extends forward from the end of m$history for the
  757. #' requested number of periods.
  758. #'
  759. #' @export
  760. make_future_dataframe <- function(m, periods, freq = 'd',
  761. include_history = TRUE) {
  762. dates <- seq(max(m$history$ds), length.out = periods, by = freq)[2:periods]
  763. if (include_history) {
  764. dates <- c(m$history$ds, dates)
  765. }
  766. return(data.frame(ds = dates))
  767. }
  768. #' Merge history and forecast for plotting.
  769. #'
  770. #' @param m Prophet object.
  771. #' @param fcst Data frame returned by prophet predict.
  772. #'
  773. #' @importFrom dplyr "%>%"
  774. df_for_plotting <- function(m, fcst) {
  775. # Make sure there is no y in fcst
  776. fcst$y <- NULL
  777. df <- m$history %>%
  778. dplyr::select(ds, y) %>%
  779. dplyr::full_join(fcst, by = "ds") %>%
  780. dplyr::arrange(ds)
  781. return(df)
  782. }
  783. #' Plot the prophet forecast.
  784. #'
  785. #' @param x Prophet object.
  786. #' @param fcst Data frame returned by predict(m, df).
  787. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  788. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  789. #' @param xlabel Optional label for x-axis
  790. #' @param ylabel Optional label for y-axis
  791. #' @param ... additional arguments
  792. #'
  793. #' @return A ggplot2 plot.
  794. #'
  795. #' @examples
  796. #' \dontrun{
  797. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  798. #' y = sin(1:366/200) + rnorm(366)/10)
  799. #' m <- prophet(history)
  800. #' future <- make_future_dataframe(m, periods = 365)
  801. #' forecast <- predict(m, future)
  802. #' plot(m, forecast)
  803. #' }
  804. #'
  805. #' @export
  806. plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds',
  807. ylabel = 'y', ...) {
  808. df <- df_for_plotting(x, fcst)
  809. forecast.color <- "#0072B2"
  810. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  811. ggplot2::labs(x = xlabel, y = ylabel)
  812. if (exists('cap', where = df)) {
  813. gg <- gg + ggplot2::geom_line(
  814. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  815. }
  816. if (uncertainty && exists('yhat_lower', where = df)) {
  817. gg <- gg +
  818. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  819. alpha = 0.2,
  820. fill = forecast.color,
  821. na.rm = TRUE)
  822. }
  823. gg <- gg +
  824. ggplot2::geom_point(na.rm=TRUE) +
  825. ggplot2::geom_line(ggplot2::aes(y = yhat), color = forecast.color,
  826. na.rm = TRUE) +
  827. ggplot2::theme(aspect.ratio = 3 / 5)
  828. return(gg)
  829. }
  830. #' Plot the components of a prophet forecast.
  831. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  832. #' present, and holidays if present.
  833. #'
  834. #' @param m Prophet object.
  835. #' @param fcst Data frame returned by predict(m, df).
  836. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  837. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  838. #'
  839. #' @export
  840. #' @importFrom dplyr "%>%"
  841. prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
  842. df <- df_for_plotting(m, fcst)
  843. forecast.color <- "#0072B2"
  844. # Plot the trend
  845. gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
  846. ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
  847. if (exists('cap', where = df)) {
  848. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  849. linetype = 'dashed',
  850. na.rm = TRUE)
  851. }
  852. if (uncertainty) {
  853. gg.trend <- gg.trend +
  854. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  855. ymax = trend_upper),
  856. alpha = 0.2,
  857. fill = forecast.color,
  858. na.rm = TRUE)
  859. }
  860. panels <- list(gg.trend)
  861. # Plot holiday components, if present.
  862. if (!is.null(m$holidays)) {
  863. holiday.comps <- unique(m$holidays$holiday) %>% as.character()
  864. df.s <- data.frame(ds = df$ds,
  865. holidays = rowSums(df[, holiday.comps, drop = FALSE]),
  866. holidays_lower = rowSums(df[, paste0(holiday.comps,
  867. "_lower"), drop = FALSE]),
  868. holidays_upper = rowSums(df[, paste0(holiday.comps,
  869. "_upper"), drop = FALSE]))
  870. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  871. # Since it is just for the visualization we will not worry about it now.
  872. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  873. ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
  874. if (uncertainty) {
  875. gg.holidays <- gg.holidays +
  876. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  877. ymax = holidays_upper),
  878. alpha = 0.2,
  879. fill = forecast.color,
  880. na.rm = TRUE)
  881. }
  882. panels[[length(panels) + 1]] <- gg.holidays
  883. }
  884. # Plot weekly seasonality, if present
  885. if ("weekly" %in% colnames(df)) {
  886. df.s <- df %>%
  887. dplyr::mutate(dow = factor(
  888. weekdays(ds), levels = c('Sunday', 'Monday', 'Tuesday', 'Wednesday',
  889. 'Thursday', 'Friday', 'Saturday')
  890. )) %>%
  891. dplyr::group_by(dow) %>%
  892. dplyr::slice(1) %>%
  893. dplyr::ungroup() %>%
  894. dplyr::arrange(dow)
  895. gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
  896. group = 1)) +
  897. ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
  898. ggplot2::labs(x = "Day of week")
  899. if (uncertainty) {
  900. gg.weekly <- gg.weekly +
  901. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  902. ymax = weekly_upper),
  903. alpha = 0.2,
  904. fill = forecast.color,
  905. na.rm = TRUE)
  906. }
  907. panels[[length(panels) + 1]] <- gg.weekly
  908. }
  909. # Plot yearly seasonality, if present
  910. if ("yearly" %in% colnames(df)) {
  911. # Drop year from the dates
  912. df.s <- df %>%
  913. dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
  914. dplyr::group_by(doy) %>%
  915. dplyr::slice(1) %>%
  916. dplyr::ungroup() %>%
  917. dplyr::mutate(doy = zoo::as.Date(doy)) %>%
  918. dplyr::arrange(doy)
  919. gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
  920. group = 1)) +
  921. ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
  922. ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
  923. ggplot2::labs(x = "Day of year")
  924. if (uncertainty) {
  925. gg.yearly <- gg.yearly +
  926. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  927. ymax = yearly_upper),
  928. alpha = 0.2,
  929. fill = forecast.color,
  930. na.rm = TRUE)
  931. }
  932. panels[[length(panels) + 1]] = gg.yearly
  933. }
  934. # Make the plot.
  935. grid::grid.newpage()
  936. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  937. 1)))
  938. for (i in 1:length(panels)) {
  939. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  940. layout.pos.col = 1))
  941. }
  942. }
  943. # fb-block 3