prophet.R 39 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df Dataframe containing the history. Must have columns ds (date type)
  16. #' and y, the time series. If growth is logistic, then df must also have a
  17. #' column cap that specifies the capacity at each ds.
  18. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  19. #' trend.
  20. #' @param changepoints Vector of dates at which to include potential
  21. #' changepoints. If not specified, potential changepoints are selected
  22. #' automatically.
  23. #' @param n.changepoints Number of potential changepoints to include. Not used
  24. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  25. #' then n.changepoints potential changepoints are selected uniformly from the
  26. #' first 80 percent of df$ds.
  27. #' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
  28. #' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
  29. #' @param holidays data frame with columns holiday (character) and ds (date
  30. #' type)and optionally columns lower_window and upper_window which specify a
  31. #' range of days around the date to be included as holidays. lower_window=-2
  32. #' will include 2 days prior to the date as holidays.
  33. #' @param seasonality.prior.scale Parameter modulating the strength of the
  34. #' seasonality model. Larger values allow the model to fit larger seasonal
  35. #' fluctuations, smaller values dampen the seasonality.
  36. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  37. #' components model.
  38. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  39. #' automatic changepoint selection. Large values will allow many changepoints,
  40. #' small values will allow few changepoints.
  41. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  42. #' inference with the specified number of MCMC samples. If 0, will do MAP
  43. #' estimation.
  44. #' @param interval.width Numeric, width of the uncertainty intervals provided
  45. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  46. #' in the trend using the MAP estimate of the extrapolated generative model.
  47. #' If mcmc.samples>0, this will be integrated over all model parameters,
  48. #' which will include uncertainty in seasonality.
  49. #' @param uncertainty.samples Number of simulated draws used to estimate
  50. #' uncertainty intervals.
  51. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  52. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  53. #'
  54. #' @return A prophet model.
  55. #'
  56. #' @examples
  57. #' \dontrun{
  58. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  59. #' y = sin(1:366/200) + rnorm(366)/10)
  60. #' m <- prophet(history)
  61. #' }
  62. #'
  63. #' @export
  64. #' @importFrom dplyr "%>%"
  65. #' @import Rcpp
  66. prophet <- function(df = df,
  67. growth = 'linear',
  68. changepoints = NULL,
  69. n.changepoints = 25,
  70. yearly.seasonality = 'auto',
  71. weekly.seasonality = 'auto',
  72. holidays = NULL,
  73. seasonality.prior.scale = 10,
  74. holidays.prior.scale = 10,
  75. changepoint.prior.scale = 0.05,
  76. mcmc.samples = 0,
  77. interval.width = 0.80,
  78. uncertainty.samples = 1000,
  79. fit = TRUE,
  80. ...
  81. ) {
  82. # fb-block 1
  83. if (!is.null(changepoints)) {
  84. n.changepoints <- length(changepoints)
  85. }
  86. m <- list(
  87. growth = growth,
  88. changepoints = changepoints,
  89. n.changepoints = n.changepoints,
  90. yearly.seasonality = yearly.seasonality,
  91. weekly.seasonality = weekly.seasonality,
  92. holidays = holidays,
  93. seasonality.prior.scale = seasonality.prior.scale,
  94. changepoint.prior.scale = changepoint.prior.scale,
  95. holidays.prior.scale = holidays.prior.scale,
  96. mcmc.samples = mcmc.samples,
  97. interval.width = interval.width,
  98. uncertainty.samples = uncertainty.samples,
  99. start = NULL, # This and following attributes are set during fitting
  100. y.scale = NULL,
  101. t.scale = NULL,
  102. changepoints.t = NULL,
  103. stan.fit = NULL,
  104. params = list(),
  105. history = NULL,
  106. history.dates = NULL
  107. )
  108. validate_inputs(m)
  109. class(m) <- append("prophet", class(m))
  110. if (fit) {
  111. m <- fit.prophet(m, df, ...)
  112. }
  113. # fb-block 2
  114. return(m)
  115. }
  116. #' Validates the inputs to Prophet.
  117. #'
  118. #' @param m Prophet object.
  119. #'
  120. validate_inputs <- function(m) {
  121. if (!(m$growth %in% c('linear', 'logistic'))) {
  122. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  123. }
  124. if (!is.null(m$holidays)) {
  125. if (!(exists('holiday', where = m$holidays))) {
  126. stop('Holidays dataframe must have holiday field.')
  127. }
  128. if (!(exists('ds', where = m$holidays))) {
  129. stop('Holidays dataframe must have ds field.')
  130. }
  131. has.lower <- exists('lower_window', where = m$holidays)
  132. has.upper <- exists('upper_window', where = m$holidays)
  133. if (has.lower + has.upper == 1) {
  134. stop(paste('Holidays must have both lower_window and upper_window,',
  135. 'or neither.'))
  136. }
  137. if (has.lower) {
  138. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  139. stop('Holiday lower_window should be <= 0')
  140. }
  141. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  142. stop('Holiday upper_window should be >= 0')
  143. }
  144. }
  145. for (h in unique(m$holidays$holiday)) {
  146. if (grepl("_delim_", h)) {
  147. stop('Holiday name cannot contain "_delim_"')
  148. }
  149. if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
  150. stop(paste0('Holiday name "', h, '" reserved.'))
  151. }
  152. }
  153. }
  154. }
  155. #' Load compiled Stan model
  156. #'
  157. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  158. #' trend.
  159. #'
  160. #' @return Stan model.
  161. get_prophet_stan_model <- function(model) {
  162. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  163. ## If the cached model doesn't work, just compile a new one.
  164. tryCatch({
  165. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  166. package = 'prophet',
  167. mustWork = TRUE)
  168. load(binary)
  169. obj.name <- paste(model, 'growth.stanm', sep = '.')
  170. stanm <- eval(parse(text = obj.name))
  171. ## Should cause an error if the model doesn't work.
  172. stanm@mk_cppmodule(stanm)
  173. stanm
  174. }, error = function(cond) {
  175. compile_stan_model(model)
  176. })
  177. }
  178. #' Compile Stan model
  179. #'
  180. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  181. #' trend.
  182. #'
  183. #' @return Stan model.
  184. compile_stan_model <- function(model) {
  185. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  186. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  187. stanc <- rstan::stanc(stan.src)
  188. model.name <- paste(model, 'growth', sep = '_')
  189. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  190. }
  191. #' Prepare dataframe for fitting or predicting.
  192. #'
  193. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  194. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  195. #' and predicting.
  196. #'
  197. #' @param m Prophet object.
  198. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  199. #' @param initialize_scales Boolean set scaling factors in m from df.
  200. #'
  201. #' @return list with items 'df' and 'm'.
  202. #'
  203. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  204. if (exists('y', where=df)) {
  205. df$y <- as.numeric(df$y)
  206. }
  207. df$ds <- zoo::as.Date(df$ds)
  208. if (anyNA(df$ds)) {
  209. stop('Unable to parse date format in column ds. Convert to date format.')
  210. }
  211. df <- df %>%
  212. dplyr::arrange(ds)
  213. if (initialize_scales) {
  214. m$y.scale <- max(abs(df$y))
  215. m$start <- min(df$ds)
  216. m$t.scale <- as.numeric(max(df$ds) - m$start)
  217. }
  218. df$t <- as.numeric(df$ds - m$start) / m$t.scale
  219. if (exists('y', where=df)) {
  220. df$y_scaled <- df$y / m$y.scale
  221. }
  222. if (m$growth == 'logistic') {
  223. if (!(exists('cap', where=df))) {
  224. stop('Capacities must be supplied for logistic growth.')
  225. }
  226. df <- df %>%
  227. dplyr::mutate(cap_scaled = cap / m$y.scale)
  228. }
  229. return(list("m" = m, "df" = df))
  230. }
  231. #' Set changepoints
  232. #'
  233. #' Sets m$changepoints to the dates of changepoints. Either:
  234. #' 1) The changepoints were passed in explicitly.
  235. #' A) They are empty.
  236. #' B) They are not empty, and need validation.
  237. #' 2) We are generating a grid of them.
  238. #' 3) The user prefers no changepoints be used.
  239. #'
  240. #' @param m Prophet object.
  241. #'
  242. #' @return m with changepoints set.
  243. #'
  244. set_changepoints <- function(m) {
  245. if (!is.null(m$changepoints)) {
  246. if (length(m$changepoints) > 0) {
  247. if (min(m$changepoints) < min(m$history$ds)
  248. || max(m$changepoints) > max(m$history$ds)) {
  249. stop('Changepoints must fall within training data.')
  250. }
  251. }
  252. } else {
  253. if (m$n.changepoints > 0) {
  254. # Place potential changepoints evenly through the first 80 pcnt of
  255. # the history.
  256. cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
  257. length.out = (m$n.changepoints + 1))) %>%
  258. utils::tail(-1)
  259. m$changepoints <- m$history$ds[cp.indexes]
  260. } else {
  261. m$changepoints <- c()
  262. }
  263. }
  264. if (length(m$changepoints) > 0) {
  265. m$changepoints <- zoo::as.Date(m$changepoints)
  266. m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
  267. } else {
  268. m$changepoints.t <- c(0) # dummy changepoint
  269. }
  270. return(m)
  271. }
  272. #' Gets changepoint matrix for history dataframe.
  273. #'
  274. #' @param m Prophet object.
  275. #'
  276. #' @return array of indexes.
  277. #'
  278. get_changepoint_matrix <- function(m) {
  279. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  280. for (i in 1:length(m$changepoints.t)) {
  281. A[m$history$t >= m$changepoints.t[i], i] <- 1
  282. }
  283. return(A)
  284. }
  285. #' Provides Fourier series components with the specified frequency and order.
  286. #'
  287. #' @param dates Vector of dates.
  288. #' @param period Number of days of the period.
  289. #' @param series.order Number of components.
  290. #'
  291. #' @return Matrix with seasonality features.
  292. #'
  293. fourier_series <- function(dates, period, series.order) {
  294. t <- dates - zoo::as.Date('1970-01-01')
  295. features <- matrix(0, length(t), 2 * series.order)
  296. for (i in 1:series.order) {
  297. x <- as.numeric(2 * i * pi * t / period)
  298. features[, i * 2 - 1] <- sin(x)
  299. features[, i * 2] <- cos(x)
  300. }
  301. return(features)
  302. }
  303. #' Data frame with seasonality features.
  304. #'
  305. #' @param dates Vector of dates.
  306. #' @param period Number of days of the period.
  307. #' @param series.order Number of components.
  308. #' @param prefix Column name prefix.
  309. #'
  310. #' @return Dataframe with seasonality.
  311. #'
  312. make_seasonality_features <- function(dates, period, series.order, prefix) {
  313. features <- fourier_series(dates, period, series.order)
  314. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
  315. return(data.frame(features))
  316. }
  317. #' Construct a matrix of holiday features.
  318. #'
  319. #' @param m Prophet object.
  320. #' @param dates Vector with dates used for computing seasonality.
  321. #'
  322. #' @return A dataframe with a column for each holiday.
  323. #'
  324. #' @importFrom dplyr "%>%"
  325. make_holiday_features <- function(m, dates) {
  326. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  327. wide <- m$holidays %>%
  328. dplyr::mutate(ds = zoo::as.Date(ds)) %>%
  329. dplyr::group_by(holiday, ds) %>%
  330. dplyr::filter(row_number() == 1) %>%
  331. dplyr::do({
  332. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  333. && !is.na(.$upper_window)) {
  334. offsets <- seq(.$lower_window, .$upper_window)
  335. } else {
  336. offsets <- c(0)
  337. }
  338. names <- paste(
  339. .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  340. dplyr::data_frame(ds = .$ds + offsets, holiday = names)
  341. }) %>%
  342. dplyr::mutate(x = scale.ratio) %>%
  343. tidyr::spread(holiday, x, fill = 0)
  344. holiday.mat <- data.frame(ds = dates) %>%
  345. dplyr::left_join(wide, by = 'ds') %>%
  346. dplyr::select(-ds)
  347. holiday.mat[is.na(holiday.mat)] <- 0
  348. return(holiday.mat)
  349. }
  350. #' Dataframe with seasonality features.
  351. #'
  352. #' @param m Prophet object.
  353. #' @param df Dataframe with dates for computing seasonality features.
  354. #'
  355. #' @return Dataframe with seasonality.
  356. #'
  357. make_all_seasonality_features <- function(m, df) {
  358. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  359. if (m$yearly.seasonality) {
  360. seasonal.features <- cbind(
  361. seasonal.features,
  362. make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
  363. }
  364. if (m$weekly.seasonality) {
  365. seasonal.features <- cbind(
  366. seasonal.features,
  367. make_seasonality_features(df$ds, 7, 3, 'weekly'))
  368. }
  369. if(!is.null(m$holidays)) {
  370. # A smaller prior scale will shrink holiday estimates more than seasonality
  371. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  372. seasonal.features <- cbind(
  373. seasonal.features,
  374. make_holiday_features(m, df$ds))
  375. }
  376. return(seasonal.features)
  377. }
  378. #' Set seasonalities that were left on auto.
  379. #'
  380. #' Turns on yearly seasonality if there is >=2 years of history.
  381. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  382. #' spacing between dates in the history is <7 days.
  383. #'
  384. #' @param m Prophet object.
  385. #'
  386. #' @return The prophet model with seasonalities set.
  387. #'
  388. set_auto_seasonalities <- function(m) {
  389. first <- min(m$history$ds)
  390. last <- max(m$history$ds)
  391. if (m$yearly.seasonality == 'auto') {
  392. if (last - first < 730) {
  393. warning('Disabling yearly seasonality. ',
  394. 'Run prophet with `yearly.seasonality=TRUE` to override this.')
  395. m$yearly.seasonality <- FALSE
  396. } else {
  397. m$yearly.seasonality <- TRUE
  398. }
  399. }
  400. if (m$weekly.seasonality == 'auto') {
  401. dt <- diff(m$history$ds)
  402. min.dt <- min(dt[dt > 0])
  403. if ((last - first < 14) || (min.dt >= 7)) {
  404. warning('Disabling weekly seasonality. ',
  405. 'Run prophet with `weekly.seasonality=TRUE` to override this.')
  406. m$weekly.seasonality <- FALSE
  407. } else {
  408. m$weekly.seasonality <- TRUE
  409. }
  410. }
  411. return(m)
  412. }
  413. #' Initialize linear growth.
  414. #'
  415. #' Provides a strong initialization for linear growth by calculating the
  416. #' growth and offset parameters that pass the function through the first and
  417. #' last points in the time series.
  418. #'
  419. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  420. #' and t (scaled time).
  421. #'
  422. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  423. #' growth function.
  424. #'
  425. linear_growth_init <- function(df) {
  426. i0 <- which.min(df$ds)
  427. i1 <- which.max(df$ds)
  428. T <- df$t[i1] - df$t[i0]
  429. # Initialize the rate
  430. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  431. # And the offset
  432. m <- df$y_scaled[i0] - k * df$t[i0]
  433. return(c(k, m))
  434. }
  435. #' Initialize logistic growth.
  436. #'
  437. #' Provides a strong initialization for logistic growth by calculating the
  438. #' growth and offset parameters that pass the function through the first and
  439. #' last points in the time series.
  440. #'
  441. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  442. #' y_scaled (scaled time series), and t (scaled time).
  443. #'
  444. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  445. #' growth function.
  446. #'
  447. logistic_growth_init <- function(df) {
  448. i0 <- which.min(df$ds)
  449. i1 <- which.max(df$ds)
  450. T <- df$t[i1] - df$t[i0]
  451. # Force valid values, in case y > cap.
  452. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  453. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  454. if (abs(r0 - r1) <= 0.01) {
  455. r0 <- 1.05 * r0
  456. }
  457. L0 <- log(r0 - 1)
  458. L1 <- log(r1 - 1)
  459. # Initialize the offset
  460. m <- L0 * T / (L0 - L1)
  461. # And the rate
  462. k <- L0 / m
  463. return(c(k, m))
  464. }
  465. #' Fit the prophet model.
  466. #'
  467. #' This sets m$params to contain the fitted model parameters. It is a list
  468. #' with the following elements:
  469. #' k (M array): M posterior samples of the initial slope.
  470. #' m (M array): The initial intercept.
  471. #' delta (MxN matrix): The slope change at each of N changepoints.
  472. #' beta (MxK matrix): Coefficients for K seasonality features.
  473. #' sigma_obs (M array): Noise level.
  474. #' Note that M=1 if MAP estimation.
  475. #'
  476. #' @param m Prophet object.
  477. #' @param df Data frame.
  478. #' @param ... Additional arguments passed to the \code{optimizing} or
  479. #' \code{sampling} functions in Stan.
  480. #'
  481. #' @export
  482. fit.prophet <- function(m, df, ...) {
  483. if (!is.null(m$history)) {
  484. stop("Prophet object can only be fit once. Instantiate a new object.")
  485. }
  486. history <- df %>%
  487. dplyr::filter(!is.na(y))
  488. if (any(is.infinite(history$y))) {
  489. stop("Found infinity in column y.")
  490. }
  491. m$history.dates <- sort(zoo::as.Date(df$ds))
  492. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  493. history <- out$df
  494. m <- out$m
  495. m$history <- history
  496. m <- set_auto_seasonalities(m)
  497. seasonal.features <- make_all_seasonality_features(m, history)
  498. m <- set_changepoints(m)
  499. A <- get_changepoint_matrix(m)
  500. # Construct input to stan
  501. dat <- list(
  502. T = nrow(history),
  503. K = ncol(seasonal.features),
  504. S = length(m$changepoints.t),
  505. y = history$y_scaled,
  506. t = history$t,
  507. A = A,
  508. t_change = array(m$changepoints.t),
  509. X = as.matrix(seasonal.features),
  510. sigma = m$seasonality.prior.scale,
  511. tau = m$changepoint.prior.scale
  512. )
  513. # Run stan
  514. if (m$growth == 'linear') {
  515. kinit <- linear_growth_init(history)
  516. } else {
  517. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  518. kinit <- logistic_growth_init(history)
  519. }
  520. if (exists(".prophet.stan.models")) {
  521. model <- .prophet.stan.models[[m$growth]]
  522. } else {
  523. model <- get_prophet_stan_model(m$growth)
  524. }
  525. stan_init <- function() {
  526. list(k = kinit[1],
  527. m = kinit[2],
  528. delta = array(rep(0, length(m$changepoints.t))),
  529. beta = array(rep(0, ncol(seasonal.features))),
  530. sigma_obs = 1
  531. )
  532. }
  533. if (m$mcmc.samples > 0) {
  534. stan.fit <- rstan::sampling(
  535. model,
  536. data = dat,
  537. init = stan_init,
  538. iter = m$mcmc.samples,
  539. ...
  540. )
  541. m$params <- rstan::extract(stan.fit)
  542. n.iteration <- length(m$params$k)
  543. } else {
  544. stan.fit <- rstan::optimizing(
  545. model,
  546. data = dat,
  547. init = stan_init,
  548. iter = 1e4,
  549. as_vector = FALSE,
  550. ...
  551. )
  552. m$params <- stan.fit$par
  553. n.iteration <- 1
  554. }
  555. # Cast the parameters to have consistent form, whether full bayes or MAP
  556. for (name in c('delta', 'beta')){
  557. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  558. }
  559. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  560. for (name in c('k', 'm', 'sigma_obs')){
  561. m$params[[name]] <- c(m$params[[name]])
  562. }
  563. # If no changepoints were requested, replace delta with 0s
  564. if (m$n.changepoints == 0) {
  565. # Fold delta into the base rate k
  566. m$params$k <- m$params$k + m$params$delta[, 1]
  567. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  568. }
  569. return(m)
  570. }
  571. #' Predict using the prophet model.
  572. #'
  573. #' @param object Prophet object.
  574. #' @param df Dataframe with dates for predictions (column ds), and capacity
  575. #' (column cap) if logistic growth. If not provided, predictions are made on
  576. #' the history.
  577. #' @param ... additional arguments.
  578. #'
  579. #' @return A dataframe with the forecast components.
  580. #'
  581. #' @examples
  582. #' \dontrun{
  583. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  584. #' y = sin(1:366/200) + rnorm(366)/10)
  585. #' m <- prophet(history)
  586. #' future <- make_future_dataframe(m, periods = 365)
  587. #' forecast <- predict(m, future)
  588. #' plot(m, forecast)
  589. #' }
  590. #'
  591. #' @export
  592. predict.prophet <- function(object, df = NULL, ...) {
  593. if (is.null(df)) {
  594. df <- object$history
  595. } else {
  596. out <- setup_dataframe(object, df)
  597. df <- out$df
  598. }
  599. df$trend <- predict_trend(object, df)
  600. df <- df %>%
  601. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  602. dplyr::bind_cols(predict_seasonal_components(object, df))
  603. df$yhat <- df$trend + df$seasonal
  604. return(df)
  605. }
  606. #' Evaluate the piecewise linear function.
  607. #'
  608. #' @param t Vector of times on which the function is evaluated.
  609. #' @param deltas Vector of rate changes at each changepoint.
  610. #' @param k Float initial rate.
  611. #' @param m Float initial offset.
  612. #' @param changepoint.ts Vector of changepoint times.
  613. #'
  614. #' @return Vector y(t).
  615. #'
  616. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  617. # Intercept changes
  618. gammas <- -changepoint.ts * deltas
  619. # Get cumulative slope and intercept at each t
  620. k_t <- rep(k, length(t))
  621. m_t <- rep(m, length(t))
  622. for (s in 1:length(changepoint.ts)) {
  623. indx <- t >= changepoint.ts[s]
  624. k_t[indx] <- k_t[indx] + deltas[s]
  625. m_t[indx] <- m_t[indx] + gammas[s]
  626. }
  627. y <- k_t * t + m_t
  628. return(y)
  629. }
  630. #' Evaluate the piecewise logistic function.
  631. #'
  632. #' @param t Vector of times on which the function is evaluated.
  633. #' @param cap Vector of capacities at each t.
  634. #' @param deltas Vector of rate changes at each changepoint.
  635. #' @param k Float initial rate.
  636. #' @param m Float initial offset.
  637. #' @param changepoint.ts Vector of changepoint times.
  638. #'
  639. #' @return Vector y(t).
  640. #'
  641. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  642. # Compute offset changes
  643. k.cum <- c(k, cumsum(deltas) + k)
  644. gammas <- rep(0, length(changepoint.ts))
  645. for (i in 1:length(changepoint.ts)) {
  646. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  647. * (1 - k.cum[i] / k.cum[i + 1]))
  648. }
  649. # Get cumulative rate and offset at each t
  650. k_t <- rep(k, length(t))
  651. m_t <- rep(m, length(t))
  652. for (s in 1:length(changepoint.ts)) {
  653. indx <- t >= changepoint.ts[s]
  654. k_t[indx] <- k_t[indx] + deltas[s]
  655. m_t[indx] <- m_t[indx] + gammas[s]
  656. }
  657. y <- cap / (1 + exp(-k_t * (t - m_t)))
  658. return(y)
  659. }
  660. #' Predict trend using the prophet model.
  661. #'
  662. #' @param model Prophet object.
  663. #' @param df Prediction dataframe.
  664. #'
  665. #' @return Vector with trend on prediction dates.
  666. #'
  667. predict_trend <- function(model, df) {
  668. k <- mean(model$params$k, na.rm = TRUE)
  669. param.m <- mean(model$params$m, na.rm = TRUE)
  670. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  671. t <- df$t
  672. if (model$growth == 'linear') {
  673. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  674. } else {
  675. cap <- df$cap_scaled
  676. trend <- piecewise_logistic(
  677. t, cap, deltas, k, param.m, model$changepoints.t)
  678. }
  679. return(trend * model$y.scale)
  680. }
  681. #' Predict seasonality broken down into components.
  682. #'
  683. #' @param m Prophet object.
  684. #' @param df Prediction dataframe.
  685. #'
  686. #' @return Dataframe with seasonal components.
  687. #'
  688. predict_seasonal_components <- function(m, df) {
  689. seasonal.features <- make_all_seasonality_features(m, df)
  690. lower.p <- (1 - m$interval.width)/2
  691. upper.p <- (1 + m$interval.width)/2
  692. # Broken down into components
  693. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  694. dplyr::mutate(col = 1:n()) %>%
  695. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  696. extra = "merge", fill = "right") %>%
  697. dplyr::filter(component != 'zeros')
  698. if (nrow(components) > 0) {
  699. component.predictions <- components %>%
  700. dplyr::group_by(component) %>% dplyr::do({
  701. comp <- (as.matrix(seasonal.features[, .$col])
  702. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  703. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  704. mean = rowMeans(comp, na.rm = TRUE),
  705. lower = apply(comp, 1, stats::quantile, lower.p,
  706. na.rm = TRUE),
  707. upper = apply(comp, 1, stats::quantile, upper.p,
  708. na.rm = TRUE))
  709. }) %>%
  710. tidyr::gather(stat, value, mean, lower, upper) %>%
  711. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  712. tidyr::unite(component, component, stat, sep="") %>%
  713. tidyr::spread(component, value) %>%
  714. dplyr::select(-ix)
  715. component.predictions$seasonal <- rowSums(
  716. component.predictions[unique(components$component)])
  717. } else {
  718. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  719. }
  720. return(component.predictions)
  721. }
  722. #' Prophet uncertainty intervals.
  723. #'
  724. #' @param m Prophet object.
  725. #' @param df Prediction dataframe.
  726. #'
  727. #' @return Dataframe with uncertainty intervals.
  728. #'
  729. predict_uncertainty <- function(m, df) {
  730. # Sample trend, seasonality, and yhat from the extrapolation model.
  731. n.iterations <- length(m$params$k)
  732. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  733. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  734. seasonal.features <- make_all_seasonality_features(m, df)
  735. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  736. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  737. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  738. for (i in 1:n.iterations) {
  739. # For each set of parameters from MCMC (or just 1 set for MAP),
  740. for (j in 1:samp.per.iter) {
  741. # Do a simulation with this set of parameters,
  742. sim <- sample_model(m, df, seasonal.features, i)
  743. # Store the results
  744. for (key in c("trend", "seasonal", "yhat")) {
  745. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  746. }
  747. }
  748. }
  749. # Add uncertainty estimates
  750. lower.p <- (1 - m$interval.width)/2
  751. upper.p <- (1 + m$interval.width)/2
  752. intervals <- cbind(
  753. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  754. na.rm = TRUE)),
  755. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  756. na.rm = TRUE)),
  757. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  758. na.rm = TRUE))
  759. ) %>% dplyr::as_data_frame()
  760. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  761. c('lower', 'upper'), sep = "_")
  762. return(intervals)
  763. }
  764. #' Simulate observations from the extrapolated generative model.
  765. #'
  766. #' @param m Prophet object.
  767. #' @param df Prediction dataframe.
  768. #' @param seasonal.features Data frame of seasonal features
  769. #' @param iteration Int sampling iteration to use parameters from.
  770. #'
  771. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  772. #'
  773. sample_model <- function(m, df, seasonal.features, iteration) {
  774. trend <- sample_predictive_trend(m, df, iteration)
  775. beta <- m$params$beta[iteration,]
  776. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  777. sigma <- m$params$sigma_obs[iteration]
  778. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  779. return(list("yhat" = trend + seasonal + noise,
  780. "trend" = trend,
  781. "seasonal" = seasonal))
  782. }
  783. #' Simulate the trend using the extrapolated generative model.
  784. #'
  785. #' @param model Prophet object.
  786. #' @param df Prediction dataframe.
  787. #' @param iteration Int sampling iteration to use parameters from.
  788. #'
  789. #' @return Vector of simulated trend over df$t.
  790. #'
  791. sample_predictive_trend <- function(model, df, iteration) {
  792. k <- model$params$k[iteration]
  793. param.m <- model$params$m[iteration]
  794. deltas <- model$params$delta[iteration,]
  795. t <- df$t
  796. T <- max(t)
  797. if (T > 1) {
  798. # Get the time discretization of the history
  799. dt <- diff(model$history$t)
  800. dt <- min(dt[dt > 0])
  801. # Number of time periods in the future
  802. N <- ceiling((T - 1) / dt)
  803. S <- length(model$changepoints.t)
  804. # The history had S split points, over t = [0, 1].
  805. # The forecast is on [1, T], and should have the same average frequency of
  806. # rate changes. Thus for N time periods in the future, we want an average
  807. # of S * (T - 1) changepoints in expectation.
  808. prob.change <- min(1, (S * (T - 1)) / N)
  809. # This calculation works for both history and df not uniformly spaced.
  810. n.changes <- stats::rbinom(1, N, prob.change)
  811. # Sample ts
  812. if (n.changes == 0) {
  813. changepoint.ts.new <- c()
  814. } else {
  815. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  816. }
  817. } else {
  818. changepoint.ts.new <- c()
  819. n.changes <- 0
  820. }
  821. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  822. lambda <- mean(abs(c(deltas))) + 1e-8
  823. # Sample deltas
  824. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  825. # Combine with changepoints from the history
  826. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  827. deltas <- c(deltas, deltas.new)
  828. # Get the corresponding trend
  829. if (model$growth == 'linear') {
  830. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  831. } else {
  832. cap <- df$cap_scaled
  833. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  834. }
  835. return(trend * model$y.scale)
  836. }
  837. #' Make dataframe with future dates for forecasting.
  838. #'
  839. #' @param m Prophet model object.
  840. #' @param periods Int number of periods to forecast forward.
  841. #' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
  842. #' @param include_history Boolean to include the historical dates in the data
  843. #' frame for predictions.
  844. #'
  845. #' @return Dataframe that extends forward from the end of m$history for the
  846. #' requested number of periods.
  847. #'
  848. #' @export
  849. make_future_dataframe <- function(m, periods, freq = 'd',
  850. include_history = TRUE) {
  851. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  852. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  853. if (include_history) {
  854. dates <- c(m$history.dates, dates)
  855. }
  856. return(data.frame(ds = dates))
  857. }
  858. #' Merge history and forecast for plotting.
  859. #'
  860. #' @param m Prophet object.
  861. #' @param fcst Data frame returned by prophet predict.
  862. #'
  863. #' @importFrom dplyr "%>%"
  864. df_for_plotting <- function(m, fcst) {
  865. # Make sure there is no y in fcst
  866. fcst$y <- NULL
  867. df <- m$history %>%
  868. dplyr::select(ds, y) %>%
  869. dplyr::full_join(fcst, by = "ds") %>%
  870. dplyr::arrange(ds)
  871. return(df)
  872. }
  873. #' Plot the prophet forecast.
  874. #'
  875. #' @param x Prophet object.
  876. #' @param fcst Data frame returned by predict(m, df).
  877. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  878. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  879. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  880. #' figure, if available.
  881. #' @param xlabel Optional label for x-axis
  882. #' @param ylabel Optional label for y-axis
  883. #' @param ... additional arguments
  884. #'
  885. #' @return A ggplot2 plot.
  886. #'
  887. #' @examples
  888. #' \dontrun{
  889. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  890. #' y = sin(1:366/200) + rnorm(366)/10)
  891. #' m <- prophet(history)
  892. #' future <- make_future_dataframe(m, periods = 365)
  893. #' forecast <- predict(m, future)
  894. #' plot(m, forecast)
  895. #' }
  896. #'
  897. #' @export
  898. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  899. xlabel = 'ds', ylabel = 'y', ...) {
  900. df <- df_for_plotting(x, fcst)
  901. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  902. ggplot2::labs(x = xlabel, y = ylabel)
  903. if (exists('cap', where = df) && plot_cap) {
  904. gg <- gg + ggplot2::geom_line(
  905. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  906. }
  907. if (uncertainty && exists('yhat_lower', where = df)) {
  908. gg <- gg +
  909. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  910. alpha = 0.2,
  911. fill = "#0072B2",
  912. na.rm = TRUE)
  913. }
  914. gg <- gg +
  915. ggplot2::geom_point(na.rm=TRUE) +
  916. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  917. na.rm = TRUE) +
  918. ggplot2::theme(aspect.ratio = 3 / 5)
  919. return(gg)
  920. }
  921. #' Plot the components of a prophet forecast.
  922. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  923. #' present, and holidays if present.
  924. #'
  925. #' @param m Prophet object.
  926. #' @param fcst Data frame returned by predict(m, df).
  927. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  928. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  929. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  930. #' figure, if available.
  931. #' @param weekly_start Integer specifying the start day of the weekly
  932. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  933. #' to Monday, and so on.
  934. #' @param yearly_start Integer specifying the start day of the yearly
  935. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  936. #' to Jan 2, and so on.
  937. #'
  938. #' @return Invisibly return a list containing the plotted ggplot objects
  939. #'
  940. #' @export
  941. #' @importFrom dplyr "%>%"
  942. prophet_plot_components <- function(
  943. m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
  944. yearly_start = 0) {
  945. df <- df_for_plotting(m, fcst)
  946. # Plot the trend
  947. panels <- list(plot_trend(df, uncertainty, plot_cap))
  948. # Plot holiday components, if present.
  949. if (!is.null(m$holidays)) {
  950. panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
  951. }
  952. # Plot weekly seasonality, if present
  953. if ("weekly" %in% colnames(df)) {
  954. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
  955. }
  956. # Plot yearly seasonality, if present
  957. if ("yearly" %in% colnames(df)) {
  958. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
  959. }
  960. # Make the plot.
  961. grid::grid.newpage()
  962. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  963. 1)))
  964. for (i in 1:length(panels)) {
  965. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  966. layout.pos.col = 1))
  967. }
  968. return(invisible(panels))
  969. }
  970. #' Plot the prophet trend.
  971. #'
  972. #' @param df Forecast dataframe for plotting.
  973. #' @param uncertainty Boolean to plot uncertainty intervals.
  974. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  975. #' figure, if available.
  976. #'
  977. #' @return A ggplot2 plot.
  978. plot_trend <- function(df, uncertainty = TRUE, plot_cap = TRUE) {
  979. df.t <- df[!is.na(df$trend),]
  980. gg.trend <- ggplot2::ggplot(df.t, ggplot2::aes(x = ds, y = trend)) +
  981. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  982. if (exists('cap', where = df.t) && plot_cap) {
  983. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  984. linetype = 'dashed',
  985. na.rm = TRUE)
  986. }
  987. if (uncertainty) {
  988. gg.trend <- gg.trend +
  989. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  990. ymax = trend_upper),
  991. alpha = 0.2,
  992. fill = "#0072B2",
  993. na.rm = TRUE)
  994. }
  995. return(gg.trend)
  996. }
  997. #' Plot the holidays component of the forecast.
  998. #'
  999. #' @param m Prophet model
  1000. #' @param df Forecast dataframe for plotting.
  1001. #' @param uncertainty Boolean to plot uncertainty intervals.
  1002. #'
  1003. #' @return A ggplot2 plot.
  1004. plot_holidays <- function(m, df, uncertainty = TRUE) {
  1005. holiday.comps <- unique(m$holidays$holiday) %>% as.character()
  1006. df.s <- data.frame(ds = df$ds,
  1007. holidays = rowSums(df[, holiday.comps, drop = FALSE]),
  1008. holidays_lower = rowSums(df[, paste0(holiday.comps,
  1009. "_lower"), drop = FALSE]),
  1010. holidays_upper = rowSums(df[, paste0(holiday.comps,
  1011. "_upper"), drop = FALSE]))
  1012. df.s <- df.s[!is.na(df.s$holidays),]
  1013. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  1014. # Since it is just for the visualization we will not worry about it now.
  1015. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  1016. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1017. if (uncertainty) {
  1018. gg.holidays <- gg.holidays +
  1019. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  1020. ymax = holidays_upper),
  1021. alpha = 0.2,
  1022. fill = "#0072B2",
  1023. na.rm = TRUE)
  1024. }
  1025. return(gg.holidays)
  1026. }
  1027. #' Plot the weekly component of the forecast.
  1028. #'
  1029. #' @param m Prophet model object
  1030. #' @param uncertainty Boolean to plot uncertainty intervals.
  1031. #' @param weekly_start Integer specifying the start day of the weekly
  1032. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  1033. #' to Monday, and so on.
  1034. #'
  1035. #' @return A ggplot2 plot.
  1036. plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
  1037. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  1038. df.w <- data.frame(
  1039. ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7) +
  1040. weekly_start, cap=1.)
  1041. df.w <- setup_dataframe(m, df.w)$df
  1042. seas <- predict_seasonal_components(m, df.w)
  1043. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  1044. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  1045. group = 1)) +
  1046. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1047. ggplot2::labs(x = "Day of week")
  1048. if (uncertainty) {
  1049. gg.weekly <- gg.weekly +
  1050. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  1051. ymax = weekly_upper),
  1052. alpha = 0.2,
  1053. fill = "#0072B2",
  1054. na.rm = TRUE)
  1055. }
  1056. return(gg.weekly)
  1057. }
  1058. #' Plot the yearly component of the forecast.
  1059. #'
  1060. #' @param m Prophet model object.
  1061. #' @param uncertainty Boolean to plot uncertainty intervals.
  1062. #' @param yearly_start Integer specifying the start day of the yearly
  1063. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  1064. #' to Jan 2, and so on.
  1065. #'
  1066. #' @return A ggplot2 plot.
  1067. plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
  1068. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  1069. df.y <- data.frame(
  1070. ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365) +
  1071. yearly_start, cap=1.)
  1072. df.y <- setup_dataframe(m, df.y)$df
  1073. seas <- predict_seasonal_components(m, df.y)
  1074. seas$ds <- df.y$ds
  1075. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  1076. group = 1)) +
  1077. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1078. ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
  1079. ggplot2::labs(x = "Day of year")
  1080. if (uncertainty) {
  1081. gg.yearly <- gg.yearly +
  1082. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  1083. ymax = yearly_upper),
  1084. alpha = 0.2,
  1085. fill = "#0072B2",
  1086. na.rm = TRUE)
  1087. }
  1088. return(gg.yearly)
  1089. }
  1090. # fb-block 3