prophet.R 45 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. utils::globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df (optional) Dataframe containing the history. Must have columns ds
  16. #' (date type) and y, the time series. If growth is logistic, then df must
  17. #' also have a column cap that specifies the capacity at each ds. If not
  18. #' provided, then the model object will be instantiated but not fit; use
  19. #' fit.prophet(m, df) to fit the model.
  20. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  21. #' trend.
  22. #' @param changepoints Vector of dates at which to include potential
  23. #' changepoints. If not specified, potential changepoints are selected
  24. #' automatically.
  25. #' @param n.changepoints Number of potential changepoints to include. Not used
  26. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  27. #' then n.changepoints potential changepoints are selected uniformly from the
  28. #' first 80 percent of df$ds.
  29. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
  30. #' FALSE, or a number of Fourier terms to generate.
  31. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
  32. #' FALSE, or a number of Fourier terms to generate.
  33. #' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
  34. #' FALSE, or a number of Fourier terms to generate.
  35. #' @param holidays data frame with columns holiday (character) and ds (date
  36. #' type)and optionally columns lower_window and upper_window which specify a
  37. #' range of days around the date to be included as holidays. lower_window=-2
  38. #' will include 2 days prior to the date as holidays. Also optionally can have
  39. #' a column prior_scale specifying the prior scale for each holiday.
  40. #' @param seasonality.prior.scale Parameter modulating the strength of the
  41. #' seasonality model. Larger values allow the model to fit larger seasonal
  42. #' fluctuations, smaller values dampen the seasonality. Can be specified for
  43. #' individual seasonalities using add_seasonality.
  44. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  45. #' components model, unless overridden in the holidays input.
  46. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  47. #' automatic changepoint selection. Large values will allow many changepoints,
  48. #' small values will allow few changepoints.
  49. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  50. #' inference with the specified number of MCMC samples. If 0, will do MAP
  51. #' estimation.
  52. #' @param interval.width Numeric, width of the uncertainty intervals provided
  53. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  54. #' in the trend using the MAP estimate of the extrapolated generative model.
  55. #' If mcmc.samples>0, this will be integrated over all model parameters,
  56. #' which will include uncertainty in seasonality.
  57. #' @param uncertainty.samples Number of simulated draws used to estimate
  58. #' uncertainty intervals.
  59. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  60. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  61. #'
  62. #' @return A prophet model.
  63. #'
  64. #' @examples
  65. #' \dontrun{
  66. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  67. #' y = sin(1:366/200) + rnorm(366)/10)
  68. #' m <- prophet(history)
  69. #' }
  70. #'
  71. #' @export
  72. #' @importFrom dplyr "%>%"
  73. #' @import Rcpp
  74. prophet <- function(df = NULL,
  75. growth = 'linear',
  76. changepoints = NULL,
  77. n.changepoints = 25,
  78. yearly.seasonality = 'auto',
  79. weekly.seasonality = 'auto',
  80. daily.seasonality = 'auto',
  81. holidays = NULL,
  82. seasonality.prior.scale = 10,
  83. holidays.prior.scale = 10,
  84. changepoint.prior.scale = 0.05,
  85. mcmc.samples = 0,
  86. interval.width = 0.80,
  87. uncertainty.samples = 1000,
  88. fit = TRUE,
  89. ...
  90. ) {
  91. # fb-block 1
  92. if (!is.null(changepoints)) {
  93. n.changepoints <- length(changepoints)
  94. }
  95. m <- list(
  96. growth = growth,
  97. changepoints = changepoints,
  98. n.changepoints = n.changepoints,
  99. yearly.seasonality = yearly.seasonality,
  100. weekly.seasonality = weekly.seasonality,
  101. daily.seasonality = daily.seasonality,
  102. holidays = holidays,
  103. seasonality.prior.scale = seasonality.prior.scale,
  104. changepoint.prior.scale = changepoint.prior.scale,
  105. holidays.prior.scale = holidays.prior.scale,
  106. mcmc.samples = mcmc.samples,
  107. interval.width = interval.width,
  108. uncertainty.samples = uncertainty.samples,
  109. specified.changepoints = !is.null(changepoints),
  110. start = NULL, # This and following attributes are set during fitting
  111. y.scale = NULL,
  112. logistic.floor = FALSE,
  113. t.scale = NULL,
  114. changepoints.t = NULL,
  115. seasonalities = list(),
  116. extra_regressors = list(),
  117. stan.fit = NULL,
  118. params = list(),
  119. history = NULL,
  120. history.dates = NULL
  121. )
  122. validate_inputs(m)
  123. class(m) <- append("prophet", class(m))
  124. if ((fit) && (!is.null(df))) {
  125. m <- fit.prophet(m, df, ...)
  126. }
  127. # fb-block 2
  128. return(m)
  129. }
  130. #' Validates the inputs to Prophet.
  131. #'
  132. #' @param m Prophet object.
  133. #'
  134. #' @keywords internal
  135. validate_inputs <- function(m) {
  136. if (!(m$growth %in% c('linear', 'logistic'))) {
  137. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  138. }
  139. if (!is.null(m$holidays)) {
  140. if (!(exists('holiday', where = m$holidays))) {
  141. stop('Holidays dataframe must have holiday field.')
  142. }
  143. if (!(exists('ds', where = m$holidays))) {
  144. stop('Holidays dataframe must have ds field.')
  145. }
  146. has.lower <- exists('lower_window', where = m$holidays)
  147. has.upper <- exists('upper_window', where = m$holidays)
  148. if (has.lower + has.upper == 1) {
  149. stop(paste('Holidays must have both lower_window and upper_window,',
  150. 'or neither.'))
  151. }
  152. if (has.lower) {
  153. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  154. stop('Holiday lower_window should be <= 0')
  155. }
  156. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  157. stop('Holiday upper_window should be >= 0')
  158. }
  159. }
  160. for (h in unique(m$holidays$holiday)) {
  161. validate_column_name(m, h, check_holidays = FALSE)
  162. }
  163. }
  164. }
  165. #' Validates the name of a seasonality, holiday, or regressor.
  166. #'
  167. #' @param m Prophet object.
  168. #' @param name string
  169. #' @param check_holidays bool check if name already used for holiday
  170. #' @param check_seasonalities bool check if name already used for seasonality
  171. #' @param check_regressors bool check if name already used for regressor
  172. #'
  173. #' @keywords internal
  174. validate_column_name <- function(
  175. m, name, check_holidays = TRUE, check_seasonalities = TRUE,
  176. check_regressors = TRUE
  177. ) {
  178. if (grepl("_delim_", name)) {
  179. stop('Holiday name cannot contain "_delim_"')
  180. }
  181. reserved_names = c(
  182. 'trend', 'seasonal', 'seasonalities', 'daily', 'weekly', 'yearly',
  183. 'holidays', 'zeros', 'extra_regressors', 'yhat'
  184. )
  185. rn_l = paste0(reserved_names,"_lower")
  186. rn_u = paste0(reserved_names,"_upper")
  187. reserved_names = c(reserved_names, rn_l, rn_u,
  188. c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
  189. if(name %in% reserved_names){
  190. stop("Name ", name, " is reserved.")
  191. }
  192. if(check_holidays && !is.null(m$holidays) &&
  193. (name %in% unique(m$holidays$holiday))){
  194. stop("Name ", name, " already used for a holiday.")
  195. }
  196. if(check_seasonalities && (!is.null(m$seasonalities[[name]]))){
  197. stop("Name ", name, " already used for a seasonality.")
  198. }
  199. if(check_regressors && (!is.null(m$seasonalities[[name]]))){
  200. stop("Name ", name, " already used for an added regressor.")
  201. }
  202. }
  203. #' Load compiled Stan model
  204. #'
  205. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  206. #' trend.
  207. #'
  208. #' @return Stan model.
  209. #'
  210. #' @keywords internal
  211. get_prophet_stan_model <- function(model) {
  212. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  213. ## If the cached model doesn't work, just compile a new one.
  214. tryCatch({
  215. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  216. package = 'prophet',
  217. mustWork = TRUE)
  218. load(binary)
  219. obj.name <- paste(model, 'growth.stanm', sep = '.')
  220. stanm <- eval(parse(text = obj.name))
  221. ## Should cause an error if the model doesn't work.
  222. stanm@mk_cppmodule(stanm)
  223. stanm
  224. }, error = function(cond) {
  225. compile_stan_model(model)
  226. })
  227. }
  228. #' Compile Stan model
  229. #'
  230. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  231. #' trend.
  232. #'
  233. #' @return Stan model.
  234. #'
  235. #' @keywords internal
  236. compile_stan_model <- function(model) {
  237. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  238. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  239. stanc <- rstan::stanc(stan.src)
  240. model.name <- paste(model, 'growth', sep = '_')
  241. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  242. }
  243. #' Convert date vector
  244. #'
  245. #' Convert the date to POSIXct object
  246. #'
  247. #' @param ds Date vector, can be consisted of characters
  248. #' @param tz string time zone
  249. #'
  250. #' @return vector of POSIXct object converted from date
  251. #'
  252. #' @keywords internal
  253. set_date <- function(ds = NULL, tz = "GMT") {
  254. if (length(ds) == 0) {
  255. return(NULL)
  256. }
  257. if (is.factor(ds)) {
  258. ds <- as.character(ds)
  259. }
  260. fmt <- if (min(nchar(ds)) < 12) "%Y-%m-%d" else "%Y-%m-%d %H:%M:%S"
  261. ds <- as.POSIXct(ds, format = fmt, tz = tz)
  262. attr(ds, "tzone") <- tz
  263. return(ds)
  264. }
  265. #' Time difference between datetimes
  266. #'
  267. #' Compute time difference of two POSIXct objects
  268. #'
  269. #' @param ds1 POSIXct object
  270. #' @param ds2 POSIXct object
  271. #' @param units string units of difference, e.g. 'days' or 'secs'.
  272. #'
  273. #' @return numeric time difference
  274. #'
  275. #' @keywords internal
  276. time_diff <- function(ds1, ds2, units = "days") {
  277. return(as.numeric(difftime(ds1, ds2, units = units)))
  278. }
  279. #' Prepare dataframe for fitting or predicting.
  280. #'
  281. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  282. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  283. #' and predicting.
  284. #'
  285. #' @param m Prophet object.
  286. #' @param df Data frame with columns ds, y, and cap if logistic growth. Any
  287. #' specified additional regressors must also be present.
  288. #' @param initialize_scales Boolean set scaling factors in m from df.
  289. #'
  290. #' @return list with items 'df' and 'm'.
  291. #'
  292. #' @keywords internal
  293. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  294. if (exists('y', where=df)) {
  295. df$y <- as.numeric(df$y)
  296. }
  297. if (any(is.infinite(df$y))) {
  298. stop("Found infinity in column y.")
  299. }
  300. df$ds <- set_date(df$ds)
  301. if (anyNA(df$ds)) {
  302. stop(paste('Unable to parse date format in column ds. Convert to date ',
  303. 'format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S'))
  304. }
  305. for (name in names(m$extra_regressors)) {
  306. if (!(name %in% colnames(df))) {
  307. stop('Regressor "', name, '" missing from dataframe')
  308. }
  309. }
  310. df <- df %>%
  311. dplyr::arrange(ds)
  312. m <- initialize_scales_fn(m, initialize_scales, df)
  313. if (m$logistic.floor) {
  314. if (!('floor' %in% colnames(df))) {
  315. stop("Expected column 'floor'.")
  316. }
  317. } else {
  318. df$floor <- 0
  319. }
  320. if (m$growth == 'logistic') {
  321. if (!(exists('cap', where=df))) {
  322. stop('Capacities must be supplied for logistic growth.')
  323. }
  324. df <- df %>%
  325. dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
  326. }
  327. df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  328. if (exists('y', where=df)) {
  329. df$y_scaled <- (df$y - df$floor) / m$y.scale
  330. }
  331. for (name in names(m$extra_regressors)) {
  332. df[[name]] <- as.numeric(df[[name]])
  333. props <- m$extra_regressors[[name]]
  334. df[[name]] <- (df[[name]] - props$mu) / props$std
  335. if (anyNA(df[[name]])) {
  336. stop('Found NaN in column ', name)
  337. }
  338. }
  339. return(list("m" = m, "df" = df))
  340. }
  341. #' Initialize model scales.
  342. #'
  343. #' Sets model scaling factors using df.
  344. #'
  345. #' @param m Prophet object.
  346. #' @param initialize_scales Boolean set the scales or not.
  347. #' @param df Dataframe for setting scales.
  348. #'
  349. #' @return Prophet object with scales set.
  350. #'
  351. #' @keywords internal
  352. initialize_scales_fn <- function(m, initialize_scales, df) {
  353. if (!initialize_scales) {
  354. return(m)
  355. }
  356. if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
  357. m$logistic.floor <- TRUE
  358. floor <- df$floor
  359. } else {
  360. floor <- 0
  361. }
  362. m$y.scale <- max(abs(df$y - floor))
  363. if (m$y.scale == 0) {
  364. m$y.scale <- 1
  365. }
  366. m$start <- min(df$ds)
  367. m$t.scale <- time_diff(max(df$ds), m$start, "secs")
  368. for (name in names(m$extra_regressors)) {
  369. n.vals <- length(unique(df[[name]]))
  370. if (n.vals < 2) {
  371. stop('Regressor ', name, ' is constant.')
  372. }
  373. standardize <- m$extra_regressors[[name]]$standardize
  374. if (standardize == 'auto') {
  375. if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
  376. # Don't standardize binary variables
  377. standardize <- FALSE
  378. } else {
  379. standardize <- TRUE
  380. }
  381. }
  382. if (standardize) {
  383. mu <- mean(df[[name]])
  384. std <- stats::sd(df[[name]])
  385. m$extra_regressors[[name]]$mu <- mu
  386. m$extra_regressors[[name]]$std <- std
  387. }
  388. }
  389. return(m)
  390. }
  391. #' Set changepoints
  392. #'
  393. #' Sets m$changepoints to the dates of changepoints. Either:
  394. #' 1) The changepoints were passed in explicitly.
  395. #' A) They are empty.
  396. #' B) They are not empty, and need validation.
  397. #' 2) We are generating a grid of them.
  398. #' 3) The user prefers no changepoints be used.
  399. #'
  400. #' @param m Prophet object.
  401. #'
  402. #' @return m with changepoints set.
  403. #'
  404. #' @keywords internal
  405. set_changepoints <- function(m) {
  406. if (!is.null(m$changepoints)) {
  407. if (length(m$changepoints) > 0) {
  408. m$changepoints <- set_date(m$changepoints)
  409. if (min(m$changepoints) < min(m$history$ds)
  410. || max(m$changepoints) > max(m$history$ds)) {
  411. stop('Changepoints must fall within training data.')
  412. }
  413. }
  414. } else {
  415. # Place potential changepoints evenly through the first 80 pcnt of
  416. # the history.
  417. hist.size <- floor(nrow(m$history) * .8)
  418. if (m$n.changepoints + 1 > hist.size) {
  419. m$n.changepoints <- hist.size - 1
  420. message('n.changepoints greater than number of observations. Using ',
  421. m$n.changepoints)
  422. }
  423. if (m$n.changepoints > 0) {
  424. cp.indexes <- round(seq.int(1, hist.size,
  425. length.out = (m$n.changepoints + 1))[-1])
  426. m$changepoints <- m$history$ds[cp.indexes]
  427. } else {
  428. m$changepoints <- c()
  429. }
  430. }
  431. if (length(m$changepoints) > 0) {
  432. m$changepoints.t <- sort(
  433. time_diff(m$changepoints, m$start, "secs")) / m$t.scale
  434. } else {
  435. m$changepoints.t <- 0 # dummy changepoint
  436. }
  437. return(m)
  438. }
  439. #' Provides Fourier series components with the specified frequency and order.
  440. #'
  441. #' @param dates Vector of dates.
  442. #' @param period Number of days of the period.
  443. #' @param series.order Number of components.
  444. #'
  445. #' @return Matrix with seasonality features.
  446. #'
  447. #' @keywords internal
  448. fourier_series <- function(dates, period, series.order) {
  449. t <- time_diff(dates, set_date('1970-01-01 00:00:00'))
  450. features <- matrix(0, length(t), 2 * series.order)
  451. for (i in seq_len(series.order)) {
  452. x <- as.numeric(2 * i * pi * t / period)
  453. features[, i * 2 - 1] <- sin(x)
  454. features[, i * 2] <- cos(x)
  455. }
  456. return(features)
  457. }
  458. #' Data frame with seasonality features.
  459. #'
  460. #' @param dates Vector of dates.
  461. #' @param period Number of days of the period.
  462. #' @param series.order Number of components.
  463. #' @param prefix Column name prefix.
  464. #'
  465. #' @return Dataframe with seasonality.
  466. #'
  467. #' @keywords internal
  468. make_seasonality_features <- function(dates, period, series.order, prefix) {
  469. features <- fourier_series(dates, period, series.order)
  470. colnames(features) <- paste(prefix, seq_len(ncol(features)), sep = '_delim_')
  471. return(data.frame(features))
  472. }
  473. #' Construct a matrix of holiday features.
  474. #'
  475. #' @param m Prophet object.
  476. #' @param dates Vector with dates used for computing seasonality.
  477. #'
  478. #' @return A list with entries
  479. #' holiday.features: dataframe with a column for each holiday.
  480. #' prior.scales: array of prior scales for each holiday column.
  481. #'
  482. #' @importFrom dplyr "%>%"
  483. #' @keywords internal
  484. make_holiday_features <- function(m, dates) {
  485. # Strip dates to be just days, for joining on holidays
  486. dates <- set_date(format(dates, "%Y-%m-%d"))
  487. wide <- m$holidays %>%
  488. dplyr::mutate(ds = set_date(ds)) %>%
  489. dplyr::group_by(holiday, ds) %>%
  490. dplyr::filter(row_number() == 1) %>%
  491. dplyr::do({
  492. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  493. && !is.na(.$upper_window)) {
  494. offsets <- seq(.$lower_window, .$upper_window)
  495. } else {
  496. offsets <- 0
  497. }
  498. names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
  499. abs(offsets), sep = '')
  500. dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
  501. }) %>%
  502. dplyr::mutate(x = 1) %>%
  503. tidyr::spread(holiday, x, fill = 0)
  504. holiday.features <- data.frame(ds = set_date(dates)) %>%
  505. dplyr::left_join(wide, by = 'ds') %>%
  506. dplyr::select(-ds)
  507. holiday.features[is.na(holiday.features)] <- 0
  508. # Prior scales
  509. if (!('prior_scale' %in% colnames(m$holidays))) {
  510. m$holidays$prior_scale <- m$holidays.prior.scale
  511. }
  512. prior.scales.list <- list()
  513. for (name in unique(m$holidays$holiday)) {
  514. df.h <- m$holidays[m$holidays$holiday == name, ]
  515. ps <- unique(df.h$prior_scale)
  516. if (length(ps) > 1) {
  517. stop('Holiday ', name, ' does not have a consistent prior scale ',
  518. 'specification')
  519. }
  520. if (is.na(ps)) {
  521. ps <- m$holidays.prior.scale
  522. }
  523. if (ps <= 0) {
  524. stop('Prior scale must be > 0.')
  525. }
  526. prior.scales.list[[name]] <- ps
  527. }
  528. prior.scales <- c()
  529. for (name in colnames(holiday.features)) {
  530. sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
  531. prior.scales <- c(prior.scales, prior.scales.list[[sn]])
  532. }
  533. return(list(holiday.features = holiday.features,
  534. prior.scales = prior.scales))
  535. }
  536. #' Add an additional regressor to be used for fitting and predicting.
  537. #'
  538. #' The dataframe passed to `fit` and `predict` will have a column with the
  539. #' specified name to be used as a regressor. When standardize='auto', the
  540. #' regressor will be standardized unless it is binary. The regression
  541. #' coefficient is given a prior with the specified scale parameter.
  542. #' Decreasing the prior scale will add additional regularization. If no
  543. #' prior scale is provided, holidays.prior.scale will be used.
  544. #'
  545. #' @param m Prophet object.
  546. #' @param name String name of the regressor
  547. #' @param prior.scale Float scale for the normal prior. If not provided,
  548. #' holidays.prior.scale will be used.
  549. #' @param standardize Bool, specify whether this regressor will be standardized
  550. #' prior to fitting. Can be 'auto' (standardize if not binary), True, or
  551. #' False.
  552. #'
  553. #' @return The prophet model with the regressor added.
  554. #'
  555. #' @export
  556. add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){
  557. if (!is.null(m$history)) {
  558. stop('Regressors must be added prior to model fitting.')
  559. }
  560. validate_column_name(m, name, check_regressors = FALSE)
  561. if (is.null(prior.scale)) {
  562. prior.scale <- m$holidays.prior.scale
  563. }
  564. if(prior.scale <= 0) {
  565. stop("Prior scale must be > 0")
  566. }
  567. m$extra_regressors[[name]] <- list(
  568. prior.scale = prior.scale,
  569. standardize = standardize,
  570. mu = 0,
  571. std = 1.0
  572. )
  573. return(m)
  574. }
  575. #' Add a seasonal component with specified period, number of Fourier
  576. #' components, and prior scale.
  577. #'
  578. #' Increasing the number of Fourier components allows the seasonality to change
  579. #' more quickly (at risk of overfitting). Default values for yearly and weekly
  580. #' seasonalities are 10 and 3 respectively.
  581. #'
  582. #' Increasing prior scale will allow this seasonality component more
  583. #' flexibility, decreasing will dampen it. If not provided, will use the
  584. #' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
  585. #'
  586. #' @param m Prophet object.
  587. #' @param name String name of the seasonality component.
  588. #' @param period Float number of days in one period.
  589. #' @param fourier.order Int number of Fourier components to use.
  590. #' @param prior.scale Float prior scale for this component.
  591. #'
  592. #' @return The prophet model with the seasonality added.
  593. #'
  594. #' @importFrom dplyr "%>%"
  595. #' @export
  596. add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) {
  597. if (!is.null(m$history)) {
  598. stop("Seasonality must be added prior to model fitting.")
  599. }
  600. if (!(name %in% c('daily', 'weekly', 'yearly'))) {
  601. # Allow overriding built-in seasonalities
  602. validate_column_name(m, name, check_seasonalities = FALSE)
  603. }
  604. if (is.null(prior.scale)) {
  605. ps <- m$seasonality.prior.scale
  606. } else {
  607. ps <- prior.scale
  608. }
  609. if (ps <= 0) {
  610. stop('Prior scale must be > 0')
  611. }
  612. m$seasonalities[[name]] <- list(
  613. period = period,
  614. fourier.order = fourier.order,
  615. prior.scale = ps
  616. )
  617. return(m)
  618. }
  619. #' Dataframe with seasonality features.
  620. #' Includes seasonality features, holiday features, and added regressors.
  621. #'
  622. #' @param m Prophet object.
  623. #' @param df Dataframe with dates for computing seasonality features and any
  624. #' added regressors.
  625. #'
  626. #' @return List with items
  627. #' seasonal.features: Dataframe with regressor features,
  628. #' prior.scales: Array of prior scales for each colum of the features
  629. #' dataframe.
  630. #'
  631. #' @keywords internal
  632. make_all_seasonality_features <- function(m, df) {
  633. seasonal.features <- data.frame(row.names = seq_len(nrow(df)))
  634. prior.scales <- c()
  635. # Seasonality features
  636. for (name in names(m$seasonalities)) {
  637. props <- m$seasonalities[[name]]
  638. features <- make_seasonality_features(
  639. df$ds, props$period, props$fourier.order, name)
  640. seasonal.features <- cbind(seasonal.features, features)
  641. prior.scales <- c(prior.scales,
  642. props$prior.scale * rep(1, ncol(features)))
  643. }
  644. # Holiday features
  645. if (!is.null(m$holidays)) {
  646. hf <- make_holiday_features(m, df$ds)
  647. seasonal.features <- cbind(seasonal.features, hf$holiday.features)
  648. prior.scales <- c(prior.scales, hf$prior.scales)
  649. }
  650. # Additional regressors
  651. for (name in names(m$extra_regressors)) {
  652. seasonal.features[[name]] <- df[[name]]
  653. prior.scales <- c(prior.scales, m$extra_regressors[[name]]$prior.scale)
  654. }
  655. if (ncol(seasonal.features) == 0) {
  656. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  657. prior.scales <- 1
  658. }
  659. return(list(seasonal.features = seasonal.features,
  660. prior.scales = prior.scales))
  661. }
  662. #' Get number of Fourier components for built-in seasonalities.
  663. #'
  664. #' @param m Prophet object.
  665. #' @param name String name of the seasonality component.
  666. #' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
  667. #' provided.
  668. #' @param auto.disable Bool if seasonality should be disabled when 'auto'.
  669. #' @param default.order Int default Fourier order.
  670. #'
  671. #' @return Number of Fourier components, or 0 for disabled.
  672. #'
  673. #' @keywords internal
  674. parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
  675. if (arg == 'auto') {
  676. fourier.order <- 0
  677. if (name %in% names(m$seasonalities)) {
  678. message('Found custom seasonality named "', name,
  679. '", disabling built-in ', name, ' seasonality.')
  680. } else if (auto.disable) {
  681. message('Disabling ', name, ' seasonality. Run prophet with ', name,
  682. '.seasonality=TRUE to override this.')
  683. } else {
  684. fourier.order <- default.order
  685. }
  686. } else if (arg == TRUE) {
  687. fourier.order <- default.order
  688. } else if (arg == FALSE) {
  689. fourier.order <- 0
  690. } else {
  691. fourier.order <- arg
  692. }
  693. return(fourier.order)
  694. }
  695. #' Set seasonalities that were left on auto.
  696. #'
  697. #' Turns on yearly seasonality if there is >=2 years of history.
  698. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  699. #' spacing between dates in the history is <7 days.
  700. #' Turns on daily seasonality if there is >=2 days of history, and the spacing
  701. #' between dates in the history is <1 day.
  702. #'
  703. #' @param m Prophet object.
  704. #'
  705. #' @return The prophet model with seasonalities set.
  706. #'
  707. #' @keywords internal
  708. set_auto_seasonalities <- function(m) {
  709. first <- min(m$history$ds)
  710. last <- max(m$history$ds)
  711. dt <- diff(time_diff(m$history$ds, m$start))
  712. min.dt <- min(dt[dt > 0])
  713. yearly.disable <- time_diff(last, first) < 730
  714. fourier.order <- parse_seasonality_args(
  715. m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
  716. if (fourier.order > 0) {
  717. m$seasonalities[['yearly']] <- list(
  718. period = 365.25,
  719. fourier.order = fourier.order,
  720. prior.scale = m$seasonality.prior.scale
  721. )
  722. }
  723. weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
  724. fourier.order <- parse_seasonality_args(
  725. m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
  726. if (fourier.order > 0) {
  727. m$seasonalities[['weekly']] <- list(
  728. period = 7,
  729. fourier.order = fourier.order,
  730. prior.scale = m$seasonality.prior.scale
  731. )
  732. }
  733. daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
  734. fourier.order <- parse_seasonality_args(
  735. m, 'daily', m$daily.seasonality, daily.disable, 4)
  736. if (fourier.order > 0) {
  737. m$seasonalities[['daily']] <- list(
  738. period = 1,
  739. fourier.order = fourier.order,
  740. prior.scale = m$seasonality.prior.scale
  741. )
  742. }
  743. return(m)
  744. }
  745. #' Initialize linear growth.
  746. #'
  747. #' Provides a strong initialization for linear growth by calculating the
  748. #' growth and offset parameters that pass the function through the first and
  749. #' last points in the time series.
  750. #'
  751. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  752. #' and t (scaled time).
  753. #'
  754. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  755. #' growth function.
  756. #'
  757. #' @keywords internal
  758. linear_growth_init <- function(df) {
  759. i0 <- which.min(df$ds)
  760. i1 <- which.max(df$ds)
  761. T <- df$t[i1] - df$t[i0]
  762. # Initialize the rate
  763. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  764. # And the offset
  765. m <- df$y_scaled[i0] - k * df$t[i0]
  766. return(c(k, m))
  767. }
  768. #' Initialize logistic growth.
  769. #'
  770. #' Provides a strong initialization for logistic growth by calculating the
  771. #' growth and offset parameters that pass the function through the first and
  772. #' last points in the time series.
  773. #'
  774. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  775. #' y_scaled (scaled time series), and t (scaled time).
  776. #'
  777. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  778. #' growth function.
  779. #'
  780. #' @keywords internal
  781. logistic_growth_init <- function(df) {
  782. i0 <- which.min(df$ds)
  783. i1 <- which.max(df$ds)
  784. T <- df$t[i1] - df$t[i0]
  785. # Force valid values, in case y > cap or y < 0
  786. C0 <- df$cap_scaled[i0]
  787. C1 <- df$cap_scaled[i1]
  788. y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
  789. y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
  790. r0 <- C0 / y0
  791. r1 <- C1 / y1
  792. if (abs(r0 - r1) <= 0.01) {
  793. r0 <- 1.05 * r0
  794. }
  795. L0 <- log(r0 - 1)
  796. L1 <- log(r1 - 1)
  797. # Initialize the offset
  798. m <- L0 * T / (L0 - L1)
  799. # And the rate
  800. k <- (L0 - L1) / T
  801. return(c(k, m))
  802. }
  803. #' Fit the prophet model.
  804. #'
  805. #' This sets m$params to contain the fitted model parameters. It is a list
  806. #' with the following elements:
  807. #' k (M array): M posterior samples of the initial slope.
  808. #' m (M array): The initial intercept.
  809. #' delta (MxN matrix): The slope change at each of N changepoints.
  810. #' beta (MxK matrix): Coefficients for K seasonality features.
  811. #' sigma_obs (M array): Noise level.
  812. #' Note that M=1 if MAP estimation.
  813. #'
  814. #' @param m Prophet object.
  815. #' @param df Data frame.
  816. #' @param ... Additional arguments passed to the \code{optimizing} or
  817. #' \code{sampling} functions in Stan.
  818. #'
  819. #' @export
  820. fit.prophet <- function(m, df, ...) {
  821. if (!is.null(m$history)) {
  822. stop("Prophet object can only be fit once. Instantiate a new object.")
  823. }
  824. history <- df %>%
  825. dplyr::filter(!is.na(y))
  826. if (nrow(history) < 2) {
  827. stop("Dataframe has less than 2 non-NA rows.")
  828. }
  829. m$history.dates <- sort(set_date(df$ds))
  830. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  831. history <- out$df
  832. m <- out$m
  833. m$history <- history
  834. m <- set_auto_seasonalities(m)
  835. out2 <- make_all_seasonality_features(m, history)
  836. seasonal.features <- out2$seasonal.features
  837. prior.scales <- out2$prior.scales
  838. m <- set_changepoints(m)
  839. # Construct input to stan
  840. dat <- list(
  841. T = nrow(history),
  842. K = ncol(seasonal.features),
  843. S = length(m$changepoints.t),
  844. y = history$y_scaled,
  845. t = history$t,
  846. t_change = array(m$changepoints.t),
  847. X = as.matrix(seasonal.features),
  848. sigmas = array(prior.scales),
  849. tau = m$changepoint.prior.scale
  850. )
  851. # Run stan
  852. if (m$growth == 'linear') {
  853. kinit <- linear_growth_init(history)
  854. } else {
  855. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  856. kinit <- logistic_growth_init(history)
  857. }
  858. if (exists(".prophet.stan.models")) {
  859. model <- .prophet.stan.models[[m$growth]]
  860. } else {
  861. model <- get_prophet_stan_model(m$growth)
  862. }
  863. stan_init <- function() {
  864. list(k = kinit[1],
  865. m = kinit[2],
  866. delta = array(rep(0, length(m$changepoints.t))),
  867. beta = array(rep(0, ncol(seasonal.features))),
  868. sigma_obs = 1
  869. )
  870. }
  871. if (min(history$y) == max(history$y)) {
  872. # Nothing to fit.
  873. m$params <- stan_init()
  874. m$params$sigma_obs <- 0.
  875. n.iteration <- 1.
  876. } else if (m$mcmc.samples > 0) {
  877. stan.fit <- rstan::sampling(
  878. model,
  879. data = dat,
  880. init = stan_init,
  881. iter = m$mcmc.samples,
  882. ...
  883. )
  884. m$params <- rstan::extract(stan.fit)
  885. n.iteration <- length(m$params$k)
  886. } else {
  887. stan.fit <- rstan::optimizing(
  888. model,
  889. data = dat,
  890. init = stan_init,
  891. iter = 1e4,
  892. as_vector = FALSE,
  893. ...
  894. )
  895. m$params <- stan.fit$par
  896. n.iteration <- 1
  897. }
  898. # Cast the parameters to have consistent form, whether full bayes or MAP
  899. for (name in c('delta', 'beta')){
  900. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  901. }
  902. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  903. for (name in c('k', 'm', 'sigma_obs')){
  904. m$params[[name]] <- c(m$params[[name]])
  905. }
  906. # If no changepoints were requested, replace delta with 0s
  907. if (m$n.changepoints == 0) {
  908. # Fold delta into the base rate k
  909. m$params$k <- m$params$k + m$params$delta[, 1]
  910. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  911. }
  912. return(m)
  913. }
  914. #' Predict using the prophet model.
  915. #'
  916. #' @param object Prophet object.
  917. #' @param df Dataframe with dates for predictions (column ds), and capacity
  918. #' (column cap) if logistic growth. If not provided, predictions are made on
  919. #' the history.
  920. #' @param ... additional arguments.
  921. #'
  922. #' @return A dataframe with the forecast components.
  923. #'
  924. #' @examples
  925. #' \dontrun{
  926. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  927. #' y = sin(1:366/200) + rnorm(366)/10)
  928. #' m <- prophet(history)
  929. #' future <- make_future_dataframe(m, periods = 365)
  930. #' forecast <- predict(m, future)
  931. #' plot(m, forecast)
  932. #' }
  933. #'
  934. #' @export
  935. predict.prophet <- function(object, df = NULL, ...) {
  936. if (is.null(df)) {
  937. df <- object$history
  938. } else {
  939. if (nrow(df) == 0) {
  940. stop("Dataframe has no rows.")
  941. }
  942. out <- setup_dataframe(object, df)
  943. df <- out$df
  944. }
  945. df$trend <- predict_trend(object, df)
  946. seasonal.components <- predict_seasonal_components(object, df)
  947. intervals <- predict_uncertainty(object, df)
  948. # Drop columns except ds, cap, floor, and trend
  949. cols <- c('ds', 'trend')
  950. if ('cap' %in% colnames(df)) {
  951. cols <- c(cols, 'cap')
  952. }
  953. if (object$logistic.floor) {
  954. cols <- c(cols, 'floor')
  955. }
  956. df <- df[cols]
  957. df <- dplyr::bind_cols(df, seasonal.components, intervals)
  958. df$yhat <- df$trend + df$seasonal
  959. return(df)
  960. }
  961. #' Evaluate the piecewise linear function.
  962. #'
  963. #' @param t Vector of times on which the function is evaluated.
  964. #' @param deltas Vector of rate changes at each changepoint.
  965. #' @param k Float initial rate.
  966. #' @param m Float initial offset.
  967. #' @param changepoint.ts Vector of changepoint times.
  968. #'
  969. #' @return Vector y(t).
  970. #'
  971. #' @keywords internal
  972. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  973. # Intercept changes
  974. gammas <- -changepoint.ts * deltas
  975. # Get cumulative slope and intercept at each t
  976. k_t <- rep(k, length(t))
  977. m_t <- rep(m, length(t))
  978. for (s in seq_along(changepoint.ts)) {
  979. indx <- t >= changepoint.ts[s]
  980. k_t[indx] <- k_t[indx] + deltas[s]
  981. m_t[indx] <- m_t[indx] + gammas[s]
  982. }
  983. y <- k_t * t + m_t
  984. return(y)
  985. }
  986. #' Evaluate the piecewise logistic function.
  987. #'
  988. #' @param t Vector of times on which the function is evaluated.
  989. #' @param cap Vector of capacities at each t.
  990. #' @param deltas Vector of rate changes at each changepoint.
  991. #' @param k Float initial rate.
  992. #' @param m Float initial offset.
  993. #' @param changepoint.ts Vector of changepoint times.
  994. #'
  995. #' @return Vector y(t).
  996. #'
  997. #' @keywords internal
  998. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  999. # Compute offset changes
  1000. k.cum <- c(k, cumsum(deltas) + k)
  1001. gammas <- rep(0, length(changepoint.ts))
  1002. for (i in seq_along(changepoint.ts)) {
  1003. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  1004. * (1 - k.cum[i] / k.cum[i + 1]))
  1005. }
  1006. # Get cumulative rate and offset at each t
  1007. k_t <- rep(k, length(t))
  1008. m_t <- rep(m, length(t))
  1009. for (s in seq_along(changepoint.ts)) {
  1010. indx <- t >= changepoint.ts[s]
  1011. k_t[indx] <- k_t[indx] + deltas[s]
  1012. m_t[indx] <- m_t[indx] + gammas[s]
  1013. }
  1014. y <- cap / (1 + exp(-k_t * (t - m_t)))
  1015. return(y)
  1016. }
  1017. #' Predict trend using the prophet model.
  1018. #'
  1019. #' @param model Prophet object.
  1020. #' @param df Prediction dataframe.
  1021. #'
  1022. #' @return Vector with trend on prediction dates.
  1023. #'
  1024. #' @keywords internal
  1025. predict_trend <- function(model, df) {
  1026. k <- mean(model$params$k, na.rm = TRUE)
  1027. param.m <- mean(model$params$m, na.rm = TRUE)
  1028. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  1029. t <- df$t
  1030. if (model$growth == 'linear') {
  1031. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  1032. } else {
  1033. cap <- df$cap_scaled
  1034. trend <- piecewise_logistic(
  1035. t, cap, deltas, k, param.m, model$changepoints.t)
  1036. }
  1037. return(trend * model$y.scale + df$floor)
  1038. }
  1039. #' Predict seasonality components, holidays, and added regressors.
  1040. #'
  1041. #' @param m Prophet object.
  1042. #' @param df Prediction dataframe.
  1043. #'
  1044. #' @return Dataframe with seasonal components.
  1045. #'
  1046. #' @keywords internal
  1047. predict_seasonal_components <- function(m, df) {
  1048. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1049. lower.p <- (1 - m$interval.width)/2
  1050. upper.p <- (1 + m$interval.width)/2
  1051. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  1052. dplyr::mutate(col = seq_len(n())) %>%
  1053. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  1054. extra = "merge", fill = "right") %>%
  1055. dplyr::select(col, component)
  1056. # Add total for all regression components
  1057. components <- rbind(
  1058. components,
  1059. data.frame(col = seq_len(ncol(seasonal.features)), component = 'seasonal'))
  1060. # Add totals for seasonality, holiday, and extra regressors
  1061. components <- add_group_component(
  1062. components, 'seasonalities', names(m$seasonalities))
  1063. if(!is.null(m$holidays)){
  1064. components <- add_group_component(
  1065. components, 'holidays', unique(m$holidays$holiday))
  1066. }
  1067. components <- add_group_component(
  1068. components, 'extra_regressors', names(m$extra_regressors))
  1069. # Remove the placeholder
  1070. components <- dplyr::filter(components, component != 'zeros')
  1071. component.predictions <- components %>%
  1072. dplyr::group_by(component) %>% dplyr::do({
  1073. comp <- (as.matrix(seasonal.features[, .$col])
  1074. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  1075. dplyr::data_frame(ix = seq_len(nrow(seasonal.features)),
  1076. mean = rowMeans(comp, na.rm = TRUE),
  1077. lower = apply(comp, 1, stats::quantile, lower.p,
  1078. na.rm = TRUE),
  1079. upper = apply(comp, 1, stats::quantile, upper.p,
  1080. na.rm = TRUE))
  1081. }) %>%
  1082. tidyr::gather(stat, value, mean, lower, upper) %>%
  1083. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  1084. tidyr::unite(component, component, stat, sep="") %>%
  1085. tidyr::spread(component, value) %>%
  1086. dplyr::select(-ix)
  1087. return(component.predictions)
  1088. }
  1089. #' Adds a component with given name that contains all of the components
  1090. #' in group.
  1091. #'
  1092. #' @param components Dataframe with components.
  1093. #' @param name Name of new group component.
  1094. #' @param group List of components that form the group.
  1095. #'
  1096. #' @return Dataframe with components.
  1097. #'
  1098. #' @keywords internal
  1099. add_group_component <- function(components, name, group) {
  1100. new_comp <- components[(components$component %in% group), ]
  1101. if (nrow(new_comp) > 0) {
  1102. new_comp$component <- name
  1103. components <- rbind(components, new_comp)
  1104. }
  1105. return(components)
  1106. }
  1107. #' Prophet posterior predictive samples.
  1108. #'
  1109. #' @param m Prophet object.
  1110. #' @param df Prediction dataframe.
  1111. #'
  1112. #' @return List with posterior predictive samples for each component.
  1113. #'
  1114. #' @keywords internal
  1115. sample_posterior_predictive <- function(m, df) {
  1116. # Sample trend, seasonality, and yhat from the extrapolation model.
  1117. n.iterations <- length(m$params$k)
  1118. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  1119. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  1120. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1121. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  1122. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  1123. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  1124. for (i in seq_len(n.iterations)) {
  1125. # For each set of parameters from MCMC (or just 1 set for MAP),
  1126. for (j in seq_len(samp.per.iter)) {
  1127. # Do a simulation with this set of parameters,
  1128. sim <- sample_model(m, df, seasonal.features, i)
  1129. # Store the results
  1130. for (key in c("trend", "seasonal", "yhat")) {
  1131. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  1132. }
  1133. }
  1134. }
  1135. return(sim.values)
  1136. }
  1137. #' Sample from the posterior predictive distribution.
  1138. #'
  1139. #' @param m Prophet object.
  1140. #' @param df Dataframe with dates for predictions (column ds), and capacity
  1141. #' (column cap) if logistic growth.
  1142. #'
  1143. #' @return A list with items "trend", "seasonal", and "yhat" containing
  1144. #' posterior predictive samples for that component. "seasonal" is the sum
  1145. #' of seasonalities, holidays, and added regressors.
  1146. #'
  1147. #' @export
  1148. predictive_samples <- function(m, df) {
  1149. df <- setup_dataframe(m, df)$df
  1150. sim.values <- sample_posterior_predictive(m, df)
  1151. return(sim.values)
  1152. }
  1153. #' Prophet uncertainty intervals for yhat and trend
  1154. #'
  1155. #' @param m Prophet object.
  1156. #' @param df Prediction dataframe.
  1157. #'
  1158. #' @return Dataframe with uncertainty intervals.
  1159. #'
  1160. #' @keywords internal
  1161. predict_uncertainty <- function(m, df) {
  1162. sim.values <- sample_posterior_predictive(m, df)
  1163. # Add uncertainty estimates
  1164. lower.p <- (1 - m$interval.width)/2
  1165. upper.p <- (1 + m$interval.width)/2
  1166. intervals <- cbind(
  1167. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  1168. na.rm = TRUE)),
  1169. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  1170. na.rm = TRUE))
  1171. ) %>% dplyr::as_data_frame()
  1172. colnames(intervals) <- paste(rep(c('yhat', 'trend'), each=2),
  1173. c('lower', 'upper'), sep = "_")
  1174. return(intervals)
  1175. }
  1176. #' Simulate observations from the extrapolated generative model.
  1177. #'
  1178. #' @param m Prophet object.
  1179. #' @param df Prediction dataframe.
  1180. #' @param seasonal.features Data frame of seasonal features
  1181. #' @param iteration Int sampling iteration to use parameters from.
  1182. #'
  1183. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  1184. #'
  1185. #' @keywords internal
  1186. sample_model <- function(m, df, seasonal.features, iteration) {
  1187. trend <- sample_predictive_trend(m, df, iteration)
  1188. beta <- m$params$beta[iteration,]
  1189. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  1190. sigma <- m$params$sigma_obs[iteration]
  1191. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  1192. return(list("yhat" = trend + seasonal + noise,
  1193. "trend" = trend,
  1194. "seasonal" = seasonal))
  1195. }
  1196. #' Simulate the trend using the extrapolated generative model.
  1197. #'
  1198. #' @param model Prophet object.
  1199. #' @param df Prediction dataframe.
  1200. #' @param iteration Int sampling iteration to use parameters from.
  1201. #'
  1202. #' @return Vector of simulated trend over df$t.
  1203. #'
  1204. #' @keywords internal
  1205. sample_predictive_trend <- function(model, df, iteration) {
  1206. k <- model$params$k[iteration]
  1207. param.m <- model$params$m[iteration]
  1208. deltas <- model$params$delta[iteration,]
  1209. t <- df$t
  1210. T <- max(t)
  1211. if (T > 1) {
  1212. # Get the time discretization of the history
  1213. dt <- diff(model$history$t)
  1214. dt <- min(dt[dt > 0])
  1215. # Number of time periods in the future
  1216. N <- ceiling((T - 1) / dt)
  1217. S <- length(model$changepoints.t)
  1218. # The history had S split points, over t = [0, 1].
  1219. # The forecast is on [1, T], and should have the same average frequency of
  1220. # rate changes. Thus for N time periods in the future, we want an average
  1221. # of S * (T - 1) changepoints in expectation.
  1222. prob.change <- min(1, (S * (T - 1)) / N)
  1223. # This calculation works for both history and df not uniformly spaced.
  1224. n.changes <- stats::rbinom(1, N, prob.change)
  1225. # Sample ts
  1226. if (n.changes == 0) {
  1227. changepoint.ts.new <- c()
  1228. } else {
  1229. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  1230. }
  1231. } else {
  1232. changepoint.ts.new <- c()
  1233. n.changes <- 0
  1234. }
  1235. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  1236. lambda <- mean(abs(c(deltas))) + 1e-8
  1237. # Sample deltas
  1238. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  1239. # Combine with changepoints from the history
  1240. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  1241. deltas <- c(deltas, deltas.new)
  1242. # Get the corresponding trend
  1243. if (model$growth == 'linear') {
  1244. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  1245. } else {
  1246. cap <- df$cap_scaled
  1247. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  1248. }
  1249. return(trend * model$y.scale + df$floor)
  1250. }
  1251. #' Make dataframe with future dates for forecasting.
  1252. #'
  1253. #' @param m Prophet model object.
  1254. #' @param periods Int number of periods to forecast forward.
  1255. #' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
  1256. #' @param include_history Boolean to include the historical dates in the data
  1257. #' frame for predictions.
  1258. #'
  1259. #' @return Dataframe that extends forward from the end of m$history for the
  1260. #' requested number of periods.
  1261. #'
  1262. #' @export
  1263. make_future_dataframe <- function(m, periods, freq = 'day',
  1264. include_history = TRUE) {
  1265. # For backwards compatability with previous zoo date type,
  1266. if (freq == 'm') {
  1267. freq <- 'month'
  1268. }
  1269. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  1270. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  1271. if (include_history) {
  1272. dates <- c(m$history.dates, dates)
  1273. attr(dates, "tzone") <- "GMT"
  1274. }
  1275. return(data.frame(ds = dates))
  1276. }
  1277. #' Copy Prophet object.
  1278. #'
  1279. #' @param m Prophet model object.
  1280. #' @param cutoff Date, possibly as string. Changepoints are only retained if
  1281. #' changepoints <= cutoff.
  1282. #'
  1283. #' @return An unfitted Prophet model object with the same parameters as the
  1284. #' input model.
  1285. #'
  1286. #' @keywords internal
  1287. prophet_copy <- function(m, cutoff = NULL) {
  1288. if (is.null(m$history)) {
  1289. stop("This is for copying a fitted Prophet object.")
  1290. }
  1291. if (m$specified.changepoints) {
  1292. changepoints <- m$changepoints
  1293. if (!is.null(cutoff)) {
  1294. cutoff <- set_date(cutoff)
  1295. changepoints <- changepoints[changepoints <= cutoff]
  1296. }
  1297. } else {
  1298. changepoints <- NULL
  1299. }
  1300. # Auto seasonalities are set to FALSE because they are already set in
  1301. # m$seasonalities.
  1302. m2 <- prophet(
  1303. growth = m$growth,
  1304. changepoints = changepoints,
  1305. n.changepoints = m$n.changepoints,
  1306. yearly.seasonality = FALSE,
  1307. weekly.seasonality = FALSE,
  1308. daily.seasonality = FALSE,
  1309. holidays = m$holidays,
  1310. seasonality.prior.scale = m$seasonality.prior.scale,
  1311. changepoint.prior.scale = m$changepoint.prior.scale,
  1312. holidays.prior.scale = m$holidays.prior.scale,
  1313. mcmc.samples = m$mcmc.samples,
  1314. interval.width = m$interval.width,
  1315. uncertainty.samples = m$uncertainty.samples,
  1316. fit = FALSE
  1317. )
  1318. m2$extra_regressors <- m$extra_regressors
  1319. m2$seasonalities <- m$seasonalities
  1320. return(m2)
  1321. }
  1322. # fb-block 3