Browse Source

Saturating minimum R

Ben Letham 8 years ago
parent
commit
2ddcf54930
4 changed files with 87 additions and 21 deletions
  1. 6 3
      R/R/diagnostics.R
  2. 49 17
      R/R/prophet.R
  3. 31 0
      R/tests/testthat/test_prophet.R
  4. 1 1
      python/fbprophet/forecaster.py

+ 6 - 3
R/R/diagnostics.R

@@ -80,11 +80,14 @@ simulated_historical_forecasts <- function(model, horizon, units, k,
     m <- fit.prophet(m, history.c)
     # Calculate yhat
     df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon)
+    columns <- c('ds')
     if (m$growth == 'logistic') {
-      future <- dplyr::select(df.predict, ds, cap)
-    } else{
-      future <- dplyr::select(df.predict, ds)
+      columns <- c(columns, 'cap')
+      if (m$logistic.floor) {
+        columns <- c(columns, 'floor')
+      }
     }
+    future <- df[columns]
     yhat <- stats::predict(m, future)
     # Merge yhat, y, and cutoff.
     df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")

+ 49 - 17
R/R/prophet.R

@@ -114,6 +114,7 @@ prophet <- function(df = NULL,
     specified.changepoints = !is.null(changepoints),
     start = NULL,  # This and following attributes are set during fitting
     y.scale = NULL,
+    logistic.floor = FALSE,
     t.scale = NULL,
     changepoints.t = NULL,
     seasonalities = list(),
@@ -191,7 +192,8 @@ validate_column_name <- function(
   )
   rn_l = paste(reserved_names,"_lower",sep="")
   rn_u = paste(reserved_names,"_upper",sep="")
-  reserved_names = c(reserved_names, rn_l, rn_u, c("ds","y"))
+  reserved_names = c(reserved_names, rn_l, rn_u,
+    c("ds", "y", "cap", "floor", "y_scaled", "cap_scaled"))
   if(name %in% reserved_names){
     stop("Name ", name, " is reserved.")
   }
@@ -332,7 +334,13 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
     dplyr::arrange(ds)
 
   if (initialize_scales) {
-    m$y.scale <- max(abs(df$y))
+    if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
+      m$logistic.floor <- TRUE
+      floor <- df$floor
+    } else {
+      floor <- 0
+    }
+    m$y.scale <- max(abs(df$y - floor))
     if (m$y.scale == 0) {
       m$y.scale <- 1
     }
@@ -360,9 +368,12 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
     }
   }
 
-  df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
-  if (exists('y', where=df)) {
-    df$y_scaled <- df$y / m$y.scale
+  if (m$logistic.floor) {
+    if (!('floor' %in% colnames(df))) {
+      stop("Expected column 'floor'.")
+    }
+  } else {
+    df$floor <- 0
   }
 
   if (m$growth == 'logistic') {
@@ -370,7 +381,12 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
       stop('Capacities must be supplied for logistic growth.')
     }
     df <- df %>%
-      dplyr::mutate(cap_scaled = cap / m$y.scale)
+      dplyr::mutate(cap_scaled = (cap - floor) / m$y.scale)
+  }
+
+  df$t <- time_diff(df$ds, m$start, "secs") / m$t.scale
+  if (exists('y', where=df)) {
+    df$y_scaled <- (df$y - df$floor) / m$y.scale
   }
 
   for (name in names(m$extra_regressors)) {
@@ -848,9 +864,16 @@ logistic_growth_init <- function(df) {
   i0 <- which.min(df$ds)
   i1 <- which.max(df$ds)
   T <- df$t[i1] - df$t[i0]
-  # Force valid values, in case y > cap.
-  r0 <- max(1.01, df$cap_scaled[i0] / df$y_scaled[i0])
-  r1 <- max(1.01, df$cap_scaled[i1] / df$y_scaled[i1])
+
+  # Force valid values, in case y > cap or y < 0
+  C0 <- df$cap_scaled[i0]
+  C1 <- df$cap_scaled[i1]
+  y0 <- max(0.01 * C0, min(0.99 * C0, df$y_scaled[i0]))
+  y1 <- max(0.01 * C1, min(0.99 * C1, df$y_scaled[i1]))
+
+  r0 <- C0 / y0
+  r1 <- C1 / y1
+
   if (abs(r0 - r1) <= 0.01) {
     r0 <- 1.05 * r0
   }
@@ -1015,11 +1038,13 @@ predict.prophet <- function(object, df = NULL, ...) {
   seasonal.components <- predict_seasonal_components(object, df)
   intervals <- predict_uncertainty(object, df)
 
-  # Drop columns except ds, cap, and trend
+  # Drop columns except ds, cap, floor, and trend
+  cols <- c('ds', 'trend')
   if ('cap' %in% colnames(df)) {
-    cols <- c('ds', 'cap', 'trend')
-  } else {
-    cols <- c('ds', 'trend')
+    cols <- c(cols, 'cap')
+  }
+  if (object$logistic.floor) {
+    cols <- c(cols, 'floor')
   }
   df <- df[cols]
   df <- df %>%
@@ -1108,7 +1133,7 @@ predict_trend <- function(model, df) {
     trend <- piecewise_logistic(
       t, cap, deltas, k, param.m, model$changepoints.t)
   }
-  return(trend * model$y.scale)
+  return(trend * model$y.scale + df$floor)
 }
 
 #' Predict seasonality components, holidays, and added regressors.
@@ -1343,7 +1368,7 @@ sample_predictive_trend <- function(model, df, iteration) {
     cap <- df$cap_scaled
     trend <- piecewise_logistic(t, cap, deltas, k, param.m, changepoint.ts)
   }
-  return(trend * model$y.scale)
+  return(trend * model$y.scale + df$floor)
 }
 
 #' Make dataframe with future dates for forecasting.
@@ -1424,6 +1449,10 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
     gg <- gg + ggplot2::geom_line(
       ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
   }
+  if (x$logistic.floor && exists('floor', where = df) && plot_cap) {
+    gg <- gg + ggplot2::geom_line(
+      ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
+  }
   if (uncertainty && exists('yhat_lower', where = df)) {
     gg <- gg +
       ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
@@ -1525,6 +1554,10 @@ plot_forecast_component <- function(
     gg.comp <- gg.comp + ggplot2::geom_line(
       ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
   }
+  if (exists('floor', where = fcst) && plot_cap) {
+    gg.comp <- gg.comp + ggplot2::geom_line(
+      ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
+  }
   if (uncertainty) {
     gg.comp <- gg.comp +
       ggplot2::geom_ribbon(
@@ -1647,8 +1680,7 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
   if (period <= 2) {
     fmt.str <- '%T'
-  }
-  else if (period < 14) {
+  } else if (period < 14) {
     fmt.str <- '%m/%d %R'
   } else {
     fmt.str <- '%m/%d'

+ 31 - 0
R/tests/testthat/test_prophet.R

@@ -74,6 +74,36 @@ test_that("setup_dataframe", {
   expect_equal(max(history$y_scaled), 1)
 })
 
+test_that("logistic_floor", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(growth = 'logistic')
+  history <- train
+  history$floor <- 10.
+  history$cap <- 80.
+  future1 <- future
+  future1$cap <- 80.
+  future1$floor <- 10.
+  m <- fit.prophet(m, history)
+  expect_true(m$logistic.floor)
+  expect_true('floor' %in% colnames(m$history))
+  expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
+  fcst1 <- predict(m, future1)
+
+  m2 <- prophet(growth = 'logistic')
+  history2 <- history
+  history2$y <- history2$y + 10.
+  history2$floor <- history2$floor + 10.
+  history2$cap <- history2$cap + 10.
+  future1$cap <- future1$cap + 10.
+  future1$floor <- future1$floor + 10.
+  m2 <- fit.prophet(m2, history2)
+  expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
+  fcst2 <- predict(m, future1)
+  fcst2$yhat <- fcst2$yhat - 10.
+  # Check for approximate shift invariance
+  expect_true(all(abs(fcst1$yhat - fcst2$yhat) < 1))
+})
+
 test_that("get_changepoints", {
   history <- train
   m <- prophet(history, fit = FALSE)
@@ -481,6 +511,7 @@ test_that("added_regressors", {
 })
 
 test_that("copy", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   inputs <- list(
     growth = c('linear', 'logistic'),
     changepoints = c(NULL, c('2016-12-25')),

+ 1 - 1
python/fbprophet/forecaster.py

@@ -1312,7 +1312,7 @@ class Prophet(object):
         return fig
 
     def plot_forecast_component(
-            self, fcst, name, ax=None, uncertainty=True, plot_cap=True):
+            self, fcst, name, ax=None, uncertainty=True, plot_cap=False):
         """Plot a particular component of the forecast.
 
         Parameters