Ver código fonte

Check for valid holiday lower/upper windows

Ben Letham 8 anos atrás
pai
commit
f89faf2c6a
3 arquivos alterados com 31 adições e 5 exclusões
  1. 17 2
      R/R/prophet.R
  2. 4 3
      R/man/prophet.Rd
  3. 10 0
      python/fbprophet/forecaster.py

+ 17 - 2
R/R/prophet.R

@@ -13,7 +13,7 @@ globalVariables(c(
   "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
   "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
 
-#' Prophet forecast.
+#' Prophet forecaster.
 #'
 #' @param df Data frame with columns ds (date type) and y, the time series.
 #'  If growth is logistic, then df must also have a column cap that specifies
@@ -31,7 +31,8 @@ globalVariables(c(
 #' @param weekly.seasonality Boolean, fit weekly seasonality.
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
-#'  range of days around the date to be included as holidays.
+#'  range of days around the date to be included as holidays. lower_window=-2
+#'  will include 2 days prior to the date as holidays.
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #'  seasonality model. Larger values allow the model to fit larger seasonal
 #'  fluctuations, smaller values dampen the seasonality.
@@ -133,6 +134,20 @@ validate_inputs <- function(m) {
     if (!(exists('ds', where = m$holidays))) {
       stop('Holidays dataframe must have ds field.')
     }
+    has.lower <- exists('lower_window', where = m$holidays)
+    has.upper <- exists('upper_window', where = m$holidays)
+    if (has.lower + has.upper == 1) {
+      stop(paste('Holidays must have both lower_window and upper_window,',
+                 'or neither.'))
+    }
+    if (has.lower) {
+      if(max(m$holidays$lower_window, na.rm=TRUE) > 0) {
+        stop('Holiday lower_window should be <= 0')
+      }
+      if(min(m$holidays$upper_window, na.rm=TRUE) < 0) {
+        stop('Holiday upper_window should be >= 0')
+      }
+    }
     for (h in unique(m$holidays$holiday)) {
       if (grepl("_delim_", h)) {
         stop('Holiday name cannot contain "_delim_"')

+ 4 - 3
R/man/prophet.Rd

@@ -2,7 +2,7 @@
 % Please edit documentation in R/prophet.R
 \name{prophet}
 \alias{prophet}
-\title{Prophet forecast.}
+\title{Prophet forecaster.}
 \usage{
 prophet(df = df, growth = "linear", changepoints = NULL,
   n.changepoints = 25, yearly.seasonality = TRUE,
@@ -34,7 +34,8 @@ first 80 percent of df$ds.}
 
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a
-range of days around the date to be included as holidays.}
+range of days around the date to be included as holidays. lower_window=-2
+will include 2 days prior to the date as holidays.}
 
 \item{seasonality.prior.scale}{Parameter modulating the strength of the
 seasonality model. Larger values allow the model to fit larger seasonal
@@ -68,7 +69,7 @@ uncertainty intervals.}
 A prophet model.
 }
 \description{
-Prophet forecast.
+Prophet forecaster.
 }
 \examples{
 \dontrun{

+ 10 - 0
python/fbprophet/forecaster.py

@@ -94,6 +94,16 @@ class Prophet(object):
             raise ValueError(
                 "Parameter 'growth' should be 'linear' or 'logistic'.")
         if self.holidays is not None:
+            has_lower = 'lower_window' in self.holidays
+            has_upper = 'upper_window' in self.holidays
+            if has_lower + has_upper == 1:
+                raise ValueError('Holidays must have both lower_window and ' +
+                                 'upper_window, or neither')
+            if has_lower:
+                if max(self.holidays['lower_window']) > 0:
+                    raise ValueError('Holiday lower_window should be <= 0')
+                if min(self.holidays['upper_window']) < 0:
+                    raise ValueError('Holiday upper_window should be >= 0')
             for h in self.holidays['holiday'].unique():
                 if '_delim_' in h:
                     raise ValueError('Holiday name cannot contain "_delim_"')