prophet.R 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecast.
  14. #'
  15. #' @param df Data frame with columns ds (date type) and y, the time series.
  16. #' If growth is logistic, then df must also have a column cap that specifies
  17. #' the capacity at each ds.
  18. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  19. #' trend.
  20. #' @param changepoints Vector of dates at which to include potential
  21. #' changepoints. Each date must be present in df$ds. If not specified,
  22. #' potential changepoints are selected automatically.
  23. #' @param n.changepoints Number of potential changepoints to include. Not used
  24. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  25. #' then n.changepoints potential changepoints are selected uniformly from the
  26. #' first 80 percent of df$ds.
  27. #' @param yearly.seasonality Boolean, fit yearly seasonality.
  28. #' @param weekly.seasonality Boolean, fit weekly seasonality.
  29. #' @param holidays data frame with columns holiday (character) and ds (date
  30. #' type)and optionally columns lower_window and upper_window which specify a
  31. #' range of days around the date to be included as holidays.
  32. #' @param seasonality.prior.scale Parameter modulating the strength of the
  33. #' seasonality model. Larger values allow the model to fit larger seasonal
  34. #' fluctuations, smaller values dampen the seasonality.
  35. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  36. #' automatic changepoint selection. Large values will allow many changepoints,
  37. #' small values will allow few changepoints.
  38. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  39. #' components model.
  40. #' @param mcmc.samples Integer, if great than 0, will do full Bayesian
  41. #' inference with the specified number of MCMC samples. If 0, will do MAP
  42. #' estimation.
  43. #' @param interval.width Numeric, width of the uncertainty intervals provided
  44. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  45. #' in the trend using the MAP estimate of the extrapolated generative model.
  46. #' If mcmc.samples>0, this will be integrated over all model parameters,
  47. #' which will include uncertainty in seasonality.
  48. #' @param uncertainty.samples Number of simulated draws used to estimate
  49. #' uncertainty intervals.
  50. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  51. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  52. #'
  53. #' @return A prophet model.
  54. #'
  55. #' @examples
  56. #' \dontrun{
  57. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  58. #' y = sin(1:366/200) + rnorm(366)/10)
  59. #' m <- prophet(history)
  60. #' }
  61. #'
  62. #' @export
  63. #' @importFrom dplyr "%>%"
  64. #' @import Rcpp
  65. prophet <- function(df = df,
  66. growth = 'linear',
  67. changepoints = NULL,
  68. n.changepoints = 25,
  69. yearly.seasonality = TRUE,
  70. weekly.seasonality = TRUE,
  71. holidays = NULL,
  72. seasonality.prior.scale = 10,
  73. changepoint.prior.scale = 0.05,
  74. holidays.prior.scale = 10,
  75. mcmc.samples = 0,
  76. interval.width = 0.80,
  77. uncertainty.samples = 1000,
  78. fit = TRUE,
  79. ...
  80. ) {
  81. # fb-block 1
  82. if (!is.null(changepoints)) {
  83. n.changepoints <- length(changepoints)
  84. }
  85. m <- list(
  86. growth = growth,
  87. changepoints = changepoints,
  88. n.changepoints = n.changepoints,
  89. yearly.seasonality = yearly.seasonality,
  90. weekly.seasonality = weekly.seasonality,
  91. holidays = holidays,
  92. seasonality.prior.scale = seasonality.prior.scale,
  93. changepoint.prior.scale = changepoint.prior.scale,
  94. holidays.prior.scale = holidays.prior.scale,
  95. mcmc.samples = mcmc.samples,
  96. interval.width = interval.width,
  97. uncertainty.samples = uncertainty.samples,
  98. start = NULL, # This and following attributes are set during fitting
  99. end = NULL,
  100. y.scale = NULL,
  101. stan.fit = NULL,
  102. params = list(),
  103. history = NULL
  104. )
  105. validate_inputs(m)
  106. class(m) <- append(class(m), "prophet")
  107. if (fit) {
  108. m <- fit.prophet(m, df, ...)
  109. }
  110. # fb-block 2
  111. return(m)
  112. }
  113. #' Validates the inputs to Prophet.
  114. #'
  115. #' @param m Prophet object.
  116. #'
  117. validate_inputs <- function(m) {
  118. if (!(m$growth %in% c('linear', 'logistic'))) {
  119. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  120. }
  121. if (!is.null(m$holidays)) {
  122. if (!(exists('holiday', where = m$holidays))) {
  123. stop('Holidays dataframe must have holiday field.')
  124. }
  125. if (!(exists('ds', where = m$holidays))) {
  126. stop('Holidays dataframe must have ds field.')
  127. }
  128. for (h in unique(m$holidays$holiday)) {
  129. if (grepl("_", h)) {
  130. stop('Holiday name cannot contain "_"')
  131. }
  132. if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
  133. stop(paste0('Holiday name "', h, '" reserved.'))
  134. }
  135. }
  136. }
  137. }
  138. #' Load compiled Stan model
  139. #'
  140. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  141. #' trend.
  142. #'
  143. #' @return Stan model.
  144. get_prophet_stan_model <- function(model) {
  145. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  146. ## If the cached model doesn't work, just compile a new one.
  147. tryCatch({
  148. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  149. package = 'prophet',
  150. mustWork = TRUE)
  151. load(binary)
  152. obj.name <- paste(model, 'growth.stanm', sep = '.')
  153. stanm <- eval(parse(text = obj.name))
  154. ## Should cause an error if the model doesn't work.
  155. stanm@mk_cppmodule(stanm)
  156. stanm
  157. }, error = function(cond) {
  158. compile_stan_model(model)
  159. })
  160. }
  161. #' Compile Stan model
  162. #'
  163. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  164. #' trend.
  165. #'
  166. #' @return Stan model.
  167. compile_stan_model <- function(model) {
  168. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  169. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  170. stanc <- rstan::stanc(stan.src)
  171. model.name <- paste(model, 'growth', sep = '_')
  172. rstan::stan_model(stanc_ret = stanc, model_name = model.name)
  173. }
  174. #' Prepare dataframe for fitting or predicting.
  175. #'
  176. #' Adds a time index and scales y.
  177. #'
  178. #' @param m Prophet object.
  179. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  180. #' @param initialize_scales Boolean set scaling factors in m from df.
  181. #'
  182. #' @return list with items 'df' and 'm'.
  183. #'
  184. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  185. if (exists('y', where=df)) {
  186. df$y <- as.numeric(df$y)
  187. }
  188. df$ds = zoo::as.Date(df$ds)
  189. df <- df %>%
  190. dplyr::arrange(ds)
  191. if (initialize_scales) {
  192. m$y.scale <- max(df$y)
  193. m$start <- min(df$ds)
  194. m$end <- max(df$ds)
  195. }
  196. t.scale <- as.numeric(m$end - m$start)
  197. df$t <- as.numeric(df$ds - m$start) / t.scale
  198. if (exists('y', where=df)) {
  199. df$y_scaled <- df$y / m$y.scale
  200. }
  201. if (m$growth == 'logistic') {
  202. if (!(exists('cap', where=df))) {
  203. stop('Capacities must be supplied for logistic growth.')
  204. }
  205. df <- df %>%
  206. dplyr::mutate(cap_scaled = cap / m$y.scale)
  207. }
  208. return(list("m" = m, "df" = df))
  209. }
  210. #' Set changepoints
  211. #'
  212. #' Sets m$changepoints to the dates of changepoints.
  213. #'
  214. #' @param m Prophet object.
  215. #'
  216. #' @return m with changepoints set.
  217. #'
  218. set_changepoints <- function(m) {
  219. if (!is.null(m$changepoints)) {
  220. if (length(m$changepoints) > 0) {
  221. if (min(m$changepoints) < min(m$history$ds)
  222. || max(m$changepoints) > max(m$history$ds)) {
  223. stop('Changepoints must fall within training data.')
  224. }
  225. }
  226. } else {
  227. if (m$n.changepoints > 0) {
  228. # Place potential changepoints evenly through the first 80 pcnt of
  229. # the history.
  230. cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
  231. length.out = (m$n.changepoints + 1))) %>%
  232. utils::tail(-1)
  233. m$changepoints <- m$history$ds[cp.indexes]
  234. } else {
  235. m$changepoints <- c()
  236. }
  237. }
  238. return(m)
  239. }
  240. #' Gets changepoint indexes in history dataframe.
  241. #'
  242. #' @param m Prophet object.
  243. #'
  244. #' @return array of indexes.
  245. #'
  246. get_changepoint_indexes <- function(m) {
  247. if (length(m$changepoints) == 0) {
  248. return(c(1))
  249. } else {
  250. return(match(zoo::as.Date(m$changepoints), m$history$ds))
  251. }
  252. }
  253. #' Gets changepoint times, in scaled space.
  254. #'
  255. #' @param m Prophet object.
  256. #'
  257. #' @return array of times.
  258. #'
  259. get_changepoint_times <- function(m) {
  260. cpi <- get_changepoint_indexes(m)
  261. return(m$history$t[cpi])
  262. }
  263. #' Gets changepoint matrix for history dataframe.
  264. #'
  265. #' @param m Prophet object.
  266. #'
  267. #' @return array of indexes.
  268. #'
  269. get_changepoint_matrix <- function(m) {
  270. changepoint.indexes <- get_changepoint_indexes(m)
  271. A <- matrix(0, nrow(m$history), length(changepoint.indexes))
  272. for (i in 1:length(changepoint.indexes)) {
  273. A[changepoint.indexes[i]:nrow(m$history), i] <- 1
  274. }
  275. return(A)
  276. }
  277. #' Provides fourier series components with the specified frequency.
  278. #'
  279. #' @param dates Vector of dates.
  280. #' @param period Number of days of the period.
  281. #' @param series.order Number of components.
  282. #'
  283. #' @return Matrix with seasonality features.
  284. #'
  285. fourier_series <- function(dates, period, series.order) {
  286. t <- dates - zoo::as.Date('1970-01-01')
  287. features <- matrix(0, length(t), 2 * series.order)
  288. for (i in 1:series.order) {
  289. x <- as.numeric(2 * i * pi * t / period)
  290. features[, i * 2 - 1] <- sin(x)
  291. features[, i * 2] <- cos(x)
  292. }
  293. return(features)
  294. }
  295. #' Data frame with seasonality features.
  296. #'
  297. #' @param dates Vector of dates.
  298. #' @param period Number of days of the period.
  299. #' @param series.order Number of components.
  300. #' @param prefix Column name prefix
  301. #'
  302. #' @return Dataframe with seasonality.
  303. #'
  304. make_seasonality_features <- function(dates, period, series.order, prefix) {
  305. features <- fourier_series(dates, period, series.order)
  306. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_')
  307. return(data.frame(features))
  308. }
  309. #' Construct a matrix of holiday features.
  310. #'
  311. #' @param m Prophet object.
  312. #' @param dates Vector with dates used for computing seasonality.
  313. #'
  314. #' @return A dataframe with a column for each holiday
  315. #'
  316. #' @importFrom dplyr "%>%"
  317. make_holiday_features <- function(m, dates) {
  318. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  319. wide <- m$holidays %>%
  320. dplyr::mutate(ds = zoo::as.Date(ds)) %>%
  321. dplyr::group_by(holiday, ds) %>%
  322. dplyr::do({
  323. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  324. && !is.na(.$upper_window)) {
  325. offsets <- seq(.$lower_window, .$upper_window)
  326. } else {
  327. offsets <- c(0)
  328. }
  329. names <- paste(
  330. .$holiday, '_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  331. dplyr::data_frame(ds = .$ds + offsets, holiday = names)
  332. }) %>%
  333. dplyr::mutate(x = scale.ratio) %>%
  334. tidyr::spread(holiday, x, fill = 0)
  335. holiday.mat <- data.frame(ds = dates) %>%
  336. dplyr::left_join(wide, by = 'ds') %>%
  337. dplyr::select(-ds)
  338. holiday.mat[is.na(holiday.mat)] <- 0
  339. return(holiday.mat)
  340. }
  341. #' Data frame seasonality features.
  342. #'
  343. #' @param m Prophet object.
  344. #' @param df Dataframe with dates for computing seasonality features.
  345. #'
  346. #' @return Dataframe with seasonality.
  347. #'
  348. make_all_seasonality_features <- function(m, df) {
  349. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  350. if (m$yearly.seasonality) {
  351. seasonal.features <- cbind(
  352. seasonal.features,
  353. make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
  354. }
  355. if (m$weekly.seasonality) {
  356. seasonal.features <- cbind(
  357. seasonal.features,
  358. make_seasonality_features(df$ds, 7, 3, 'weekly'))
  359. }
  360. if(!is.null(m$holidays)) {
  361. # A smaller prior scale will shrink holiday estimates more than seasonality
  362. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  363. seasonal.features <- cbind(
  364. seasonal.features,
  365. make_holiday_features(m, df$ds))
  366. }
  367. return(seasonal.features)
  368. }
  369. #' Initialize linear growth
  370. #'
  371. #' Provides a strong initialization for linear growth by calculating the
  372. #' growth and offset parameters that pass the function through the first and
  373. #' last points in the time series.
  374. #'
  375. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  376. #' y_scaled (scaled time series), and t (scaled time).
  377. #'
  378. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  379. #' growth function.
  380. #'
  381. linear_growth_init <- function(df) {
  382. i0 <- which.min(df$ds)
  383. i1 <- which.max(df$ds)
  384. T <- df$t[i1] - df$t[i0]
  385. # Initialize the rate
  386. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  387. # And the offset
  388. m <- df$y_scaled[i0] - k * df$t[i0]
  389. return(c(k, m))
  390. }
  391. #' Initialize logistic growth
  392. #'
  393. #' Provides a strong initialization for logistic growth by calculating the
  394. #' growth and offset parameters that pass the function through the first and
  395. #' last points in the time series.
  396. #'
  397. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  398. #' y_scaled (scaled time series), and t (scaled time).
  399. #'
  400. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  401. #' growth function.
  402. #'
  403. logistic_growth_init <- function(df) {
  404. i0 <- which.min(df$ds)
  405. i1 <- which.max(df$ds)
  406. T <- df$t[i1] - df$t[i0]
  407. # Force valid values, in case y > cap.
  408. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  409. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  410. if (abs(r0 - r1) <= 0.01) {
  411. r0 <- 1.05 * r0
  412. }
  413. L0 <- log(r0 - 1)
  414. L1 <- log(r1 - 1)
  415. # Initialize the offset
  416. m <- L0 * T / (L0 - L1)
  417. # And the rate
  418. k <- L0 / m
  419. return(c(k, m))
  420. }
  421. #' Fit the prophet model.
  422. #'
  423. #' @param m Prophet object.
  424. #' @param df Data frame.
  425. #' @param ... Additional arguments passed to the \code{optimizing} or
  426. #' \code{sampling} functions in Stan.
  427. #'
  428. #' @export
  429. fit.prophet <- function(m, df, ...) {
  430. history <- df %>%
  431. dplyr::filter(!is.na(y))
  432. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  433. history <- out$df
  434. m <- out$m
  435. m$history <- history
  436. seasonal.features <- make_all_seasonality_features(m, history)
  437. m <- set_changepoints(m)
  438. A <- get_changepoint_matrix(m)
  439. changepoint.indexes <- get_changepoint_indexes(m)
  440. # Construct input to stan
  441. dat <- list(
  442. T = nrow(history),
  443. K = ncol(seasonal.features),
  444. S = length(changepoint.indexes),
  445. y = history$y_scaled,
  446. t = history$t,
  447. A = A,
  448. s_indx = array(changepoint.indexes),
  449. X = as.matrix(seasonal.features),
  450. sigma = m$seasonality.prior.scale,
  451. tau = m$changepoint.prior.scale
  452. )
  453. # Run stan
  454. if (m$growth == 'linear') {
  455. kinit <- linear_growth_init(history)
  456. model <- get_prophet_stan_model('linear')
  457. } else {
  458. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  459. kinit <- logistic_growth_init(history)
  460. model <- get_prophet_stan_model('logistic')
  461. }
  462. stan_init <- function() {
  463. list(k = kinit[1],
  464. m = kinit[2],
  465. delta = array(rep(0, length(changepoint.indexes))),
  466. beta = array(rep(0, ncol(seasonal.features))),
  467. sigma_obs = 1
  468. )
  469. }
  470. if (m$mcmc.samples > 0) {
  471. stan.fit <- rstan::sampling(
  472. model,
  473. data = dat,
  474. init = stan_init,
  475. iter = m$mcmc.samples,
  476. ...
  477. )
  478. m$params <- rstan::extract(stan.fit)
  479. n.iteration <- length(m$params$k)
  480. } else {
  481. stan.fit <- rstan::optimizing(
  482. model,
  483. data = dat,
  484. init = stan_init,
  485. iter = 1e4,
  486. as_vector = FALSE,
  487. ...
  488. )
  489. m$params <- stan.fit$par
  490. n.iteration <- 1
  491. }
  492. # Cast the parameters to have consistent form, whether full bayes or MAP
  493. for (name in c('delta', 'beta')){
  494. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  495. }
  496. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  497. for (name in c('k', 'm', 'sigma_obs')){
  498. m$params[[name]] <- c(m$params[[name]])
  499. }
  500. # If no changepoints were requested, replace delta with 0s
  501. if (m$n.changepoints == 0) {
  502. # Fold delta into the base rate k
  503. m$params$k <- m$params$k + m$params$delta[, 1]
  504. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  505. }
  506. return(m)
  507. }
  508. #' Predict using the prophet model.
  509. #'
  510. #' @param object Prophet object.
  511. #' @param df Dataframe with dates for predictions, and capacity if logistic
  512. #' growth. If not provided, predictions are made on the history.
  513. #' @param ... additional arguments
  514. #'
  515. #' @return A data_frame with a forecast
  516. #'
  517. #' @examples
  518. #' \dontrun{
  519. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  520. #' y = sin(1:366/200) + rnorm(366)/10)
  521. #' m <- prophet(history)
  522. #' future <- make_future_dataframe(m, periods = 365)
  523. #' forecast <- predict(m, future)
  524. #' plot(m, forecast)
  525. #' }
  526. #'
  527. #' @export
  528. predict.prophet <- function(object, df = NULL, ...) {
  529. if (is.null(df)) {
  530. df <- object$history
  531. } else {
  532. out <- setup_dataframe(object, df)
  533. df <- out$df
  534. }
  535. df$trend <- predict_trend(object, df)
  536. df <- df %>%
  537. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  538. dplyr::bind_cols(predict_seasonal_components(object, df))
  539. df$yhat <- df$trend + df$seasonal
  540. return(df)
  541. }
  542. #' Evaluate the piecewise linear function.
  543. #'
  544. #' @param t Vector of times on which the function is evaluated.
  545. #' @param deltas Vector of rate changes at each changepoint.
  546. #' @param k Float initial rate.
  547. #' @param m Float initial offset.
  548. #' @param changepoint.ts Vector of changepoint times.
  549. #'
  550. #' @return Vector y(t).
  551. #'
  552. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  553. # Intercept changes
  554. gammas <- -changepoint.ts * deltas
  555. # Get cumulative slope and intercept at each t
  556. k_t <- rep(k, length(t))
  557. m_t <- rep(m, length(t))
  558. for (s in 1:length(changepoint.ts)) {
  559. indx <- t >= changepoint.ts[s]
  560. k_t[indx] <- k_t[indx] + deltas[s]
  561. m_t[indx] <- m_t[indx] + gammas[s]
  562. }
  563. y <- k_t * t + m_t
  564. return(y)
  565. }
  566. #' Evaluate the piecewise logistic function.
  567. #'
  568. #' @param t Vector of times on which the function is evaluated.
  569. #' @param cap Vector of capacities at each t.
  570. #' @param deltas Vector of rate changes at each changepoint.
  571. #' @param k Float initial rate.
  572. #' @param m Float initial offset.
  573. #' @param changepoint.ts Vector of changepoint times.
  574. #'
  575. #' @return Vector y(t).
  576. #'
  577. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  578. # Compute offset changes
  579. k.cum <- c(k, cumsum(deltas) + k)
  580. gammas <- rep(0, length(changepoint.ts))
  581. for (i in 1:length(changepoint.ts)) {
  582. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  583. * (1 - k.cum[i] / k.cum[i + 1]))
  584. }
  585. # Get cumulative rate and offset at each t
  586. k_t <- rep(k, length(t))
  587. m_t <- rep(m, length(t))
  588. for (s in 1:length(changepoint.ts)) {
  589. indx <- t >= changepoint.ts[s]
  590. k_t[indx] <- k_t[indx] + deltas[s]
  591. m_t[indx] <- m_t[indx] + gammas[s]
  592. }
  593. y <- cap / (1 + exp(-k_t * (t - m_t)))
  594. return(y)
  595. }
  596. #' Predict trend using the prophet model.
  597. #'
  598. #' @param model Prophet object.
  599. #' @param df Data frame.
  600. #'
  601. predict_trend <- function(model, df) {
  602. k <- mean(model$params$k, na.rm = TRUE)
  603. param.m <- mean(model$params$m, na.rm = TRUE)
  604. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  605. t <- df$t
  606. cpts <- get_changepoint_times(model)
  607. if (model$growth == 'linear') {
  608. trend <- piecewise_linear(t, deltas, k, param.m, cpts)
  609. } else {
  610. cap <- df$cap_scaled
  611. trend <- piecewise_logistic(t, cap, deltas, k, param.m, cpts)
  612. }
  613. return(trend * model$y.scale)
  614. }
  615. #' Seasonality broken down into components
  616. #'
  617. #' @param m Prophet object.
  618. #' @param df Data frame.
  619. #'
  620. predict_seasonal_components <- function(m, df) {
  621. seasonal.features <- make_all_seasonality_features(m, df)
  622. lower.p <- (1 - m$interval.width)/2
  623. upper.p <- (1 + m$interval.width)/2
  624. # Broken down into components
  625. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  626. dplyr::mutate(col = 1:n()) %>%
  627. tidyr::separate(component, c('component', 'part'), sep = "_",
  628. extra = "merge", fill = "right") %>%
  629. dplyr::filter(component != 'zeros')
  630. if (nrow(components) > 0) {
  631. component.predictions <- components %>%
  632. dplyr::group_by(component) %>% dplyr::do({
  633. comp <- (as.matrix(seasonal.features[, .$col])
  634. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  635. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  636. mean = rowMeans(comp, na.rm = TRUE),
  637. lower = apply(comp, 1, stats::quantile, lower.p,
  638. na.rm = TRUE),
  639. upper = apply(comp, 1, stats::quantile, upper.p,
  640. na.rm = TRUE))
  641. }) %>%
  642. tidyr::gather(stat, value, c(mean, lower, upper)) %>%
  643. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  644. tidyr::unite(component, component, stat, sep="") %>%
  645. tidyr::spread(component, value) %>%
  646. dplyr::select(-ix)
  647. component.predictions$seasonal <- rowSums(
  648. component.predictions[unique(components$component)])
  649. } else {
  650. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  651. }
  652. return(component.predictions)
  653. }
  654. #' Prophet uncertainty intervals.
  655. #'
  656. #' @param m Prophet object.
  657. #' @param df Data frame.
  658. #'
  659. predict_uncertainty <- function(m, df) {
  660. # Sample trend, seasonality, and yhat from the extrapolation model.
  661. n.iterations <- length(m$params$k)
  662. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  663. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  664. seasonal.features <- make_all_seasonality_features(m, df)
  665. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  666. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  667. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  668. for (i in 1:n.iterations) {
  669. # For each set of parameters from MCMC (or just 1 set for MAP),
  670. for (j in 1:samp.per.iter) {
  671. # Do a simulation with this set of parameters,
  672. sim <- sample_model(m, df, seasonal.features, i)
  673. # Store the results
  674. for (key in c("trend", "seasonal", "yhat")) {
  675. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  676. }
  677. }
  678. }
  679. # Add uncertainty estimates
  680. lower.p <- (1 - m$interval.width)/2
  681. upper.p <- (1 + m$interval.width)/2
  682. intervals <- cbind(
  683. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  684. na.rm = TRUE)),
  685. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  686. na.rm = TRUE)),
  687. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  688. na.rm = TRUE))
  689. ) %>% dplyr::as_data_frame()
  690. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  691. c('lower', 'upper'), sep = "_")
  692. return(intervals)
  693. }
  694. #' Simulate observations from the extrapolated generative model.
  695. #'
  696. #' @param m Prophet object.
  697. #' @param df Dataframe that was fit by Prophet.
  698. #' @param seasonal.features Data frame of seasonal features
  699. #' @param iteration Int sampling iteration ot use parameters from.
  700. #'
  701. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  702. #'
  703. sample_model <- function(m, df, seasonal.features, iteration) {
  704. trend <- sample_predictive_trend(m, df, iteration)
  705. beta <- m$params$beta[iteration,]
  706. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  707. sigma <- m$params$sigma_obs[iteration]
  708. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  709. return(list("yhat" = trend + seasonal + noise,
  710. "trend" = trend,
  711. "seasonal" = seasonal))
  712. }
  713. #' Simulate the trend using the extrapolated generative model.
  714. #'
  715. #' @param model Prophet object.
  716. #' @param df Dataframe that was fit by Prophet.
  717. #' @param iteration Int sampling iteration ot use parameters from.
  718. #'
  719. #' @return Vector of simulated trend over df$t.
  720. #'
  721. sample_predictive_trend <- function(model, df, iteration) {
  722. k <- model$params$k[iteration]
  723. param.m <- model$params$m[iteration]
  724. deltas <- model$params$delta[iteration,]
  725. t <- df$t
  726. changepoint.ts <- get_changepoint_times(model)
  727. T <- max(t)
  728. if (T > 1) {
  729. # Get the time discretization of the history
  730. dt <- diff(model$history$t)
  731. dt <- min(dt[dt > 0])
  732. # Number of time periods in the future
  733. N <- ceiling((T - 1) / dt)
  734. S <- length(changepoint.ts)
  735. # The history had S split points, over t = [0, 1].
  736. # The forecast is on [1, T], and should have the same average frequency of
  737. # rate changes. Thus for N time periods in the future, we want an average
  738. # of S * (T - 1) changepoints in expectation.
  739. prob.change <- min(1, (S * (T - 1)) / N)
  740. # This calculation works for both history and df not uniformly spaced.
  741. n.changes <- stats::rbinom(1, N, prob.change)
  742. # Sample ts
  743. if (n.changes == 0) {
  744. changepoint.ts.new <- c()
  745. } else {
  746. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  747. }
  748. } else {
  749. changepoint.ts.new <- c()
  750. n.changes <- 0
  751. }
  752. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  753. lambda <- mean(abs(c(deltas))) + 1e-8
  754. # Sample deltas
  755. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  756. # Combine with changepoints from the history
  757. changepoint.ts <- c(changepoint.ts, changepoint.ts.new)
  758. deltas <- c(deltas, deltas.new)
  759. # Get the corresponding trend
  760. if (model$growth == 'linear') {
  761. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  762. } else {
  763. cap <- df$cap_scaled
  764. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  765. }
  766. return(trend * model$y.scale)
  767. }
  768. #' Make dataframe with future dates for forecasting.
  769. #'
  770. #' @param m Prophet model object.
  771. #' @param periods Int number of periods to forecast forward.
  772. #' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
  773. #' @param include_history Boolean to include the historical dates in the data
  774. #' frame for predictions.
  775. #'
  776. #' @return Dataframe that extends forward from the end of m$history for the
  777. #' requested number of periods.
  778. #'
  779. #' @export
  780. make_future_dataframe <- function(m, periods, freq = 'd',
  781. include_history = TRUE) {
  782. dates <- seq(max(m$history$ds), length.out = periods, by = freq)[2:periods]
  783. if (include_history) {
  784. dates <- c(m$history$ds, dates)
  785. }
  786. return(data.frame(ds = dates))
  787. }
  788. #' Merge history and forecast for plotting.
  789. #'
  790. #' @param m Prophet object.
  791. #' @param fcst Data frame returned by prophet predict.
  792. #'
  793. #' @importFrom dplyr "%>%"
  794. df_for_plotting <- function(m, fcst) {
  795. # Make sure there is no y in fcst
  796. fcst$y <- NULL
  797. df <- m$history %>%
  798. dplyr::select(ds, y) %>%
  799. dplyr::full_join(fcst, by = "ds") %>%
  800. dplyr::arrange(ds)
  801. return(df)
  802. }
  803. #' Plot the prophet forecast.
  804. #'
  805. #' @param x Prophet object.
  806. #' @param fcst Data frame returned by predict(m, df).
  807. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  808. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  809. #' @param ... additional arguments
  810. #'
  811. #' @return A ggplot2 plot.
  812. #'
  813. #' @examples
  814. #' \dontrun{
  815. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  816. #' y = sin(1:366/200) + rnorm(366)/10)
  817. #' m <- prophet(history)
  818. #' future <- make_future_dataframe(m, periods = 365)
  819. #' forecast <- predict(m, future)
  820. #' plot(m, forecast)
  821. #' }
  822. #'
  823. #' @export
  824. plot.prophet <- function(x, fcst, uncertainty = TRUE, ...) {
  825. df <- df_for_plotting(x, fcst)
  826. forecast.color <- "#0072B2"
  827. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y))
  828. if (exists('cap', where = df)) {
  829. gg <- gg + ggplot2::geom_line(
  830. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  831. }
  832. if (uncertainty && exists('yhat_lower', where = df)) {
  833. gg <- gg +
  834. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  835. alpha = 0.2,
  836. fill = forecast.color,
  837. na.rm = TRUE)
  838. }
  839. gg <- gg +
  840. ggplot2::geom_point(na.rm=TRUE) +
  841. ggplot2::geom_line(ggplot2::aes(y = yhat), color = forecast.color,
  842. na.rm = TRUE) +
  843. ggplot2::theme(aspect.ratio = 3 / 5)
  844. return(gg)
  845. }
  846. #' Plot the components of a prophet forecast.
  847. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  848. #' present, and holidays if present.
  849. #'
  850. #' @param m Prophet object.
  851. #' @param fcst Data frame returned by predict(m, df).
  852. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  853. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  854. #'
  855. #' @export
  856. #' @importFrom dplyr "%>%"
  857. prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
  858. df <- df_for_plotting(m, fcst)
  859. forecast.color <- "#0072B2"
  860. # Plot the trend
  861. gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
  862. ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
  863. if (exists('cap', where = df)) {
  864. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  865. linetype = 'dashed',
  866. na.rm = TRUE)
  867. }
  868. if (uncertainty) {
  869. gg.trend <- gg.trend +
  870. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  871. ymax = trend_upper),
  872. alpha = 0.2,
  873. fill = forecast.color,
  874. na.rm = TRUE)
  875. }
  876. panels <- list(gg.trend)
  877. # Plot holiday components, if present.
  878. if (!is.null(m$holidays)) {
  879. holiday.comps <- unique(m$holidays$holiday)
  880. df.s <- data.frame(ds = df$ds,
  881. holidays = rowSums(df[, holiday.comps]),
  882. holidays_lower = rowSums(df[, paste0(holiday.comps,
  883. "_lower")]),
  884. holidays_upper = rowSums(df[, paste0(holiday.comps,
  885. "_upper")]))
  886. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  887. # Since it is just for the visualization we will not worry about it now.
  888. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  889. ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
  890. if (uncertainty) {
  891. gg.holidays <- gg.holidays +
  892. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  893. ymax = holidays_upper),
  894. alpha = 0.2,
  895. fill = forecast.color,
  896. na.rm = TRUE)
  897. }
  898. panels[[length(panels) + 1]] <- gg.holidays
  899. }
  900. # Plot weekly seasonality, if present
  901. if ("weekly" %in% colnames(df)) {
  902. df.s <- df %>%
  903. dplyr::mutate(dow = factor(
  904. weekdays(ds), levels = c('Sunday', 'Monday', 'Tuesday', 'Wednesday',
  905. 'Thursday', 'Friday', 'Saturday')
  906. )) %>%
  907. dplyr::group_by(dow) %>%
  908. dplyr::slice(1) %>%
  909. dplyr::ungroup() %>%
  910. dplyr::arrange(dow)
  911. gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
  912. group = 1)) +
  913. ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
  914. ggplot2::labs(x = "Day of week")
  915. if (uncertainty) {
  916. gg.weekly <- gg.weekly +
  917. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  918. ymax = weekly_upper),
  919. alpha = 0.2,
  920. fill = forecast.color,
  921. na.rm = TRUE)
  922. }
  923. panels[[length(panels) + 1]] <- gg.weekly
  924. }
  925. # Plot yearly seasonality, if present
  926. if ("yearly" %in% colnames(df)) {
  927. # Drop year from the dates
  928. df.s <- df %>%
  929. dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
  930. dplyr::group_by(doy) %>%
  931. dplyr::slice(1) %>%
  932. dplyr::ungroup() %>%
  933. dplyr::mutate(doy = zoo::as.Date(doy)) %>%
  934. dplyr::arrange(doy)
  935. gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
  936. group = 1)) +
  937. ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
  938. ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
  939. ggplot2::labs(x = "Day of year")
  940. if (uncertainty) {
  941. gg.yearly <- gg.yearly +
  942. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  943. ymax = yearly_upper),
  944. alpha = 0.2,
  945. fill = forecast.color,
  946. na.rm = TRUE)
  947. }
  948. panels[[length(panels) + 1]] = gg.yearly
  949. }
  950. # Make the plot.
  951. grid::grid.newpage()
  952. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  953. 1)))
  954. for (i in 1:length(panels)) {
  955. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  956. layout.pos.col = 1))
  957. }
  958. }
  959. # fb-block 3