prophet.R 49 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. utils::globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df (optional) Dataframe containing the history. Must have columns ds
  16. #' (date type) and y, the time series. If growth is logistic, then df must
  17. #' also have a column cap that specifies the capacity at each ds. If not
  18. #' provided, then the model object will be instantiated but not fit; use
  19. #' fit.prophet(m, df) to fit the model.
  20. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  21. #' trend.
  22. #' @param changepoints Vector of dates at which to include potential
  23. #' changepoints. If not specified, potential changepoints are selected
  24. #' automatically.
  25. #' @param n.changepoints Number of potential changepoints to include. Not used
  26. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  27. #' then n.changepoints potential changepoints are selected uniformly from the
  28. #' first `changepoint.range` proportion of df$ds.
  29. #' @param changepoint.range Proportion of history in which trend changepoints
  30. #' will be estimated. Defaults to 0.8 for the first 80%. Not used if
  31. #' `changepoints` is specified.
  32. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
  33. #' FALSE, or a number of Fourier terms to generate.
  34. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
  35. #' FALSE, or a number of Fourier terms to generate.
  36. #' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
  37. #' FALSE, or a number of Fourier terms to generate.
  38. #' @param holidays data frame with columns holiday (character) and ds (date
  39. #' type)and optionally columns lower_window and upper_window which specify a
  40. #' range of days around the date to be included as holidays. lower_window=-2
  41. #' will include 2 days prior to the date as holidays. Also optionally can have
  42. #' a column prior_scale specifying the prior scale for each holiday.
  43. #' @param seasonality.mode 'additive' (default) or 'multiplicative'.
  44. #' @param seasonality.prior.scale Parameter modulating the strength of the
  45. #' seasonality model. Larger values allow the model to fit larger seasonal
  46. #' fluctuations, smaller values dampen the seasonality. Can be specified for
  47. #' individual seasonalities using add_seasonality.
  48. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  49. #' components model, unless overridden in the holidays input.
  50. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  51. #' automatic changepoint selection. Large values will allow many changepoints,
  52. #' small values will allow few changepoints.
  53. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  54. #' inference with the specified number of MCMC samples. If 0, will do MAP
  55. #' estimation.
  56. #' @param interval.width Numeric, width of the uncertainty intervals provided
  57. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  58. #' in the trend using the MAP estimate of the extrapolated generative model.
  59. #' If mcmc.samples>0, this will be integrated over all model parameters,
  60. #' which will include uncertainty in seasonality.
  61. #' @param uncertainty.samples Number of simulated draws used to estimate
  62. #' uncertainty intervals.
  63. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  64. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  65. #'
  66. #' @return A prophet model.
  67. #'
  68. #' @examples
  69. #' \dontrun{
  70. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  71. #' y = sin(1:366/200) + rnorm(366)/10)
  72. #' m <- prophet(history)
  73. #' }
  74. #'
  75. #' @export
  76. #' @importFrom dplyr "%>%"
  77. #' @import Rcpp
  78. prophet <- function(df = NULL,
  79. growth = 'linear',
  80. changepoints = NULL,
  81. n.changepoints = 25,
  82. changepoint.range = 0.8,
  83. yearly.seasonality = 'auto',
  84. weekly.seasonality = 'auto',
  85. daily.seasonality = 'auto',
  86. holidays = NULL,
  87. seasonality.mode = 'additive',
  88. seasonality.prior.scale = 10,
  89. holidays.prior.scale = 10,
  90. changepoint.prior.scale = 0.05,
  91. mcmc.samples = 0,
  92. interval.width = 0.80,
  93. uncertainty.samples = 1000,
  94. fit = TRUE,
  95. ...
  96. ) {
  97. if (!is.null(changepoints)) {
  98. n.changepoints <- length(changepoints)
  99. }
  100. m <- list(
  101. growth = growth,
  102. changepoints = changepoints,
  103. n.changepoints = n.changepoints,
  104. changepoint.range = changepoint.range,
  105. yearly.seasonality = yearly.seasonality,
  106. weekly.seasonality = weekly.seasonality,
  107. daily.seasonality = daily.seasonality,
  108. holidays = holidays,
  109. seasonality.mode = seasonality.mode,
  110. seasonality.prior.scale = seasonality.prior.scale,
  111. changepoint.prior.scale = changepoint.prior.scale,
  112. holidays.prior.scale = holidays.prior.scale,
  113. mcmc.samples = mcmc.samples,
  114. interval.width = interval.width,
  115. uncertainty.samples = uncertainty.samples,
  116. specified.changepoints = !is.null(changepoints),
  117. start = NULL, # This and following attributes are set during fitting
  118. y.scale = NULL,
  119. logistic.floor = FALSE,
  120. t.scale = NULL,
  121. changepoints.t = NULL,
  122. seasonalities = list(),
  123. extra_regressors = list(),
  124. stan.fit = NULL,
  125. params = list(),
  126. history = NULL,
  127. history.dates = NULL,
  128. train.component.cols = NULL,
  129. component.modes = NULL
  130. )
  131. validate_inputs(m)
  132. class(m) <- append("prophet", class(m))
  133. if ((fit) && (!is.null(df))) {
  134. m <- fit.prophet(m, df, ...)
  135. }
  136. return(m)
  137. }
  138. #' Validates the inputs to Prophet.
  139. #'
  140. #' @param m Prophet object.
  141. #'
  142. #' @keywords internal
  143. validate_inputs <- function(m) {
  144. if (!(m$growth %in% c('linear', 'logistic'))) {
  145. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  146. }
  147. if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) {
  148. stop("Parameter 'changepoint.range' must be in [0, 1]")
  149. }
  150. if (!is.null(m$holidays)) {
  151. if (!(exists('holiday', where = m$holidays))) {
  152. stop('Holidays dataframe must have holiday field.')
  153. }
  154. if (!(exists('ds', where = m$holidays))) {
  155. stop('Holidays dataframe must have ds field.')
  156. }
  157. has.lower <- exists('lower_window', where = m$holidays)
  158. has.upper <- exists('upper_window', where = m$holidays)
  159. if (has.lower + has.upper == 1) {
  160. stop(paste('Holidays must have both lower_window and upper_window,',
  161. 'or neither.'))
  162. }
  163. if (has.lower) {
  164. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  165. stop('Holiday lower_window should be <= 0')
  166. }
  167. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  168. stop('Holiday upper_window should be >= 0')
  169. }
  170. }
  171. for (h in unique(m$holidays$holiday)) {
  172. validate_column_name(m, h, check_holidays = FALSE)
  173. }
  174. }
  175. if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
  176. stop("seasonality.mode must be 'additive' or 'multiplicative'")
  177. }
  178. }
  179. #' Validates the name of a seasonality, holiday, or regressor.
  180. #'
  181. #' @param m Prophet object.
  182. #' @param name string
  183. #' @param check_holidays bool check if name already used for holiday
  184. #' @param check_seasonalities bool check if name already used for seasonality
  185. #' @param check_regressors bool check if name already used for regressor
  186. #'
  187. #' @keywords internal
  188. validate_column_name <- function(
  189. m, name, check_holidays = TRUE, check_seasonalities = TRUE,
  190. check_regressors = TRUE
  191. ) {
  192. if (grepl("_delim_", name)) {
  193. stop('Holiday name cannot contain "_delim_"')
  194. }
  195. reserved_names = c(
  196. 'trend', 'additive_terms', 'daily', 'weekly', 'yearly',
  197. 'holidays', 'zeros', 'extra_regressors_additive', 'yhat',
  198. 'extra_regressors_multiplicative', 'multiplicative_terms'
  199. )
  200. rn_l = paste0(reserved_names,"_lower")
  201. rn_u = paste0(reserved_names,"_upper")
  202. reserved_names = c(reserved_names, rn_l, rn_u,
  203. c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
  204. if(name %in% reserved_names){
  205. stop("Name ", name, " is reserved.")
  206. }
  207. if(check_holidays && !is.null(m$holidays) &&
  208. (name %in% unique(m$holidays$holiday))){
  209. stop("Name ", name, " already used for a holiday.")
  210. }
  211. if(check_seasonalities && (!is.null(m$seasonalities[[name]]))){
  212. stop("Name ", name, " already used for a seasonality.")
  213. }
  214. if(check_regressors && (!is.null(m$seasonalities[[name]]))){
  215. stop("Name ", name, " already used for an added regressor.")
  216. }
  217. }
  218. #' Load compiled Stan model
  219. #'
  220. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  221. #' trend.
  222. #'
  223. #' @return Stan model.
  224. #'
  225. #' @keywords internal
  226. get_prophet_stan_model <- function() {
  227. ## If the cached model doesn't work, just compile a new one.
  228. tryCatch({
  229. binary <- system.file(
  230. 'libs',
  231. Sys.getenv('R_ARCH'),
  232. 'prophet_stan_model.RData',
  233. package = 'prophet',
  234. mustWork = TRUE
  235. )
  236. load(binary)
  237. obj.name <- 'model.stanm'
  238. stanm <- eval(parse(text = obj.name))
  239. ## Should cause an error if the model doesn't work.
  240. stanm@mk_cppmodule(stanm)
  241. stanm
  242. }, error = function(cond) {
  243. compile_stan_model()
  244. })
  245. }
  246. #' Compile Stan model
  247. #'
  248. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  249. #' trend.
  250. #'
  251. #' @return Stan model.
  252. #'
  253. #' @keywords internal
  254. compile_stan_model <- function() {
  255. fn <- 'stan/prophet.stan'
  256. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  257. stanc <- rstan::stanc(stan.src)
  258. return(rstan::stan_model(stanc_ret = stanc, model_name = 'prophet_model'))
  259. }
  260. #' Convert date vector
  261. #'
  262. #' Convert the date to POSIXct object
  263. #'
  264. #' @param ds Date vector, can be consisted of characters
  265. #' @param tz string time zone
  266. #'
  267. #' @return vector of POSIXct object converted from date
  268. #'
  269. #' @keywords internal
  270. set_date <- function(ds = NULL, tz = "GMT") {
  271. if (length(ds) == 0) {
  272. return(NULL)
  273. }
  274. if (is.factor(ds)) {
  275. ds <- as.character(ds)
  276. }
  277. fmt <- if (min(nchar(ds)) < 12) "%Y-%m-%d" else "%Y-%m-%d %H:%M:%S"
  278. ds <- as.POSIXct(ds, format = fmt, tz = tz)
  279. attr(ds, "tzone") <- tz
  280. return(ds)
  281. }
  282. #' Time difference between datetimes
  283. #'
  284. #' Compute time difference of two POSIXct objects
  285. #'
  286. #' @param ds1 POSIXct object
  287. #' @param ds2 POSIXct object
  288. #' @param units string units of difference, e.g. 'days' or 'secs'.
  289. #'
  290. #' @return numeric time difference
  291. #'
  292. #' @keywords internal
  293. time_diff <- function(ds1, ds2, units = "days") {
  294. return(as.numeric(difftime(ds1, ds2, units = units)))
  295. }
  296. #' Prepare dataframe for fitting or predicting.
  297. #'
  298. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  299. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  300. #' and predicting.
  301. #'
  302. #' @param m Prophet object.
  303. #' @param df Data frame with columns ds, y, and cap if logistic growth. Any
  304. #' specified additional regressors must also be present.
  305. #' @param initialize_scales Boolean set scaling factors in m from df.
  306. #'
  307. #' @return list with items 'df' and 'm'.
  308. #'
  309. #' @keywords internal
  310. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  311. if (exists('y', where=df)) {
  312. df$y <- as.numeric(df$y)
  313. }
  314. if (any(is.infinite(df$y))) {
  315. stop("Found infinity in column y.")
  316. }
  317. df$ds <- set_date(df$ds)
  318. if (anyNA(df$ds)) {
  319. stop(paste('Unable to parse date format in column ds. Convert to date ',
  320. 'format (%Y-%m-%d or %Y-%m-%d %H:%M:%S) and check that there',
  321. 'are no NAs.'))
  322. }
  323. for (name in names(m$extra_regressors)) {
  324. if (!(name %in% colnames(df))) {
  325. stop('Regressor "', name, '" missing from dataframe')
  326. }
  327. }
  328. df <- df %>%
  329. dplyr::arrange(ds)
  330. m <- initialize_scales_fn(m, initialize_scales, df)
  331. if (m$logistic.floor) {
  332. if (!('floor' %in% colnames(df))) {
  333. stop("Expected column 'floor'.")
  334. }
  335. } else {
  336. df$floor <- 0
  337. }
  338. if (m$growth == 'logistic') {
  339. if (!(exists('cap', where=df))) {
  340. stop('Capacities must be supplied for logistic growth.')
  341. }
  342. df <- df %>%
  343. dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
  344. }
  345. df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  346. if (exists('y', where=df)) {
  347. df$y_scaled <- (df$y - df$floor) / m$y.scale
  348. }
  349. for (name in names(m$extra_regressors)) {
  350. df[[name]] <- as.numeric(df[[name]])
  351. props <- m$extra_regressors[[name]]
  352. df[[name]] <- (df[[name]] - props$mu) / props$std
  353. if (anyNA(df[[name]])) {
  354. stop('Found NaN in column ', name)
  355. }
  356. }
  357. return(list("m" = m, "df" = df))
  358. }
  359. #' Initialize model scales.
  360. #'
  361. #' Sets model scaling factors using df.
  362. #'
  363. #' @param m Prophet object.
  364. #' @param initialize_scales Boolean set the scales or not.
  365. #' @param df Dataframe for setting scales.
  366. #'
  367. #' @return Prophet object with scales set.
  368. #'
  369. #' @keywords internal
  370. initialize_scales_fn <- function(m, initialize_scales, df) {
  371. if (!initialize_scales) {
  372. return(m)
  373. }
  374. if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
  375. m$logistic.floor <- TRUE
  376. floor <- df$floor
  377. } else {
  378. floor <- 0
  379. }
  380. m$y.scale <- max(abs(df$y - floor))
  381. if (m$y.scale == 0) {
  382. m$y.scale <- 1
  383. }
  384. m$start <- min(df$ds)
  385. m$t.scale <- time_diff(max(df$ds), m$start, "secs")
  386. for (name in names(m$extra_regressors)) {
  387. n.vals <- length(unique(df[[name]]))
  388. if (n.vals < 2) {
  389. stop('Regressor ', name, ' is constant.')
  390. }
  391. standardize <- m$extra_regressors[[name]]$standardize
  392. if (standardize == 'auto') {
  393. if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
  394. # Don't standardize binary variables
  395. standardize <- FALSE
  396. } else {
  397. standardize <- TRUE
  398. }
  399. }
  400. if (standardize) {
  401. mu <- mean(df[[name]])
  402. std <- stats::sd(df[[name]])
  403. m$extra_regressors[[name]]$mu <- mu
  404. m$extra_regressors[[name]]$std <- std
  405. }
  406. }
  407. return(m)
  408. }
  409. #' Set changepoints
  410. #'
  411. #' Sets m$changepoints to the dates of changepoints. Either:
  412. #' 1) The changepoints were passed in explicitly.
  413. #' A) They are empty.
  414. #' B) They are not empty, and need validation.
  415. #' 2) We are generating a grid of them.
  416. #' 3) The user prefers no changepoints be used.
  417. #'
  418. #' @param m Prophet object.
  419. #'
  420. #' @return m with changepoints set.
  421. #'
  422. #' @keywords internal
  423. set_changepoints <- function(m) {
  424. if (!is.null(m$changepoints)) {
  425. if (length(m$changepoints) > 0) {
  426. m$changepoints <- set_date(m$changepoints)
  427. if (min(m$changepoints) < min(m$history$ds)
  428. || max(m$changepoints) > max(m$history$ds)) {
  429. stop('Changepoints must fall within training data.')
  430. }
  431. }
  432. } else {
  433. # Place potential changepoints evenly through the first changepoint.range
  434. # proportion of the history.
  435. hist.size <- floor(nrow(m$history) * m$changepoint.range)
  436. if (m$n.changepoints + 1 > hist.size) {
  437. m$n.changepoints <- hist.size - 1
  438. message('n.changepoints greater than number of observations. Using ',
  439. m$n.changepoints)
  440. }
  441. if (m$n.changepoints > 0) {
  442. cp.indexes <- round(seq.int(1, hist.size,
  443. length.out = (m$n.changepoints + 1))[-1])
  444. m$changepoints <- m$history$ds[cp.indexes]
  445. } else {
  446. m$changepoints <- c()
  447. }
  448. }
  449. if (length(m$changepoints) > 0) {
  450. m$changepoints.t <- sort(
  451. time_diff(m$changepoints, m$start, "secs")) / m$t.scale
  452. } else {
  453. m$changepoints.t <- 0 # dummy changepoint
  454. }
  455. return(m)
  456. }
  457. #' Provides Fourier series components with the specified frequency and order.
  458. #'
  459. #' @param dates Vector of dates.
  460. #' @param period Number of days of the period.
  461. #' @param series.order Number of components.
  462. #'
  463. #' @return Matrix with seasonality features.
  464. #'
  465. #' @keywords internal
  466. fourier_series <- function(dates, period, series.order) {
  467. t <- time_diff(dates, set_date('1970-01-01 00:00:00'))
  468. features <- matrix(0, length(t), 2 * series.order)
  469. for (i in seq_len(series.order)) {
  470. x <- as.numeric(2 * i * pi * t / period)
  471. features[, i * 2 - 1] <- sin(x)
  472. features[, i * 2] <- cos(x)
  473. }
  474. return(features)
  475. }
  476. #' Data frame with seasonality features.
  477. #'
  478. #' @param dates Vector of dates.
  479. #' @param period Number of days of the period.
  480. #' @param series.order Number of components.
  481. #' @param prefix Column name prefix.
  482. #'
  483. #' @return Dataframe with seasonality.
  484. #'
  485. #' @keywords internal
  486. make_seasonality_features <- function(dates, period, series.order, prefix) {
  487. features <- fourier_series(dates, period, series.order)
  488. colnames(features) <- paste(prefix, seq_len(ncol(features)), sep = '_delim_')
  489. return(data.frame(features))
  490. }
  491. #' Construct a matrix of holiday features.
  492. #'
  493. #' @param m Prophet object.
  494. #' @param dates Vector with dates used for computing seasonality.
  495. #'
  496. #' @return A list with entries
  497. #' holiday.features: dataframe with a column for each holiday.
  498. #' prior.scales: array of prior scales for each holiday column.
  499. #' holiday.names: array of names of all holidays.
  500. #'
  501. #' @importFrom dplyr "%>%"
  502. #' @keywords internal
  503. make_holiday_features <- function(m, dates) {
  504. # Strip dates to be just days, for joining on holidays
  505. dates <- set_date(format(dates, "%Y-%m-%d"))
  506. wide <- m$holidays %>%
  507. dplyr::mutate(ds = set_date(ds)) %>%
  508. dplyr::group_by(holiday, ds) %>%
  509. dplyr::filter(row_number() == 1) %>%
  510. dplyr::do({
  511. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  512. && !is.na(.$upper_window)) {
  513. offsets <- seq(.$lower_window, .$upper_window)
  514. } else {
  515. offsets <- 0
  516. }
  517. names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
  518. abs(offsets), sep = '')
  519. dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
  520. }) %>%
  521. dplyr::mutate(x = 1) %>%
  522. tidyr::spread(holiday, x, fill = 0)
  523. holiday.features <- data.frame(ds = set_date(dates)) %>%
  524. dplyr::left_join(wide, by = 'ds') %>%
  525. dplyr::select(-ds)
  526. holiday.features[is.na(holiday.features)] <- 0
  527. # Prior scales
  528. if (!('prior_scale' %in% colnames(m$holidays))) {
  529. m$holidays$prior_scale <- m$holidays.prior.scale
  530. }
  531. prior.scales.list <- list()
  532. for (name in unique(m$holidays$holiday)) {
  533. df.h <- m$holidays[m$holidays$holiday == name, ]
  534. ps <- unique(df.h$prior_scale)
  535. if (length(ps) > 1) {
  536. stop('Holiday ', name, ' does not have a consistent prior scale ',
  537. 'specification')
  538. }
  539. if (is.na(ps)) {
  540. ps <- m$holidays.prior.scale
  541. }
  542. if (ps <= 0) {
  543. stop('Prior scale must be > 0.')
  544. }
  545. prior.scales.list[[name]] <- ps
  546. }
  547. prior.scales <- c()
  548. for (name in colnames(holiday.features)) {
  549. sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
  550. prior.scales <- c(prior.scales, prior.scales.list[[sn]])
  551. }
  552. return(list(holiday.features = holiday.features,
  553. prior.scales = prior.scales,
  554. holiday.names = names(prior.scales.list)))
  555. }
  556. #' Add an additional regressor to be used for fitting and predicting.
  557. #'
  558. #' The dataframe passed to `fit` and `predict` will have a column with the
  559. #' specified name to be used as a regressor. When standardize='auto', the
  560. #' regressor will be standardized unless it is binary. The regression
  561. #' coefficient is given a prior with the specified scale parameter.
  562. #' Decreasing the prior scale will add additional regularization. If no
  563. #' prior scale is provided, holidays.prior.scale will be used.
  564. #' Mode can be specified as either 'additive' or 'multiplicative'. If not
  565. #' specified, m$seasonality.mode will be used. 'additive' means the effect of
  566. #' the regressor will be added to the trend, 'multiplicative' means it will
  567. #' multiply the trend.
  568. #'
  569. #' @param m Prophet object.
  570. #' @param name String name of the regressor
  571. #' @param prior.scale Float scale for the normal prior. If not provided,
  572. #' holidays.prior.scale will be used.
  573. #' @param standardize Bool, specify whether this regressor will be standardized
  574. #' prior to fitting. Can be 'auto' (standardize if not binary), True, or
  575. #' False.
  576. #' @param mode Optional, 'additive' or 'multiplicative'. Defaults to
  577. #' m$seasonality.mode.
  578. #'
  579. #' @return The prophet model with the regressor added.
  580. #'
  581. #' @export
  582. add_regressor <- function(
  583. m, name, prior.scale = NULL, standardize = 'auto', mode = NULL
  584. ){
  585. if (!is.null(m$history)) {
  586. stop('Regressors must be added prior to model fitting.')
  587. }
  588. validate_column_name(m, name, check_regressors = FALSE)
  589. if (is.null(prior.scale)) {
  590. prior.scale <- m$holidays.prior.scale
  591. }
  592. if (is.null(mode)) {
  593. mode <- m$seasonality.mode
  594. }
  595. if(prior.scale <= 0) {
  596. stop("Prior scale must be > 0")
  597. }
  598. if (!(mode %in% c('additive', 'multiplicative'))) {
  599. stop("mode must be 'additive' or 'multiplicative'")
  600. }
  601. m$extra_regressors[[name]] <- list(
  602. prior.scale = prior.scale,
  603. standardize = standardize,
  604. mu = 0,
  605. std = 1.0,
  606. mode = mode
  607. )
  608. return(m)
  609. }
  610. #' Add a seasonal component with specified period, number of Fourier
  611. #' components, and prior scale.
  612. #'
  613. #' Increasing the number of Fourier components allows the seasonality to change
  614. #' more quickly (at risk of overfitting). Default values for yearly and weekly
  615. #' seasonalities are 10 and 3 respectively.
  616. #'
  617. #' Increasing prior scale will allow this seasonality component more
  618. #' flexibility, decreasing will dampen it. If not provided, will use the
  619. #' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
  620. #'
  621. #' Mode can be specified as either 'additive' or 'multiplicative'. If not
  622. #' specified, m$seasonality.mode will be used (defaults to 'additive').
  623. #' Additive means the seasonality will be added to the trend, multiplicative
  624. #' means it will multiply the trend.
  625. #'
  626. #' @param m Prophet object.
  627. #' @param name String name of the seasonality component.
  628. #' @param period Float number of days in one period.
  629. #' @param fourier.order Int number of Fourier components to use.
  630. #' @param prior.scale Optional float prior scale for this component.
  631. #' @param mode Optional 'additive' or 'multiplicative'.
  632. #'
  633. #' @return The prophet model with the seasonality added.
  634. #'
  635. #' @importFrom dplyr "%>%"
  636. #' @export
  637. add_seasonality <- function(
  638. m, name, period, fourier.order, prior.scale = NULL, mode = NULL
  639. ) {
  640. if (!is.null(m$history)) {
  641. stop("Seasonality must be added prior to model fitting.")
  642. }
  643. if (!(name %in% c('daily', 'weekly', 'yearly'))) {
  644. # Allow overriding built-in seasonalities
  645. validate_column_name(m, name, check_seasonalities = FALSE)
  646. }
  647. if (is.null(prior.scale)) {
  648. ps <- m$seasonality.prior.scale
  649. } else {
  650. ps <- prior.scale
  651. }
  652. if (ps <= 0) {
  653. stop('Prior scale must be > 0')
  654. }
  655. if (is.null(mode)) {
  656. mode <- m$seasonality.mode
  657. }
  658. if (!(mode %in% c('additive', 'multiplicative'))) {
  659. stop("mode must be 'additive' or 'multiplicative'")
  660. }
  661. m$seasonalities[[name]] <- list(
  662. period = period,
  663. fourier.order = fourier.order,
  664. prior.scale = ps,
  665. mode = mode
  666. )
  667. return(m)
  668. }
  669. #' Dataframe with seasonality features.
  670. #' Includes seasonality features, holiday features, and added regressors.
  671. #'
  672. #' @param m Prophet object.
  673. #' @param df Dataframe with dates for computing seasonality features and any
  674. #' added regressors.
  675. #'
  676. #' @return List with items
  677. #' seasonal.features: Dataframe with regressor features,
  678. #' prior.scales: Array of prior scales for each colum of the features
  679. #' dataframe.
  680. #' component.cols: Dataframe with indicators for which regression components
  681. #' correspond to which columns.
  682. #' modes: List with keys 'additive' and 'multiplicative' with arrays of
  683. #' component names for each mode of seasonality.
  684. #'
  685. #' @keywords internal
  686. make_all_seasonality_features <- function(m, df) {
  687. seasonal.features <- data.frame(row.names = seq_len(nrow(df)))
  688. prior.scales <- c()
  689. modes <- list(additive = c(), multiplicative = c())
  690. # Seasonality features
  691. for (name in names(m$seasonalities)) {
  692. props <- m$seasonalities[[name]]
  693. features <- make_seasonality_features(
  694. df$ds, props$period, props$fourier.order, name)
  695. seasonal.features <- cbind(seasonal.features, features)
  696. prior.scales <- c(prior.scales,
  697. props$prior.scale * rep(1, ncol(features)))
  698. modes[[props$mode]] <- c(modes[[props$mode]], name)
  699. }
  700. # Holiday features
  701. if (!is.null(m$holidays)) {
  702. hf <- make_holiday_features(m, df$ds)
  703. seasonal.features <- cbind(seasonal.features, hf$holiday.features)
  704. prior.scales <- c(prior.scales, hf$prior.scales)
  705. modes[[m$seasonality.mode]] <- c(
  706. modes[[m$seasonality.mode]], hf$holiday.names)
  707. }
  708. # Additional regressors
  709. for (name in names(m$extra_regressors)) {
  710. props <- m$extra_regressors[[name]]
  711. seasonal.features[[name]] <- df[[name]]
  712. prior.scales <- c(prior.scales, props$prior.scale)
  713. modes[[props$mode]] <- c(modes[[props$mode]], name)
  714. }
  715. # Dummy to prevent empty X
  716. if (ncol(seasonal.features) == 0) {
  717. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  718. prior.scales <- 1
  719. }
  720. components.list <- regressor_column_matrix(m, seasonal.features, modes)
  721. return(list(seasonal.features = seasonal.features,
  722. prior.scales = prior.scales,
  723. component.cols = components.list$component.cols,
  724. modes = components.list$modes))
  725. }
  726. #' Dataframe indicating which columns of the feature matrix correspond to
  727. #' which seasonality/regressor components.
  728. #'
  729. #' Includes combination components, like 'additive_terms'. These combination
  730. #' components will be added to the 'modes' input.
  731. #'
  732. #' @param m Prophet object.
  733. #' @param seasonal.features Constructed seasonal features dataframe.
  734. #' @param modes List with keys 'additive' and 'multiplicative' with arrays of
  735. #' component names for each mode of seasonality.
  736. #'
  737. #' @return List with items
  738. #' component.cols: A binary indicator dataframe with columns seasonal
  739. #' components and rows columns in seasonal.features. Entry is 1 if that
  740. #' column is used in that component.
  741. #' modes: Updated input with combination components.
  742. #'
  743. #' @keywords internal
  744. regressor_column_matrix <- function(m, seasonal.features, modes) {
  745. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  746. dplyr::mutate(col = seq_len(n())) %>%
  747. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  748. extra = "merge", fill = "right") %>%
  749. dplyr::select(col, component)
  750. # Add total for holidays
  751. if(!is.null(m$holidays)){
  752. components <- add_group_component(
  753. components, 'holidays', unique(m$holidays$holiday))
  754. }
  755. # Add totals for additive and multiplicative components, and regressors
  756. for (mode in c('additive', 'multiplicative')) {
  757. components <- add_group_component(
  758. components, paste0(mode, '_terms'), modes[[mode]])
  759. regressors_by_mode <- c()
  760. for (name in names(m$extra_regressors)) {
  761. if (m$extra_regressors[[name]]$mode == mode) {
  762. regressors_by_mode <- c(regressors_by_mode, name)
  763. }
  764. }
  765. components <- add_group_component(
  766. components, paste0('extra_regressors_', mode), regressors_by_mode)
  767. # Add combination components to modes
  768. modes[[mode]] <- c(modes[[mode]], paste0(mode, '_terms'))
  769. modes[[mode]] <- c(modes[[mode]], paste0('extra_regressors_', mode))
  770. }
  771. # After all of the additive/multiplicative groups have been added,
  772. modes[[m$seasonality.mode]] <- c(modes[[m$seasonality.mode]], 'holidays')
  773. # Convert to a binary matrix
  774. component.cols <- as.data.frame.matrix(
  775. table(components$col, components$component)
  776. )
  777. component.cols <- (
  778. component.cols[order(as.numeric(row.names(component.cols))), ,
  779. drop = FALSE]
  780. )
  781. # Add columns for additive and multiplicative terms, if missing
  782. for (name in c('additive_terms', 'multiplicative_terms')) {
  783. if (!(name %in% colnames(component.cols))) {
  784. component.cols[[name]] <- 0
  785. }
  786. }
  787. # Remove the placeholder
  788. components <- dplyr::filter(components, component != 'zeros')
  789. # Validation
  790. if (
  791. max(component.cols$additive_terms
  792. + component.cols$multiplicative_terms) > 1
  793. ) {
  794. stop('A bug occurred in seasonal components.')
  795. }
  796. # Compare to training, if set.
  797. if (!is.null(m$train.component.cols)) {
  798. component.cols <- component.cols[, colnames(m$train.component.cols)]
  799. if (!all(component.cols == m$train.component.cols)) {
  800. stop('A bug occurred in constructing regressors.')
  801. }
  802. }
  803. return(list(component.cols = component.cols, modes = modes))
  804. }
  805. #' Adds a component with given name that contains all of the components
  806. #' in group.
  807. #'
  808. #' @param components Dataframe with components.
  809. #' @param name Name of new group component.
  810. #' @param group List of components that form the group.
  811. #'
  812. #' @return Dataframe with components.
  813. #'
  814. #' @keywords internal
  815. add_group_component <- function(components, name, group) {
  816. new_comp <- components[(components$component %in% group), ]
  817. group_cols <- unique(new_comp$col)
  818. if (length(group_cols) > 0) {
  819. new_comp <- data.frame(col=group_cols, component=name)
  820. components <- rbind(components, new_comp)
  821. }
  822. return(components)
  823. }
  824. #' Get number of Fourier components for built-in seasonalities.
  825. #'
  826. #' @param m Prophet object.
  827. #' @param name String name of the seasonality component.
  828. #' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
  829. #' provided.
  830. #' @param auto.disable Bool if seasonality should be disabled when 'auto'.
  831. #' @param default.order Int default Fourier order.
  832. #'
  833. #' @return Number of Fourier components, or 0 for disabled.
  834. #'
  835. #' @keywords internal
  836. parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
  837. if (arg == 'auto') {
  838. fourier.order <- 0
  839. if (name %in% names(m$seasonalities)) {
  840. message('Found custom seasonality named "', name,
  841. '", disabling built-in ', name, ' seasonality.')
  842. } else if (auto.disable) {
  843. message('Disabling ', name, ' seasonality. Run prophet with ', name,
  844. '.seasonality=TRUE to override this.')
  845. } else {
  846. fourier.order <- default.order
  847. }
  848. } else if (arg == TRUE) {
  849. fourier.order <- default.order
  850. } else if (arg == FALSE) {
  851. fourier.order <- 0
  852. } else {
  853. fourier.order <- arg
  854. }
  855. return(fourier.order)
  856. }
  857. #' Set seasonalities that were left on auto.
  858. #'
  859. #' Turns on yearly seasonality if there is >=2 years of history.
  860. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  861. #' spacing between dates in the history is <7 days.
  862. #' Turns on daily seasonality if there is >=2 days of history, and the spacing
  863. #' between dates in the history is <1 day.
  864. #'
  865. #' @param m Prophet object.
  866. #'
  867. #' @return The prophet model with seasonalities set.
  868. #'
  869. #' @keywords internal
  870. set_auto_seasonalities <- function(m) {
  871. first <- min(m$history$ds)
  872. last <- max(m$history$ds)
  873. dt <- diff(time_diff(m$history$ds, m$start))
  874. min.dt <- min(dt[dt > 0])
  875. yearly.disable <- time_diff(last, first) < 730
  876. fourier.order <- parse_seasonality_args(
  877. m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
  878. if (fourier.order > 0) {
  879. m$seasonalities[['yearly']] <- list(
  880. period = 365.25,
  881. fourier.order = fourier.order,
  882. prior.scale = m$seasonality.prior.scale,
  883. mode = m$seasonality.mode
  884. )
  885. }
  886. weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
  887. fourier.order <- parse_seasonality_args(
  888. m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
  889. if (fourier.order > 0) {
  890. m$seasonalities[['weekly']] <- list(
  891. period = 7,
  892. fourier.order = fourier.order,
  893. prior.scale = m$seasonality.prior.scale,
  894. mode = m$seasonality.mode
  895. )
  896. }
  897. daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
  898. fourier.order <- parse_seasonality_args(
  899. m, 'daily', m$daily.seasonality, daily.disable, 4)
  900. if (fourier.order > 0) {
  901. m$seasonalities[['daily']] <- list(
  902. period = 1,
  903. fourier.order = fourier.order,
  904. prior.scale = m$seasonality.prior.scale,
  905. mode = m$seasonality.mode
  906. )
  907. }
  908. return(m)
  909. }
  910. #' Initialize linear growth.
  911. #'
  912. #' Provides a strong initialization for linear growth by calculating the
  913. #' growth and offset parameters that pass the function through the first and
  914. #' last points in the time series.
  915. #'
  916. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  917. #' and t (scaled time).
  918. #'
  919. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  920. #' growth function.
  921. #'
  922. #' @keywords internal
  923. linear_growth_init <- function(df) {
  924. i0 <- which.min(df$ds)
  925. i1 <- which.max(df$ds)
  926. T <- df$t[i1] - df$t[i0]
  927. # Initialize the rate
  928. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  929. # And the offset
  930. m <- df$y_scaled[i0] - k * df$t[i0]
  931. return(c(k, m))
  932. }
  933. #' Initialize logistic growth.
  934. #'
  935. #' Provides a strong initialization for logistic growth by calculating the
  936. #' growth and offset parameters that pass the function through the first and
  937. #' last points in the time series.
  938. #'
  939. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  940. #' y_scaled (scaled time series), and t (scaled time).
  941. #'
  942. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  943. #' growth function.
  944. #'
  945. #' @keywords internal
  946. logistic_growth_init <- function(df) {
  947. i0 <- which.min(df$ds)
  948. i1 <- which.max(df$ds)
  949. T <- df$t[i1] - df$t[i0]
  950. # Force valid values, in case y > cap or y < 0
  951. C0 <- df$cap_scaled[i0]
  952. C1 <- df$cap_scaled[i1]
  953. y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
  954. y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
  955. r0 <- C0 / y0
  956. r1 <- C1 / y1
  957. if (abs(r0 - r1) <= 0.01) {
  958. r0 <- 1.05 * r0
  959. }
  960. L0 <- log(r0 - 1)
  961. L1 <- log(r1 - 1)
  962. # Initialize the offset
  963. m <- L0 * T / (L0 - L1)
  964. # And the rate
  965. k <- (L0 - L1) / T
  966. return(c(k, m))
  967. }
  968. #' Fit the prophet model.
  969. #'
  970. #' This sets m$params to contain the fitted model parameters. It is a list
  971. #' with the following elements:
  972. #' k (M array): M posterior samples of the initial slope.
  973. #' m (M array): The initial intercept.
  974. #' delta (MxN matrix): The slope change at each of N changepoints.
  975. #' beta (MxK matrix): Coefficients for K seasonality features.
  976. #' sigma_obs (M array): Noise level.
  977. #' Note that M=1 if MAP estimation.
  978. #'
  979. #' @param m Prophet object.
  980. #' @param df Data frame.
  981. #' @param ... Additional arguments passed to the \code{optimizing} or
  982. #' \code{sampling} functions in Stan.
  983. #'
  984. #' @export
  985. fit.prophet <- function(m, df, ...) {
  986. if (!is.null(m$history)) {
  987. stop("Prophet object can only be fit once. Instantiate a new object.")
  988. }
  989. if (!(exists('ds', where = df)) | !(exists('y', where = df))) {
  990. stop(paste(
  991. "Dataframe must have columns 'ds' and 'y' with the dates and values",
  992. "respectively."
  993. ))
  994. }
  995. history <- df %>%
  996. dplyr::filter(!is.na(y))
  997. if (nrow(history) < 2) {
  998. stop("Dataframe has less than 2 non-NA rows.")
  999. }
  1000. m$history.dates <- sort(set_date(df$ds))
  1001. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  1002. history <- out$df
  1003. m <- out$m
  1004. m$history <- history
  1005. m <- set_auto_seasonalities(m)
  1006. out2 <- make_all_seasonality_features(m, history)
  1007. seasonal.features <- out2$seasonal.features
  1008. prior.scales <- out2$prior.scales
  1009. component.cols <- out2$component.cols
  1010. m$train.component.cols <- component.cols
  1011. m$component.modes <- out2$modes
  1012. m <- set_changepoints(m)
  1013. # Construct input to stan
  1014. dat <- list(
  1015. T = nrow(history),
  1016. K = ncol(seasonal.features),
  1017. S = length(m$changepoints.t),
  1018. y = history$y_scaled,
  1019. t = history$t,
  1020. t_change = array(m$changepoints.t),
  1021. X = as.matrix(seasonal.features),
  1022. sigmas = array(prior.scales),
  1023. tau = m$changepoint.prior.scale,
  1024. trend_indicator = as.numeric(m$growth == 'logistic'),
  1025. s_a = array(component.cols$additive_terms),
  1026. s_m = array(component.cols$multiplicative_terms)
  1027. )
  1028. # Run stan
  1029. if (m$growth == 'linear') {
  1030. dat$cap <- rep(0, nrow(history)) # Unused inside Stan
  1031. kinit <- linear_growth_init(history)
  1032. } else {
  1033. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  1034. kinit <- logistic_growth_init(history)
  1035. }
  1036. if (exists(".prophet.stan.model")) {
  1037. model <- .prophet.stan.model
  1038. } else {
  1039. model <- get_prophet_stan_model()
  1040. }
  1041. stan_init <- function() {
  1042. list(k = kinit[1],
  1043. m = kinit[2],
  1044. delta = array(rep(0, length(m$changepoints.t))),
  1045. beta = array(rep(0, ncol(seasonal.features))),
  1046. sigma_obs = 1
  1047. )
  1048. }
  1049. if (min(history$y) == max(history$y)) {
  1050. # Nothing to fit.
  1051. m$params <- stan_init()
  1052. m$params$sigma_obs <- 0.
  1053. n.iteration <- 1.
  1054. } else if (m$mcmc.samples > 0) {
  1055. stan.fit <- rstan::sampling(
  1056. model,
  1057. data = dat,
  1058. init = stan_init,
  1059. iter = m$mcmc.samples,
  1060. ...
  1061. )
  1062. m$params <- rstan::extract(stan.fit)
  1063. n.iteration <- length(m$params$k)
  1064. } else {
  1065. stan.fit <- rstan::optimizing(
  1066. model,
  1067. data = dat,
  1068. init = stan_init,
  1069. iter = 1e4,
  1070. as_vector = FALSE,
  1071. ...
  1072. )
  1073. m$params <- stan.fit$par
  1074. n.iteration <- 1
  1075. }
  1076. # Cast the parameters to have consistent form, whether full bayes or MAP
  1077. for (name in c('delta', 'beta')){
  1078. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  1079. }
  1080. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  1081. for (name in c('k', 'm', 'sigma_obs')){
  1082. m$params[[name]] <- c(m$params[[name]])
  1083. }
  1084. # If no changepoints were requested, replace delta with 0s
  1085. if (m$n.changepoints == 0) {
  1086. # Fold delta into the base rate k
  1087. m$params$k <- m$params$k + m$params$delta[, 1]
  1088. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  1089. }
  1090. return(m)
  1091. }
  1092. #' Predict using the prophet model.
  1093. #'
  1094. #' @param object Prophet object.
  1095. #' @param df Dataframe with dates for predictions (column ds), and capacity
  1096. #' (column cap) if logistic growth. If not provided, predictions are made on
  1097. #' the history.
  1098. #' @param ... additional arguments.
  1099. #'
  1100. #' @return A dataframe with the forecast components.
  1101. #'
  1102. #' @examples
  1103. #' \dontrun{
  1104. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  1105. #' y = sin(1:366/200) + rnorm(366)/10)
  1106. #' m <- prophet(history)
  1107. #' future <- make_future_dataframe(m, periods = 365)
  1108. #' forecast <- predict(m, future)
  1109. #' plot(m, forecast)
  1110. #' }
  1111. #'
  1112. #' @export
  1113. predict.prophet <- function(object, df = NULL, ...) {
  1114. if (is.null(df)) {
  1115. df <- object$history
  1116. } else {
  1117. if (nrow(df) == 0) {
  1118. stop("Dataframe has no rows.")
  1119. }
  1120. out <- setup_dataframe(object, df)
  1121. df <- out$df
  1122. }
  1123. df$trend <- predict_trend(object, df)
  1124. seasonal.components <- predict_seasonal_components(object, df)
  1125. intervals <- predict_uncertainty(object, df)
  1126. # Drop columns except ds, cap, floor, and trend
  1127. cols <- c('ds', 'trend')
  1128. if ('cap' %in% colnames(df)) {
  1129. cols <- c(cols, 'cap')
  1130. }
  1131. if (object$logistic.floor) {
  1132. cols <- c(cols, 'floor')
  1133. }
  1134. df <- df[cols]
  1135. df <- dplyr::bind_cols(df, seasonal.components, intervals)
  1136. df$yhat <- df$trend * (1 + df$multiplicative_terms) + df$additive_terms
  1137. return(df)
  1138. }
  1139. #' Evaluate the piecewise linear function.
  1140. #'
  1141. #' @param t Vector of times on which the function is evaluated.
  1142. #' @param deltas Vector of rate changes at each changepoint.
  1143. #' @param k Float initial rate.
  1144. #' @param m Float initial offset.
  1145. #' @param changepoint.ts Vector of changepoint times.
  1146. #'
  1147. #' @return Vector y(t).
  1148. #'
  1149. #' @keywords internal
  1150. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  1151. # Intercept changes
  1152. gammas <- -changepoint.ts * deltas
  1153. # Get cumulative slope and intercept at each t
  1154. k_t <- rep(k, length(t))
  1155. m_t <- rep(m, length(t))
  1156. for (s in seq_along(changepoint.ts)) {
  1157. indx <- t >= changepoint.ts[s]
  1158. k_t[indx] <- k_t[indx] + deltas[s]
  1159. m_t[indx] <- m_t[indx] + gammas[s]
  1160. }
  1161. y <- k_t * t + m_t
  1162. return(y)
  1163. }
  1164. #' Evaluate the piecewise logistic function.
  1165. #'
  1166. #' @param t Vector of times on which the function is evaluated.
  1167. #' @param cap Vector of capacities at each t.
  1168. #' @param deltas Vector of rate changes at each changepoint.
  1169. #' @param k Float initial rate.
  1170. #' @param m Float initial offset.
  1171. #' @param changepoint.ts Vector of changepoint times.
  1172. #'
  1173. #' @return Vector y(t).
  1174. #'
  1175. #' @keywords internal
  1176. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  1177. # Compute offset changes
  1178. k.cum <- c(k, cumsum(deltas) + k)
  1179. gammas <- rep(0, length(changepoint.ts))
  1180. for (i in seq_along(changepoint.ts)) {
  1181. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  1182. * (1 - k.cum[i] / k.cum[i + 1]))
  1183. }
  1184. # Get cumulative rate and offset at each t
  1185. k_t <- rep(k, length(t))
  1186. m_t <- rep(m, length(t))
  1187. for (s in seq_along(changepoint.ts)) {
  1188. indx <- t >= changepoint.ts[s]
  1189. k_t[indx] <- k_t[indx] + deltas[s]
  1190. m_t[indx] <- m_t[indx] + gammas[s]
  1191. }
  1192. y <- cap / (1 + exp(-k_t * (t - m_t)))
  1193. return(y)
  1194. }
  1195. #' Predict trend using the prophet model.
  1196. #'
  1197. #' @param model Prophet object.
  1198. #' @param df Prediction dataframe.
  1199. #'
  1200. #' @return Vector with trend on prediction dates.
  1201. #'
  1202. #' @keywords internal
  1203. predict_trend <- function(model, df) {
  1204. k <- mean(model$params$k, na.rm = TRUE)
  1205. param.m <- mean(model$params$m, na.rm = TRUE)
  1206. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  1207. t <- df$t
  1208. if (model$growth == 'linear') {
  1209. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  1210. } else {
  1211. cap <- df$cap_scaled
  1212. trend <- piecewise_logistic(
  1213. t, cap, deltas, k, param.m, model$changepoints.t)
  1214. }
  1215. return(trend * model$y.scale + df$floor)
  1216. }
  1217. #' Predict seasonality components, holidays, and added regressors.
  1218. #'
  1219. #' @param m Prophet object.
  1220. #' @param df Prediction dataframe.
  1221. #'
  1222. #' @return Dataframe with seasonal components.
  1223. #'
  1224. #' @keywords internal
  1225. predict_seasonal_components <- function(m, df) {
  1226. out <- make_all_seasonality_features(m, df)
  1227. seasonal.features <- out$seasonal.features
  1228. component.cols <- out$component.cols
  1229. lower.p <- (1 - m$interval.width)/2
  1230. upper.p <- (1 + m$interval.width)/2
  1231. X <- as.matrix(seasonal.features)
  1232. component.predictions <- data.frame(matrix(ncol = 0, nrow = nrow(X)))
  1233. for (component in colnames(component.cols)) {
  1234. beta.c <- m$params$beta * component.cols[[component]]
  1235. comp <- X %*% t(beta.c)
  1236. if (component %in% m$component.modes$additive) {
  1237. comp <- comp * m$y.scale
  1238. }
  1239. component.predictions[[component]] <- rowMeans(comp, na.rm = TRUE)
  1240. component.predictions[[paste0(component, '_lower')]] <- apply(
  1241. comp, 1, stats::quantile, lower.p, na.rm = TRUE)
  1242. component.predictions[[paste0(component, '_upper')]] <- apply(
  1243. comp, 1, stats::quantile, upper.p, na.rm = TRUE)
  1244. }
  1245. return(component.predictions)
  1246. }
  1247. #' Prophet posterior predictive samples.
  1248. #'
  1249. #' @param m Prophet object.
  1250. #' @param df Prediction dataframe.
  1251. #'
  1252. #' @return List with posterior predictive samples for the forecast yhat and
  1253. #' for the trend component.
  1254. #'
  1255. #' @keywords internal
  1256. sample_posterior_predictive <- function(m, df) {
  1257. # Sample trend, seasonality, and yhat from the extrapolation model.
  1258. n.iterations <- length(m$params$k)
  1259. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  1260. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  1261. out <- make_all_seasonality_features(m, df)
  1262. seasonal.features <- out$seasonal.features
  1263. component.cols <- out$component.cols
  1264. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  1265. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  1266. for (i in seq_len(n.iterations)) {
  1267. # For each set of parameters from MCMC (or just 1 set for MAP),
  1268. for (j in seq_len(samp.per.iter)) {
  1269. # Do a simulation with this set of parameters,
  1270. sim <- sample_model(
  1271. m = m,
  1272. df = df,
  1273. seasonal.features = seasonal.features,
  1274. iteration = i,
  1275. s_a = component.cols$additive_terms,
  1276. s_m = component.cols$multiplicative_terms
  1277. )
  1278. # Store the results
  1279. for (key in c("trend", "yhat")) {
  1280. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  1281. }
  1282. }
  1283. }
  1284. return(sim.values)
  1285. }
  1286. #' Sample from the posterior predictive distribution.
  1287. #'
  1288. #' @param m Prophet object.
  1289. #' @param df Dataframe with dates for predictions (column ds), and capacity
  1290. #' (column cap) if logistic growth.
  1291. #'
  1292. #' @return A list with items "trend" and "yhat" containing
  1293. #' posterior predictive samples for that component.
  1294. #'
  1295. #' @export
  1296. predictive_samples <- function(m, df) {
  1297. df <- setup_dataframe(m, df)$df
  1298. sim.values <- sample_posterior_predictive(m, df)
  1299. return(sim.values)
  1300. }
  1301. #' Prophet uncertainty intervals for yhat and trend
  1302. #'
  1303. #' @param m Prophet object.
  1304. #' @param df Prediction dataframe.
  1305. #'
  1306. #' @return Dataframe with uncertainty intervals.
  1307. #'
  1308. #' @keywords internal
  1309. predict_uncertainty <- function(m, df) {
  1310. sim.values <- sample_posterior_predictive(m, df)
  1311. # Add uncertainty estimates
  1312. lower.p <- (1 - m$interval.width)/2
  1313. upper.p <- (1 + m$interval.width)/2
  1314. intervals <- cbind(
  1315. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  1316. na.rm = TRUE)),
  1317. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  1318. na.rm = TRUE))
  1319. ) %>% dplyr::as_data_frame()
  1320. colnames(intervals) <- paste(rep(c('yhat', 'trend'), each=2),
  1321. c('lower', 'upper'), sep = "_")
  1322. return(intervals)
  1323. }
  1324. #' Simulate observations from the extrapolated generative model.
  1325. #'
  1326. #' @param m Prophet object.
  1327. #' @param df Prediction dataframe.
  1328. #' @param seasonal.features Data frame of seasonal features
  1329. #' @param iteration Int sampling iteration to use parameters from.
  1330. #' @param s_a Indicator vector for additive components
  1331. #' @param s_m Indicator vector for multiplicative components
  1332. #'
  1333. #' @return List of trend and yhat, each a vector like df$t.
  1334. #'
  1335. #' @keywords internal
  1336. sample_model <- function(m, df, seasonal.features, iteration, s_a, s_m) {
  1337. trend <- sample_predictive_trend(m, df, iteration)
  1338. beta <- m$params$beta[iteration,]
  1339. Xb_a = as.matrix(seasonal.features) %*% (beta * s_a) * m$y.scale
  1340. Xb_m = as.matrix(seasonal.features) %*% (beta * s_m)
  1341. sigma <- m$params$sigma_obs[iteration]
  1342. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  1343. return(list("yhat" = trend * (1 + Xb_m) + Xb_a + noise,
  1344. "trend" = trend))
  1345. }
  1346. #' Simulate the trend using the extrapolated generative model.
  1347. #'
  1348. #' @param model Prophet object.
  1349. #' @param df Prediction dataframe.
  1350. #' @param iteration Int sampling iteration to use parameters from.
  1351. #'
  1352. #' @return Vector of simulated trend over df$t.
  1353. #'
  1354. #' @keywords internal
  1355. sample_predictive_trend <- function(model, df, iteration) {
  1356. k <- model$params$k[iteration]
  1357. param.m <- model$params$m[iteration]
  1358. deltas <- model$params$delta[iteration,]
  1359. t <- df$t
  1360. T <- max(t)
  1361. # New changepoints from a Poisson process with rate S on [1, T]
  1362. if (T > 1) {
  1363. S <- length(model$changepoints.t)
  1364. n.changes <- stats::rpois(1, S * (T - 1))
  1365. } else {
  1366. n.changes <- 0
  1367. }
  1368. if (n.changes > 0) {
  1369. changepoint.ts.new <- 1 + stats::runif(n.changes) * (T - 1)
  1370. changepoint.ts.new <- sort(changepoint.ts.new)
  1371. } else {
  1372. changepoint.ts.new <- c()
  1373. }
  1374. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  1375. lambda <- mean(abs(c(deltas))) + 1e-8
  1376. # Sample deltas
  1377. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  1378. # Combine with changepoints from the history
  1379. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  1380. deltas <- c(deltas, deltas.new)
  1381. # Get the corresponding trend
  1382. if (model$growth == 'linear') {
  1383. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  1384. } else {
  1385. cap <- df$cap_scaled
  1386. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  1387. }
  1388. return(trend * model$y.scale + df$floor)
  1389. }
  1390. #' Make dataframe with future dates for forecasting.
  1391. #'
  1392. #' @param m Prophet model object.
  1393. #' @param periods Int number of periods to forecast forward.
  1394. #' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
  1395. #' @param include_history Boolean to include the historical dates in the data
  1396. #' frame for predictions.
  1397. #'
  1398. #' @return Dataframe that extends forward from the end of m$history for the
  1399. #' requested number of periods.
  1400. #'
  1401. #' @export
  1402. make_future_dataframe <- function(m, periods, freq = 'day',
  1403. include_history = TRUE) {
  1404. # For backwards compatability with previous zoo date type,
  1405. if (freq == 'm') {
  1406. freq <- 'month'
  1407. }
  1408. if (is.null(m$history.dates)) {
  1409. stop('Model must be fit before this can be used.')
  1410. }
  1411. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  1412. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  1413. if (include_history) {
  1414. dates <- c(m$history.dates, dates)
  1415. attr(dates, "tzone") <- "GMT"
  1416. }
  1417. return(data.frame(ds = dates))
  1418. }