prophet.R 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df Dataframe containing the history. Must have columns ds (date type)
  16. #' and y, the time series. If growth is logistic, then df must also have a
  17. #' column cap that specifies the capacity at each ds.
  18. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  19. #' trend.
  20. #' @param changepoints Vector of dates at which to include potential
  21. #' changepoints. If not specified, potential changepoints are selected
  22. #' automatically.
  23. #' @param n.changepoints Number of potential changepoints to include. Not used
  24. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  25. #' then n.changepoints potential changepoints are selected uniformly from the
  26. #' first 80 percent of df$ds.
  27. #' @param yearly.seasonality Boolean, fit yearly seasonality.
  28. #' @param weekly.seasonality Boolean, fit weekly seasonality.
  29. #' @param holidays data frame with columns holiday (character) and ds (date
  30. #' type)and optionally columns lower_window and upper_window which specify a
  31. #' range of days around the date to be included as holidays. lower_window=-2
  32. #' will include 2 days prior to the date as holidays.
  33. #' @param seasonality.prior.scale Parameter modulating the strength of the
  34. #' seasonality model. Larger values allow the model to fit larger seasonal
  35. #' fluctuations, smaller values dampen the seasonality.
  36. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  37. #' components model.
  38. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  39. #' automatic changepoint selection. Large values will allow many changepoints,
  40. #' small values will allow few changepoints.
  41. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  42. #' inference with the specified number of MCMC samples. If 0, will do MAP
  43. #' estimation.
  44. #' @param interval.width Numeric, width of the uncertainty intervals provided
  45. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  46. #' in the trend using the MAP estimate of the extrapolated generative model.
  47. #' If mcmc.samples>0, this will be integrated over all model parameters,
  48. #' which will include uncertainty in seasonality.
  49. #' @param uncertainty.samples Number of simulated draws used to estimate
  50. #' uncertainty intervals.
  51. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  52. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  53. #'
  54. #' @return A prophet model.
  55. #'
  56. #' @examples
  57. #' \dontrun{
  58. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  59. #' y = sin(1:366/200) + rnorm(366)/10)
  60. #' m <- prophet(history)
  61. #' }
  62. #'
  63. #' @export
  64. #' @importFrom dplyr "%>%"
  65. #' @import Rcpp
  66. prophet <- function(df = df,
  67. growth = 'linear',
  68. changepoints = NULL,
  69. n.changepoints = 25,
  70. yearly.seasonality = TRUE,
  71. weekly.seasonality = TRUE,
  72. holidays = NULL,
  73. seasonality.prior.scale = 10,
  74. holidays.prior.scale = 10,
  75. changepoint.prior.scale = 0.05,
  76. mcmc.samples = 0,
  77. interval.width = 0.80,
  78. uncertainty.samples = 1000,
  79. fit = TRUE,
  80. ...
  81. ) {
  82. # fb-block 1
  83. if (!is.null(changepoints)) {
  84. n.changepoints <- length(changepoints)
  85. }
  86. m <- list(
  87. growth = growth,
  88. changepoints = changepoints,
  89. n.changepoints = n.changepoints,
  90. yearly.seasonality = yearly.seasonality,
  91. weekly.seasonality = weekly.seasonality,
  92. holidays = holidays,
  93. seasonality.prior.scale = seasonality.prior.scale,
  94. changepoint.prior.scale = changepoint.prior.scale,
  95. holidays.prior.scale = holidays.prior.scale,
  96. mcmc.samples = mcmc.samples,
  97. interval.width = interval.width,
  98. uncertainty.samples = uncertainty.samples,
  99. start = NULL, # This and following attributes are set during fitting
  100. y.scale = NULL,
  101. t.scale = NULL,
  102. changepoints.t = NULL,
  103. stan.fit = NULL,
  104. params = list(),
  105. history = NULL,
  106. history.dates = NULL
  107. )
  108. validate_inputs(m)
  109. class(m) <- append("prophet", class(m))
  110. if (fit) {
  111. m <- fit.prophet(m, df, ...)
  112. }
  113. # fb-block 2
  114. return(m)
  115. }
  116. #' Validates the inputs to Prophet.
  117. #'
  118. #' @param m Prophet object.
  119. #'
  120. validate_inputs <- function(m) {
  121. if (!(m$growth %in% c('linear', 'logistic'))) {
  122. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  123. }
  124. if (!is.null(m$holidays)) {
  125. if (!(exists('holiday', where = m$holidays))) {
  126. stop('Holidays dataframe must have holiday field.')
  127. }
  128. if (!(exists('ds', where = m$holidays))) {
  129. stop('Holidays dataframe must have ds field.')
  130. }
  131. has.lower <- exists('lower_window', where = m$holidays)
  132. has.upper <- exists('upper_window', where = m$holidays)
  133. if (has.lower + has.upper == 1) {
  134. stop(paste('Holidays must have both lower_window and upper_window,',
  135. 'or neither.'))
  136. }
  137. if (has.lower) {
  138. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  139. stop('Holiday lower_window should be <= 0')
  140. }
  141. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  142. stop('Holiday upper_window should be >= 0')
  143. }
  144. }
  145. for (h in unique(m$holidays$holiday)) {
  146. if (grepl("_delim_", h)) {
  147. stop('Holiday name cannot contain "_delim_"')
  148. }
  149. if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
  150. stop(paste0('Holiday name "', h, '" reserved.'))
  151. }
  152. }
  153. }
  154. }
  155. #' Load compiled Stan model
  156. #'
  157. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  158. #' trend.
  159. #'
  160. #' @return Stan model.
  161. get_prophet_stan_model <- function(model) {
  162. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  163. ## If the cached model doesn't work, just compile a new one.
  164. tryCatch({
  165. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  166. package = 'prophet',
  167. mustWork = TRUE)
  168. load(binary)
  169. obj.name <- paste(model, 'growth.stanm', sep = '.')
  170. stanm <- eval(parse(text = obj.name))
  171. ## Should cause an error if the model doesn't work.
  172. stanm@mk_cppmodule(stanm)
  173. stanm
  174. }, error = function(cond) {
  175. compile_stan_model(model)
  176. })
  177. }
  178. #' Compile Stan model
  179. #'
  180. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  181. #' trend.
  182. #'
  183. #' @return Stan model.
  184. compile_stan_model <- function(model) {
  185. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  186. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  187. stanc <- rstan::stanc(stan.src)
  188. model.name <- paste(model, 'growth', sep = '_')
  189. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  190. }
  191. #' Prepare dataframe for fitting or predicting.
  192. #'
  193. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  194. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  195. #' and predicting.
  196. #'
  197. #' @param m Prophet object.
  198. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  199. #' @param initialize_scales Boolean set scaling factors in m from df.
  200. #'
  201. #' @return list with items 'df' and 'm'.
  202. #'
  203. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  204. if (exists('y', where=df)) {
  205. df$y <- as.numeric(df$y)
  206. }
  207. df$ds <- zoo::as.Date(df$ds)
  208. if (anyNA(df$ds)) {
  209. stop('Unable to parse date format in column ds. Convert to date format.')
  210. }
  211. df <- df %>%
  212. dplyr::arrange(ds)
  213. if (initialize_scales) {
  214. m$y.scale <- max(df$y)
  215. m$start <- min(df$ds)
  216. m$t.scale <- as.numeric(max(df$ds) - m$start)
  217. }
  218. df$t <- as.numeric(df$ds - m$start) / m$t.scale
  219. if (exists('y', where=df)) {
  220. df$y_scaled <- df$y / m$y.scale
  221. }
  222. if (m$growth == 'logistic') {
  223. if (!(exists('cap', where=df))) {
  224. stop('Capacities must be supplied for logistic growth.')
  225. }
  226. df <- df %>%
  227. dplyr::mutate(cap_scaled = cap / m$y.scale)
  228. }
  229. return(list("m" = m, "df" = df))
  230. }
  231. #' Set changepoints
  232. #'
  233. #' Sets m$changepoints to the dates of changepoints. Either:
  234. #' 1) The changepoints were passed in explicitly.
  235. #' A) They are empty.
  236. #' B) They are not empty, and need validation.
  237. #' 2) We are generating a grid of them.
  238. #' 3) The user prefers no changepoints be used.
  239. #'
  240. #' @param m Prophet object.
  241. #'
  242. #' @return m with changepoints set.
  243. #'
  244. set_changepoints <- function(m) {
  245. if (!is.null(m$changepoints)) {
  246. if (length(m$changepoints) > 0) {
  247. if (min(m$changepoints) < min(m$history$ds)
  248. || max(m$changepoints) > max(m$history$ds)) {
  249. stop('Changepoints must fall within training data.')
  250. }
  251. }
  252. } else {
  253. if (m$n.changepoints > 0) {
  254. # Place potential changepoints evenly through the first 80 pcnt of
  255. # the history.
  256. cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
  257. length.out = (m$n.changepoints + 1))) %>%
  258. utils::tail(-1)
  259. m$changepoints <- m$history$ds[cp.indexes]
  260. } else {
  261. m$changepoints <- c()
  262. }
  263. }
  264. if (length(m$changepoints) > 0) {
  265. m$changepoints <- zoo::as.Date(m$changepoints)
  266. m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
  267. } else {
  268. m$changepoints.t <- c(0) # dummy changepoint
  269. }
  270. return(m)
  271. }
  272. #' Gets changepoint matrix for history dataframe.
  273. #'
  274. #' @param m Prophet object.
  275. #'
  276. #' @return array of indexes.
  277. #'
  278. get_changepoint_matrix <- function(m) {
  279. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  280. for (i in 1:length(m$changepoints.t)) {
  281. A[m$history$t >= m$changepoints.t[i], i] <- 1
  282. }
  283. return(A)
  284. }
  285. #' Provides Fourier series components with the specified frequency and order.
  286. #'
  287. #' @param dates Vector of dates.
  288. #' @param period Number of days of the period.
  289. #' @param series.order Number of components.
  290. #'
  291. #' @return Matrix with seasonality features.
  292. #'
  293. fourier_series <- function(dates, period, series.order) {
  294. t <- dates - zoo::as.Date('1970-01-01')
  295. features <- matrix(0, length(t), 2 * series.order)
  296. for (i in 1:series.order) {
  297. x <- as.numeric(2 * i * pi * t / period)
  298. features[, i * 2 - 1] <- sin(x)
  299. features[, i * 2] <- cos(x)
  300. }
  301. return(features)
  302. }
  303. #' Data frame with seasonality features.
  304. #'
  305. #' @param dates Vector of dates.
  306. #' @param period Number of days of the period.
  307. #' @param series.order Number of components.
  308. #' @param prefix Column name prefix.
  309. #'
  310. #' @return Dataframe with seasonality.
  311. #'
  312. make_seasonality_features <- function(dates, period, series.order, prefix) {
  313. features <- fourier_series(dates, period, series.order)
  314. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
  315. return(data.frame(features))
  316. }
  317. #' Construct a matrix of holiday features.
  318. #'
  319. #' @param m Prophet object.
  320. #' @param dates Vector with dates used for computing seasonality.
  321. #'
  322. #' @return A dataframe with a column for each holiday.
  323. #'
  324. #' @importFrom dplyr "%>%"
  325. make_holiday_features <- function(m, dates) {
  326. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  327. wide <- m$holidays %>%
  328. dplyr::mutate(ds = zoo::as.Date(ds)) %>%
  329. dplyr::group_by(holiday, ds) %>%
  330. dplyr::filter(row_number() == 1) %>%
  331. dplyr::do({
  332. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  333. && !is.na(.$upper_window)) {
  334. offsets <- seq(.$lower_window, .$upper_window)
  335. } else {
  336. offsets <- c(0)
  337. }
  338. names <- paste(
  339. .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  340. dplyr::data_frame(ds = .$ds + offsets, holiday = names)
  341. }) %>%
  342. dplyr::mutate(x = scale.ratio) %>%
  343. tidyr::spread(holiday, x, fill = 0)
  344. holiday.mat <- data.frame(ds = dates) %>%
  345. dplyr::left_join(wide, by = 'ds') %>%
  346. dplyr::select(-ds)
  347. holiday.mat[is.na(holiday.mat)] <- 0
  348. return(holiday.mat)
  349. }
  350. #' Dataframe with seasonality features.
  351. #'
  352. #' @param m Prophet object.
  353. #' @param df Dataframe with dates for computing seasonality features.
  354. #'
  355. #' @return Dataframe with seasonality.
  356. #'
  357. make_all_seasonality_features <- function(m, df) {
  358. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  359. if (m$yearly.seasonality) {
  360. seasonal.features <- cbind(
  361. seasonal.features,
  362. make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
  363. }
  364. if (m$weekly.seasonality) {
  365. seasonal.features <- cbind(
  366. seasonal.features,
  367. make_seasonality_features(df$ds, 7, 3, 'weekly'))
  368. }
  369. if(!is.null(m$holidays)) {
  370. # A smaller prior scale will shrink holiday estimates more than seasonality
  371. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  372. seasonal.features <- cbind(
  373. seasonal.features,
  374. make_holiday_features(m, df$ds))
  375. }
  376. return(seasonal.features)
  377. }
  378. #' Initialize linear growth.
  379. #'
  380. #' Provides a strong initialization for linear growth by calculating the
  381. #' growth and offset parameters that pass the function through the first and
  382. #' last points in the time series.
  383. #'
  384. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  385. #' and t (scaled time).
  386. #'
  387. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  388. #' growth function.
  389. #'
  390. linear_growth_init <- function(df) {
  391. i0 <- which.min(df$ds)
  392. i1 <- which.max(df$ds)
  393. T <- df$t[i1] - df$t[i0]
  394. # Initialize the rate
  395. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  396. # And the offset
  397. m <- df$y_scaled[i0] - k * df$t[i0]
  398. return(c(k, m))
  399. }
  400. #' Initialize logistic growth.
  401. #'
  402. #' Provides a strong initialization for logistic growth by calculating the
  403. #' growth and offset parameters that pass the function through the first and
  404. #' last points in the time series.
  405. #'
  406. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  407. #' y_scaled (scaled time series), and t (scaled time).
  408. #'
  409. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  410. #' growth function.
  411. #'
  412. logistic_growth_init <- function(df) {
  413. i0 <- which.min(df$ds)
  414. i1 <- which.max(df$ds)
  415. T <- df$t[i1] - df$t[i0]
  416. # Force valid values, in case y > cap.
  417. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  418. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  419. if (abs(r0 - r1) <= 0.01) {
  420. r0 <- 1.05 * r0
  421. }
  422. L0 <- log(r0 - 1)
  423. L1 <- log(r1 - 1)
  424. # Initialize the offset
  425. m <- L0 * T / (L0 - L1)
  426. # And the rate
  427. k <- L0 / m
  428. return(c(k, m))
  429. }
  430. #' Fit the prophet model.
  431. #'
  432. #' This sets m$params to contain the fitted model parameters. It is a list
  433. #' with the following elements:
  434. #' k (M array): M posterior samples of the initial slope.
  435. #' m (M array): The initial intercept.
  436. #' delta (MxN matrix): The slope change at each of N changepoints.
  437. #' beta (MxK matrix): Coefficients for K seasonality features.
  438. #' sigma_obs (M array): Noise level.
  439. #' Note that M=1 if MAP estimation.
  440. #'
  441. #' @param m Prophet object.
  442. #' @param df Data frame.
  443. #' @param ... Additional arguments passed to the \code{optimizing} or
  444. #' \code{sampling} functions in Stan.
  445. #'
  446. #' @export
  447. fit.prophet <- function(m, df, ...) {
  448. if (!is.null(m$history)) {
  449. stop("Prophet object can only be fit once. Instantiate a new object.")
  450. }
  451. history <- df %>%
  452. dplyr::filter(!is.na(y))
  453. m$history.dates <- sort(zoo::as.Date(df$ds))
  454. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  455. history <- out$df
  456. m <- out$m
  457. m$history <- history
  458. seasonal.features <- make_all_seasonality_features(m, history)
  459. m <- set_changepoints(m)
  460. A <- get_changepoint_matrix(m)
  461. # Construct input to stan
  462. dat <- list(
  463. T = nrow(history),
  464. K = ncol(seasonal.features),
  465. S = length(m$changepoints.t),
  466. y = history$y_scaled,
  467. t = history$t,
  468. A = A,
  469. t_change = array(m$changepoints.t),
  470. X = as.matrix(seasonal.features),
  471. sigma = m$seasonality.prior.scale,
  472. tau = m$changepoint.prior.scale
  473. )
  474. # Run stan
  475. if (m$growth == 'linear') {
  476. kinit <- linear_growth_init(history)
  477. model <- get_prophet_stan_model('linear')
  478. } else {
  479. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  480. kinit <- logistic_growth_init(history)
  481. model <- get_prophet_stan_model('logistic')
  482. }
  483. stan_init <- function() {
  484. list(k = kinit[1],
  485. m = kinit[2],
  486. delta = array(rep(0, length(m$changepoints.t))),
  487. beta = array(rep(0, ncol(seasonal.features))),
  488. sigma_obs = 1
  489. )
  490. }
  491. if (m$mcmc.samples > 0) {
  492. stan.fit <- rstan::sampling(
  493. model,
  494. data = dat,
  495. init = stan_init,
  496. iter = m$mcmc.samples,
  497. ...
  498. )
  499. m$params <- rstan::extract(stan.fit)
  500. n.iteration <- length(m$params$k)
  501. } else {
  502. stan.fit <- rstan::optimizing(
  503. model,
  504. data = dat,
  505. init = stan_init,
  506. iter = 1e4,
  507. as_vector = FALSE,
  508. ...
  509. )
  510. m$params <- stan.fit$par
  511. n.iteration <- 1
  512. }
  513. # Cast the parameters to have consistent form, whether full bayes or MAP
  514. for (name in c('delta', 'beta')){
  515. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  516. }
  517. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  518. for (name in c('k', 'm', 'sigma_obs')){
  519. m$params[[name]] <- c(m$params[[name]])
  520. }
  521. # If no changepoints were requested, replace delta with 0s
  522. if (m$n.changepoints == 0) {
  523. # Fold delta into the base rate k
  524. m$params$k <- m$params$k + m$params$delta[, 1]
  525. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  526. }
  527. return(m)
  528. }
  529. #' Predict using the prophet model.
  530. #'
  531. #' @param object Prophet object.
  532. #' @param df Dataframe with dates for predictions (column ds), and capacity
  533. #' (column cap) if logistic growth. If not provided, predictions are made on
  534. #' the history.
  535. #' @param ... additional arguments.
  536. #'
  537. #' @return A dataframe with the forecast components.
  538. #'
  539. #' @examples
  540. #' \dontrun{
  541. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  542. #' y = sin(1:366/200) + rnorm(366)/10)
  543. #' m <- prophet(history)
  544. #' future <- make_future_dataframe(m, periods = 365)
  545. #' forecast <- predict(m, future)
  546. #' plot(m, forecast)
  547. #' }
  548. #'
  549. #' @export
  550. predict.prophet <- function(object, df = NULL, ...) {
  551. if (is.null(df)) {
  552. df <- object$history
  553. } else {
  554. out <- setup_dataframe(object, df)
  555. df <- out$df
  556. }
  557. df$trend <- predict_trend(object, df)
  558. df <- df %>%
  559. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  560. dplyr::bind_cols(predict_seasonal_components(object, df))
  561. df$yhat <- df$trend + df$seasonal
  562. return(df)
  563. }
  564. #' Evaluate the piecewise linear function.
  565. #'
  566. #' @param t Vector of times on which the function is evaluated.
  567. #' @param deltas Vector of rate changes at each changepoint.
  568. #' @param k Float initial rate.
  569. #' @param m Float initial offset.
  570. #' @param changepoint.ts Vector of changepoint times.
  571. #'
  572. #' @return Vector y(t).
  573. #'
  574. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  575. # Intercept changes
  576. gammas <- -changepoint.ts * deltas
  577. # Get cumulative slope and intercept at each t
  578. k_t <- rep(k, length(t))
  579. m_t <- rep(m, length(t))
  580. for (s in 1:length(changepoint.ts)) {
  581. indx <- t >= changepoint.ts[s]
  582. k_t[indx] <- k_t[indx] + deltas[s]
  583. m_t[indx] <- m_t[indx] + gammas[s]
  584. }
  585. y <- k_t * t + m_t
  586. return(y)
  587. }
  588. #' Evaluate the piecewise logistic function.
  589. #'
  590. #' @param t Vector of times on which the function is evaluated.
  591. #' @param cap Vector of capacities at each t.
  592. #' @param deltas Vector of rate changes at each changepoint.
  593. #' @param k Float initial rate.
  594. #' @param m Float initial offset.
  595. #' @param changepoint.ts Vector of changepoint times.
  596. #'
  597. #' @return Vector y(t).
  598. #'
  599. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  600. # Compute offset changes
  601. k.cum <- c(k, cumsum(deltas) + k)
  602. gammas <- rep(0, length(changepoint.ts))
  603. for (i in 1:length(changepoint.ts)) {
  604. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  605. * (1 - k.cum[i] / k.cum[i + 1]))
  606. }
  607. # Get cumulative rate and offset at each t
  608. k_t <- rep(k, length(t))
  609. m_t <- rep(m, length(t))
  610. for (s in 1:length(changepoint.ts)) {
  611. indx <- t >= changepoint.ts[s]
  612. k_t[indx] <- k_t[indx] + deltas[s]
  613. m_t[indx] <- m_t[indx] + gammas[s]
  614. }
  615. y <- cap / (1 + exp(-k_t * (t - m_t)))
  616. return(y)
  617. }
  618. #' Predict trend using the prophet model.
  619. #'
  620. #' @param model Prophet object.
  621. #' @param df Prediction dataframe.
  622. #'
  623. #' @return Vector with trend on prediction dates.
  624. #'
  625. predict_trend <- function(model, df) {
  626. k <- mean(model$params$k, na.rm = TRUE)
  627. param.m <- mean(model$params$m, na.rm = TRUE)
  628. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  629. t <- df$t
  630. if (model$growth == 'linear') {
  631. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  632. } else {
  633. cap <- df$cap_scaled
  634. trend <- piecewise_logistic(
  635. t, cap, deltas, k, param.m, model$changepoints.t)
  636. }
  637. return(trend * model$y.scale)
  638. }
  639. #' Predict seasonality broken down into components.
  640. #'
  641. #' @param m Prophet object.
  642. #' @param df Prediction dataframe.
  643. #'
  644. #' @return Dataframe with seasonal components.
  645. #'
  646. predict_seasonal_components <- function(m, df) {
  647. seasonal.features <- make_all_seasonality_features(m, df)
  648. lower.p <- (1 - m$interval.width)/2
  649. upper.p <- (1 + m$interval.width)/2
  650. # Broken down into components
  651. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  652. dplyr::mutate(col = 1:n()) %>%
  653. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  654. extra = "merge", fill = "right") %>%
  655. dplyr::filter(component != 'zeros')
  656. if (nrow(components) > 0) {
  657. component.predictions <- components %>%
  658. dplyr::group_by(component) %>% dplyr::do({
  659. comp <- (as.matrix(seasonal.features[, .$col])
  660. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  661. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  662. mean = rowMeans(comp, na.rm = TRUE),
  663. lower = apply(comp, 1, stats::quantile, lower.p,
  664. na.rm = TRUE),
  665. upper = apply(comp, 1, stats::quantile, upper.p,
  666. na.rm = TRUE))
  667. }) %>%
  668. tidyr::gather(stat, value, c(mean, lower, upper)) %>%
  669. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  670. tidyr::unite(component, component, stat, sep="") %>%
  671. tidyr::spread(component, value) %>%
  672. dplyr::select(-ix)
  673. component.predictions$seasonal <- rowSums(
  674. component.predictions[unique(components$component)])
  675. } else {
  676. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  677. }
  678. return(component.predictions)
  679. }
  680. #' Prophet uncertainty intervals.
  681. #'
  682. #' @param m Prophet object.
  683. #' @param df Prediction dataframe.
  684. #'
  685. #' @return Dataframe with uncertainty intervals.
  686. #'
  687. predict_uncertainty <- function(m, df) {
  688. # Sample trend, seasonality, and yhat from the extrapolation model.
  689. n.iterations <- length(m$params$k)
  690. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  691. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  692. seasonal.features <- make_all_seasonality_features(m, df)
  693. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  694. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  695. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  696. for (i in 1:n.iterations) {
  697. # For each set of parameters from MCMC (or just 1 set for MAP),
  698. for (j in 1:samp.per.iter) {
  699. # Do a simulation with this set of parameters,
  700. sim <- sample_model(m, df, seasonal.features, i)
  701. # Store the results
  702. for (key in c("trend", "seasonal", "yhat")) {
  703. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  704. }
  705. }
  706. }
  707. # Add uncertainty estimates
  708. lower.p <- (1 - m$interval.width)/2
  709. upper.p <- (1 + m$interval.width)/2
  710. intervals <- cbind(
  711. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  712. na.rm = TRUE)),
  713. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  714. na.rm = TRUE)),
  715. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  716. na.rm = TRUE))
  717. ) %>% dplyr::as_data_frame()
  718. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  719. c('lower', 'upper'), sep = "_")
  720. return(intervals)
  721. }
  722. #' Simulate observations from the extrapolated generative model.
  723. #'
  724. #' @param m Prophet object.
  725. #' @param df Prediction dataframe.
  726. #' @param seasonal.features Data frame of seasonal features
  727. #' @param iteration Int sampling iteration to use parameters from.
  728. #'
  729. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  730. #'
  731. sample_model <- function(m, df, seasonal.features, iteration) {
  732. trend <- sample_predictive_trend(m, df, iteration)
  733. beta <- m$params$beta[iteration,]
  734. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  735. sigma <- m$params$sigma_obs[iteration]
  736. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  737. return(list("yhat" = trend + seasonal + noise,
  738. "trend" = trend,
  739. "seasonal" = seasonal))
  740. }
  741. #' Simulate the trend using the extrapolated generative model.
  742. #'
  743. #' @param model Prophet object.
  744. #' @param df Prediction dataframe.
  745. #' @param iteration Int sampling iteration to use parameters from.
  746. #'
  747. #' @return Vector of simulated trend over df$t.
  748. #'
  749. sample_predictive_trend <- function(model, df, iteration) {
  750. k <- model$params$k[iteration]
  751. param.m <- model$params$m[iteration]
  752. deltas <- model$params$delta[iteration,]
  753. t <- df$t
  754. T <- max(t)
  755. if (T > 1) {
  756. # Get the time discretization of the history
  757. dt <- diff(model$history$t)
  758. dt <- min(dt[dt > 0])
  759. # Number of time periods in the future
  760. N <- ceiling((T - 1) / dt)
  761. S <- length(model$changepoints.t)
  762. # The history had S split points, over t = [0, 1].
  763. # The forecast is on [1, T], and should have the same average frequency of
  764. # rate changes. Thus for N time periods in the future, we want an average
  765. # of S * (T - 1) changepoints in expectation.
  766. prob.change <- min(1, (S * (T - 1)) / N)
  767. # This calculation works for both history and df not uniformly spaced.
  768. n.changes <- stats::rbinom(1, N, prob.change)
  769. # Sample ts
  770. if (n.changes == 0) {
  771. changepoint.ts.new <- c()
  772. } else {
  773. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  774. }
  775. } else {
  776. changepoint.ts.new <- c()
  777. n.changes <- 0
  778. }
  779. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  780. lambda <- mean(abs(c(deltas))) + 1e-8
  781. # Sample deltas
  782. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  783. # Combine with changepoints from the history
  784. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  785. deltas <- c(deltas, deltas.new)
  786. # Get the corresponding trend
  787. if (model$growth == 'linear') {
  788. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  789. } else {
  790. cap <- df$cap_scaled
  791. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  792. }
  793. return(trend * model$y.scale)
  794. }
  795. #' Make dataframe with future dates for forecasting.
  796. #'
  797. #' @param m Prophet model object.
  798. #' @param periods Int number of periods to forecast forward.
  799. #' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
  800. #' @param include_history Boolean to include the historical dates in the data
  801. #' frame for predictions.
  802. #'
  803. #' @return Dataframe that extends forward from the end of m$history for the
  804. #' requested number of periods.
  805. #'
  806. #' @export
  807. make_future_dataframe <- function(m, periods, freq = 'd',
  808. include_history = TRUE) {
  809. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  810. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  811. if (include_history) {
  812. dates <- c(m$history.dates, dates)
  813. }
  814. return(data.frame(ds = dates))
  815. }
  816. #' Merge history and forecast for plotting.
  817. #'
  818. #' @param m Prophet object.
  819. #' @param fcst Data frame returned by prophet predict.
  820. #'
  821. #' @importFrom dplyr "%>%"
  822. df_for_plotting <- function(m, fcst) {
  823. # Make sure there is no y in fcst
  824. fcst$y <- NULL
  825. df <- m$history %>%
  826. dplyr::select(ds, y) %>%
  827. dplyr::full_join(fcst, by = "ds") %>%
  828. dplyr::arrange(ds)
  829. return(df)
  830. }
  831. #' Plot the prophet forecast.
  832. #'
  833. #' @param x Prophet object.
  834. #' @param fcst Data frame returned by predict(m, df).
  835. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  836. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  837. #' @param xlabel Optional label for x-axis
  838. #' @param ylabel Optional label for y-axis
  839. #' @param ... additional arguments
  840. #'
  841. #' @return A ggplot2 plot.
  842. #'
  843. #' @examples
  844. #' \dontrun{
  845. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  846. #' y = sin(1:366/200) + rnorm(366)/10)
  847. #' m <- prophet(history)
  848. #' future <- make_future_dataframe(m, periods = 365)
  849. #' forecast <- predict(m, future)
  850. #' plot(m, forecast)
  851. #' }
  852. #'
  853. #' @export
  854. plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds',
  855. ylabel = 'y', ...) {
  856. df <- df_for_plotting(x, fcst)
  857. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  858. ggplot2::labs(x = xlabel, y = ylabel)
  859. if (exists('cap', where = df)) {
  860. gg <- gg + ggplot2::geom_line(
  861. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  862. }
  863. if (uncertainty && exists('yhat_lower', where = df)) {
  864. gg <- gg +
  865. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  866. alpha = 0.2,
  867. fill = "#0072B2",
  868. na.rm = TRUE)
  869. }
  870. gg <- gg +
  871. ggplot2::geom_point(na.rm=TRUE) +
  872. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  873. na.rm = TRUE) +
  874. ggplot2::theme(aspect.ratio = 3 / 5)
  875. return(gg)
  876. }
  877. #' Plot the components of a prophet forecast.
  878. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  879. #' present, and holidays if present.
  880. #'
  881. #' @param m Prophet object.
  882. #' @param fcst Data frame returned by predict(m, df).
  883. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  884. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  885. #'
  886. #' @return Invisibly return a list containing the plotted ggplot objects
  887. #'
  888. #' @export
  889. #' @importFrom dplyr "%>%"
  890. prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
  891. df <- df_for_plotting(m, fcst)
  892. # Plot the trend
  893. panels <- list(plot_trend(df, uncertainty))
  894. # Plot holiday components, if present.
  895. if (!is.null(m$holidays)) {
  896. panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
  897. }
  898. # Plot weekly seasonality, if present
  899. if ("weekly" %in% colnames(df)) {
  900. panels[[length(panels) + 1]] <- plot_weekly(df, uncertainty)
  901. }
  902. # Plot yearly seasonality, if present
  903. if ("yearly" %in% colnames(df)) {
  904. panels[[length(panels) + 1]] <- plot_yearly(df, uncertainty)
  905. }
  906. # Make the plot.
  907. grid::grid.newpage()
  908. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  909. 1)))
  910. for (i in 1:length(panels)) {
  911. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  912. layout.pos.col = 1))
  913. }
  914. return(invisible(panels))
  915. }
  916. #' Plot the prophet trend.
  917. #'
  918. #' @param df Forecast dataframe for plotting.
  919. #' @param uncertainty Boolean to plot uncertainty intervals.
  920. #'
  921. #' @return A ggplot2 plot.
  922. plot_trend <- function(df, uncertainty = TRUE) {
  923. gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
  924. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  925. if (exists('cap', where = df)) {
  926. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  927. linetype = 'dashed',
  928. na.rm = TRUE)
  929. }
  930. if (uncertainty) {
  931. gg.trend <- gg.trend +
  932. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  933. ymax = trend_upper),
  934. alpha = 0.2,
  935. fill = "#0072B2",
  936. na.rm = TRUE)
  937. }
  938. return(gg.trend)
  939. }
  940. #' Plot the holidays component of the forecast.
  941. #'
  942. #' @param m Prophet model
  943. #' @param df Forecast dataframe for plotting.
  944. #' @param uncertainty Boolean to plot uncertainty intervals.
  945. #'
  946. #' @return A ggplot2 plot.
  947. plot_holidays <- function(m, df, uncertainty = TRUE) {
  948. holiday.comps <- unique(m$holidays$holiday) %>% as.character()
  949. df.s <- data.frame(ds = df$ds,
  950. holidays = rowSums(df[, holiday.comps, drop = FALSE]),
  951. holidays_lower = rowSums(df[, paste0(holiday.comps,
  952. "_lower"), drop = FALSE]),
  953. holidays_upper = rowSums(df[, paste0(holiday.comps,
  954. "_upper"), drop = FALSE]))
  955. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  956. # Since it is just for the visualization we will not worry about it now.
  957. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  958. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  959. if (uncertainty) {
  960. gg.holidays <- gg.holidays +
  961. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  962. ymax = holidays_upper),
  963. alpha = 0.2,
  964. fill = "#0072B2",
  965. na.rm = TRUE)
  966. }
  967. return(gg.holidays)
  968. }
  969. #' Plot the weekly component of the forecast.
  970. #'
  971. #' @param df Forecast dataframe for plotting.
  972. #' @param uncertainty Boolean to plot uncertainty intervals.
  973. #'
  974. #' @return A ggplot2 plot.
  975. plot_weekly <- function(df, uncertainty = TRUE) {
  976. # Get weekday names in current locale
  977. days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7))
  978. df.s <- df %>%
  979. dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>%
  980. dplyr::group_by(dow) %>%
  981. dplyr::slice(1) %>%
  982. dplyr::ungroup() %>%
  983. dplyr::arrange(dow)
  984. gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
  985. group = 1)) +
  986. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  987. ggplot2::labs(x = "Day of week")
  988. if (uncertainty) {
  989. gg.weekly <- gg.weekly +
  990. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  991. ymax = weekly_upper),
  992. alpha = 0.2,
  993. fill = "#0072B2",
  994. na.rm = TRUE)
  995. }
  996. return(gg.weekly)
  997. }
  998. #' Plot the yearly component of the forecast.
  999. #'
  1000. #' @param df Forecast dataframe for plotting.
  1001. #' @param uncertainty Boolean to plot uncertainty intervals.
  1002. #'
  1003. #' @return A ggplot2 plot.
  1004. plot_yearly <- function(df, uncertainty = TRUE) {
  1005. # Drop year from the dates
  1006. df.s <- df %>%
  1007. dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
  1008. dplyr::group_by(doy) %>%
  1009. dplyr::slice(1) %>%
  1010. dplyr::ungroup() %>%
  1011. dplyr::mutate(doy = zoo::as.Date(doy)) %>%
  1012. dplyr::arrange(doy)
  1013. gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
  1014. group = 1)) +
  1015. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1016. ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
  1017. ggplot2::labs(x = "Day of year")
  1018. if (uncertainty) {
  1019. gg.yearly <- gg.yearly +
  1020. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  1021. ymax = yearly_upper),
  1022. alpha = 0.2,
  1023. fill = "#0072B2",
  1024. na.rm = TRUE)
  1025. }
  1026. return(gg.yearly)
  1027. }
  1028. # fb-block 3