浏览代码

Allow _ in holiday name, fix #50

Ben Letham 8 年之前
父节点
当前提交
4b7a418121
共有 2 个文件被更改,包括 21 次插入11 次删除
  1. 5 5
      R/R/prophet.R
  2. 16 6
      python/fbprophet/forecaster.py

+ 5 - 5
R/R/prophet.R

@@ -134,8 +134,8 @@ validate_inputs <- function(m) {
       stop('Holidays dataframe must have ds field.')
     }
     for (h in unique(m$holidays$holiday)) {
-      if (grepl("_", h)) {
-        stop('Holiday name cannot contain "_"')
+      if (grepl("_delim_", h)) {
+        stop('Holiday name cannot contain "_delim_"')
       }
       if (h %in% c('zeros', 'yearly', 'weekly', 'yhat', 'seasonal', 'trend')) {
         stop(paste0('Holiday name "', h, '" reserved.'))
@@ -306,7 +306,7 @@ fourier_series <- function(dates, period, series.order) {
 #'
 make_seasonality_features <- function(dates, period, series.order, prefix) {
   features <- fourier_series(dates, period, series.order)
-  colnames(features) <- paste(prefix, 1:ncol(features), sep = '_')
+  colnames(features) <- paste(prefix, 1:ncol(features), sep = '_delim_')
   return(data.frame(features))
 }
 
@@ -332,7 +332,7 @@ make_holiday_features <- function(m, dates) {
         offsets <- c(0)
       }
       names <- paste(
-        .$holiday, '_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
+        .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
       dplyr::data_frame(ds = .$ds + offsets, holiday = names)
     }) %>%
     dplyr::mutate(x = scale.ratio) %>%
@@ -650,7 +650,7 @@ predict_seasonal_components <- function(m, df) {
   # Broken down into components
   components <- dplyr::data_frame(component = colnames(seasonal.features)) %>%
     dplyr::mutate(col = 1:n()) %>%
-    tidyr::separate(component, c('component', 'part'), sep = "_",
+    tidyr::separate(component, c('component', 'part'), sep = "_delim_",
                     extra = "merge", fill = "right") %>%
     dplyr::filter(component != 'zeros')
 

+ 16 - 6
python/fbprophet/forecaster.py

@@ -49,9 +49,6 @@ class Prophet(object):
             interval_width=0.80,
             uncertainty_samples=1000,
     ):
-        if growth not in ('linear', 'logistic'):
-            raise ValueError("growth setting must be 'linear' or 'logistic'")
-
         self.growth = growth
 
         self.changepoints = pd.to_datetime(changepoints)
@@ -90,6 +87,19 @@ class Prophet(object):
         self.stan_fit = None
         self.params = {}
         self.history = None
+        self.validate_inputs()
+
+    def validate_inputs(self):
+        if self.growth not in ('linear', 'logistic'):
+            raise ValueError(
+                "Parameter 'growth' should be 'linear' or 'logistic'.")
+        if self.holidays is not None:
+            for h in self.holidays['holiday'].unique():
+                if '_delim_' in h:
+                    raise ValueError('Holiday name cannot contain "_delim_"')
+                if h in ['zeros', 'yearly', 'weekly', 'yhat', 'seasonal',
+                         'trend']:
+                    raise ValueError('Holiday name {} reserved.'.format(h))
 
     @classmethod
     def get_linear_model(cls):
@@ -215,7 +225,7 @@ class Prophet(object):
     def make_seasonality_features(cls, dates, period, series_order, prefix):
         features = cls.fourier_series(dates, period, series_order)
         columns = [
-            '{}_{}'.format(prefix, i + 1)
+            '{}_delim_{}'.format(prefix, i + 1)
             for i in range(features.shape[1])
         ]
         return pd.DataFrame(features, columns=columns)
@@ -245,7 +255,7 @@ class Prophet(object):
                 except KeyError:
                     loc = None
 
-                key = '{}_{}{}'.format(
+                key = '{}_delim_{}{}'.format(
                     row.holiday,
                     '+' if offset >= 0 else '-',
                     abs(offset)
@@ -469,7 +479,7 @@ class Prophet(object):
 
         components = pd.DataFrame({
             'col': np.arange(seasonal_features.shape[1]),
-            'component': [x.split('_')[0] for x in seasonal_features.columns],
+            'component': [x.split('_delim_')[0] for x in seasonal_features.columns],
         })
         # Remove the placeholder
         components = components[components['component'] != 'zeros']