|
@@ -114,6 +114,7 @@ prophet <- function(df = NULL,
|
|
|
specified.changepoints = !is.null(changepoints),
|
|
|
start = NULL, # This and following attributes are set during fitting
|
|
|
y.scale = NULL,
|
|
|
+ logistic.floor = FALSE,
|
|
|
t.scale = NULL,
|
|
|
changepoints.t = NULL,
|
|
|
seasonalities = list(),
|
|
@@ -191,7 +192,8 @@ validate_column_name <- function(
|
|
|
)
|
|
|
rn_l = paste(reserved_names,"_lower",sep="")
|
|
|
rn_u = paste(reserved_names,"_upper",sep="")
|
|
|
- reserved_names = c(reserved_names, rn_l, rn_u, c("ds","y"))
|
|
|
+ reserved_names = c(reserved_names, rn_l, rn_u,
|
|
|
+ c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
|
|
|
if(name %in% reserved_names){
|
|
|
stop("Name ", name, " is reserved.")
|
|
|
}
|
|
@@ -332,7 +334,13 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|
|
dplyr::arrange(ds)
|
|
|
|
|
|
if (initialize_scales) {
|
|
|
- m$y.scale <- max(abs(df$y))
|
|
|
+ if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
|
|
|
+ m$logistic.floor <- TRUE
|
|
|
+ floor <- df$floor
|
|
|
+ } else {
|
|
|
+ floor <- 0
|
|
|
+ }
|
|
|
+ m$y.scale <- max(abs(df$y - floor))
|
|
|
if (m$y.scale == 0) {
|
|
|
m$y.scale <- 1
|
|
|
}
|
|
@@ -360,9 +368,12 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
|
|
|
- if (exists('y', where=df)) {
|
|
|
- df$y_scaled <- df$y / m$y.scale
|
|
|
+ if (m$logistic.floor) {
|
|
|
+ if (!('floor' %in% colnames(df))) {
|
|
|
+ stop("Expected column 'floor'.")
|
|
|
+ }
|
|
|
+ } else {
|
|
|
+ df$floor <- 0
|
|
|
}
|
|
|
|
|
|
if (m$growth == 'logistic') {
|
|
@@ -370,7 +381,12 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
|
|
|
stop('Capacities must be supplied for logistic growth.')
|
|
|
}
|
|
|
df <- df %>%
|
|
|
- dplyr::mutate(cap_scaled = cap / m$y.scale)
|
|
|
+ dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
|
|
|
+ }
|
|
|
+
|
|
|
+ df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
|
|
|
+ if (exists('y', where=df)) {
|
|
|
+ df$y_scaled <- (df$y - df$floor) / m$y.scale
|
|
|
}
|
|
|
|
|
|
for (name in names(m$extra_regressors)) {
|
|
@@ -848,9 +864,16 @@ logistic_growth_init <- function(df) {
|
|
|
i0 <- which.min(df$ds)
|
|
|
i1 <- which.max(df$ds)
|
|
|
T <- df$t[i1] - df$t[i0]
|
|
|
- # Force valid values, in case y > cap.
|
|
|
- r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
|
|
|
- r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
|
|
|
+
|
|
|
+ # Force valid values, in case y > cap or y < 0
|
|
|
+ C0 <- df$cap_scaled[i0]
|
|
|
+ C1 <- df$cap_scaled[i1]
|
|
|
+ y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
|
|
|
+ y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
|
|
|
+
|
|
|
+ r0 <- C0 / y0
|
|
|
+ r1 <- C1 / y1
|
|
|
+
|
|
|
if (abs(r0 - r1) <= 0.01) {
|
|
|
r0 <- 1.05 * r0
|
|
|
}
|
|
@@ -1015,11 +1038,13 @@ predict.prophet <- function(object, df = NULL, ...) {
|
|
|
seasonal.components <- predict_seasonal_components(object, df)
|
|
|
intervals <- predict_uncertainty(object, df)
|
|
|
|
|
|
- # Drop columns except ds, cap, and trend
|
|
|
+ # Drop columns except ds, cap, floor, and trend
|
|
|
+ cols <- c('ds', 'trend')
|
|
|
if ('cap' %in% colnames(df)) {
|
|
|
- cols <- c('ds', 'cap', 'trend')
|
|
|
- } else {
|
|
|
- cols <- c('ds', 'trend')
|
|
|
+ cols <- c(cols, 'cap')
|
|
|
+ }
|
|
|
+ if (object$logistic.floor) {
|
|
|
+ cols <- c(cols, 'floor')
|
|
|
}
|
|
|
df <- df[cols]
|
|
|
df <- df %>%
|
|
@@ -1108,7 +1133,7 @@ predict_trend <- function(model, df) {
|
|
|
trend <- piecewise_logistic(
|
|
|
t, cap, deltas, k, param.m, model$changepoints.t)
|
|
|
}
|
|
|
- return(trend * model$y.scale)
|
|
|
+ return(trend * model$y.scale + df$floor)
|
|
|
}
|
|
|
|
|
|
#' Predict seasonality components, holidays, and added regressors.
|
|
@@ -1343,7 +1368,7 @@ sample_predictive_trend <- function(model, df, iteration) {
|
|
|
cap <- df$cap_scaled
|
|
|
trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
|
|
|
}
|
|
|
- return(trend * model$y.scale)
|
|
|
+ return(trend * model$y.scale + df$floor)
|
|
|
}
|
|
|
|
|
|
#' Make dataframe with future dates for forecasting.
|
|
@@ -1424,6 +1449,10 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
|
|
|
gg <- gg + ggplot2::geom_line(
|
|
|
ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
|
|
|
}
|
|
|
+ if (x$logistic.floor && exists('floor', where = df) && plot_cap) {
|
|
|
+ gg <- gg + ggplot2::geom_line(
|
|
|
+ ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
|
|
|
+ }
|
|
|
if (uncertainty && exists('yhat_lower', where = df)) {
|
|
|
gg <- gg +
|
|
|
ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
|
|
@@ -1525,6 +1554,10 @@ plot_forecast_component <- function(
|
|
|
gg.comp <- gg.comp + ggplot2::geom_line(
|
|
|
ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
|
|
|
}
|
|
|
+ if (exists('floor', where = fcst) && plot_cap) {
|
|
|
+ gg.comp <- gg.comp + ggplot2::geom_line(
|
|
|
+ ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
|
|
|
+ }
|
|
|
if (uncertainty) {
|
|
|
gg.comp <- gg.comp +
|
|
|
ggplot2::geom_ribbon(
|
|
@@ -1647,8 +1680,7 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
|
|
|
ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
|
|
|
if (period <= 2) {
|
|
|
fmt.str <- '%T'
|
|
|
- }
|
|
|
- else if (period < 14) {
|
|
|
+ } else if (period < 14) {
|
|
|
fmt.str <- '%m/%d %R'
|
|
|
} else {
|
|
|
fmt.str <- '%m/%d'
|