prophet.R 56 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746
  1. ## Copyright (c) 2017-present, Facebook, Inc.
  2. ## All rights reserved.
  3. ## This source code is licensed under the BSD-style license found in the
  4. ## LICENSE file in the root directory of this source tree. An additional grant
  5. ## of patent rights can be found in the PATENTS file in the same directory.
  6. ## Makes R CMD CHECK happy due to dplyr syntax below
  7. globalVariables(c(
  8. "ds", "y", "cap", ".",
  9. "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
  10. "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
  11. "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
  12. "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
  13. #' Prophet forecaster.
  14. #'
  15. #' @param df (optional) Dataframe containing the history. Must have columns ds
  16. #' (date type) and y, the time series. If growth is logistic, then df must
  17. #' also have a column cap that specifies the capacity at each ds. If not
  18. #' provided, then the model object will be instantiated but not fit; use
  19. #' fit.prophet(m, df) to fit the model.
  20. #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
  21. #' trend.
  22. #' @param changepoints Vector of dates at which to include potential
  23. #' changepoints. If not specified, potential changepoints are selected
  24. #' automatically.
  25. #' @param n.changepoints Number of potential changepoints to include. Not used
  26. #' if input `changepoints` is supplied. If `changepoints` is not supplied,
  27. #' then n.changepoints potential changepoints are selected uniformly from the
  28. #' first 80 percent of df$ds.
  29. #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
  30. #' FALSE, or a number of Fourier terms to generate.
  31. #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
  32. #' FALSE, or a number of Fourier terms to generate.
  33. #' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
  34. #' FALSE, or a number of Fourier terms to generate.
  35. #' @param holidays data frame with columns holiday (character) and ds (date
  36. #' type)and optionally columns lower_window and upper_window which specify a
  37. #' range of days around the date to be included as holidays. lower_window=-2
  38. #' will include 2 days prior to the date as holidays. Also optionally can have
  39. #' a column prior_scale specifying the prior scale for each holiday.
  40. #' @param seasonality.prior.scale Parameter modulating the strength of the
  41. #' seasonality model. Larger values allow the model to fit larger seasonal
  42. #' fluctuations, smaller values dampen the seasonality. Can be specified for
  43. #' individual seasonalities using add_seasonality.
  44. #' @param holidays.prior.scale Parameter modulating the strength of the holiday
  45. #' components model, unless overridden in the holidays input.
  46. #' @param changepoint.prior.scale Parameter modulating the flexibility of the
  47. #' automatic changepoint selection. Large values will allow many changepoints,
  48. #' small values will allow few changepoints.
  49. #' @param mcmc.samples Integer, if greater than 0, will do full Bayesian
  50. #' inference with the specified number of MCMC samples. If 0, will do MAP
  51. #' estimation.
  52. #' @param interval.width Numeric, width of the uncertainty intervals provided
  53. #' for the forecast. If mcmc.samples=0, this will be only the uncertainty
  54. #' in the trend using the MAP estimate of the extrapolated generative model.
  55. #' If mcmc.samples>0, this will be integrated over all model parameters,
  56. #' which will include uncertainty in seasonality.
  57. #' @param uncertainty.samples Number of simulated draws used to estimate
  58. #' uncertainty intervals.
  59. #' @param fit Boolean, if FALSE the model is initialized but not fit.
  60. #' @param ... Additional arguments, passed to \code{\link{fit.prophet}}
  61. #'
  62. #' @return A prophet model.
  63. #'
  64. #' @examples
  65. #' \dontrun{
  66. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  67. #' y = sin(1:366/200) + rnorm(366)/10)
  68. #' m <- prophet(history)
  69. #' }
  70. #'
  71. #' @export
  72. #' @importFrom dplyr "%>%"
  73. #' @import Rcpp
  74. prophet <- function(df = NULL,
  75. growth = 'linear',
  76. changepoints = NULL,
  77. n.changepoints = 25,
  78. yearly.seasonality = 'auto',
  79. weekly.seasonality = 'auto',
  80. daily.seasonality = 'auto',
  81. holidays = NULL,
  82. seasonality.prior.scale = 10,
  83. holidays.prior.scale = 10,
  84. changepoint.prior.scale = 0.05,
  85. mcmc.samples = 0,
  86. interval.width = 0.80,
  87. uncertainty.samples = 1000,
  88. fit = TRUE,
  89. ...
  90. ) {
  91. # fb-block 1
  92. if (!is.null(changepoints)) {
  93. n.changepoints <- length(changepoints)
  94. }
  95. m <- list(
  96. growth = growth,
  97. changepoints = changepoints,
  98. n.changepoints = n.changepoints,
  99. yearly.seasonality = yearly.seasonality,
  100. weekly.seasonality = weekly.seasonality,
  101. daily.seasonality = daily.seasonality,
  102. holidays = holidays,
  103. seasonality.prior.scale = seasonality.prior.scale,
  104. changepoint.prior.scale = changepoint.prior.scale,
  105. holidays.prior.scale = holidays.prior.scale,
  106. mcmc.samples = mcmc.samples,
  107. interval.width = interval.width,
  108. uncertainty.samples = uncertainty.samples,
  109. specified.changepoints = !is.null(changepoints),
  110. start = NULL, # This and following attributes are set during fitting
  111. y.scale = NULL,
  112. logistic.floor = FALSE,
  113. t.scale = NULL,
  114. changepoints.t = NULL,
  115. seasonalities = list(),
  116. extra_regressors = list(),
  117. stan.fit = NULL,
  118. params = list(),
  119. history = NULL,
  120. history.dates = NULL
  121. )
  122. validate_inputs(m)
  123. class(m) <- append("prophet", class(m))
  124. if ((fit) && (!is.null(df))) {
  125. m <- fit.prophet(m, df, ...)
  126. }
  127. # fb-block 2
  128. return(m)
  129. }
  130. #' Validates the inputs to Prophet.
  131. #'
  132. #' @param m Prophet object.
  133. #'
  134. #' @keywords internal
  135. validate_inputs <- function(m) {
  136. if (!(m$growth %in% c('linear', 'logistic'))) {
  137. stop("Parameter 'growth' should be 'linear' or 'logistic'.")
  138. }
  139. if (!is.null(m$holidays)) {
  140. if (!(exists('holiday', where = m$holidays))) {
  141. stop('Holidays dataframe must have holiday field.')
  142. }
  143. if (!(exists('ds', where = m$holidays))) {
  144. stop('Holidays dataframe must have ds field.')
  145. }
  146. has.lower <- exists('lower_window', where = m$holidays)
  147. has.upper <- exists('upper_window', where = m$holidays)
  148. if (has.lower + has.upper == 1) {
  149. stop(paste('Holidays must have both lower_window and upper_window,',
  150. 'or neither.'))
  151. }
  152. if (has.lower) {
  153. if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
  154. stop('Holiday lower_window should be <= 0')
  155. }
  156. if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
  157. stop('Holiday upper_window should be >= 0')
  158. }
  159. }
  160. for (h in unique(m$holidays$holiday)) {
  161. validate_column_name(m, h, check_holidays = FALSE)
  162. }
  163. }
  164. }
  165. #' Validates the name of a seasonality, holiday, or regressor.
  166. #'
  167. #' @param m Prophet object.
  168. #' @param name string
  169. #' @param check_holidays bool check if name already used for holiday
  170. #' @param check_seasonalities bool check if name already used for seasonality
  171. #' @param check_regressors bool check if name already used for regressor
  172. #'
  173. #' @keywords internal
  174. validate_column_name <- function(
  175. m, name, check_holidays = TRUE, check_seasonalities = TRUE,
  176. check_regressors = TRUE
  177. ) {
  178. if (grepl("_delim_", name)) {
  179. stop('Holiday name cannot contain "_delim_"')
  180. }
  181. reserved_names = c(
  182. 'trend', 'seasonal', 'seasonalities', 'daily', 'weekly', 'yearly',
  183. 'holidays', 'zeros', 'extra_regressors', 'yhat'
  184. )
  185. rn_l = paste(reserved_names,"_lower",sep="")
  186. rn_u = paste(reserved_names,"_upper",sep="")
  187. reserved_names = c(reserved_names, rn_l, rn_u,
  188. c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
  189. if(name %in% reserved_names){
  190. stop("Name ", name, " is reserved.")
  191. }
  192. if(check_holidays & !is.null(m$holidays) &
  193. (name %in% unique(m$holidays$holiday))){
  194. stop("Name ", name, " already used for a holiday.")
  195. }
  196. if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){
  197. stop("Name ", name, " already used for a seasonality.")
  198. }
  199. if(check_regressors & (!is.null(m$seasonalities[[name]]))){
  200. stop("Name ", name, " already used for an added regressor.")
  201. }
  202. }
  203. #' Load compiled Stan model
  204. #'
  205. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  206. #' trend.
  207. #'
  208. #' @return Stan model.
  209. #'
  210. #' @keywords internal
  211. get_prophet_stan_model <- function(model) {
  212. fn <- paste('prophet', model, 'growth.RData', sep = '_')
  213. ## If the cached model doesn't work, just compile a new one.
  214. tryCatch({
  215. binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
  216. package = 'prophet',
  217. mustWork = TRUE)
  218. load(binary)
  219. obj.name <- paste(model, 'growth.stanm', sep = '.')
  220. stanm <- eval(parse(text = obj.name))
  221. ## Should cause an error if the model doesn't work.
  222. stanm@mk_cppmodule(stanm)
  223. stanm
  224. }, error = function(cond) {
  225. compile_stan_model(model)
  226. })
  227. }
  228. #' Compile Stan model
  229. #'
  230. #' @param model String 'linear' or 'logistic' to specify a linear or logistic
  231. #' trend.
  232. #'
  233. #' @return Stan model.
  234. #'
  235. #' @keywords internal
  236. compile_stan_model <- function(model) {
  237. fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
  238. stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
  239. stanc <- rstan::stanc(stan.src)
  240. model.name <- paste(model, 'growth', sep = '_')
  241. return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
  242. }
  243. #' Convert date vector
  244. #'
  245. #' Convert the date to POSIXct object
  246. #'
  247. #' @param ds Date vector, can be consisted of characters
  248. #' @param tz string time zone
  249. #'
  250. #' @return vector of POSIXct object converted from date
  251. #'
  252. #' @keywords internal
  253. set_date <- function(ds = NULL, tz = "GMT") {
  254. if (length(ds) == 0) {
  255. return(NULL)
  256. }
  257. if (is.factor(ds)) {
  258. ds <- as.character(ds)
  259. }
  260. if (min(nchar(ds)) < 12) {
  261. ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
  262. } else {
  263. ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
  264. }
  265. attr(ds, "tzone") <- tz
  266. return(ds)
  267. }
  268. #' Time difference between datetimes
  269. #'
  270. #' Compute time difference of two POSIXct objects
  271. #'
  272. #' @param ds1 POSIXct object
  273. #' @param ds2 POSIXct object
  274. #' @param units string units of difference, e.g. 'days' or 'secs'.
  275. #'
  276. #' @return numeric time difference
  277. #'
  278. #' @keywords internal
  279. time_diff <- function(ds1, ds2, units = "days") {
  280. return(as.numeric(difftime(ds1, ds2, units = units)))
  281. }
  282. #' Prepare dataframe for fitting or predicting.
  283. #'
  284. #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
  285. #' 'y_scaled', and 'cap_scaled'. These columns are used during both fitting
  286. #' and predicting.
  287. #'
  288. #' @param m Prophet object.
  289. #' @param df Data frame with columns ds, y, and cap if logistic growth. Any
  290. #' specified additional regressors must also be present.
  291. #' @param initialize_scales Boolean set scaling factors in m from df.
  292. #'
  293. #' @return list with items 'df' and 'm'.
  294. #'
  295. #' @keywords internal
  296. setup_dataframe <- function(m, df, initialize_scales = FALSE) {
  297. if (exists('y', where=df)) {
  298. df$y <- as.numeric(df$y)
  299. }
  300. if (any(is.infinite(df$y))) {
  301. stop("Found infinity in column y.")
  302. }
  303. df$ds <- set_date(df$ds)
  304. if (anyNA(df$ds)) {
  305. stop(paste('Unable to parse date format in column ds. Convert to date ',
  306. 'format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S'))
  307. }
  308. for (name in names(m$extra_regressors)) {
  309. if (!(name %in% colnames(df))) {
  310. stop('Regressor "', name, '" missing from dataframe')
  311. }
  312. }
  313. df <- df %>%
  314. dplyr::arrange(ds)
  315. m <- initialize_scales_fn(m, initialize_scales, df)
  316. if (m$logistic.floor) {
  317. if (!('floor' %in% colnames(df))) {
  318. stop("Expected column 'floor'.")
  319. }
  320. } else {
  321. df$floor <- 0
  322. }
  323. if (m$growth == 'logistic') {
  324. if (!(exists('cap', where=df))) {
  325. stop('Capacities must be supplied for logistic growth.')
  326. }
  327. df <- df %>%
  328. dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
  329. }
  330. df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
  331. if (exists('y', where=df)) {
  332. df$y_scaled <- (df$y - df$floor) / m$y.scale
  333. }
  334. for (name in names(m$extra_regressors)) {
  335. df[[name]] <- as.numeric(df[[name]])
  336. props <- m$extra_regressors[[name]]
  337. df[[name]] <- (df[[name]] - props$mu) / props$std
  338. if (anyNA(df[[name]])) {
  339. stop('Found NaN in column ', name)
  340. }
  341. }
  342. return(list("m" = m, "df" = df))
  343. }
  344. #' Initialize model scales.
  345. #'
  346. #' Sets model scaling factors using df.
  347. #'
  348. #' @param m Prophet object.
  349. #' @param initialize_scales Boolean set the scales or not.
  350. #' @param df Dataframe for setting scales.
  351. #'
  352. #' @return Prophet object with scales set.
  353. #'
  354. #' @keywords internal
  355. initialize_scales_fn <- function(m, initialize_scales, df) {
  356. if (!initialize_scales) {
  357. return(m)
  358. }
  359. if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
  360. m$logistic.floor <- TRUE
  361. floor <- df$floor
  362. } else {
  363. floor <- 0
  364. }
  365. m$y.scale <- max(abs(df$y - floor))
  366. if (m$y.scale == 0) {
  367. m$y.scale <- 1
  368. }
  369. m$start <- min(df$ds)
  370. m$t.scale <- time_diff(max(df$ds), m$start, "secs")
  371. for (name in names(m$extra_regressors)) {
  372. n.vals <- length(unique(df[[name]]))
  373. if (n.vals < 2) {
  374. stop('Regressor ', name, ' is constant.')
  375. }
  376. standardize <- m$extra_regressors[[name]]$standardize
  377. if (standardize == 'auto') {
  378. if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
  379. # Don't standardize binary variables
  380. standardize <- FALSE
  381. } else {
  382. standardize <- TRUE
  383. }
  384. }
  385. if (standardize) {
  386. mu <- mean(df[[name]])
  387. std <- stats::sd(df[[name]])
  388. m$extra_regressors[[name]]$mu <- mu
  389. m$extra_regressors[[name]]$std <- std
  390. }
  391. }
  392. return(m)
  393. }
  394. #' Set changepoints
  395. #'
  396. #' Sets m$changepoints to the dates of changepoints. Either:
  397. #' 1) The changepoints were passed in explicitly.
  398. #' A) They are empty.
  399. #' B) They are not empty, and need validation.
  400. #' 2) We are generating a grid of them.
  401. #' 3) The user prefers no changepoints be used.
  402. #'
  403. #' @param m Prophet object.
  404. #'
  405. #' @return m with changepoints set.
  406. #'
  407. #' @keywords internal
  408. set_changepoints <- function(m) {
  409. if (!is.null(m$changepoints)) {
  410. if (length(m$changepoints) > 0) {
  411. m$changepoints <- set_date(m$changepoints)
  412. if (min(m$changepoints) < min(m$history$ds)
  413. || max(m$changepoints) > max(m$history$ds)) {
  414. stop('Changepoints must fall within training data.')
  415. }
  416. }
  417. } else {
  418. # Place potential changepoints evenly through the first 80 pcnt of
  419. # the history.
  420. hist.size <- floor(nrow(m$history) * .8)
  421. if (m$n.changepoints + 1 > hist.size) {
  422. m$n.changepoints <- hist.size - 1
  423. message('n.changepoints greater than number of observations. Using ',
  424. m$n.changepoints)
  425. }
  426. if (m$n.changepoints > 0) {
  427. cp.indexes <- round(seq.int(1, hist.size,
  428. length.out = (m$n.changepoints + 1))[-1])
  429. m$changepoints <- m$history$ds[cp.indexes]
  430. } else {
  431. m$changepoints <- c()
  432. }
  433. }
  434. if (length(m$changepoints) > 0) {
  435. m$changepoints.t <- sort(
  436. time_diff(m$changepoints, m$start, "secs")) / m$t.scale
  437. } else {
  438. m$changepoints.t <- c(0) # dummy changepoint
  439. }
  440. return(m)
  441. }
  442. #' Gets changepoint matrix for history dataframe.
  443. #'
  444. #' @param m Prophet object.
  445. #'
  446. #' @return array of indexes.
  447. #'
  448. #' @keywords internal
  449. get_changepoint_matrix <- function(m) {
  450. A <- matrix(0, nrow(m$history), length(m$changepoints.t))
  451. for (i in 1:length(m$changepoints.t)) {
  452. A[m$history$t >= m$changepoints.t[i], i] <- 1
  453. }
  454. return(A)
  455. }
  456. #' Provides Fourier series components with the specified frequency and order.
  457. #'
  458. #' @param dates Vector of dates.
  459. #' @param period Number of days of the period.
  460. #' @param series.order Number of components.
  461. #'
  462. #' @return Matrix with seasonality features.
  463. #'
  464. #' @keywords internal
  465. fourier_series <- function(dates, period, series.order) {
  466. t <- time_diff(dates, set_date('1970-01-01 00:00:00'))
  467. features <- matrix(0, length(t), 2 * series.order)
  468. for (i in 1:series.order) {
  469. x <- as.numeric(2 * i * pi * t / period)
  470. features[, i * 2 - 1] <- sin(x)
  471. features[, i * 2] <- cos(x)
  472. }
  473. return(features)
  474. }
  475. #' Data frame with seasonality features.
  476. #'
  477. #' @param dates Vector of dates.
  478. #' @param period Number of days of the period.
  479. #' @param series.order Number of components.
  480. #' @param prefix Column name prefix.
  481. #'
  482. #' @return Dataframe with seasonality.
  483. #'
  484. #' @keywords internal
  485. make_seasonality_features <- function(dates, period, series.order, prefix) {
  486. features <- fourier_series(dates, period, series.order)
  487. colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
  488. return(data.frame(features))
  489. }
  490. #' Construct a matrix of holiday features.
  491. #'
  492. #' @param m Prophet object.
  493. #' @param dates Vector with dates used for computing seasonality.
  494. #'
  495. #' @return A list with entries
  496. #' holiday.features: dataframe with a column for each holiday.
  497. #' prior.scales: array of prior scales for each holiday column.
  498. #'
  499. #' @importFrom dplyr "%>%"
  500. #' @keywords internal
  501. make_holiday_features <- function(m, dates) {
  502. # Strip dates to be just days, for joining on holidays
  503. dates <- set_date(format(dates, "%Y-%m-%d"))
  504. wide <- m$holidays %>%
  505. dplyr::mutate(ds = set_date(ds)) %>%
  506. dplyr::group_by(holiday, ds) %>%
  507. dplyr::filter(row_number() == 1) %>%
  508. dplyr::do({
  509. if (exists('lower_window', where = .) && !is.na(.$lower_window)
  510. && !is.na(.$upper_window)) {
  511. offsets <- seq(.$lower_window, .$upper_window)
  512. } else {
  513. offsets <- c(0)
  514. }
  515. names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
  516. abs(offsets), sep = '')
  517. dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
  518. }) %>%
  519. dplyr::mutate(x = 1.) %>%
  520. tidyr::spread(holiday, x, fill = 0)
  521. holiday.features <- data.frame(ds = set_date(dates)) %>%
  522. dplyr::left_join(wide, by = 'ds') %>%
  523. dplyr::select(-ds)
  524. holiday.features[is.na(holiday.features)] <- 0
  525. # Prior scales
  526. if (!('prior_scale' %in% colnames(m$holidays))) {
  527. m$holidays$prior_scale <- m$holidays.prior.scale
  528. }
  529. prior.scales.list <- list()
  530. for (name in unique(m$holidays$holiday)) {
  531. df.h <- m$holidays[m$holidays$holiday == name, ]
  532. ps <- unique(df.h$prior_scale)
  533. if (length(ps) > 1) {
  534. stop('Holiday ', name, ' does not have a consistent prior scale ',
  535. 'specification')
  536. }
  537. if (is.na(ps)) {
  538. ps <- m$holidays.prior.scale
  539. }
  540. if (ps <= 0) {
  541. stop('Prior scale must be > 0.')
  542. }
  543. prior.scales.list[[name]] <- ps
  544. }
  545. prior.scales <- c()
  546. for (name in colnames(holiday.features)) {
  547. sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
  548. prior.scales <- c(prior.scales, prior.scales.list[[sn]])
  549. }
  550. return(list(holiday.features = holiday.features,
  551. prior.scales = prior.scales))
  552. }
  553. #' Add an additional regressor to be used for fitting and predicting.
  554. #'
  555. #' The dataframe passed to `fit` and `predict` will have a column with the
  556. #' specified name to be used as a regressor. When standardize='auto', the
  557. #' regressor will be standardized unless it is binary. The regression
  558. #' coefficient is given a prior with the specified scale parameter.
  559. #' Decreasing the prior scale will add additional regularization. If no
  560. #' prior scale is provided, holidays.prior.scale will be used.
  561. #'
  562. #' @param m Prophet object.
  563. #' @param name String name of the regressor
  564. #' @param prior.scale Float scale for the normal prior. If not provided,
  565. #' holidays.prior.scale will be used.
  566. #' @param standardize Bool, specify whether this regressor will be standardized
  567. #' prior to fitting. Can be 'auto' (standardize if not binary), True, or
  568. #' False.
  569. #'
  570. #' @return The prophet model with the regressor added.
  571. #'
  572. #' @export
  573. add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){
  574. if (!is.null(m$history)) {
  575. stop('Regressors must be added prior to model fitting.')
  576. }
  577. validate_column_name(m, name, check_regressors = FALSE)
  578. if (is.null(prior.scale)) {
  579. prior.scale <- m$holidays.prior.scale
  580. }
  581. if(prior.scale <= 0) {
  582. stop("Prior scale must be > 0")
  583. }
  584. m$extra_regressors[[name]] <- list(
  585. prior.scale = prior.scale,
  586. standardize = standardize,
  587. mu = 0,
  588. std = 1.0
  589. )
  590. return(m)
  591. }
  592. #' Add a seasonal component with specified period, number of Fourier
  593. #' components, and prior scale.
  594. #'
  595. #' Increasing the number of Fourier components allows the seasonality to change
  596. #' more quickly (at risk of overfitting). Default values for yearly and weekly
  597. #' seasonalities are 10 and 3 respectively.
  598. #'
  599. #' Increasing prior scale will allow this seasonality component more
  600. #' flexibility, decreasing will dampen it. If not provided, will use the
  601. #' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
  602. #'
  603. #' @param m Prophet object.
  604. #' @param name String name of the seasonality component.
  605. #' @param period Float number of days in one period.
  606. #' @param fourier.order Int number of Fourier components to use.
  607. #' @param prior.scale Float prior scale for this component.
  608. #'
  609. #' @return The prophet model with the seasonality added.
  610. #'
  611. #' @importFrom dplyr "%>%"
  612. #' @export
  613. add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) {
  614. if (!is.null(m$history)) {
  615. stop("Seasonality must be added prior to model fitting.")
  616. }
  617. if (!(name %in% c('daily', 'weekly', 'yearly'))) {
  618. # Allow overriding built-in seasonalities
  619. validate_column_name(m, name, check_seasonalities = FALSE)
  620. }
  621. if (is.null(prior.scale)) {
  622. ps <- m$seasonality.prior.scale
  623. } else {
  624. ps <- prior.scale
  625. }
  626. if (ps <= 0) {
  627. stop('Prior scale must be > 0')
  628. }
  629. m$seasonalities[[name]] <- list(
  630. period = period,
  631. fourier.order = fourier.order,
  632. prior.scale = ps
  633. )
  634. return(m)
  635. }
  636. #' Dataframe with seasonality features.
  637. #' Includes seasonality features, holiday features, and added regressors.
  638. #'
  639. #' @param m Prophet object.
  640. #' @param df Dataframe with dates for computing seasonality features and any
  641. #' added regressors.
  642. #'
  643. #' @return List with items
  644. #' seasonal.features: Dataframe with regressor features,
  645. #' prior.scales: Array of prior scales for each colum of the features
  646. #' dataframe.
  647. #'
  648. #' @keywords internal
  649. make_all_seasonality_features <- function(m, df) {
  650. seasonal.features <- data.frame(row.names = 1:nrow(df))
  651. prior.scales <- c()
  652. # Seasonality features
  653. for (name in names(m$seasonalities)) {
  654. props <- m$seasonalities[[name]]
  655. features <- make_seasonality_features(
  656. df$ds, props$period, props$fourier.order, name)
  657. seasonal.features <- cbind(seasonal.features, features)
  658. prior.scales <- c(prior.scales,
  659. props$prior.scale * rep(1, ncol(features)))
  660. }
  661. # Holiday features
  662. if (!is.null(m$holidays)) {
  663. hf <- make_holiday_features(m, df$ds)
  664. seasonal.features <- cbind(seasonal.features, hf$holiday.features)
  665. prior.scales <- c(prior.scales, hf$prior.scales)
  666. }
  667. # Additional regressors
  668. for (name in names(m$extra_regressors)) {
  669. seasonal.features[[name]] <- df[[name]]
  670. prior.scales <- c(prior.scales, m$extra_regressors[[name]]$prior.scale)
  671. }
  672. if (ncol(seasonal.features) == 0) {
  673. seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
  674. prior.scales <- c(1.)
  675. }
  676. return(list(seasonal.features = seasonal.features,
  677. prior.scales = prior.scales))
  678. }
  679. #' Get number of Fourier components for built-in seasonalities.
  680. #'
  681. #' @param m Prophet object.
  682. #' @param name String name of the seasonality component.
  683. #' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
  684. #' provided.
  685. #' @param auto.disable Bool if seasonality should be disabled when 'auto'.
  686. #' @param default.order Int default Fourier order.
  687. #'
  688. #' @return Number of Fourier components, or 0 for disabled.
  689. #'
  690. #' @keywords internal
  691. parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
  692. if (arg == 'auto') {
  693. fourier.order <- 0
  694. if (name %in% names(m$seasonalities)) {
  695. message('Found custom seasonality named "', name,
  696. '", disabling built-in ', name, ' seasonality.')
  697. } else if (auto.disable) {
  698. message('Disabling ', name, ' seasonality. Run prophet with ', name,
  699. '.seasonality=TRUE to override this.')
  700. } else {
  701. fourier.order <- default.order
  702. }
  703. } else if (arg == TRUE) {
  704. fourier.order <- default.order
  705. } else if (arg == FALSE) {
  706. fourier.order <- 0
  707. } else {
  708. fourier.order <- arg
  709. }
  710. return(fourier.order)
  711. }
  712. #' Set seasonalities that were left on auto.
  713. #'
  714. #' Turns on yearly seasonality if there is >=2 years of history.
  715. #' Turns on weekly seasonality if there is >=2 weeks of history, and the
  716. #' spacing between dates in the history is <7 days.
  717. #' Turns on daily seasonality if there is >=2 days of history, and the spacing
  718. #' between dates in the history is <1 day.
  719. #'
  720. #' @param m Prophet object.
  721. #'
  722. #' @return The prophet model with seasonalities set.
  723. #'
  724. #' @keywords internal
  725. set_auto_seasonalities <- function(m) {
  726. first <- min(m$history$ds)
  727. last <- max(m$history$ds)
  728. dt <- diff(time_diff(m$history$ds, m$start))
  729. min.dt <- min(dt[dt > 0])
  730. yearly.disable <- time_diff(last, first) < 730
  731. fourier.order <- parse_seasonality_args(
  732. m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
  733. if (fourier.order > 0) {
  734. m$seasonalities[['yearly']] <- list(
  735. period = 365.25,
  736. fourier.order = fourier.order,
  737. prior.scale = m$seasonality.prior.scale
  738. )
  739. }
  740. weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
  741. fourier.order <- parse_seasonality_args(
  742. m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
  743. if (fourier.order > 0) {
  744. m$seasonalities[['weekly']] <- list(
  745. period = 7,
  746. fourier.order = fourier.order,
  747. prior.scale = m$seasonality.prior.scale
  748. )
  749. }
  750. daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
  751. fourier.order <- parse_seasonality_args(
  752. m, 'daily', m$daily.seasonality, daily.disable, 4)
  753. if (fourier.order > 0) {
  754. m$seasonalities[['daily']] <- list(
  755. period = 1,
  756. fourier.order = fourier.order,
  757. prior.scale = m$seasonality.prior.scale
  758. )
  759. }
  760. return(m)
  761. }
  762. #' Initialize linear growth.
  763. #'
  764. #' Provides a strong initialization for linear growth by calculating the
  765. #' growth and offset parameters that pass the function through the first and
  766. #' last points in the time series.
  767. #'
  768. #' @param df Data frame with columns ds (date), y_scaled (scaled time series),
  769. #' and t (scaled time).
  770. #'
  771. #' @return A vector (k, m) with the rate (k) and offset (m) of the linear
  772. #' growth function.
  773. #'
  774. #' @keywords internal
  775. linear_growth_init <- function(df) {
  776. i0 <- which.min(df$ds)
  777. i1 <- which.max(df$ds)
  778. T <- df$t[i1] - df$t[i0]
  779. # Initialize the rate
  780. k <- (df$y_scaled[i1] - df$y_scaled[i0]) / T
  781. # And the offset
  782. m <- df$y_scaled[i0] - k * df$t[i0]
  783. return(c(k, m))
  784. }
  785. #' Initialize logistic growth.
  786. #'
  787. #' Provides a strong initialization for logistic growth by calculating the
  788. #' growth and offset parameters that pass the function through the first and
  789. #' last points in the time series.
  790. #'
  791. #' @param df Data frame with columns ds (date), cap_scaled (scaled capacity),
  792. #' y_scaled (scaled time series), and t (scaled time).
  793. #'
  794. #' @return A vector (k, m) with the rate (k) and offset (m) of the logistic
  795. #' growth function.
  796. #'
  797. #' @keywords internal
  798. logistic_growth_init <- function(df) {
  799. i0 <- which.min(df$ds)
  800. i1 <- which.max(df$ds)
  801. T <- df$t[i1] - df$t[i0]
  802. # Force valid values, in case y > cap or y < 0
  803. C0 <- df$cap_scaled[i0]
  804. C1 <- df$cap_scaled[i1]
  805. y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
  806. y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
  807. r0 <- C0 / y0
  808. r1 <- C1 / y1
  809. if (abs(r0 - r1) <= 0.01) {
  810. r0 <- 1.05 * r0
  811. }
  812. L0 <- log(r0 - 1)
  813. L1 <- log(r1 - 1)
  814. # Initialize the offset
  815. m <- L0 * T / (L0 - L1)
  816. # And the rate
  817. k <- (L0 - L1) / T
  818. return(c(k, m))
  819. }
  820. #' Fit the prophet model.
  821. #'
  822. #' This sets m$params to contain the fitted model parameters. It is a list
  823. #' with the following elements:
  824. #' k (M array): M posterior samples of the initial slope.
  825. #' m (M array): The initial intercept.
  826. #' delta (MxN matrix): The slope change at each of N changepoints.
  827. #' beta (MxK matrix): Coefficients for K seasonality features.
  828. #' sigma_obs (M array): Noise level.
  829. #' Note that M=1 if MAP estimation.
  830. #'
  831. #' @param m Prophet object.
  832. #' @param df Data frame.
  833. #' @param ... Additional arguments passed to the \code{optimizing} or
  834. #' \code{sampling} functions in Stan.
  835. #'
  836. #' @export
  837. fit.prophet <- function(m, df, ...) {
  838. if (!is.null(m$history)) {
  839. stop("Prophet object can only be fit once. Instantiate a new object.")
  840. }
  841. history <- df %>%
  842. dplyr::filter(!is.na(y))
  843. if (nrow(history) < 2) {
  844. stop("Dataframe has less than 2 non-NA rows.")
  845. }
  846. m$history.dates <- sort(set_date(df$ds))
  847. out <- setup_dataframe(m, history, initialize_scales = TRUE)
  848. history <- out$df
  849. m <- out$m
  850. m$history <- history
  851. m <- set_auto_seasonalities(m)
  852. out2 <- make_all_seasonality_features(m, history)
  853. seasonal.features <- out2$seasonal.features
  854. prior.scales <- out2$prior.scales
  855. m <- set_changepoints(m)
  856. A <- get_changepoint_matrix(m)
  857. # Construct input to stan
  858. dat <- list(
  859. T = nrow(history),
  860. K = ncol(seasonal.features),
  861. S = length(m$changepoints.t),
  862. y = history$y_scaled,
  863. t = history$t,
  864. A = A,
  865. t_change = array(m$changepoints.t),
  866. X = as.matrix(seasonal.features),
  867. sigmas = array(prior.scales),
  868. tau = m$changepoint.prior.scale
  869. )
  870. # Run stan
  871. if (m$growth == 'linear') {
  872. kinit <- linear_growth_init(history)
  873. } else {
  874. dat$cap <- history$cap_scaled # Add capacities to the Stan data
  875. kinit <- logistic_growth_init(history)
  876. }
  877. if (exists(".prophet.stan.models")) {
  878. model <- .prophet.stan.models[[m$growth]]
  879. } else {
  880. model <- get_prophet_stan_model(m$growth)
  881. }
  882. stan_init <- function() {
  883. list(k = kinit[1],
  884. m = kinit[2],
  885. delta = array(rep(0, length(m$changepoints.t))),
  886. beta = array(rep(0, ncol(seasonal.features))),
  887. sigma_obs = 1
  888. )
  889. }
  890. if (min(history$y) == max(history$y)) {
  891. # Nothing to fit.
  892. m$params <- stan_init()
  893. m$params$sigma_obs <- 0.
  894. n.iteration <- 1.
  895. } else if (m$mcmc.samples > 0) {
  896. stan.fit <- rstan::sampling(
  897. model,
  898. data = dat,
  899. init = stan_init,
  900. iter = m$mcmc.samples,
  901. ...
  902. )
  903. m$params <- rstan::extract(stan.fit)
  904. n.iteration <- length(m$params$k)
  905. } else {
  906. stan.fit <- rstan::optimizing(
  907. model,
  908. data = dat,
  909. init = stan_init,
  910. iter = 1e4,
  911. as_vector = FALSE,
  912. ...
  913. )
  914. m$params <- stan.fit$par
  915. n.iteration <- 1
  916. }
  917. # Cast the parameters to have consistent form, whether full bayes or MAP
  918. for (name in c('delta', 'beta')){
  919. m$params[[name]] <- matrix(m$params[[name]], nrow = n.iteration)
  920. }
  921. # rstan::sampling returns 1d arrays; converts to atomic vectors.
  922. for (name in c('k', 'm', 'sigma_obs')){
  923. m$params[[name]] <- c(m$params[[name]])
  924. }
  925. # If no changepoints were requested, replace delta with 0s
  926. if (m$n.changepoints == 0) {
  927. # Fold delta into the base rate k
  928. m$params$k <- m$params$k + m$params$delta[, 1]
  929. m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
  930. }
  931. return(m)
  932. }
  933. #' Predict using the prophet model.
  934. #'
  935. #' @param object Prophet object.
  936. #' @param df Dataframe with dates for predictions (column ds), and capacity
  937. #' (column cap) if logistic growth. If not provided, predictions are made on
  938. #' the history.
  939. #' @param ... additional arguments.
  940. #'
  941. #' @return A dataframe with the forecast components.
  942. #'
  943. #' @examples
  944. #' \dontrun{
  945. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  946. #' y = sin(1:366/200) + rnorm(366)/10)
  947. #' m <- prophet(history)
  948. #' future <- make_future_dataframe(m, periods = 365)
  949. #' forecast <- predict(m, future)
  950. #' plot(m, forecast)
  951. #' }
  952. #'
  953. #' @export
  954. predict.prophet <- function(object, df = NULL, ...) {
  955. if (is.null(df)) {
  956. df <- object$history
  957. } else {
  958. if (nrow(df) == 0) {
  959. stop("Dataframe has no rows.")
  960. }
  961. out <- setup_dataframe(object, df)
  962. df <- out$df
  963. }
  964. df$trend <- predict_trend(object, df)
  965. seasonal.components <- predict_seasonal_components(object, df)
  966. intervals <- predict_uncertainty(object, df)
  967. # Drop columns except ds, cap, floor, and trend
  968. cols <- c('ds', 'trend')
  969. if ('cap' %in% colnames(df)) {
  970. cols <- c(cols, 'cap')
  971. }
  972. if (object$logistic.floor) {
  973. cols <- c(cols, 'floor')
  974. }
  975. df <- df[cols]
  976. df <- df %>%
  977. dplyr::bind_cols(seasonal.components) %>%
  978. dplyr::bind_cols(intervals)
  979. df$yhat <- df$trend + df$seasonal
  980. return(df)
  981. }
  982. #' Evaluate the piecewise linear function.
  983. #'
  984. #' @param t Vector of times on which the function is evaluated.
  985. #' @param deltas Vector of rate changes at each changepoint.
  986. #' @param k Float initial rate.
  987. #' @param m Float initial offset.
  988. #' @param changepoint.ts Vector of changepoint times.
  989. #'
  990. #' @return Vector y(t).
  991. #'
  992. #' @keywords internal
  993. piecewise_linear <- function(t, deltas, k, m, changepoint.ts) {
  994. # Intercept changes
  995. gammas <- -changepoint.ts * deltas
  996. # Get cumulative slope and intercept at each t
  997. k_t <- rep(k, length(t))
  998. m_t <- rep(m, length(t))
  999. for (s in 1:length(changepoint.ts)) {
  1000. indx <- t >= changepoint.ts[s]
  1001. k_t[indx] <- k_t[indx] + deltas[s]
  1002. m_t[indx] <- m_t[indx] + gammas[s]
  1003. }
  1004. y <- k_t * t + m_t
  1005. return(y)
  1006. }
  1007. #' Evaluate the piecewise logistic function.
  1008. #'
  1009. #' @param t Vector of times on which the function is evaluated.
  1010. #' @param cap Vector of capacities at each t.
  1011. #' @param deltas Vector of rate changes at each changepoint.
  1012. #' @param k Float initial rate.
  1013. #' @param m Float initial offset.
  1014. #' @param changepoint.ts Vector of changepoint times.
  1015. #'
  1016. #' @return Vector y(t).
  1017. #'
  1018. #' @keywords internal
  1019. piecewise_logistic <- function(t, cap, deltas, k, m, changepoint.ts) {
  1020. # Compute offset changes
  1021. k.cum <- c(k, cumsum(deltas) + k)
  1022. gammas <- rep(0, length(changepoint.ts))
  1023. for (i in 1:length(changepoint.ts)) {
  1024. gammas[i] <- ((changepoint.ts[i] - m - sum(gammas))
  1025. * (1 - k.cum[i] / k.cum[i + 1]))
  1026. }
  1027. # Get cumulative rate and offset at each t
  1028. k_t <- rep(k, length(t))
  1029. m_t <- rep(m, length(t))
  1030. for (s in 1:length(changepoint.ts)) {
  1031. indx <- t >= changepoint.ts[s]
  1032. k_t[indx] <- k_t[indx] + deltas[s]
  1033. m_t[indx] <- m_t[indx] + gammas[s]
  1034. }
  1035. y <- cap / (1 + exp(-k_t * (t - m_t)))
  1036. return(y)
  1037. }
  1038. #' Predict trend using the prophet model.
  1039. #'
  1040. #' @param model Prophet object.
  1041. #' @param df Prediction dataframe.
  1042. #'
  1043. #' @return Vector with trend on prediction dates.
  1044. #'
  1045. #' @keywords internal
  1046. predict_trend <- function(model, df) {
  1047. k <- mean(model$params$k, na.rm = TRUE)
  1048. param.m <- mean(model$params$m, na.rm = TRUE)
  1049. deltas <- colMeans(model$params$delta, na.rm = TRUE)
  1050. t <- df$t
  1051. if (model$growth == 'linear') {
  1052. trend <- piecewise_linear(t, deltas, k, param.m, model$changepoints.t)
  1053. } else {
  1054. cap <- df$cap_scaled
  1055. trend <- piecewise_logistic(
  1056. t, cap, deltas, k, param.m, model$changepoints.t)
  1057. }
  1058. return(trend * model$y.scale + df$floor)
  1059. }
  1060. #' Predict seasonality components, holidays, and added regressors.
  1061. #'
  1062. #' @param m Prophet object.
  1063. #' @param df Prediction dataframe.
  1064. #'
  1065. #' @return Dataframe with seasonal components.
  1066. #'
  1067. #' @keywords internal
  1068. predict_seasonal_components <- function(m, df) {
  1069. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1070. lower.p <- (1 - m$interval.width)/2
  1071. upper.p <- (1 + m$interval.width)/2
  1072. components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
  1073. dplyr::mutate(col = 1:n()) %>%
  1074. tidyr::separate(component, c('component', 'part'), sep = "_delim_",
  1075. extra = "merge", fill = "right") %>%
  1076. dplyr::select(col, component)
  1077. # Add total for all regression components
  1078. components <- rbind(
  1079. components,
  1080. data.frame(col = 1:ncol(seasonal.features), component = 'seasonal'))
  1081. # Add totals for seasonality, holiday, and extra regressors
  1082. components <- add_group_component(
  1083. components, 'seasonalities', names(m$seasonalities))
  1084. if(!is.null(m$holidays)){
  1085. components <- add_group_component(
  1086. components, 'holidays', unique(m$holidays$holiday))
  1087. }
  1088. components <- add_group_component(
  1089. components, 'extra_regressors', names(m$extra_regressors))
  1090. # Remove the placeholder
  1091. components <- dplyr::filter(components, component != 'zeros')
  1092. component.predictions <- components %>%
  1093. dplyr::group_by(component) %>% dplyr::do({
  1094. comp <- (as.matrix(seasonal.features[, .$col])
  1095. %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
  1096. dplyr::data_frame(ix = 1:nrow(seasonal.features),
  1097. mean = rowMeans(comp, na.rm = TRUE),
  1098. lower = apply(comp, 1, stats::quantile, lower.p,
  1099. na.rm = TRUE),
  1100. upper = apply(comp, 1, stats::quantile, upper.p,
  1101. na.rm = TRUE))
  1102. }) %>%
  1103. tidyr::gather(stat, value, mean, lower, upper) %>%
  1104. dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
  1105. tidyr::unite(component, component, stat, sep="") %>%
  1106. tidyr::spread(component, value) %>%
  1107. dplyr::select(-ix)
  1108. return(component.predictions)
  1109. }
  1110. #' Adds a component with given name that contains all of the components
  1111. #' in group.
  1112. #'
  1113. #' @param components Dataframe with components.
  1114. #' @param name Name of new group component.
  1115. #' @param group List of components that form the group.
  1116. #'
  1117. #' @return Dataframe with components.
  1118. #'
  1119. #' @keywords internal
  1120. add_group_component <- function(components, name, group) {
  1121. new_comp <- components[(components$component %in% group), ]
  1122. if (nrow(new_comp) > 0) {
  1123. new_comp$component <- name
  1124. components <- rbind(components, new_comp)
  1125. }
  1126. return(components)
  1127. }
  1128. #' Prophet posterior predictive samples.
  1129. #'
  1130. #' @param m Prophet object.
  1131. #' @param df Prediction dataframe.
  1132. #'
  1133. #' @return List with posterior predictive samples for each component.
  1134. #'
  1135. #' @keywords internal
  1136. sample_posterior_predictive <- function(m, df) {
  1137. # Sample trend, seasonality, and yhat from the extrapolation model.
  1138. n.iterations <- length(m$params$k)
  1139. samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
  1140. nsamp <- n.iterations * samp.per.iter # The actual number of samples
  1141. seasonal.features <- make_all_seasonality_features(m, df)$seasonal.features
  1142. sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
  1143. "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
  1144. "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
  1145. for (i in 1:n.iterations) {
  1146. # For each set of parameters from MCMC (or just 1 set for MAP),
  1147. for (j in 1:samp.per.iter) {
  1148. # Do a simulation with this set of parameters,
  1149. sim <- sample_model(m, df, seasonal.features, i)
  1150. # Store the results
  1151. for (key in c("trend", "seasonal", "yhat")) {
  1152. sim.values[[key]][,(i - 1) * samp.per.iter + j] <- sim[[key]]
  1153. }
  1154. }
  1155. }
  1156. return(sim.values)
  1157. }
  1158. #' Sample from the posterior predictive distribution.
  1159. #'
  1160. #' @param m Prophet object.
  1161. #' @param df Dataframe with dates for predictions (column ds), and capacity
  1162. #' (column cap) if logistic growth.
  1163. #'
  1164. #' @return A list with items "trend", "seasonal", and "yhat" containing
  1165. #' posterior predictive samples for that component. "seasonal" is the sum
  1166. #' of seasonalities, holidays, and added regressors.
  1167. #'
  1168. #' @export
  1169. predictive_samples <- function(m, df) {
  1170. df <- setup_dataframe(m, df)$df
  1171. sim.values <- sample_posterior_predictive(m, df)
  1172. return(sim.values)
  1173. }
  1174. #' Prophet uncertainty intervals for yhat and trend
  1175. #'
  1176. #' @param m Prophet object.
  1177. #' @param df Prediction dataframe.
  1178. #'
  1179. #' @return Dataframe with uncertainty intervals.
  1180. #'
  1181. #' @keywords internal
  1182. predict_uncertainty <- function(m, df) {
  1183. sim.values <- sample_posterior_predictive(m, df)
  1184. # Add uncertainty estimates
  1185. lower.p <- (1 - m$interval.width)/2
  1186. upper.p <- (1 + m$interval.width)/2
  1187. intervals <- cbind(
  1188. t(apply(t(sim.values$yhat), 2, stats::quantile, c(lower.p, upper.p),
  1189. na.rm = TRUE)),
  1190. t(apply(t(sim.values$trend), 2, stats::quantile, c(lower.p, upper.p),
  1191. na.rm = TRUE))
  1192. ) %>% dplyr::as_data_frame()
  1193. colnames(intervals) <- paste(rep(c('yhat', 'trend'), each=2),
  1194. c('lower', 'upper'), sep = "_")
  1195. return(intervals)
  1196. }
  1197. #' Simulate observations from the extrapolated generative model.
  1198. #'
  1199. #' @param m Prophet object.
  1200. #' @param df Prediction dataframe.
  1201. #' @param seasonal.features Data frame of seasonal features
  1202. #' @param iteration Int sampling iteration to use parameters from.
  1203. #'
  1204. #' @return List of trend, seasonality, and yhat, each a vector like df$t.
  1205. #'
  1206. #' @keywords internal
  1207. sample_model <- function(m, df, seasonal.features, iteration) {
  1208. trend <- sample_predictive_trend(m, df, iteration)
  1209. beta <- m$params$beta[iteration,]
  1210. seasonal <- (as.matrix(seasonal.features) %*% beta) * m$y.scale
  1211. sigma <- m$params$sigma_obs[iteration]
  1212. noise <- stats::rnorm(nrow(df), mean = 0, sd = sigma) * m$y.scale
  1213. return(list("yhat" = trend + seasonal + noise,
  1214. "trend" = trend,
  1215. "seasonal" = seasonal))
  1216. }
  1217. #' Simulate the trend using the extrapolated generative model.
  1218. #'
  1219. #' @param model Prophet object.
  1220. #' @param df Prediction dataframe.
  1221. #' @param iteration Int sampling iteration to use parameters from.
  1222. #'
  1223. #' @return Vector of simulated trend over df$t.
  1224. #'
  1225. #' @keywords internal
  1226. sample_predictive_trend <- function(model, df, iteration) {
  1227. k <- model$params$k[iteration]
  1228. param.m <- model$params$m[iteration]
  1229. deltas <- model$params$delta[iteration,]
  1230. t <- df$t
  1231. T <- max(t)
  1232. if (T > 1) {
  1233. # Get the time discretization of the history
  1234. dt <- diff(model$history$t)
  1235. dt <- min(dt[dt > 0])
  1236. # Number of time periods in the future
  1237. N <- ceiling((T - 1) / dt)
  1238. S <- length(model$changepoints.t)
  1239. # The history had S split points, over t = [0, 1].
  1240. # The forecast is on [1, T], and should have the same average frequency of
  1241. # rate changes. Thus for N time periods in the future, we want an average
  1242. # of S * (T - 1) changepoints in expectation.
  1243. prob.change <- min(1, (S * (T - 1)) / N)
  1244. # This calculation works for both history and df not uniformly spaced.
  1245. n.changes <- stats::rbinom(1, N, prob.change)
  1246. # Sample ts
  1247. if (n.changes == 0) {
  1248. changepoint.ts.new <- c()
  1249. } else {
  1250. changepoint.ts.new <- sort(stats::runif(n.changes, min = 1, max = T))
  1251. }
  1252. } else {
  1253. changepoint.ts.new <- c()
  1254. n.changes <- 0
  1255. }
  1256. # Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
  1257. lambda <- mean(abs(c(deltas))) + 1e-8
  1258. # Sample deltas
  1259. deltas.new <- extraDistr::rlaplace(n.changes, mu = 0, sigma = lambda)
  1260. # Combine with changepoints from the history
  1261. changepoint.ts <- c(model$changepoints.t, changepoint.ts.new)
  1262. deltas <- c(deltas, deltas.new)
  1263. # Get the corresponding trend
  1264. if (model$growth == 'linear') {
  1265. trend <- piecewise_linear(t, deltas, k, param.m, changepoint.ts)
  1266. } else {
  1267. cap <- df$cap_scaled
  1268. trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
  1269. }
  1270. return(trend * model$y.scale + df$floor)
  1271. }
  1272. #' Make dataframe with future dates for forecasting.
  1273. #'
  1274. #' @param m Prophet model object.
  1275. #' @param periods Int number of periods to forecast forward.
  1276. #' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
  1277. #' @param include_history Boolean to include the historical dates in the data
  1278. #' frame for predictions.
  1279. #'
  1280. #' @return Dataframe that extends forward from the end of m$history for the
  1281. #' requested number of periods.
  1282. #'
  1283. #' @export
  1284. make_future_dataframe <- function(m, periods, freq = 'day',
  1285. include_history = TRUE) {
  1286. # For backwards compatability with previous zoo date type,
  1287. if (freq == 'm') {
  1288. freq <- 'month'
  1289. }
  1290. if (is.null(m$history.dates)) {
  1291. stop('Model must be fit before this can be used.')
  1292. }
  1293. dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
  1294. dates <- dates[2:(periods + 1)] # Drop the first, which is max(history$ds)
  1295. if (include_history) {
  1296. dates <- c(m$history.dates, dates)
  1297. attr(dates, "tzone") <- "GMT"
  1298. }
  1299. return(data.frame(ds = dates))
  1300. }
  1301. #' Merge history and forecast for plotting.
  1302. #'
  1303. #' @param m Prophet object.
  1304. #' @param fcst Data frame returned by prophet predict.
  1305. #'
  1306. #' @importFrom dplyr "%>%"
  1307. #' @keywords internal
  1308. df_for_plotting <- function(m, fcst) {
  1309. # Make sure there is no y in fcst
  1310. fcst$y <- NULL
  1311. df <- m$history %>%
  1312. dplyr::select(ds, y) %>%
  1313. dplyr::full_join(fcst, by = "ds") %>%
  1314. dplyr::arrange(ds)
  1315. return(df)
  1316. }
  1317. #' Plot the prophet forecast.
  1318. #'
  1319. #' @param x Prophet object.
  1320. #' @param fcst Data frame returned by predict(m, df).
  1321. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  1322. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  1323. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1324. #' figure, if available.
  1325. #' @param xlabel Optional label for x-axis
  1326. #' @param ylabel Optional label for y-axis
  1327. #' @param ... additional arguments
  1328. #'
  1329. #' @return A ggplot2 plot.
  1330. #'
  1331. #' @examples
  1332. #' \dontrun{
  1333. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  1334. #' y = sin(1:366/200) + rnorm(366)/10)
  1335. #' m <- prophet(history)
  1336. #' future <- make_future_dataframe(m, periods = 365)
  1337. #' forecast <- predict(m, future)
  1338. #' plot(m, forecast)
  1339. #' }
  1340. #'
  1341. #' @export
  1342. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  1343. xlabel = 'ds', ylabel = 'y', ...) {
  1344. df <- df_for_plotting(x, fcst)
  1345. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  1346. ggplot2::labs(x = xlabel, y = ylabel)
  1347. if (exists('cap', where = df) && plot_cap) {
  1348. gg <- gg + ggplot2::geom_line(
  1349. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  1350. }
  1351. if (x$logistic.floor && exists('floor', where = df) && plot_cap) {
  1352. gg <- gg + ggplot2::geom_line(
  1353. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  1354. }
  1355. if (uncertainty && exists('yhat_lower', where = df)) {
  1356. gg <- gg +
  1357. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  1358. alpha = 0.2,
  1359. fill = "#0072B2",
  1360. na.rm = TRUE)
  1361. }
  1362. gg <- gg +
  1363. ggplot2::geom_point(na.rm=TRUE) +
  1364. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  1365. na.rm = TRUE) +
  1366. ggplot2::theme(aspect.ratio = 3 / 5)
  1367. return(gg)
  1368. }
  1369. #' Plot the components of a prophet forecast.
  1370. #' Prints a ggplot2 with panels for trend, weekly and yearly seasonalities if
  1371. #' present, and holidays if present.
  1372. #'
  1373. #' @param m Prophet object.
  1374. #' @param fcst Data frame returned by predict(m, df).
  1375. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  1376. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  1377. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1378. #' figure, if available.
  1379. #' @param weekly_start Integer specifying the start day of the weekly
  1380. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  1381. #' to Monday, and so on.
  1382. #' @param yearly_start Integer specifying the start day of the yearly
  1383. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  1384. #' to Jan 2, and so on.
  1385. #'
  1386. #' @return Invisibly return a list containing the plotted ggplot objects
  1387. #'
  1388. #' @export
  1389. #' @importFrom dplyr "%>%"
  1390. prophet_plot_components <- function(
  1391. m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
  1392. yearly_start = 0
  1393. ) {
  1394. # Plot the trend
  1395. panels <- list(plot_forecast_component(fcst, 'trend', uncertainty, plot_cap))
  1396. # Plot holiday components, if present.
  1397. if (!is.null(m$holidays) & ('holidays' %in% colnames(fcst))) {
  1398. panels[[length(panels) + 1]] <- plot_forecast_component(
  1399. fcst, 'holidays', uncertainty, FALSE)
  1400. }
  1401. # Plot weekly seasonality, if present
  1402. if ("weekly" %in% colnames(fcst)) {
  1403. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
  1404. }
  1405. # Plot yearly seasonality, if present
  1406. if ("yearly" %in% colnames(fcst)) {
  1407. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
  1408. }
  1409. # Plot other seasonalities
  1410. for (name in names(m$seasonalities)) {
  1411. if (!(name %in% c('weekly', 'yearly')) &&
  1412. (name %in% colnames(fcst))) {
  1413. panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
  1414. }
  1415. }
  1416. # Plot extra regressors
  1417. if ((length(m$extra_regressors) > 0)
  1418. & ('extra_regressors' %in% colnames(fcst))) {
  1419. panels[[length(panels) + 1]] <- plot_forecast_component(
  1420. fcst, 'extra_regressors', uncertainty, FALSE)
  1421. }
  1422. # Make the plot.
  1423. grid::grid.newpage()
  1424. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  1425. 1)))
  1426. for (i in 1:length(panels)) {
  1427. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  1428. layout.pos.col = 1))
  1429. }
  1430. return(invisible(panels))
  1431. }
  1432. #' Plot a particular component of the forecast.
  1433. #'
  1434. #' @param fcst Dataframe output of `predict`.
  1435. #' @param name String name of the component to plot (column of fcst).
  1436. #' @param uncertainty Boolean to plot uncertainty intervals.
  1437. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  1438. #' figure, if available.
  1439. #'
  1440. #' @return A ggplot2 plot.
  1441. #'
  1442. #' @export
  1443. plot_forecast_component <- function(
  1444. fcst, name, uncertainty = TRUE, plot_cap = FALSE
  1445. ) {
  1446. gg.comp <- ggplot2::ggplot(
  1447. fcst, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  1448. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1449. if (exists('cap', where = fcst) && plot_cap) {
  1450. gg.comp <- gg.comp + ggplot2::geom_line(
  1451. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  1452. }
  1453. if (exists('floor', where = fcst) && plot_cap) {
  1454. gg.comp <- gg.comp + ggplot2::geom_line(
  1455. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  1456. }
  1457. if (uncertainty) {
  1458. gg.comp <- gg.comp +
  1459. ggplot2::geom_ribbon(
  1460. ggplot2::aes_string(
  1461. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  1462. ),
  1463. alpha = 0.2,
  1464. fill = "#0072B2",
  1465. na.rm = TRUE)
  1466. }
  1467. return(gg.comp)
  1468. }
  1469. #' Prepare dataframe for plotting seasonal components.
  1470. #'
  1471. #' @param m Prophet object.
  1472. #' @param ds Array of dates for column ds.
  1473. #'
  1474. #' @return A dataframe with seasonal components on ds.
  1475. #'
  1476. #' @keywords internal
  1477. seasonality_plot_df <- function(m, ds) {
  1478. df_list <- list(ds = ds, cap = 1, floor = 0)
  1479. for (name in names(m$extra_regressors)) {
  1480. df_list[[name]] <- 0
  1481. }
  1482. df <- as.data.frame(df_list)
  1483. df <- setup_dataframe(m, df)$df
  1484. return(df)
  1485. }
  1486. #' Plot the weekly component of the forecast.
  1487. #'
  1488. #' @param m Prophet model object
  1489. #' @param uncertainty Boolean to plot uncertainty intervals.
  1490. #' @param weekly_start Integer specifying the start day of the weekly
  1491. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  1492. #' to Monday, and so on.
  1493. #'
  1494. #' @return A ggplot2 plot.
  1495. #'
  1496. #' @keywords internal
  1497. plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
  1498. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  1499. days <- seq(set_date('2017-01-01'), by='d', length.out=7) + as.difftime(
  1500. weekly_start, units = "days")
  1501. df.w <- seasonality_plot_df(m, days)
  1502. seas <- predict_seasonal_components(m, df.w)
  1503. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  1504. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  1505. group = 1)) +
  1506. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1507. ggplot2::labs(x = "Day of week")
  1508. if (uncertainty) {
  1509. gg.weekly <- gg.weekly +
  1510. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  1511. ymax = weekly_upper),
  1512. alpha = 0.2,
  1513. fill = "#0072B2",
  1514. na.rm = TRUE)
  1515. }
  1516. return(gg.weekly)
  1517. }
  1518. #' Plot the yearly component of the forecast.
  1519. #'
  1520. #' @param m Prophet model object.
  1521. #' @param uncertainty Boolean to plot uncertainty intervals.
  1522. #' @param yearly_start Integer specifying the start day of the yearly
  1523. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  1524. #' to Jan 2, and so on.
  1525. #'
  1526. #' @return A ggplot2 plot.
  1527. #'
  1528. #' @keywords internal
  1529. plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
  1530. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  1531. days <- seq(set_date('2017-01-01'), by='d', length.out=365) + as.difftime(
  1532. yearly_start, units = "days")
  1533. df.y <- seasonality_plot_df(m, days)
  1534. seas <- predict_seasonal_components(m, df.y)
  1535. seas$ds <- df.y$ds
  1536. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  1537. group = 1)) +
  1538. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  1539. ggplot2::labs(x = "Day of year") +
  1540. ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
  1541. if (uncertainty) {
  1542. gg.yearly <- gg.yearly +
  1543. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  1544. ymax = yearly_upper),
  1545. alpha = 0.2,
  1546. fill = "#0072B2",
  1547. na.rm = TRUE)
  1548. }
  1549. return(gg.yearly)
  1550. }
  1551. #' Plot a custom seasonal component.
  1552. #'
  1553. #' @param m Prophet model object.
  1554. #' @param name String name of the seasonality.
  1555. #' @param uncertainty Boolean to plot uncertainty intervals.
  1556. #'
  1557. #' @return A ggplot2 plot.
  1558. #'
  1559. #' @keywords internal
  1560. plot_seasonality <- function(m, name, uncertainty = TRUE) {
  1561. # Compute seasonality from Jan 1 through a single period.
  1562. start <- set_date('2017-01-01')
  1563. period <- m$seasonalities[[name]]$period
  1564. end <- start + period * 24 * 3600
  1565. plot.points <- 200
  1566. days <- seq(from=start, to=end, length.out=plot.points)
  1567. df.y <- seasonality_plot_df(m, days)
  1568. seas <- predict_seasonal_components(m, df.y)
  1569. seas$ds <- df.y$ds
  1570. gg.s <- ggplot2::ggplot(
  1571. seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  1572. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  1573. if (period <= 2) {
  1574. fmt.str <- '%T'
  1575. } else if (period < 14) {
  1576. fmt.str <- '%m/%d %R'
  1577. } else {
  1578. fmt.str <- '%m/%d'
  1579. }
  1580. gg.s <- gg.s +
  1581. ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
  1582. if (uncertainty) {
  1583. gg.s <- gg.s +
  1584. ggplot2::geom_ribbon(
  1585. ggplot2::aes_string(
  1586. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  1587. ),
  1588. alpha = 0.2,
  1589. fill = "#0072B2",
  1590. na.rm = TRUE)
  1591. }
  1592. return(gg.s)
  1593. }
  1594. #' Copy Prophet object.
  1595. #'
  1596. #' @param m Prophet model object.
  1597. #' @param cutoff Date, possibly as string. Changepoints are only retained if
  1598. #' changepoints <= cutoff.
  1599. #'
  1600. #' @return An unfitted Prophet model object with the same parameters as the
  1601. #' input model.
  1602. #'
  1603. #' @keywords internal
  1604. prophet_copy <- function(m, cutoff = NULL) {
  1605. if (is.null(m$history)) {
  1606. stop("This is for copying a fitted Prophet object.")
  1607. }
  1608. if (m$specified.changepoints) {
  1609. changepoints <- m$changepoints
  1610. if (!is.null(cutoff)) {
  1611. cutoff <- set_date(cutoff)
  1612. changepoints <- changepoints[changepoints <= cutoff]
  1613. }
  1614. } else {
  1615. changepoints <- NULL
  1616. }
  1617. # Auto seasonalities are set to FALSE because they are already set in
  1618. # m$seasonalities.
  1619. m2 <- prophet(
  1620. growth = m$growth,
  1621. changepoints = changepoints,
  1622. n.changepoints = m$n.changepoints,
  1623. yearly.seasonality = FALSE,
  1624. weekly.seasonality = FALSE,
  1625. daily.seasonality = FALSE,
  1626. holidays = m$holidays,
  1627. seasonality.prior.scale = m$seasonality.prior.scale,
  1628. changepoint.prior.scale = m$changepoint.prior.scale,
  1629. holidays.prior.scale = m$holidays.prior.scale,
  1630. mcmc.samples = m$mcmc.samples,
  1631. interval.width = m$interval.width,
  1632. uncertainty.samples = m$uncertainty.samples,
  1633. fit = FALSE
  1634. )
  1635. m2$extra_regressors <- m$extra_regressors
  1636. m2$seasonalities <- m$seasonalities
  1637. return(m2)
  1638. }
  1639. # fb-block 3