prophet.R 44 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. utils::globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df (optional) Dataframe containing the history. Must have columns ds
  16. #' (date type) and y, the time series. If growth is logistic, then df must
  17. #' also have a column cap that specifies the capacity at each ds. If not
  18. #' provided, then the model object will be instantiated but not fit; use
  19. #' fit.prophet(m, df) to fit the model.
  20. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  21. #' trend.
  22. #' @param changepoints Vector of dates at which to include potential
  23. #' changepoints. If not specified, potential changepoints are selected
  24. #' automatically.
  25. #' @param n.changepoints Number of potential changepoints to include. Not used
  26. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  27. #' then n.changepoints potential changepoints are selected uniformly from the
  28. #' first 80 percent of df$ds.
  29. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
  30. #' FALSE, or a number of Fourier terms to generate.
  31. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
  32. #' FALSE, or a number of Fourier terms to generate.
  33. #' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
  34. #' FALSE, or a number of Fourier terms to generate.
  35. #' @param holidays data frame with columns holiday (character) and ds (date
  36. #' type)and optionally columns lower_window and upper_window which specify a
  37. #' range of days around the date to be included as holidays. lower_window=-2
  38. #' will include 2 days prior to the date as holidays. Also optionally can have
  39. #' a column prior_scale specifying the prior scale for each holiday.
  40. #' @param seasonality.prior.scale Parameter modulating the strength of the
  41. #' seasonality model. Larger values allow the model to fit larger seasonal
  42. #' fluctuations, smaller values dampen the seasonality. Can be specified for
  43. #' individual seasonalities using add_seasonality.
  44. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  45. #' components model, unless overridden in the holidays input.
  46. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  47. #' automatic changepoint selection. Large values will allow many changepoints,
  48. #' small values will allow few changepoints.
  49. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  50. #' inference with the specified number of MCMC samples. If 0, will do MAP
  51. #' estimation.
  52. #' @param interval.width Numeric, width of the uncertainty intervals provided
  53. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  54. #' in the trend using the MAP estimate of the extrapolated generative model.
  55. #' If mcmc.samples>0, this will be integrated over all model parameters,
  56. #' which will include uncertainty in seasonality.
  57. #' @param uncertainty.samples Number of simulated draws used to estimate
  58. #' uncertainty intervals.
  59. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  60. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  61. #'
  62. #' @return A prophet model.
  63. #'
  64. #' @examples
  65. #' \dontrun{
  66. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  67. #' y = sin(1:366/200) + rnorm(366)/10)
  68. #' m <- prophet(history)
  69. #' }
  70. #'
  71. #' @export
  72. #' @importFrom dplyr "%>%"
  73. #' @import Rcpp
  74. prophet <- function(df = NULL,
  75. growth = 'linear',
  76. changepoints = NULL,
  77. n.changepoints = 25,
  78. yearly.seasonality = 'auto',
  79. weekly.seasonality = 'auto',
  80. daily.seasonality = 'auto',
  81. holidays = NULL,
  82. seasonality.prior.scale = 10,
  83. holidays.prior.scale = 10,
  84. changepoint.prior.scale = 0.05,
  85. mcmc.samples = 0,
  86. interval.width = 0.80,
  87. uncertainty.samples = 1000,
  88. fit = TRUE,
  89. ...
  90. ) {
  91. # fb-block 1
  92. if (!is.null(changepoints)) {
  93. n.changepoints <- length(changepoints)
  94. }
  95. m <- list(
  96. growth = growth,
  97. changepoints = changepoints,
  98. n.changepoints = n.changepoints,
  99. yearly.seasonality = yearly.seasonality,
  100. weekly.seasonality = weekly.seasonality,
  101. daily.seasonality = daily.seasonality,
  102. holidays = holidays,
  103. seasonality.prior.scale = seasonality.prior.scale,
  104. changepoint.prior.scale = changepoint.prior.scale,
  105. holidays.prior.scale = holidays.prior.scale,
  106. mcmc.samples = mcmc.samples,
  107. interval.width = interval.width,
  108. uncertainty.samples = uncertainty.samples,
  109. specified.changepoints = !is.null(changepoints),
  110. start = NULL, # This and following attributes are set during fitting
  111. y.scale = NULL,
  112. logistic.floor = FALSE,
  113. t.scale = NULL,
  114. changepoints.t = NULL,
  115. seasonalities = list(),
  116. extra_regressors = list(),
  117. stan.fit = NULL,
  118. params = list(),
  119. history = NULL,
  120. history.dates = NULL
  121. )
  122. validate_inputs(m)
  123. class(m) <- append("prophet", class(m))
  124. if ((fit) && (!is.null(df))) {
  125. m <- fit.prophet(m, df, ...)
  126. }
  127. # fb-block 2
  128. return(m)
  129. }
  130. #' Validates the inputs to Prophet.
  131. #'
  132. #' @param m Prophet object.
  133. #'
  134. #' @keywords internal
  135. validate_inputs <- function(m) {
  136. if (!(m$growth %in% c('linear', 'logistic'))) {
  137. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  138. }
  139. if (!is.null(m$holidays)) {
  140. if (!(exists('holiday', where = m$holidays))) {
  141. stop('Holidays dataframe must have holiday field.')
  142. }
  143. if (!(exists('ds', where = m$holidays))) {
  144. stop('Holidays dataframe must have ds field.')
  145. }
  146. has.lower <- exists('lower_window', where = m$holidays)
  147. has.upper <- exists('upper_window', where = m$holidays)
  148. if (has.lower + has.upper == 1) {
  149. stop(paste('Holidays must have both lower_window and upper_window,',
  150. 'or neither.'))
  151. }
  152. if (has.lower) {
  153. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  154. stop('Holiday lower_window should be <= 0')
  155. }
  156. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  157. stop('Holiday upper_window should be >= 0')
  158. }
  159. }
  160. for (h in unique(m$holidays$holiday)) {
  161. validate_column_name(m, h, check_holidays = FALSE)
  162. }
  163. }
  164. }
  165. #' Validates the name of a seasonality, holiday, or regressor.
  166. #'
  167. #' @param m Prophet object.
  168. #' @param name string
  169. #' @param check_holidays bool check if name already used for holiday
  170. #' @param check_seasonalities bool check if name already used for seasonality
  171. #' @param check_regressors bool check if name already used for regressor
  172. #'
  173. #' @keywords internal
  174. validate_column_name <- function(
  175. m, name, check_holidays = TRUE, check_seasonalities = TRUE,
  176. check_regressors = TRUE
  177. ) {
  178. if (grepl("_delim_", name)) {
  179. stop('Holiday name cannot contain "_delim_"')
  180. }
  181. reserved_names = c(
  182. 'trend', 'seasonal', 'seasonalities', 'daily', 'weekly', 'yearly',
  183. 'holidays', 'zeros', 'extra_regressors', 'yhat'
  184. )
  185. rn_l = paste0(reserved_names,"_lower")
  186. rn_u = paste0(reserved_names,"_upper")
  187. reserved_names = c(reserved_names, rn_l, rn_u,
  188. c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
  189. if(name %in% reserved_names){
  190. stop("Name ", name, " is reserved.")
  191. }
  192. if(check_holidays && !is.null(m$holidays) &&
  193. (name %in% unique(m$holidays$holiday))){
  194. stop("Name ", name, " already used for a holiday.")
  195. }
  196. if(check_seasonalities && (!is.null(m$seasonalities[[name]]))){
  197. stop("Name ", name, " already used for a seasonality.")
  198. }
  199. if(check_regressors && (!is.null(m$seasonalities[[name]]))){
  200. stop("Name ", name, " already used for an added regressor.")
  201. }
  202. }
  203. #' Load compiled Stan model
  204. #'
  205. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  206. #' trend.
  207. #'
  208. #' @return Stan model.
  209. #'
  210. #' @keywords internal
  211. get_prophet_stan_model <- function() {
  212. ## If the cached model doesn't work, just compile a new one.
  213. tryCatch({
  214. binary <- system.file(
  215. 'libs',
  216. Sys.getenv('R_ARCH'),
  217. 'prophet_stan_model.RData',
  218. package = 'prophet',
  219. mustWork = TRUE
  220. )
  221. load(binary)
  222. obj.name <- 'model.stanm'
  223. stanm <- eval(parse(text = obj.name))
  224. ## Should cause an error if the model doesn't work.
  225. stanm@mk_cppmodule(stanm)
  226. stanm
  227. }, error = function(cond) {
  228. compile_stan_model()
  229. })
  230. }
  231. #' Compile Stan model
  232. #'
  233. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  234. #' trend.
  235. #'
  236. #' @return Stan model.
  237. #'
  238. #' @keywords internal
  239. compile_stan_model <- function() {
  240. fn <- 'stan/prophet.stan'
  241. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  242. stanc <- rstan::stanc(stan.src)
  243. return(rstan::stan_model(stanc_ret = stanc, model_name = 'prophet_model'))
  244. }
  245. #' Convert date vector
  246. #'
  247. #' Convert the date to POSIXct object
  248. #'
  249. #' @param ds Date vector, can be consisted of characters
  250. #' @param tz string time zone
  251. #'
  252. #' @return vector of POSIXct object converted from date
  253. #'
  254. #' @keywords internal
  255. set_date <- function(ds = NULL, tz = "GMT") {
  256. if (length(ds) == 0) {
  257. return(NULL)
  258. }
  259. if (is.factor(ds)) {
  260. ds <- as.character(ds)
  261. }
  262. fmt <- if (min(nchar(ds)) < 12) "%Y-%m-%d" else "%Y-%m-%d %H:%M:%S"
  263. ds <- as.POSIXct(ds, format = fmt, tz = tz)
  264. attr(ds, "tzone") <- tz
  265. return(ds)
  266. }
  267. #' Time difference between datetimes
  268. #'
  269. #' Compute time difference of two POSIXct objects
  270. #'
  271. #' @param ds1 POSIXct object
  272. #' @param ds2 POSIXct object
  273. #' @param units string units of difference, e.g. 'days' or 'secs'.
  274. #'
  275. #' @return numeric time difference
  276. #'
  277. #' @keywords internal
  278. time_diff <- function(ds1, ds2, units = "days") {
  279. return(as.numeric(difftime(ds1, ds2, units = units)))
  280. }
  281. #' Prepare dataframe for fitting or predicting.
  282. #'
  283. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  284. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  285. #' and predicting.
  286. #'
  287. #' @param m Prophet object.
  288. #' @param df Data frame with columns ds, y, and cap if logistic growth. Any
  289. #' specified additional regressors must also be present.
  290. #' @param initialize_scales Boolean set scaling factors in m from df.
  291. #'
  292. #' @return list with items 'df' and 'm'.
  293. #'
  294. #' @keywords internal
  295. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  296. if (exists('y', where=df)) {
  297. df$y <- as.numeric(df$y)
  298. }
  299. if (any(is.infinite(df$y))) {
  300. stop("Found infinity in column y.")
  301. }
  302. df$ds <- set_date(df$ds)
  303. if (anyNA(df$ds)) {
  304. stop(paste('Unable to parse date format in column ds. Convert to date ',
  305. 'format (%Y-%m-%d or %Y-%m-%d %H:%M:%S) and check that there',
  306. 'are no NAs.'))
  307. }
  308. for (name in names(m$extra_regressors)) {
  309. if (!(name %in% colnames(df))) {
  310. stop('Regressor "', name, '" missing from dataframe')
  311. }
  312. }
  313. df <- df %>%
  314. dplyr::arrange(ds)
  315. m <- initialize_scales_fn(m, initialize_scales, df)
  316. if (m$logistic.floor) {
  317. if (!('floor' %in% colnames(df))) {
  318. stop("Expected column 'floor'.")
  319. }
  320. } else {
  321. df$floor <- 0
  322. }
  323. if (m$growth == 'logistic') {
  324. if (!(exists('cap', where=df))) {
  325. stop('Capacities must be supplied for logistic growth.')
  326. }
  327. df <- df %>%
  328. dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
  329. }
  330. df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  331. if (exists('y', where=df)) {
  332. df$y_scaled <- (df$y - df$floor) / m$y.scale
  333. }
  334. for (name in names(m$extra_regressors)) {
  335. df[[name]] <- as.numeric(df[[name]])
  336. props <- m$extra_regressors[[name]]
  337. df[[name]] <- (df[[name]] - props$mu) / props$std
  338. if (anyNA(df[[name]])) {
  339. stop('Found NaN in column ', name)
  340. }
  341. }
  342. return(list("m" = m, "df" = df))
  343. }
  344. #' Initialize model scales.
  345. #'
  346. #' Sets model scaling factors using df.
  347. #'
  348. #' @param m Prophet object.
  349. #' @param initialize_scales Boolean set the scales or not.
  350. #' @param df Dataframe for setting scales.
  351. #'
  352. #' @return Prophet object with scales set.
  353. #'
  354. #' @keywords internal
  355. initialize_scales_fn <- function(m, initialize_scales, df) {
  356. if (!initialize_scales) {
  357. return(m)
  358. }
  359. if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
  360. m$logistic.floor <- TRUE
  361. floor <- df$floor
  362. } else {
  363. floor <- 0
  364. }
  365. m$y.scale <- max(abs(df$y - floor))
  366. if (m$y.scale == 0) {
  367. m$y.scale <- 1
  368. }
  369. m$start <- min(df$ds)
  370. m$t.scale <- time_diff(max(df$ds), m$start, "secs")
  371. for (name in names(m$extra_regressors)) {
  372. n.vals <- length(unique(df[[name]]))
  373. if (n.vals < 2) {
  374. stop('Regressor ', name, ' is constant.')
  375. }
  376. standardize <- m$extra_regressors[[name]]$standardize
  377. if (standardize == 'auto') {
  378. if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
  379. # Don't standardize binary variables
  380. standardize <- FALSE
  381. } else {
  382. standardize <- TRUE
  383. }
  384. }
  385. if (standardize) {
  386. mu <- mean(df[[name]])
  387. std <- stats::sd(df[[name]])
  388. m$extra_regressors[[name]]$mu <- mu
  389. m$extra_regressors[[name]]$std <- std
  390. }
  391. }
  392. return(m)
  393. }
  394. #' Set changepoints
  395. #'
  396. #' Sets m$changepoints to the dates of changepoints. Either:
  397. #' 1) The changepoints were passed in explicitly.
  398. #' A) They are empty.
  399. #' B) They are not empty, and need validation.
  400. #' 2) We are generating a grid of them.
  401. #' 3) The user prefers no changepoints be used.
  402. #'
  403. #' @param m Prophet object.
  404. #'
  405. #' @return m with changepoints set.
  406. #'
  407. #' @keywords internal
  408. set_changepoints <- function(m) {
  409. if (!is.null(m$changepoints)) {
  410. if (length(m$changepoints) > 0) {
  411. m$changepoints <- set_date(m$changepoints)
  412. if (min(m$changepoints) < min(m$history$ds)
  413. || max(m$changepoints) > max(m$history$ds)) {
  414. stop('Changepoints must fall within training data.')
  415. }
  416. }
  417. } else {
  418. # Place potential changepoints evenly through the first 80 pcnt of
  419. # the history.
  420. hist.size <- floor(nrow(m$history) * .8)
  421. if (m$n.changepoints + 1 > hist.size) {
  422. m$n.changepoints <- hist.size - 1
  423. message('n.changepoints greater than number of observations. Using ',
  424. m$n.changepoints)
  425. }
  426. if (m$n.changepoints > 0) {
  427. cp.indexes <- round(seq.int(1, hist.size,
  428. length.out = (m$n.changepoints + 1))[-1])
  429. m$changepoints <- m$history$ds[cp.indexes]
  430. } else {
  431. m$changepoints <- c()
  432. }
  433. }
  434. if (length(m$changepoints) > 0) {
  435. m$changepoints.t <- sort(
  436. time_diff(m$changepoints, m$start, "secs")) / m$t.scale
  437. } else {
  438. m$changepoints.t <- 0 # dummy changepoint
  439. }
  440. return(m)
  441. }
  442. #' Provides Fourier series components with the specified frequency and order.
  443. #'
  444. #' @param dates Vector of dates.
  445. #' @param period Number of days of the period.
  446. #' @param series.order Number of components.
  447. #'
  448. #' @return Matrix with seasonality features.
  449. #'
  450. #' @keywords internal
  451. fourier_series <- function(dates, period, series.order) {
  452. t <- time_diff(dates, set_date('1970-01-01 00:00:00'))
  453. features <- matrix(0, length(t), 2 * series.order)
  454. for (i in seq_len(series.order)) {
  455. x <- as.numeric(2 * i * pi * t / period)
  456. features[, i * 2 - 1] <- sin(x)
  457. features[, i * 2] <- cos(x)
  458. }
  459. return(features)
  460. }
  461. #' Data frame with seasonality features.
  462. #'
  463. #' @param dates Vector of dates.
  464. #' @param period Number of days of the period.
  465. #' @param series.order Number of components.
  466. #' @param prefix Column name prefix.
  467. #'
  468. #' @return Dataframe with seasonality.
  469. #'
  470. #' @keywords internal
  471. make_seasonality_features <- function(dates, period, series.order, prefix) {
  472. features <- fourier_series(dates, period, series.order)
  473. colnames(features) <- paste(prefix, seq_len(ncol(features)), sep = '_delim_')
  474. return(data.frame(features))
  475. }
  476. #' Construct a matrix of holiday features.
  477. #'
  478. #' @param m Prophet object.
  479. #' @param dates Vector with dates used for computing seasonality.
  480. #'
  481. #' @return A list with entries
  482. #' holiday.features: dataframe with a column for each holiday.
  483. #' prior.scales: array of prior scales for each holiday column.
  484. #'
  485. #' @importFrom dplyr "%>%"
  486. #' @keywords internal
  487. make_holiday_features <- function(m, dates) {
  488. # Strip dates to be just days, for joining on holidays
  489. dates <- set_date(format(dates, "%Y-%m-%d"))
  490. wide <- m$holidays %>%
  491. dplyr::mutate(ds = set_date(ds)) %>%
  492. dplyr::group_by(holiday, ds) %>%
  493. dplyr::filter(row_number() == 1) %>%
  494. dplyr::do({
  495. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  496. && !is.na(.$upper_window)) {
  497. offsets <- seq(.$lower_window, .$upper_window)
  498. } else {
  499. offsets <- 0
  500. }
  501. names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
  502. abs(offsets), sep = '')
  503. dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
  504. }) %>%
  505. dplyr::mutate(x = 1) %>%
  506. tidyr::spread(holiday, x, fill = 0)
  507. holiday.features <- data.frame(ds = set_date(dates)) %>%
  508. dplyr::left_join(wide, by = 'ds') %>%
  509. dplyr::select(-ds)
  510. holiday.features[is.na(holiday.features)] <- 0
  511. # Prior scales
  512. if (!('prior_scale' %in% colnames(m$holidays))) {
  513. m$holidays$prior_scale <- m$holidays.prior.scale
  514. }
  515. prior.scales.list <- list()
  516. for (name in unique(m$holidays$holiday)) {
  517. df.h <- m$holidays[m$holidays$holiday == name, ]
  518. ps <- unique(df.h$prior_scale)
  519. if (length(ps) > 1) {
  520. stop('Holiday ', name, ' does not have a consistent prior scale ',
  521. 'specification')
  522. }
  523. if (is.na(ps)) {
  524. ps <- m$holidays.prior.scale
  525. }
  526. if (ps <= 0) {
  527. stop('Prior scale must be > 0.')
  528. }
  529. prior.scales.list[[name]] <- ps
  530. }
  531. prior.scales <- c()
  532. for (name in colnames(holiday.features)) {
  533. sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
  534. prior.scales <- c(prior.scales, prior.scales.list[[sn]])
  535. }
  536. return(list(holiday.features = holiday.features,
  537. prior.scales = prior.scales))
  538. }
  539. #' Add an additional regressor to be used for fitting and predicting.
  540. #'
  541. #' The dataframe passed to `fit` and `predict` will have a column with the
  542. #' specified name to be used as a regressor. When standardize='auto', the
  543. #' regressor will be standardized unless it is binary. The regression
  544. #' coefficient is given a prior with the specified scale parameter.
  545. #' Decreasing the prior scale will add additional regularization. If no
  546. #' prior scale is provided, holidays.prior.scale will be used.
  547. #'
  548. #' @param m Prophet object.
  549. #' @param name String name of the regressor
  550. #' @param prior.scale Float scale for the normal prior. If not provided,
  551. #' holidays.prior.scale will be used.
  552. #' @param standardize Bool, specify whether this regressor will be standardized
  553. #' prior to fitting. Can be 'auto' (standardize if not binary), True, or
  554. #' False.
  555. #'
  556. #' @return The prophet model with the regressor added.
  557. #'
  558. #' @export
  559. add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){
  560. if (!is.null(m$history)) {
  561. stop('Regressors must be added prior to model fitting.')
  562. }
  563. validate_column_name(m, name, check_regressors = FALSE)
  564. if (is.null(prior.scale)) {
  565. prior.scale <- m$holidays.prior.scale
  566. }
  567. if(prior.scale <= 0) {
  568. stop("Prior scale must be > 0")
  569. }
  570. m$extra_regressors[[name]] <- list(
  571. prior.scale = prior.scale,
  572. standardize = standardize,
  573. mu = 0,
  574. std = 1.0
  575. )
  576. return(m)
  577. }
  578. #' Add a seasonal component with specified period, number of Fourier
  579. #' components, and prior scale.
  580. #'
  581. #' Increasing the number of Fourier components allows the seasonality to change
  582. #' more quickly (at risk of overfitting). Default values for yearly and weekly
  583. #' seasonalities are 10 and 3 respectively.
  584. #'
  585. #' Increasing prior scale will allow this seasonality component more
  586. #' flexibility, decreasing will dampen it. If not provided, will use the
  587. #' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
  588. #'
  589. #' @param m Prophet object.
  590. #' @param name String name of the seasonality component.
  591. #' @param period Float number of days in one period.
  592. #' @param fourier.order Int number of Fourier components to use.
  593. #' @param prior.scale Float prior scale for this component.
  594. #'
  595. #' @return The prophet model with the seasonality added.
  596. #'
  597. #' @importFrom dplyr "%>%"
  598. #' @export
  599. add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) {
  600. if (!is.null(m$history)) {
  601. stop("Seasonality must be added prior to model fitting.")
  602. }
  603. if (!(name %in% c('daily', 'weekly', 'yearly'))) {
  604. # Allow overriding built-in seasonalities
  605. validate_column_name(m, name, check_seasonalities = FALSE)
  606. }
  607. if (is.null(prior.scale)) {
  608. ps <- m$seasonality.prior.scale
  609. } else {
  610. ps <- prior.scale
  611. }
  612. if (ps <= 0) {
  613. stop('Prior scale must be > 0')
  614. }
  615. m$seasonalities[[name]] <- list(
  616. period = period,
  617. fourier.order = fourier.order,
  618. prior.scale = ps
  619. )
  620. return(m)
  621. }
  622. #' Dataframe with seasonality features.
  623. #' Includes seasonality features, holiday features, and added regressors.
  624. #'
  625. #' @param m Prophet object.
  626. #' @param df Dataframe with dates for computing seasonality features and any
  627. #' added regressors.
  628. #'
  629. #' @return List with items
  630. #' seasonal.features: Dataframe with regressor features,
  631. #' prior.scales: Array of prior scales for each colum of the features
  632. #' dataframe.
  633. #'
  634. #' @keywords internal
  635. make_all_seasonality_features <- function(m, df) {
  636. seasonal.features <- data.frame(row.names = seq_len(nrow(df)))
  637. prior.scales <- c()
  638. # Seasonality features
  639. for (name in names(m$seasonalities)) {
  640. props <- m$seasonalities[[name]]
  641. features <- make_seasonality_features(
  642. df$ds, props$period, props$fourier.order, name)
  643. seasonal.features <- cbind(seasonal.features, features)
  644. prior.scales <- c(prior.scales,
  645. props$prior.scale * rep(1, ncol(features)))
  646. }
  647. # Holiday features
  648. if (!is.null(m$holidays)) {
  649. hf <- make_holiday_features(m, df$ds)
  650. seasonal.features <- cbind(seasonal.features, hf$holiday.features)
  651. prior.scales <- c(prior.scales, hf$prior.scales)
  652. }
  653. # Additional regressors
  654. for (name in names(m$extra_regressors)) {
  655. seasonal.features[[name]] <- df[[name]]
  656. prior.scales <- c(prior.scales, m$extra_regressors[[name]]$prior.scale)
  657. }
  658. if (ncol(seasonal.features) == 0) {
  659. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  660. prior.scales <- 1
  661. }
  662. return(list(seasonal.features = seasonal.features,
  663. prior.scales = prior.scales))
  664. }
  665. #' Get number of Fourier components for built-in seasonalities.
  666. #'
  667. #' @param m Prophet object.
  668. #' @param name String name of the seasonality component.
  669. #' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
  670. #' provided.
  671. #' @param auto.disable Bool if seasonality should be disabled when 'auto'.
  672. #' @param default.order Int default Fourier order.
  673. #'
  674. #' @return Number of Fourier components, or 0 for disabled.
  675. #'
  676. #' @keywords internal
  677. parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
  678. if (arg == 'auto') {
  679. fourier.order <- 0
  680. if (name %in% names(m$seasonalities)) {
  681. message('Found custom seasonality named "', name,
  682. '", disabling built-in ', name, ' seasonality.')
  683. } else if (auto.disable) {
  684. message('Disabling ', name, ' seasonality. Run prophet with ', name,
  685. '.seasonality=TRUE to override this.')
  686. } else {
  687. fourier.order <- default.order
  688. }
  689. } else if (arg == TRUE) {
  690. fourier.order <- default.order
  691. } else if (arg == FALSE) {
  692. fourier.order <- 0
  693. } else {
  694. fourier.order <- arg
  695. }
  696. return(fourier.order)
  697. }
  698. #' Set seasonalities that were left on auto.
  699. #'
  700. #' Turns on yearly seasonality if there is >=2 years of history.
  701. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  702. #' spacing between dates in the history is <7 days.
  703. #' Turns on daily seasonality if there is >=2 days of history, and the spacing
  704. #' between dates in the history is <1 day.
  705. #'
  706. #' @param m Prophet object.
  707. #'
  708. #' @return The prophet model with seasonalities set.
  709. #'
  710. #' @keywords internal
  711. set_auto_seasonalities <- function(m) {
  712. first <- min(m$history$ds)
  713. last <- max(m$history$ds)
  714. dt <- diff(time_diff(m$history$ds, m$start))
  715. min.dt <- min(dt[dt > 0])
  716. yearly.disable <- time_diff(last, first) < 730
  717. fourier.order <- parse_seasonality_args(
  718. m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
  719. if (fourier.order > 0) {
  720. m$seasonalities[['yearly']] <- list(
  721. period = 365.25,
  722. fourier.order = fourier.order,
  723. prior.scale = m$seasonality.prior.scale
  724. )
  725. }
  726. weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
  727. fourier.order <- parse_seasonality_args(
  728. m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
  729. if (fourier.order > 0) {
  730. m$seasonalities[['weekly']] <- list(
  731. period = 7,
  732. fourier.order = fourier.order,
  733. prior.scale = m$seasonality.prior.scale
  734. )
  735. }
  736. daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
  737. fourier.order <- parse_seasonality_args(
  738. m, 'daily', m$daily.seasonality, daily.disable, 4)
  739. if (fourier.order > 0) {
  740. m$seasonalities[['daily']] <- list(
  741. period = 1,
  742. fourier.order = fourier.order,
  743. prior.scale = m$seasonality.prior.scale
  744. )
  745. }
  746. return(m)
  747. }
  748. #' Initialize linear growth.
  749. #'
  750. #' Provides a strong initialization for linear growth by calculating the
  751. #' growth and offset parameters that pass the function through the first and
  752. #' last points in the time series.
  753. #'
  754. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  755. #' and t (scaled time).
  756. #'
  757. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  758. #' growth function.
  759. #'
  760. #' @keywords internal
  761. linear_growth_init <- function(df) {
  762. i0 <- which.min(df$ds)
  763. i1 <- which.max(df$ds)
  764. T <- df$t[i1] - df$t[i0]
  765. # Initialize the rate
  766. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  767. # And the offset
  768. m <- df$y_scaled[i0] - k * df$t[i0]
  769. return(c(k, m))
  770. }
  771. #' Initialize logistic growth.
  772. #'
  773. #' Provides a strong initialization for logistic growth by calculating the
  774. #' growth and offset parameters that pass the function through the first and
  775. #' last points in the time series.
  776. #'
  777. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  778. #' y_scaled (scaled time series), and t (scaled time).
  779. #'
  780. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  781. #' growth function.
  782. #'
  783. #' @keywords internal
  784. logistic_growth_init <- function(df) {
  785. i0 <- which.min(df$ds)
  786. i1 <- which.max(df$ds)
  787. T <- df$t[i1] - df$t[i0]
  788. # Force valid values, in case y > cap or y < 0
  789. C0 <- df$cap_scaled[i0]
  790. C1 <- df$cap_scaled[i1]
  791. y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
  792. y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
  793. r0 <- C0 / y0
  794. r1 <- C1 / y1
  795. if (abs(r0 - r1) <= 0.01) {
  796. r0 <- 1.05 * r0
  797. }
  798. L0 <- log(r0 - 1)
  799. L1 <- log(r1 - 1)
  800. # Initialize the offset
  801. m <- L0 * T / (L0 - L1)
  802. # And the rate
  803. k <- (L0 - L1) / T
  804. return(c(k, m))
  805. }
  806. #' Fit the prophet model.
  807. #'
  808. #' This sets m$params to contain the fitted model parameters. It is a list
  809. #' with the following elements:
  810. #' k (M array): M posterior samples of the initial slope.
  811. #' m (M array): The initial intercept.
  812. #' delta (MxN matrix): The slope change at each of N changepoints.
  813. #' beta (MxK matrix): Coefficients for K seasonality features.
  814. #' sigma_obs (M array): Noise level.
  815. #' Note that M=1 if MAP estimation.
  816. #'
  817. #' @param m Prophet object.
  818. #' @param df Data frame.
  819. #' @param ... Additional arguments passed to the \code{optimizing} or
  820. #' \code{sampling} functions in Stan.
  821. #'
  822. #' @export
  823. fit.prophet <- function(m, df, ...) {
  824. if (!is.null(m$history)) {
  825. stop("Prophet object can only be fit once. Instantiate a new object.")
  826. }
  827. history <- df %>%
  828. dplyr::filter(!is.na(y))
  829. if (nrow(history) < 2) {
  830. stop("Dataframe has less than 2 non-NA rows.")
  831. }
  832. m$history.dates <- sort(set_date(df$ds))
  833. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  834. history <- out$df
  835. m <- out$m
  836. m$history <- history
  837. m <- set_auto_seasonalities(m)
  838. out2 <- make_all_seasonality_features(m, history)
  839. seasonal.features <- out2$seasonal.features
  840. prior.scales <- out2$prior.scales
  841. m <- set_changepoints(m)
  842. # Construct input to stan
  843. dat <- list(
  844. T = nrow(history),
  845. K = ncol(seasonal.features),
  846. S = length(m$changepoints.t),
  847. y = history$y_scaled,
  848. t = history$t,
  849. t_change = array(m$changepoints.t),
  850. X = as.matrix(seasonal.features),
  851. sigmas = array(prior.scales),
  852. tau = m$changepoint.prior.scale,
  853. trend_indicator = as.numeric(m$growth == 'logistic')
  854. )
  855. # Run stan
  856. if (m$growth == 'linear') {
  857. dat$cap <- rep(0, nrow(history)) # Unused inside Stan
  858. kinit <- linear_growth_init(history)
  859. } else {
  860. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  861. kinit <- logistic_growth_init(history)
  862. }
  863. if (exists(".prophet.stan.model")) {
  864. model <- .prophet.stan.model
  865. } else {
  866. model <- get_prophet_stan_model()
  867. }
  868. stan_init <- function() {
  869. list(k = kinit[1],
  870. m = kinit[2],
  871. delta = array(rep(0, length(m$changepoints.t))),
  872. beta = array(rep(0, ncol(seasonal.features))),
  873. sigma_obs = 1
  874. )
  875. }
  876. if (min(history$y) == max(history$y)) {
  877. # Nothing to fit.
  878. m$params <- stan_init()
  879. m$params$sigma_obs <- 0.
  880. n.iteration <- 1.
  881. } else if (m$mcmc.samples > 0) {
  882. stan.fit <- rstan::sampling(
  883. model,
  884. data = dat,
  885. init = stan_init,
  886. iter = m$mcmc.samples,
  887. ...
  888. )
  889. m$params <- rstan::extract(stan.fit)
  890. n.iteration <- length(m$params$k)
  891. } else {
  892. stan.fit <- rstan::optimizing(
  893. model,
  894. data = dat,
  895. init = stan_init,
  896. iter = 1e4,
  897. as_vector = FALSE,
  898. ...
  899. )
  900. m$params <- stan.fit$par
  901. n.iteration <- 1
  902. }
  903. # Cast the parameters to have consistent form, whether full bayes or MAP
  904. for (name in c('delta', 'beta')){
  905. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  906. }
  907. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  908. for (name in c('k', 'm', 'sigma_obs')){
  909. m$params[[name]] <- c(m$params[[name]])
  910. }
  911. # If no changepoints were requested, replace delta with 0s
  912. if (m$n.changepoints == 0) {
  913. # Fold delta into the base rate k
  914. m$params$k <- m$params$k + m$params$delta[, 1]
  915. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  916. }
  917. return(m)
  918. }
  919. #' Predict using the prophet model.
  920. #'
  921. #' @param object Prophet object.
  922. #' @param df Dataframe with dates for predictions (column ds), and capacity
  923. #' (column cap) if logistic growth. If not provided, predictions are made on
  924. #' the history.
  925. #' @param ... additional arguments.
  926. #'
  927. #' @return A dataframe with the forecast components.
  928. #'
  929. #' @examples
  930. #' \dontrun{
  931. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  932. #' y = sin(1:366/200) + rnorm(366)/10)
  933. #' m <- prophet(history)
  934. #' future <- make_future_dataframe(m, periods = 365)
  935. #' forecast <- predict(m, future)
  936. #' plot(m, forecast)
  937. #' }
  938. #'
  939. #' @export
  940. predict.prophet <- function(object, df = NULL, ...) {
  941. if (is.null(df)) {
  942. df <- object$history
  943. } else {
  944. if (nrow(df) == 0) {
  945. stop("Dataframe has no rows.")
  946. }
  947. out <- setup_dataframe(object, df)
  948. df <- out$df
  949. }
  950. df$trend <- predict_trend(object, df)
  951. seasonal.components <- predict_seasonal_components(object, df)
  952. intervals <- predict_uncertainty(object, df)
  953. # Drop columns except ds, cap, floor, and trend
  954. cols <- c('ds', 'trend')
  955. if ('cap' %in% colnames(df)) {
  956. cols <- c(cols, 'cap')
  957. }
  958. if (object$logistic.floor) {
  959. cols <- c(cols, 'floor')
  960. }
  961. df <- df[cols]
  962. df <- dplyr::bind_cols(df, seasonal.components, intervals)
  963. df$yhat <- df$trend + df$seasonal
  964. return(df)
  965. }
  966. #' Evaluate the piecewise linear function.
  967. #'
  968. #' @param t Vector of times on which the function is evaluated.
  969. #' @param deltas Vector of rate changes at each changepoint.
  970. #' @param k Float initial rate.
  971. #' @param m Float initial offset.
  972. #' @param changepoint.ts Vector of changepoint times.
  973. #'
  974. #' @return Vector y(t).
  975. #'
  976. #' @keywords internal
  977. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  978. # Intercept changes
  979. gammas <- -changepoint.ts * deltas
  980. # Get cumulative slope and intercept at each t
  981. k_t <- rep(k, length(t))
  982. m_t <- rep(m, length(t))
  983. for (s in seq_along(changepoint.ts)) {
  984. indx <- t >= changepoint.ts[s]
  985. k_t[indx] <- k_t[indx] + deltas[s]
  986. m_t[indx] <- m_t[indx] + gammas[s]
  987. }
  988. y <- k_t * t + m_t
  989. return(y)
  990. }
  991. #' Evaluate the piecewise logistic function.
  992. #'
  993. #' @param t Vector of times on which the function is evaluated.
  994. #' @param cap Vector of capacities at each t.
  995. #' @param deltas Vector of rate changes at each changepoint.
  996. #' @param k Float initial rate.
  997. #' @param m Float initial offset.
  998. #' @param changepoint.ts Vector of changepoint times.
  999. #'
  1000. #' @return Vector y(t).
  1001. #'
  1002. #' @keywords internal
  1003. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  1004. # Compute offset changes
  1005. k.cum <- c(k, cumsum(deltas) + k)
  1006. gammas <- rep(0, length(changepoint.ts))
  1007. for (i in seq_along(changepoint.ts)) {
  1008. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  1009. * (1 - k.cum[i] / k.cum[i + 1]))
  1010. }
  1011. # Get cumulative rate and offset at each t
  1012. k_t <- rep(k, length(t))
  1013. m_t <- rep(m, length(t))
  1014. for (s in seq_along(changepoint.ts)) {
  1015. indx <- t >= changepoint.ts[s]
  1016. k_t[indx] <- k_t[indx] + deltas[s]
  1017. m_t[indx] <- m_t[indx] + gammas[s]
  1018. }
  1019. y <- cap / (1 + exp(-k_t * (t - m_t)))
  1020. return(y)
  1021. }
  1022. #' Predict trend using the prophet model.
  1023. #'
  1024. #' @param model Prophet object.
  1025. #' @param df Prediction dataframe.
  1026. #'
  1027. #' @return Vector with trend on prediction dates.
  1028. #'
  1029. #' @keywords internal
  1030. predict_trend <- function(model, df) {
  1031. k <- mean(model$params$k, na.rm = TRUE)
  1032. param.m <- mean(model$params$m, na.rm = TRUE)
  1033. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  1034. t <- df$t
  1035. if (model$growth == 'linear') {
  1036. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  1037. } else {
  1038. cap <- df$cap_scaled
  1039. trend <- piecewise_logistic(
  1040. t, cap, deltas, k, param.m, model$changepoints.t)
  1041. }
  1042. return(trend * model$y.scale + df$floor)
  1043. }
  1044. #' Predict seasonality components, holidays, and added regressors.
  1045. #'
  1046. #' @param m Prophet object.
  1047. #' @param df Prediction dataframe.
  1048. #'
  1049. #' @return Dataframe with seasonal components.
  1050. #'
  1051. #' @keywords internal
  1052. predict_seasonal_components <- function(m, df) {
  1053. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1054. lower.p <- (1 - m$interval.width)/2
  1055. upper.p <- (1 + m$interval.width)/2
  1056. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  1057. dplyr::mutate(col = seq_len(n())) %>%
  1058. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  1059. extra = "merge", fill = "right") %>%
  1060. dplyr::select(col, component)
  1061. # Add total for all regression components
  1062. components <- rbind(
  1063. components,
  1064. data.frame(col = seq_len(ncol(seasonal.features)), component = 'seasonal'))
  1065. # Add totals for seasonality, holiday, and extra regressors
  1066. components <- add_group_component(
  1067. components, 'seasonalities', names(m$seasonalities))
  1068. if(!is.null(m$holidays)){
  1069. components <- add_group_component(
  1070. components, 'holidays', unique(m$holidays$holiday))
  1071. }
  1072. components <- add_group_component(
  1073. components, 'extra_regressors', names(m$extra_regressors))
  1074. # Remove the placeholder
  1075. components <- dplyr::filter(components, component != 'zeros')
  1076. component.predictions <- components %>%
  1077. dplyr::group_by(component) %>% dplyr::do({
  1078. comp <- (as.matrix(seasonal.features[, .$col])
  1079. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  1080. dplyr::data_frame(ix = seq_len(nrow(seasonal.features)),
  1081. mean = rowMeans(comp, na.rm = TRUE),
  1082. lower = apply(comp, 1, stats::quantile, lower.p,
  1083. na.rm = TRUE),
  1084. upper = apply(comp, 1, stats::quantile, upper.p,
  1085. na.rm = TRUE))
  1086. }) %>%
  1087. tidyr::gather(stat, value, mean, lower, upper) %>%
  1088. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  1089. tidyr::unite(component, component, stat, sep="") %>%
  1090. tidyr::spread(component, value) %>%
  1091. dplyr::select(-ix)
  1092. return(component.predictions)
  1093. }
  1094. #' Adds a component with given name that contains all of the components
  1095. #' in group.
  1096. #'
  1097. #' @param components Dataframe with components.
  1098. #' @param name Name of new group component.
  1099. #' @param group List of components that form the group.
  1100. #'
  1101. #' @return Dataframe with components.
  1102. #'
  1103. #' @keywords internal
  1104. add_group_component <- function(components, name, group) {
  1105. new_comp <- components[(components$component %in% group), ]
  1106. if (nrow(new_comp) > 0) {
  1107. new_comp$component <- name
  1108. components <- rbind(components, new_comp)
  1109. }
  1110. return(components)
  1111. }
  1112. #' Prophet posterior predictive samples.
  1113. #'
  1114. #' @param m Prophet object.
  1115. #' @param df Prediction dataframe.
  1116. #'
  1117. #' @return List with posterior predictive samples for each component.
  1118. #'
  1119. #' @keywords internal
  1120. sample_posterior_predictive <- function(m, df) {
  1121. # Sample trend, seasonality, and yhat from the extrapolation model.
  1122. n.iterations <- length(m$params$k)
  1123. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  1124. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  1125. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1126. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  1127. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  1128. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  1129. for (i in seq_len(n.iterations)) {
  1130. # For each set of parameters from MCMC (or just 1 set for MAP),
  1131. for (j in seq_len(samp.per.iter)) {
  1132. # Do a simulation with this set of parameters,
  1133. sim <- sample_model(m, df, seasonal.features, i)
  1134. # Store the results
  1135. for (key in c("trend", "seasonal", "yhat")) {
  1136. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  1137. }
  1138. }
  1139. }
  1140. return(sim.values)
  1141. }
  1142. #' Sample from the posterior predictive distribution.
  1143. #'
  1144. #' @param m Prophet object.
  1145. #' @param df Dataframe with dates for predictions (column ds), and capacity
  1146. #' (column cap) if logistic growth.
  1147. #'
  1148. #' @return A list with items "trend", "seasonal", and "yhat" containing
  1149. #' posterior predictive samples for that component. "seasonal" is the sum
  1150. #' of seasonalities, holidays, and added regressors.
  1151. #'
  1152. #' @export
  1153. predictive_samples <- function(m, df) {
  1154. df <- setup_dataframe(m, df)$df
  1155. sim.values <- sample_posterior_predictive(m, df)
  1156. return(sim.values)
  1157. }
  1158. #' Prophet uncertainty intervals for yhat and trend
  1159. #'
  1160. #' @param m Prophet object.
  1161. #' @param df Prediction dataframe.
  1162. #'
  1163. #' @return Dataframe with uncertainty intervals.
  1164. #'
  1165. #' @keywords internal
  1166. predict_uncertainty <- function(m, df) {
  1167. sim.values <- sample_posterior_predictive(m, df)
  1168. # Add uncertainty estimates
  1169. lower.p <- (1 - m$interval.width)/2
  1170. upper.p <- (1 + m$interval.width)/2
  1171. intervals <- cbind(
  1172. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  1173. na.rm = TRUE)),
  1174. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  1175. na.rm = TRUE))
  1176. ) %>% dplyr::as_data_frame()
  1177. colnames(intervals) <- paste(rep(c('yhat', 'trend'), each=2),
  1178. c('lower', 'upper'), sep = "_")
  1179. return(intervals)
  1180. }
  1181. #' Simulate observations from the extrapolated generative model.
  1182. #'
  1183. #' @param m Prophet object.
  1184. #' @param df Prediction dataframe.
  1185. #' @param seasonal.features Data frame of seasonal features
  1186. #' @param iteration Int sampling iteration to use parameters from.
  1187. #'
  1188. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  1189. #'
  1190. #' @keywords internal
  1191. sample_model <- function(m, df, seasonal.features, iteration) {
  1192. trend <- sample_predictive_trend(m, df, iteration)
  1193. beta <- m$params$beta[iteration,]
  1194. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  1195. sigma <- m$params$sigma_obs[iteration]
  1196. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  1197. return(list("yhat" = trend + seasonal + noise,
  1198. "trend" = trend,
  1199. "seasonal" = seasonal))
  1200. }
  1201. #' Simulate the trend using the extrapolated generative model.
  1202. #'
  1203. #' @param model Prophet object.
  1204. #' @param df Prediction dataframe.
  1205. #' @param iteration Int sampling iteration to use parameters from.
  1206. #'
  1207. #' @return Vector of simulated trend over df$t.
  1208. #'
  1209. #' @keywords internal
  1210. sample_predictive_trend <- function(model, df, iteration) {
  1211. k <- model$params$k[iteration]
  1212. param.m <- model$params$m[iteration]
  1213. deltas <- model$params$delta[iteration,]
  1214. t <- df$t
  1215. T <- max(t)
  1216. if (T > 1) {
  1217. # Get the time discretization of the history
  1218. dt <- diff(model$history$t)
  1219. dt <- min(dt[dt > 0])
  1220. # Number of time periods in the future
  1221. N <- ceiling((T - 1) / dt)
  1222. S <- length(model$changepoints.t)
  1223. # The history had S split points, over t = [0, 1].
  1224. # The forecast is on [1, T], and should have the same average frequency of
  1225. # rate changes. Thus for N time periods in the future, we want an average
  1226. # of S * (T - 1) changepoints in expectation.
  1227. prob.change <- min(1, (S * (T - 1)) / N)
  1228. # This calculation works for both history and df not uniformly spaced.
  1229. n.changes <- stats::rbinom(1, N, prob.change)
  1230. # Sample ts
  1231. if (n.changes == 0) {
  1232. changepoint.ts.new <- c()
  1233. } else {
  1234. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  1235. }
  1236. } else {
  1237. changepoint.ts.new <- c()
  1238. n.changes <- 0
  1239. }
  1240. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  1241. lambda <- mean(abs(c(deltas))) + 1e-8
  1242. # Sample deltas
  1243. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  1244. # Combine with changepoints from the history
  1245. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  1246. deltas <- c(deltas, deltas.new)
  1247. # Get the corresponding trend
  1248. if (model$growth == 'linear') {
  1249. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  1250. } else {
  1251. cap <- df$cap_scaled
  1252. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  1253. }
  1254. return(trend * model$y.scale + df$floor)
  1255. }
  1256. #' Make dataframe with future dates for forecasting.
  1257. #'
  1258. #' @param m Prophet model object.
  1259. #' @param periods Int number of periods to forecast forward.
  1260. #' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
  1261. #' @param include_history Boolean to include the historical dates in the data
  1262. #' frame for predictions.
  1263. #'
  1264. #' @return Dataframe that extends forward from the end of m$history for the
  1265. #' requested number of periods.
  1266. #'
  1267. #' @export
  1268. make_future_dataframe <- function(m, periods, freq = 'day',
  1269. include_history = TRUE) {
  1270. # For backwards compatability with previous zoo date type,
  1271. if (freq == 'm') {
  1272. freq <- 'month'
  1273. }
  1274. if (is.null(m$history.dates)) {
  1275. stop('Model must be fit before this can be used.')
  1276. }
  1277. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  1278. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  1279. if (include_history) {
  1280. dates <- c(m$history.dates, dates)
  1281. attr(dates, "tzone") <- "GMT"
  1282. }
  1283. return(data.frame(ds = dates))
  1284. }
  1285. # fb-block 3