|
@@ -218,22 +218,25 @@ validate_column_name <- function(
|
|
#' @return Stan model.
|
|
#' @return Stan model.
|
|
#'
|
|
#'
|
|
#' @keywords internal
|
|
#' @keywords internal
|
|
-get_prophet_stan_model <- function(model) {
|
|
|
|
- fn <- paste('prophet', model, 'growth.RData', sep = '_')
|
|
|
|
|
|
+get_prophet_stan_model <- function() {
|
|
## If the cached model doesn't work, just compile a new one.
|
|
## If the cached model doesn't work, just compile a new one.
|
|
tryCatch({
|
|
tryCatch({
|
|
- binary <- system.file('libs', Sys.getenv('R_ARCH'), fn,
|
|
|
|
- package = 'prophet',
|
|
|
|
- mustWork = TRUE)
|
|
|
|
|
|
+ binary <- system.file(
|
|
|
|
+ 'libs',
|
|
|
|
+ Sys.getenv('R_ARCH'),
|
|
|
|
+ 'prophet_stan_model.RData',
|
|
|
|
+ package = 'prophet',
|
|
|
|
+ mustWork = TRUE
|
|
|
|
+ )
|
|
load(binary)
|
|
load(binary)
|
|
- obj.name <- paste(model, 'growth.stanm', sep = '.')
|
|
|
|
|
|
+ obj.name <- 'model.stanm'
|
|
stanm <- eval(parse(text = obj.name))
|
|
stanm <- eval(parse(text = obj.name))
|
|
|
|
|
|
## Should cause an error if the model doesn't work.
|
|
## Should cause an error if the model doesn't work.
|
|
stanm@mk_cppmodule(stanm)
|
|
stanm@mk_cppmodule(stanm)
|
|
stanm
|
|
stanm
|
|
}, error = function(cond) {
|
|
}, error = function(cond) {
|
|
- compile_stan_model(model)
|
|
|
|
|
|
+ compile_stan_model()
|
|
})
|
|
})
|
|
}
|
|
}
|
|
|
|
|
|
@@ -245,14 +248,13 @@ get_prophet_stan_model <- function(model) {
|
|
#' @return Stan model.
|
|
#' @return Stan model.
|
|
#'
|
|
#'
|
|
#' @keywords internal
|
|
#' @keywords internal
|
|
-compile_stan_model <- function(model) {
|
|
|
|
- fn <- paste('stan/prophet', model, 'growth.stan', sep = '_')
|
|
|
|
|
|
+compile_stan_model <- function() {
|
|
|
|
+ fn <- 'stan/prophet.stan'
|
|
|
|
|
|
stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
|
|
stan.src <- system.file(fn, package = 'prophet', mustWork = TRUE)
|
|
stanc <- rstan::stanc(stan.src)
|
|
stanc <- rstan::stanc(stan.src)
|
|
|
|
|
|
- model.name <- paste(model, 'growth', sep = '_')
|
|
|
|
- return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
|
|
|
|
|
|
+ return(rstan::stan_model(stanc_ret = stanc, model_name = 'prophet_model'))
|
|
}
|
|
}
|
|
|
|
|
|
#' Convert date vector
|
|
#' Convert date vector
|
|
@@ -901,21 +903,23 @@ fit.prophet <- function(m, df, ...) {
|
|
t_change = array(m$changepoints.t),
|
|
t_change = array(m$changepoints.t),
|
|
X = as.matrix(seasonal.features),
|
|
X = as.matrix(seasonal.features),
|
|
sigmas = array(prior.scales),
|
|
sigmas = array(prior.scales),
|
|
- tau = m$changepoint.prior.scale
|
|
|
|
|
|
+ tau = m$changepoint.prior.scale,
|
|
|
|
+ trend_indicator = as.numeric(m$growth == 'logistic')
|
|
)
|
|
)
|
|
|
|
|
|
# Run stan
|
|
# Run stan
|
|
if (m$growth == 'linear') {
|
|
if (m$growth == 'linear') {
|
|
|
|
+ dat$cap <- rep(0, nrow(history)) # Unused inside Stan
|
|
kinit <- linear_growth_init(history)
|
|
kinit <- linear_growth_init(history)
|
|
} else {
|
|
} else {
|
|
dat$cap <- history$cap_scaled # Add capacities to the Stan data
|
|
dat$cap <- history$cap_scaled # Add capacities to the Stan data
|
|
kinit <- logistic_growth_init(history)
|
|
kinit <- logistic_growth_init(history)
|
|
}
|
|
}
|
|
|
|
|
|
- if (exists(".prophet.stan.models")) {
|
|
|
|
- model <- .prophet.stan.models[[m$growth]]
|
|
|
|
|
|
+ if (exists(".prophet.stan.model")) {
|
|
|
|
+ model <- .prophet.stan.model
|
|
} else {
|
|
} else {
|
|
- model <- get_prophet_stan_model(m$growth)
|
|
|
|
|
|
+ model <- get_prophet_stan_model()
|
|
}
|
|
}
|
|
|
|
|
|
stan_init <- function() {
|
|
stan_init <- function() {
|