prophet.R 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df (optional) Dataframe containing the history. Must have columns ds
  16. #' (date type) and y, the time series. If growth is logistic, then df must
  17. #' also have a column cap that specifies the capacity at each ds. If not
  18. #' provided, then the model object will be instantiated but not fit; use
  19. #' fit.prophet(m, df) to fit the model.
  20. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  21. #' trend.
  22. #' @param changepoints Vector of dates at which to include potential
  23. #' changepoints. If not specified, potential changepoints are selected
  24. #' automatically.
  25. #' @param n.changepoints Number of potential changepoints to include. Not used
  26. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  27. #' then n.changepoints potential changepoints are selected uniformly from the
  28. #' first 80 percent of df$ds.
  29. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
  30. #' FALSE, or a number of Fourier terms to generate.
  31. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
  32. #' FALSE, or a number of Fourier terms to generate.
  33. #' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
  34. #' FALSE, or a number of Fourier terms to generate.
  35. #' @param holidays data frame with columns holiday (character) and ds (date
  36. #' type)and optionally columns lower_window and upper_window which specify a
  37. #' range of days around the date to be included as holidays. lower_window=-2
  38. #' will include 2 days prior to the date as holidays.
  39. #' @param seasonality.prior.scale Parameter modulating the strength of the
  40. #' seasonality model. Larger values allow the model to fit larger seasonal
  41. #' fluctuations, smaller values dampen the seasonality.
  42. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  43. #' components model.
  44. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  45. #' automatic changepoint selection. Large values will allow many changepoints,
  46. #' small values will allow few changepoints.
  47. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  48. #' inference with the specified number of MCMC samples. If 0, will do MAP
  49. #' estimation.
  50. #' @param interval.width Numeric, width of the uncertainty intervals provided
  51. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  52. #' in the trend using the MAP estimate of the extrapolated generative model.
  53. #' If mcmc.samples>0, this will be integrated over all model parameters,
  54. #' which will include uncertainty in seasonality.
  55. #' @param uncertainty.samples Number of simulated draws used to estimate
  56. #' uncertainty intervals.
  57. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  58. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  59. #'
  60. #' @return A prophet model.
  61. #'
  62. #' @examples
  63. #' \dontrun{
  64. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  65. #' y = sin(1:366/200) + rnorm(366)/10)
  66. #' m <- prophet(history)
  67. #' }
  68. #'
  69. #' @export
  70. #' @importFrom dplyr "%>%"
  71. #' @import Rcpp
  72. prophet <- function(df = NULL,
  73. growth = 'linear',
  74. changepoints = NULL,
  75. n.changepoints = 25,
  76. yearly.seasonality = 'auto',
  77. weekly.seasonality = 'auto',
  78. daily.seasonality = 'auto',
  79. holidays = NULL,
  80. seasonality.prior.scale = 10,
  81. holidays.prior.scale = 10,
  82. changepoint.prior.scale = 0.05,
  83. mcmc.samples = 0,
  84. interval.width = 0.80,
  85. uncertainty.samples = 1000,
  86. fit = TRUE,
  87. ...
  88. ) {
  89. # fb-block 1
  90. if (!is.null(changepoints)) {
  91. n.changepoints <- length(changepoints)
  92. }
  93. m <- list(
  94. growth = growth,
  95. changepoints = changepoints,
  96. n.changepoints = n.changepoints,
  97. yearly.seasonality = yearly.seasonality,
  98. weekly.seasonality = weekly.seasonality,
  99. daily.seasonality = daily.seasonality,
  100. holidays = holidays,
  101. seasonality.prior.scale = seasonality.prior.scale,
  102. changepoint.prior.scale = changepoint.prior.scale,
  103. holidays.prior.scale = holidays.prior.scale,
  104. mcmc.samples = mcmc.samples,
  105. interval.width = interval.width,
  106. uncertainty.samples = uncertainty.samples,
  107. specified.changepoints = !is.null(changepoints),
  108. start = NULL, # This and following attributes are set during fitting
  109. y.scale = NULL,
  110. t.scale = NULL,
  111. changepoints.t = NULL,
  112. seasonalities = list(),
  113. stan.fit = NULL,
  114. params = list(),
  115. history = NULL,
  116. history.dates = NULL
  117. )
  118. validate_inputs(m)
  119. class(m) <- append("prophet", class(m))
  120. if ((fit) && (!is.null(df))) {
  121. m <- fit.prophet(m, df, ...)
  122. }
  123. # fb-block 2
  124. return(m)
  125. }
  126. #' Validates the inputs to Prophet.
  127. #'
  128. #' @param m Prophet object.
  129. #'
  130. #' @keywords internal
  131. validate_inputs <- function(m) {
  132. if (!(m$growth %in% c('linear', 'logistic'))) {
  133. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  134. }
  135. if (!is.null(m$holidays)) {
  136. if (!(exists('holiday', where = m$holidays))) {
  137. stop('Holidays dataframe must have holiday field.')
  138. }
  139. if (!(exists('ds', where = m$holidays))) {
  140. stop('Holidays dataframe must have ds field.')
  141. }
  142. has.lower <- exists('lower_window', where = m$holidays)
  143. has.upper <- exists('upper_window', where = m$holidays)
  144. if (has.lower + has.upper == 1) {
  145. stop(paste('Holidays must have both lower_window and upper_window,',
  146. 'or neither.'))
  147. }
  148. if (has.lower) {
  149. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  150. stop('Holiday lower_window should be <= 0')
  151. }
  152. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  153. stop('Holiday upper_window should be >= 0')
  154. }
  155. }
  156. for (h in unique(m$holidays$holiday)) {
  157. if (grepl("_delim_", h)) {
  158. stop('Holiday name cannot contain "_delim_"')
  159. }
  160. if (h %in% c('zeros', 'yearly', 'weekly', 'daily', 'yhat', 'seasonal',
  161. 'trend')) {
  162. stop(paste0('Holiday name "', h, '" reserved.'))
  163. }
  164. }
  165. }
  166. }
  167. #' Load compiled Stan model
  168. #'
  169. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  170. #' trend.
  171. #'
  172. #' @return Stan model.
  173. #'
  174. #' @keywords internal
  175. get_prophet_stan_model <- function(model) {
  176. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  177. ## If the cached model doesn't work, just compile a new one.
  178. tryCatch({
  179. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  180. package = 'prophet',
  181. mustWork = TRUE)
  182. load(binary)
  183. obj.name <- paste(model, 'growth.stanm', sep = '.')
  184. stanm <- eval(parse(text = obj.name))
  185. ## Should cause an error if the model doesn't work.
  186. stanm@mk_cppmodule(stanm)
  187. stanm
  188. }, error = function(cond) {
  189. compile_stan_model(model)
  190. })
  191. }
  192. #' Compile Stan model
  193. #'
  194. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  195. #' trend.
  196. #'
  197. #' @return Stan model.
  198. #'
  199. #' @keywords internal
  200. compile_stan_model <- function(model) {
  201. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  202. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  203. stanc <- rstan::stanc(stan.src)
  204. model.name <- paste(model, 'growth', sep = '_')
  205. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  206. }
  207. #' Convert date vector
  208. #'
  209. #' Convert the date to POSIXct object
  210. #'
  211. #' @param ds Date vector, can be consisted of characters
  212. #' @param tz string time zone
  213. #'
  214. #' @return vector of POSIXct object converted from date
  215. #'
  216. #' @keywords internal
  217. set_date <- function(ds = NULL, tz = "GMT") {
  218. if (length(ds) == 0) {
  219. return(NULL)
  220. }
  221. if (is.factor(ds)) {
  222. ds <- as.character(ds)
  223. }
  224. if (min(nchar(ds)) < 12) {
  225. ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
  226. } else {
  227. ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
  228. }
  229. attr(ds, "tzone") <- tz
  230. return(ds)
  231. }
  232. #' Time difference between datetimes
  233. #'
  234. #' Compute time difference of two POSIXct objects
  235. #'
  236. #' @param ds1 POSIXct object
  237. #' @param ds2 POSIXct object
  238. #' @param units string units of difference, e.g. 'days' or 'secs'.
  239. #'
  240. #' @return numeric time difference
  241. #'
  242. #' @keywords internal
  243. time_diff <- function(ds1, ds2, units = "days") {
  244. return(as.numeric(difftime(ds1, ds2, units = units)))
  245. }
  246. #' Prepare dataframe for fitting or predicting.
  247. #'
  248. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  249. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  250. #' and predicting.
  251. #'
  252. #' @param m Prophet object.
  253. #' @param df Data frame with columns ds, y, and cap if logistic growth.
  254. #' @param initialize_scales Boolean set scaling factors in m from df.
  255. #'
  256. #' @return list with items 'df' and 'm'.
  257. #'
  258. #' @keywords internal
  259. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  260. if (exists('y', where=df)) {
  261. df$y <- as.numeric(df$y)
  262. }
  263. df$ds <- set_date(df$ds)
  264. if (anyNA(df$ds)) {
  265. stop(paste('Unable to parse date format in column ds. Convert to date ',
  266. 'format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S'))
  267. }
  268. df <- df %>%
  269. dplyr::arrange(ds)
  270. if (initialize_scales) {
  271. m$y.scale <- max(abs(df$y))
  272. if (m$y.scale == 0) {
  273. m$y.scale <- 1
  274. }
  275. m$start <- min(df$ds)
  276. m$t.scale <- time_diff(max(df$ds), m$start, "secs")
  277. }
  278. df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  279. if (exists('y', where=df)) {
  280. df$y_scaled <- df$y / m$y.scale
  281. }
  282. if (m$growth == 'logistic') {
  283. if (!(exists('cap', where=df))) {
  284. stop('Capacities must be supplied for logistic growth.')
  285. }
  286. df <- df %>%
  287. dplyr::mutate(cap_scaled = cap / m$y.scale)
  288. }
  289. return(list("m" = m, "df" = df))
  290. }
  291. #' Set changepoints
  292. #'
  293. #' Sets m$changepoints to the dates of changepoints. Either:
  294. #' 1) The changepoints were passed in explicitly.
  295. #' A) They are empty.
  296. #' B) They are not empty, and need validation.
  297. #' 2) We are generating a grid of them.
  298. #' 3) The user prefers no changepoints be used.
  299. #'
  300. #' @param m Prophet object.
  301. #'
  302. #' @return m with changepoints set.
  303. #'
  304. #' @keywords internal
  305. set_changepoints <- function(m) {
  306. if (!is.null(m$changepoints)) {
  307. if (length(m$changepoints) > 0) {
  308. m$changepoints <- set_date(m$changepoints)
  309. if (min(m$changepoints) < min(m$history$ds)
  310. || max(m$changepoints) > max(m$history$ds)) {
  311. stop('Changepoints must fall within training data.')
  312. }
  313. }
  314. } else {
  315. # Place potential changepoints evenly through the first 80 pcnt of
  316. # the history.
  317. hist.size <- floor(nrow(m$history) * .8)
  318. if (m$n.changepoints + 1 > hist.size) {
  319. m$n.changepoints <- hist.size - 1
  320. warning('n.changepoints greater than number of observations. Using ',
  321. m$n.changepoints)
  322. }
  323. if (m$n.changepoints > 0) {
  324. cp.indexes <- round(seq.int(1, hist.size,
  325. length.out = (m$n.changepoints + 1))[-1])
  326. m$changepoints <- m$history$ds[cp.indexes]
  327. } else {
  328. m$changepoints <- c()
  329. }
  330. }
  331. if (length(m$changepoints) > 0) {
  332. m$changepoints.t <- sort(
  333. time_diff(m$changepoints, m$start, "secs")) / m$t.scale
  334. } else {
  335. m$changepoints.t <- c(0) # dummy changepoint
  336. }
  337. return(m)
  338. }
  339. #' Gets changepoint matrix for history dataframe.
  340. #'
  341. #' @param m Prophet object.
  342. #'
  343. #' @return array of indexes.
  344. #'
  345. #' @keywords internal
  346. get_changepoint_matrix <- function(m) {
  347. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  348. for (i in 1:length(m$changepoints.t)) {
  349. A[m$history$t >= m$changepoints.t[i], i] <- 1
  350. }
  351. return(A)
  352. }
  353. #' Provides Fourier series components with the specified frequency and order.
  354. #'
  355. #' @param dates Vector of dates.
  356. #' @param period Number of days of the period.
  357. #' @param series.order Number of components.
  358. #'
  359. #' @return Matrix with seasonality features.
  360. #'
  361. #' @keywords internal
  362. fourier_series <- function(dates, period, series.order) {
  363. t <- time_diff(dates, set_date('1970-01-01 00:00:00'))
  364. features <- matrix(0, length(t), 2 * series.order)
  365. for (i in 1:series.order) {
  366. x <- as.numeric(2 * i * pi * t / period)
  367. features[, i * 2 - 1] <- sin(x)
  368. features[, i * 2] <- cos(x)
  369. }
  370. return(features)
  371. }
  372. #' Data frame with seasonality features.
  373. #'
  374. #' @param dates Vector of dates.
  375. #' @param period Number of days of the period.
  376. #' @param series.order Number of components.
  377. #' @param prefix Column name prefix.
  378. #'
  379. #' @return Dataframe with seasonality.
  380. #'
  381. #' @keywords internal
  382. make_seasonality_features <- function(dates, period, series.order, prefix) {
  383. features <- fourier_series(dates, period, series.order)
  384. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
  385. return(data.frame(features))
  386. }
  387. #' Construct a matrix of holiday features.
  388. #'
  389. #' @param m Prophet object.
  390. #' @param dates Vector with dates used for computing seasonality.
  391. #'
  392. #' @return A dataframe with a column for each holiday.
  393. #'
  394. #' @importFrom dplyr "%>%"
  395. #' @keywords internal
  396. make_holiday_features <- function(m, dates) {
  397. scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
  398. # Strip dates to be just days, for joining on holidays
  399. dates <- set_date(format(dates, "%Y-%m-%d"))
  400. wide <- m$holidays %>%
  401. dplyr::mutate(ds = set_date(ds)) %>%
  402. dplyr::group_by(holiday, ds) %>%
  403. dplyr::filter(row_number() == 1) %>%
  404. dplyr::do({
  405. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  406. && !is.na(.$upper_window)) {
  407. offsets <- seq(.$lower_window, .$upper_window)
  408. } else {
  409. offsets <- c(0)
  410. }
  411. names <- paste(
  412. .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
  413. dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
  414. }) %>%
  415. dplyr::mutate(x = scale.ratio) %>%
  416. tidyr::spread(holiday, x, fill = 0)
  417. holiday.mat <- data.frame(ds = dates) %>%
  418. dplyr::left_join(wide, by = 'ds') %>%
  419. dplyr::select(-ds)
  420. holiday.mat[is.na(holiday.mat)] <- 0
  421. return(holiday.mat)
  422. }
  423. #' Add a seasonal component with specified period and number of Fourier
  424. #' components.
  425. #'
  426. #' Increasing the number of Fourier components allows the seasonality to change
  427. #' more quickly (at risk of overfitting). Default values for yearly and weekly
  428. #' seasonalities are 10 and 3 respectively.
  429. #'
  430. #' @param m Prophet object.
  431. #' @param name String name of the seasonality component.
  432. #' @param period Float number of days in one period.
  433. #' @param fourier.order Int number of Fourier components to use.
  434. #'
  435. #' @return The prophet model with the seasonality added.
  436. #'
  437. #' @importFrom dplyr "%>%"
  438. #' @export
  439. add_seasonality <- function(m, name, period, fourier.order) {
  440. if (!is.null(m$holidays)) {
  441. if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
  442. stop('Name "', name, '" already used for holiday')
  443. }
  444. }
  445. m$seasonalities[[name]] <- c(period, fourier.order)
  446. return(m)
  447. }
  448. #' Dataframe with seasonality features.
  449. #'
  450. #' @param m Prophet object.
  451. #' @param df Dataframe with dates for computing seasonality features.
  452. #'
  453. #' @return Dataframe with seasonality.
  454. #'
  455. #' @keywords internal
  456. make_all_seasonality_features <- function(m, df) {
  457. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  458. for (name in names(m$seasonalities)) {
  459. period <- m$seasonalities[[name]][1]
  460. series.order <- m$seasonalities[[name]][2]
  461. seasonal.features <- cbind(
  462. seasonal.features,
  463. make_seasonality_features(df$ds, period, series.order, name))
  464. }
  465. if(!is.null(m$holidays)) {
  466. seasonal.features <- cbind(
  467. seasonal.features,
  468. make_holiday_features(m, df$ds))
  469. }
  470. return(seasonal.features)
  471. }
  472. #' Get number of Fourier components for built-in seasonalities.
  473. #'
  474. #' @param m Prophet object.
  475. #' @param name String name of the seasonality component.
  476. #' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
  477. #' provided.
  478. #' @param auto.disable Bool if seasonality should be disabled when 'auto'.
  479. #' @param default.order Int default Fourier order.
  480. #'
  481. #' @return Number of Fourier components, or 0 for disabled.
  482. #'
  483. #' @keywords internal
  484. parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
  485. if (arg == 'auto') {
  486. fourier.order <- 0
  487. if (name %in% names(m$seasonalities)) {
  488. warning('Found custom seasonality named "', name,
  489. '", disabling built-in ', name, ' seasonality.')
  490. } else if (auto.disable) {
  491. warning('Disabling ', name, ' seasonality. Run prophet with ', name,
  492. '.seasonality=TRUE to override this.')
  493. } else {
  494. fourier.order <- default.order
  495. }
  496. } else if (arg == TRUE) {
  497. fourier.order <- default.order
  498. } else if (arg == FALSE) {
  499. fourier.order <- 0
  500. } else {
  501. fourier.order <- arg
  502. }
  503. return(fourier.order)
  504. }
  505. #' Set seasonalities that were left on auto.
  506. #'
  507. #' Turns on yearly seasonality if there is >=2 years of history.
  508. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  509. #' spacing between dates in the history is <7 days.
  510. #' Turns on daily seasonality if there is >=2 days of history, and the spacing
  511. #' between dates in the history is <1 day.
  512. #'
  513. #' @param m Prophet object.
  514. #'
  515. #' @return The prophet model with seasonalities set.
  516. #'
  517. #' @keywords internal
  518. set_auto_seasonalities <- function(m) {
  519. first <- min(m$history$ds)
  520. last <- max(m$history$ds)
  521. dt <- diff(time_diff(m$history$ds, m$start))
  522. min.dt <- min(dt[dt > 0])
  523. yearly.disable <- time_diff(last, first) < 730
  524. fourier.order <- parse_seasonality_args(
  525. m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
  526. if (fourier.order > 0) {
  527. m$seasonalities[['yearly']] <- c(365.25, fourier.order)
  528. }
  529. weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
  530. fourier.order <- parse_seasonality_args(
  531. m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
  532. if (fourier.order > 0) {
  533. m$seasonalities[['weekly']] <- c(7, fourier.order)
  534. }
  535. daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
  536. fourier.order <- parse_seasonality_args(
  537. m, 'daily', m$daily.seasonality, daily.disable, 4)
  538. if (fourier.order > 0) {
  539. m$seasonalities[['daily']] <- c(1, fourier.order)
  540. }
  541. return(m)
  542. }
  543. #' Initialize linear growth.
  544. #'
  545. #' Provides a strong initialization for linear growth by calculating the
  546. #' growth and offset parameters that pass the function through the first and
  547. #' last points in the time series.
  548. #'
  549. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  550. #' and t (scaled time).
  551. #'
  552. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  553. #' growth function.
  554. #'
  555. #' @keywords internal
  556. linear_growth_init <- function(df) {
  557. i0 <- which.min(df$ds)
  558. i1 <- which.max(df$ds)
  559. T <- df$t[i1] - df$t[i0]
  560. # Initialize the rate
  561. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  562. # And the offset
  563. m <- df$y_scaled[i0] - k * df$t[i0]
  564. return(c(k, m))
  565. }
  566. #' Initialize logistic growth.
  567. #'
  568. #' Provides a strong initialization for logistic growth by calculating the
  569. #' growth and offset parameters that pass the function through the first and
  570. #' last points in the time series.
  571. #'
  572. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  573. #' y_scaled (scaled time series), and t (scaled time).
  574. #'
  575. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  576. #' growth function.
  577. #'
  578. #' @keywords internal
  579. logistic_growth_init <- function(df) {
  580. i0 <- which.min(df$ds)
  581. i1 <- which.max(df$ds)
  582. T <- df$t[i1] - df$t[i0]
  583. # Force valid values, in case y > cap.
  584. r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
  585. r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
  586. if (abs(r0 - r1) <= 0.01) {
  587. r0 <- 1.05 * r0
  588. }
  589. L0 <- log(r0 - 1)
  590. L1 <- log(r1 - 1)
  591. # Initialize the offset
  592. m <- L0 * T / (L0 - L1)
  593. # And the rate
  594. k <- (L0 - L1) / T
  595. return(c(k, m))
  596. }
  597. #' Fit the prophet model.
  598. #'
  599. #' This sets m$params to contain the fitted model parameters. It is a list
  600. #' with the following elements:
  601. #' k (M array): M posterior samples of the initial slope.
  602. #' m (M array): The initial intercept.
  603. #' delta (MxN matrix): The slope change at each of N changepoints.
  604. #' beta (MxK matrix): Coefficients for K seasonality features.
  605. #' sigma_obs (M array): Noise level.
  606. #' Note that M=1 if MAP estimation.
  607. #'
  608. #' @param m Prophet object.
  609. #' @param df Data frame.
  610. #' @param ... Additional arguments passed to the \code{optimizing} or
  611. #' \code{sampling} functions in Stan.
  612. #'
  613. #' @export
  614. fit.prophet <- function(m, df, ...) {
  615. if (!is.null(m$history)) {
  616. stop("Prophet object can only be fit once. Instantiate a new object.")
  617. }
  618. history <- df %>%
  619. dplyr::filter(!is.na(y))
  620. if (any(is.infinite(history$y))) {
  621. stop("Found infinity in column y.")
  622. }
  623. m$history.dates <- sort(set_date(df$ds))
  624. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  625. history <- out$df
  626. m <- out$m
  627. m$history <- history
  628. m <- set_auto_seasonalities(m)
  629. seasonal.features <- make_all_seasonality_features(m, history)
  630. m <- set_changepoints(m)
  631. A <- get_changepoint_matrix(m)
  632. # Construct input to stan
  633. dat <- list(
  634. T = nrow(history),
  635. K = ncol(seasonal.features),
  636. S = length(m$changepoints.t),
  637. y = history$y_scaled,
  638. t = history$t,
  639. A = A,
  640. t_change = array(m$changepoints.t),
  641. X = as.matrix(seasonal.features),
  642. sigma = m$seasonality.prior.scale,
  643. tau = m$changepoint.prior.scale
  644. )
  645. # Run stan
  646. if (m$growth == 'linear') {
  647. kinit <- linear_growth_init(history)
  648. } else {
  649. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  650. kinit <- logistic_growth_init(history)
  651. }
  652. if (exists(".prophet.stan.models")) {
  653. model <- .prophet.stan.models[[m$growth]]
  654. } else {
  655. model <- get_prophet_stan_model(m$growth)
  656. }
  657. stan_init <- function() {
  658. list(k = kinit[1],
  659. m = kinit[2],
  660. delta = array(rep(0, length(m$changepoints.t))),
  661. beta = array(rep(0, ncol(seasonal.features))),
  662. sigma_obs = 1
  663. )
  664. }
  665. if (min(history$y) == max(history$y)) {
  666. # Nothing to fit.
  667. m$params <- stan_init()
  668. m$params$sigma_obs <- 0.
  669. n.iteration <- 1.
  670. } else if (m$mcmc.samples > 0) {
  671. stan.fit <- rstan::sampling(
  672. model,
  673. data = dat,
  674. init = stan_init,
  675. iter = m$mcmc.samples,
  676. ...
  677. )
  678. m$params <- rstan::extract(stan.fit)
  679. n.iteration <- length(m$params$k)
  680. } else {
  681. stan.fit <- rstan::optimizing(
  682. model,
  683. data = dat,
  684. init = stan_init,
  685. iter = 1e4,
  686. as_vector = FALSE,
  687. ...
  688. )
  689. m$params <- stan.fit$par
  690. n.iteration <- 1
  691. }
  692. # Cast the parameters to have consistent form, whether full bayes or MAP
  693. for (name in c('delta', 'beta')){
  694. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  695. }
  696. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  697. for (name in c('k', 'm', 'sigma_obs')){
  698. m$params[[name]] <- c(m$params[[name]])
  699. }
  700. # If no changepoints were requested, replace delta with 0s
  701. if (m$n.changepoints == 0) {
  702. # Fold delta into the base rate k
  703. m$params$k <- m$params$k + m$params$delta[, 1]
  704. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  705. }
  706. return(m)
  707. }
  708. #' Predict using the prophet model.
  709. #'
  710. #' @param object Prophet object.
  711. #' @param df Dataframe with dates for predictions (column ds), and capacity
  712. #' (column cap) if logistic growth. If not provided, predictions are made on
  713. #' the history.
  714. #' @param ... additional arguments.
  715. #'
  716. #' @return A dataframe with the forecast components.
  717. #'
  718. #' @examples
  719. #' \dontrun{
  720. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  721. #' y = sin(1:366/200) + rnorm(366)/10)
  722. #' m <- prophet(history)
  723. #' future <- make_future_dataframe(m, periods = 365)
  724. #' forecast <- predict(m, future)
  725. #' plot(m, forecast)
  726. #' }
  727. #'
  728. #' @export
  729. predict.prophet <- function(object, df = NULL, ...) {
  730. if (is.null(df)) {
  731. df <- object$history
  732. } else {
  733. out <- setup_dataframe(object, df)
  734. df <- out$df
  735. }
  736. df$trend <- predict_trend(object, df)
  737. df <- df %>%
  738. dplyr::bind_cols(predict_uncertainty(object, df)) %>%
  739. dplyr::bind_cols(predict_seasonal_components(object, df))
  740. df$yhat <- df$trend + df$seasonal
  741. return(df)
  742. }
  743. #' Evaluate the piecewise linear function.
  744. #'
  745. #' @param t Vector of times on which the function is evaluated.
  746. #' @param deltas Vector of rate changes at each changepoint.
  747. #' @param k Float initial rate.
  748. #' @param m Float initial offset.
  749. #' @param changepoint.ts Vector of changepoint times.
  750. #'
  751. #' @return Vector y(t).
  752. #'
  753. #' @keywords internal
  754. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  755. # Intercept changes
  756. gammas <- -changepoint.ts * deltas
  757. # Get cumulative slope and intercept at each t
  758. k_t <- rep(k, length(t))
  759. m_t <- rep(m, length(t))
  760. for (s in 1:length(changepoint.ts)) {
  761. indx <- t >= changepoint.ts[s]
  762. k_t[indx] <- k_t[indx] + deltas[s]
  763. m_t[indx] <- m_t[indx] + gammas[s]
  764. }
  765. y <- k_t * t + m_t
  766. return(y)
  767. }
  768. #' Evaluate the piecewise logistic function.
  769. #'
  770. #' @param t Vector of times on which the function is evaluated.
  771. #' @param cap Vector of capacities at each t.
  772. #' @param deltas Vector of rate changes at each changepoint.
  773. #' @param k Float initial rate.
  774. #' @param m Float initial offset.
  775. #' @param changepoint.ts Vector of changepoint times.
  776. #'
  777. #' @return Vector y(t).
  778. #'
  779. #' @keywords internal
  780. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  781. # Compute offset changes
  782. k.cum <- c(k, cumsum(deltas) + k)
  783. gammas <- rep(0, length(changepoint.ts))
  784. for (i in 1:length(changepoint.ts)) {
  785. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  786. * (1 - k.cum[i] / k.cum[i + 1]))
  787. }
  788. # Get cumulative rate and offset at each t
  789. k_t <- rep(k, length(t))
  790. m_t <- rep(m, length(t))
  791. for (s in 1:length(changepoint.ts)) {
  792. indx <- t >= changepoint.ts[s]
  793. k_t[indx] <- k_t[indx] + deltas[s]
  794. m_t[indx] <- m_t[indx] + gammas[s]
  795. }
  796. y <- cap / (1 + exp(-k_t * (t - m_t)))
  797. return(y)
  798. }
  799. #' Predict trend using the prophet model.
  800. #'
  801. #' @param model Prophet object.
  802. #' @param df Prediction dataframe.
  803. #'
  804. #' @return Vector with trend on prediction dates.
  805. #'
  806. #' @keywords internal
  807. predict_trend <- function(model, df) {
  808. k <- mean(model$params$k, na.rm = TRUE)
  809. param.m <- mean(model$params$m, na.rm = TRUE)
  810. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  811. t <- df$t
  812. if (model$growth == 'linear') {
  813. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  814. } else {
  815. cap <- df$cap_scaled
  816. trend <- piecewise_logistic(
  817. t, cap, deltas, k, param.m, model$changepoints.t)
  818. }
  819. return(trend * model$y.scale)
  820. }
  821. #' Predict seasonality broken down into components.
  822. #'
  823. #' @param m Prophet object.
  824. #' @param df Prediction dataframe.
  825. #'
  826. #' @return Dataframe with seasonal components.
  827. #'
  828. #' @keywords internal
  829. predict_seasonal_components <- function(m, df) {
  830. seasonal.features <- make_all_seasonality_features(m, df)
  831. lower.p <- (1 - m$interval.width)/2
  832. upper.p <- (1 + m$interval.width)/2
  833. # Broken down into components
  834. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  835. dplyr::mutate(col = 1:n()) %>%
  836. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  837. extra = "merge", fill = "right") %>%
  838. dplyr::filter(component != 'zeros')
  839. if (nrow(components) > 0) {
  840. component.predictions <- components %>%
  841. dplyr::group_by(component) %>% dplyr::do({
  842. comp <- (as.matrix(seasonal.features[, .$col])
  843. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  844. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  845. mean = rowMeans(comp, na.rm = TRUE),
  846. lower = apply(comp, 1, stats::quantile, lower.p,
  847. na.rm = TRUE),
  848. upper = apply(comp, 1, stats::quantile, upper.p,
  849. na.rm = TRUE))
  850. }) %>%
  851. tidyr::gather(stat, value, mean, lower, upper) %>%
  852. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  853. tidyr::unite(component, component, stat, sep="") %>%
  854. tidyr::spread(component, value) %>%
  855. dplyr::select(-ix)
  856. component.predictions$seasonal <- rowSums(
  857. component.predictions[unique(components$component)])
  858. } else {
  859. component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
  860. }
  861. return(component.predictions)
  862. }
  863. #' Prophet posterior predictive samples.
  864. #'
  865. #' @param m Prophet object.
  866. #' @param df Prediction dataframe.
  867. #'
  868. #' @return List with posterior predictive samples for each component.
  869. #'
  870. #' @keywords internal
  871. sample_posterior_predictive <- function(m, df) {
  872. # Sample trend, seasonality, and yhat from the extrapolation model.
  873. n.iterations <- length(m$params$k)
  874. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  875. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  876. seasonal.features <- make_all_seasonality_features(m, df)
  877. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  878. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  879. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  880. for (i in 1:n.iterations) {
  881. # For each set of parameters from MCMC (or just 1 set for MAP),
  882. for (j in 1:samp.per.iter) {
  883. # Do a simulation with this set of parameters,
  884. sim <- sample_model(m, df, seasonal.features, i)
  885. # Store the results
  886. for (key in c("trend", "seasonal", "yhat")) {
  887. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  888. }
  889. }
  890. }
  891. return(sim.values)
  892. }
  893. #' Sample from the posterior predictive distribution.
  894. #'
  895. #' @param m Prophet object.
  896. #' @param df Dataframe with dates for predictions (column ds), and capacity
  897. #' (column cap) if logistic growth.
  898. #'
  899. #' @return A list with items "trend", "seasonal", and "yhat" containing
  900. #' posterior predictive samples for that component.
  901. #'
  902. #' @export
  903. predictive_samples <- function(m, df) {
  904. df <- setup_dataframe(m, df)$df
  905. sim.values <- sample_posterior_predictive(m, df)
  906. return(sim.values)
  907. }
  908. #' Prophet uncertainty intervals.
  909. #'
  910. #' @param m Prophet object.
  911. #' @param df Prediction dataframe.
  912. #'
  913. #' @return Dataframe with uncertainty intervals.
  914. #'
  915. #' @keywords internal
  916. predict_uncertainty <- function(m, df) {
  917. sim.values <- sample_posterior_predictive(m, df)
  918. # Add uncertainty estimates
  919. lower.p <- (1 - m$interval.width)/2
  920. upper.p <- (1 + m$interval.width)/2
  921. intervals <- cbind(
  922. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  923. na.rm = TRUE)),
  924. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  925. na.rm = TRUE)),
  926. t(apply(t(sim.values$seasonal), 2, stats::quantile, c(lower.p, upper.p),
  927. na.rm = TRUE))
  928. ) %>% dplyr::as_data_frame()
  929. colnames(intervals) <- paste(rep(c('yhat', 'trend', 'seasonal'), each=2),
  930. c('lower', 'upper'), sep = "_")
  931. return(intervals)
  932. }
  933. #' Simulate observations from the extrapolated generative model.
  934. #'
  935. #' @param m Prophet object.
  936. #' @param df Prediction dataframe.
  937. #' @param seasonal.features Data frame of seasonal features
  938. #' @param iteration Int sampling iteration to use parameters from.
  939. #'
  940. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  941. #'
  942. #' @keywords internal
  943. sample_model <- function(m, df, seasonal.features, iteration) {
  944. trend <- sample_predictive_trend(m, df, iteration)
  945. beta <- m$params$beta[iteration,]
  946. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  947. sigma <- m$params$sigma_obs[iteration]
  948. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  949. return(list("yhat" = trend + seasonal + noise,
  950. "trend" = trend,
  951. "seasonal" = seasonal))
  952. }
  953. #' Simulate the trend using the extrapolated generative model.
  954. #'
  955. #' @param model Prophet object.
  956. #' @param df Prediction dataframe.
  957. #' @param iteration Int sampling iteration to use parameters from.
  958. #'
  959. #' @return Vector of simulated trend over df$t.
  960. #'
  961. #' @keywords internal
  962. sample_predictive_trend <- function(model, df, iteration) {
  963. k <- model$params$k[iteration]
  964. param.m <- model$params$m[iteration]
  965. deltas <- model$params$delta[iteration,]
  966. t <- df$t
  967. T <- max(t)
  968. if (T > 1) {
  969. # Get the time discretization of the history
  970. dt <- diff(model$history$t)
  971. dt <- min(dt[dt > 0])
  972. # Number of time periods in the future
  973. N <- ceiling((T - 1) / dt)
  974. S <- length(model$changepoints.t)
  975. # The history had S split points, over t = [0, 1].
  976. # The forecast is on [1, T], and should have the same average frequency of
  977. # rate changes. Thus for N time periods in the future, we want an average
  978. # of S * (T - 1) changepoints in expectation.
  979. prob.change <- min(1, (S * (T - 1)) / N)
  980. # This calculation works for both history and df not uniformly spaced.
  981. n.changes <- stats::rbinom(1, N, prob.change)
  982. # Sample ts
  983. if (n.changes == 0) {
  984. changepoint.ts.new <- c()
  985. } else {
  986. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  987. }
  988. } else {
  989. changepoint.ts.new <- c()
  990. n.changes <- 0
  991. }
  992. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  993. lambda <- mean(abs(c(deltas))) + 1e-8
  994. # Sample deltas
  995. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  996. # Combine with changepoints from the history
  997. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  998. deltas <- c(deltas, deltas.new)
  999. # Get the corresponding trend
  1000. if (model$growth == 'linear') {
  1001. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  1002. } else {
  1003. cap <- df$cap_scaled
  1004. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  1005. }
  1006. return(trend * model$y.scale)
  1007. }
  1008. #' Make dataframe with future dates for forecasting.
  1009. #'
  1010. #' @param m Prophet model object.
  1011. #' @param periods Int number of periods to forecast forward.
  1012. #' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
  1013. #' @param include_history Boolean to include the historical dates in the data
  1014. #' frame for predictions.
  1015. #'
  1016. #' @return Dataframe that extends forward from the end of m$history for the
  1017. #' requested number of periods.
  1018. #'
  1019. #' @export
  1020. make_future_dataframe <- function(m, periods, freq = 'day',
  1021. include_history = TRUE) {
  1022. # For backwards compatability with previous zoo date type,
  1023. if (freq == 'm') {
  1024. freq <- 'month'
  1025. }
  1026. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  1027. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  1028. if (include_history) {
  1029. dates <- c(m$history.dates, dates)
  1030. attr(dates, "tzone") <- "GMT"
  1031. }
  1032. return(data.frame(ds = dates))
  1033. }
  1034. #' Merge history and forecast for plotting.
  1035. #'
  1036. #' @param m Prophet object.
  1037. #' @param fcst Data frame returned by prophet predict.
  1038. #'
  1039. #' @importFrom dplyr "%>%"
  1040. #' @keywords internal
  1041. df_for_plotting <- function(m, fcst) {
  1042. # Make sure there is no y in fcst
  1043. fcst$y <- NULL
  1044. df <- m$history %>%
  1045. dplyr::select(ds, y) %>%
  1046. dplyr::full_join(fcst, by = "ds") %>%
  1047. dplyr::arrange(ds)
  1048. return(df)
  1049. }
  1050. #' Plot the prophet forecast.
  1051. #'
  1052. #' @param x Prophet object.
  1053. #' @param fcst Data frame returned by predict(m, df).
  1054. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  1055. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  1056. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1057. #' figure, if available.
  1058. #' @param xlabel Optional label for x-axis
  1059. #' @param ylabel Optional label for y-axis
  1060. #' @param ... additional arguments
  1061. #'
  1062. #' @return A ggplot2 plot.
  1063. #'
  1064. #' @examples
  1065. #' \dontrun{
  1066. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  1067. #' y = sin(1:366/200) + rnorm(366)/10)
  1068. #' m <- prophet(history)
  1069. #' future <- make_future_dataframe(m, periods = 365)
  1070. #' forecast <- predict(m, future)
  1071. #' plot(m, forecast)
  1072. #' }
  1073. #'
  1074. #' @export
  1075. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  1076. xlabel = 'ds', ylabel = 'y', ...) {
  1077. df <- df_for_plotting(x, fcst)
  1078. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  1079. ggplot2::labs(x = xlabel, y = ylabel)
  1080. if (exists('cap', where = df) && plot_cap) {
  1081. gg <- gg + ggplot2::geom_line(
  1082. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  1083. }
  1084. if (uncertainty && exists('yhat_lower', where = df)) {
  1085. gg <- gg +
  1086. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  1087. alpha = 0.2,
  1088. fill = "#0072B2",
  1089. na.rm = TRUE)
  1090. }
  1091. gg <- gg +
  1092. ggplot2::geom_point(na.rm=TRUE) +
  1093. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  1094. na.rm = TRUE) +
  1095. ggplot2::theme(aspect.ratio = 3 / 5)
  1096. return(gg)
  1097. }
  1098. #' Plot the components of a prophet forecast.
  1099. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  1100. #' present, and holidays if present.
  1101. #'
  1102. #' @param m Prophet object.
  1103. #' @param fcst Data frame returned by predict(m, df).
  1104. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  1105. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  1106. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1107. #' figure, if available.
  1108. #' @param weekly_start Integer specifying the start day of the weekly
  1109. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  1110. #' to Monday, and so on.
  1111. #' @param yearly_start Integer specifying the start day of the yearly
  1112. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  1113. #' to Jan 2, and so on.
  1114. #'
  1115. #' @return Invisibly return a list containing the plotted ggplot objects
  1116. #'
  1117. #' @export
  1118. #' @importFrom dplyr "%>%"
  1119. prophet_plot_components <- function(
  1120. m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
  1121. yearly_start = 0) {
  1122. df <- df_for_plotting(m, fcst)
  1123. # Plot the trend
  1124. panels <- list(plot_trend(df, uncertainty, plot_cap))
  1125. # Plot holiday components, if present.
  1126. if (!is.null(m$holidays)) {
  1127. panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
  1128. }
  1129. # Plot weekly seasonality, if present
  1130. if ("weekly" %in% colnames(df)) {
  1131. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
  1132. }
  1133. # Plot yearly seasonality, if present
  1134. if ("yearly" %in% colnames(df)) {
  1135. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
  1136. }
  1137. # Plot other seasonalities
  1138. for (name in names(m$seasonalities)) {
  1139. if (!(name %in% c('weekly', 'yearly')) &&
  1140. (name %in% colnames(df))) {
  1141. panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
  1142. }
  1143. }
  1144. # Make the plot.
  1145. grid::grid.newpage()
  1146. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  1147. 1)))
  1148. for (i in 1:length(panels)) {
  1149. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  1150. layout.pos.col = 1))
  1151. }
  1152. return(invisible(panels))
  1153. }
  1154. #' Plot the prophet trend.
  1155. #'
  1156. #' @param df Forecast dataframe for plotting.
  1157. #' @param uncertainty Boolean to plot uncertainty intervals.
  1158. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1159. #' figure, if available.
  1160. #'
  1161. #' @return A ggplot2 plot.
  1162. #'
  1163. #' @keywords internal
  1164. plot_trend <- function(df, uncertainty = TRUE, plot_cap = TRUE) {
  1165. df.t <- df[!is.na(df$trend),]
  1166. gg.trend <- ggplot2::ggplot(df.t, ggplot2::aes(x = ds, y = trend)) +
  1167. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1168. if (exists('cap', where = df.t) && plot_cap) {
  1169. gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
  1170. linetype = 'dashed',
  1171. na.rm = TRUE)
  1172. }
  1173. if (uncertainty) {
  1174. gg.trend <- gg.trend +
  1175. ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
  1176. ymax = trend_upper),
  1177. alpha = 0.2,
  1178. fill = "#0072B2",
  1179. na.rm = TRUE)
  1180. }
  1181. return(gg.trend)
  1182. }
  1183. #' Plot the holidays component of the forecast.
  1184. #'
  1185. #' @param m Prophet model
  1186. #' @param df Forecast dataframe for plotting.
  1187. #' @param uncertainty Boolean to plot uncertainty intervals.
  1188. #'
  1189. #' @return A ggplot2 plot.
  1190. #'
  1191. #' @keywords internal
  1192. plot_holidays <- function(m, df, uncertainty = TRUE) {
  1193. holiday.comps <- unique(m$holidays$holiday) %>% as.character()
  1194. df.s <- data.frame(ds = df$ds,
  1195. holidays = rowSums(df[, holiday.comps, drop = FALSE]),
  1196. holidays_lower = rowSums(df[, paste0(holiday.comps,
  1197. "_lower"), drop = FALSE]),
  1198. holidays_upper = rowSums(df[, paste0(holiday.comps,
  1199. "_upper"), drop = FALSE]))
  1200. df.s <- df.s[!is.na(df.s$holidays),]
  1201. # NOTE the above CI calculation is incorrect if holidays overlap in time.
  1202. # Since it is just for the visualization we will not worry about it now.
  1203. gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
  1204. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1205. if (uncertainty) {
  1206. gg.holidays <- gg.holidays +
  1207. ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
  1208. ymax = holidays_upper),
  1209. alpha = 0.2,
  1210. fill = "#0072B2",
  1211. na.rm = TRUE)
  1212. }
  1213. return(gg.holidays)
  1214. }
  1215. #' Plot the weekly component of the forecast.
  1216. #'
  1217. #' @param m Prophet model object
  1218. #' @param uncertainty Boolean to plot uncertainty intervals.
  1219. #' @param weekly_start Integer specifying the start day of the weekly
  1220. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  1221. #' to Monday, and so on.
  1222. #'
  1223. #' @return A ggplot2 plot.
  1224. #'
  1225. #' @keywords internal
  1226. plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
  1227. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  1228. df.w <- data.frame(
  1229. ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
  1230. weekly_start, cap=1.)
  1231. df.w <- setup_dataframe(m, df.w)$df
  1232. seas <- predict_seasonal_components(m, df.w)
  1233. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  1234. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  1235. group = 1)) +
  1236. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1237. ggplot2::labs(x = "Day of week")
  1238. if (uncertainty) {
  1239. gg.weekly <- gg.weekly +
  1240. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  1241. ymax = weekly_upper),
  1242. alpha = 0.2,
  1243. fill = "#0072B2",
  1244. na.rm = TRUE)
  1245. }
  1246. return(gg.weekly)
  1247. }
  1248. #' Plot the yearly component of the forecast.
  1249. #'
  1250. #' @param m Prophet model object.
  1251. #' @param uncertainty Boolean to plot uncertainty intervals.
  1252. #' @param yearly_start Integer specifying the start day of the yearly
  1253. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  1254. #' to Jan 2, and so on.
  1255. #'
  1256. #' @return A ggplot2 plot.
  1257. #'
  1258. #' @keywords internal
  1259. plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
  1260. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  1261. df.y <- data.frame(
  1262. ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
  1263. yearly_start, cap=1.)
  1264. df.y <- setup_dataframe(m, df.y)$df
  1265. seas <- predict_seasonal_components(m, df.y)
  1266. seas$ds <- df.y$ds
  1267. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  1268. group = 1)) +
  1269. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1270. ggplot2::labs(x = "Day of year") +
  1271. ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
  1272. if (uncertainty) {
  1273. gg.yearly <- gg.yearly +
  1274. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  1275. ymax = yearly_upper),
  1276. alpha = 0.2,
  1277. fill = "#0072B2",
  1278. na.rm = TRUE)
  1279. }
  1280. return(gg.yearly)
  1281. }
  1282. #' Plot a custom seasonal component.
  1283. #'
  1284. #' @param m Prophet model object.
  1285. #' @param name String name of the seasonality.
  1286. #' @param uncertainty Boolean to plot uncertainty intervals.
  1287. #'
  1288. #' @return A ggplot2 plot.
  1289. #'
  1290. #' @keywords internal
  1291. plot_seasonality <- function(m, name, uncertainty = TRUE) {
  1292. # Compute seasonality from Jan 1 through a single period.
  1293. start <- set_date('2017-01-01')
  1294. period <- m$seasonalities[[name]][1]
  1295. end <- start + period * 24 * 3600
  1296. plot.points <- 200
  1297. df.y <- data.frame(
  1298. ds=seq(from=start, to=end, length.out=plot.points), cap=1.)
  1299. df.y <- setup_dataframe(m, df.y)$df
  1300. seas <- predict_seasonal_components(m, df.y)
  1301. seas$ds <- df.y$ds
  1302. gg.s <- ggplot2::ggplot(
  1303. seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  1304. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1305. if (period <= 2) {
  1306. fmt.str <- '%T'
  1307. }
  1308. else if (period < 14) {
  1309. fmt.str <- '%m/%d %R'
  1310. } else {
  1311. fmt.str <- '%m/%d'
  1312. }
  1313. gg.s <- gg.s +
  1314. ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
  1315. if (uncertainty) {
  1316. gg.s <- gg.s +
  1317. ggplot2::geom_ribbon(
  1318. ggplot2::aes_string(
  1319. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  1320. ),
  1321. alpha = 0.2,
  1322. fill = "#0072B2",
  1323. na.rm = TRUE)
  1324. }
  1325. return(gg.s)
  1326. }
  1327. #' Copy Prophet object.
  1328. #'
  1329. #' @param m Prophet model object.
  1330. #' @param cutoff Date, possibly as string. Changepoints are only retained if
  1331. #' changepoints <= cutoff.
  1332. #'
  1333. #' @return An unfitted Prophet model object with the same parameters as the
  1334. #' input model.
  1335. #'
  1336. #' @keywords internal
  1337. prophet_copy <- function(m, cutoff = NULL) {
  1338. if (m$specified.changepoints) {
  1339. changepoints <- m$changepoints
  1340. if (!is.null(cutoff)) {
  1341. cutoff <- set_date(cutoff)
  1342. changepoints <- changepoints[changepoints <= cutoff]
  1343. }
  1344. } else {
  1345. changepoints <- NULL
  1346. }
  1347. return(prophet(
  1348. growth = m$growth,
  1349. changepoints = changepoints,
  1350. n.changepoints = m$n.changepoints,
  1351. yearly.seasonality = m$yearly.seasonality,
  1352. weekly.seasonality = m$weekly.seasonality,
  1353. daily.seasonality = m$daily.seasonality,
  1354. holidays = m$holidays,
  1355. seasonality.prior.scale = m$seasonality.prior.scale,
  1356. changepoint.prior.scale = m$changepoint.prior.scale,
  1357. holidays.prior.scale = m$holidays.prior.scale,
  1358. mcmc.samples = m$mcmc.samples,
  1359. interval.width = m$interval.width,
  1360. uncertainty.samples = m$uncertainty.samples,
  1361. fit = FALSE,
  1362. ))
  1363. }
  1364. # fb-block 3