浏览代码

Fix R warning with extra regressor; disallow constant extra regressors.

bl 8 年之前
父节点
当前提交
feb7be397b

+ 5 - 4
R/R/prophet.R

@@ -395,9 +395,13 @@ initialize_scales_fn <- function(m, initialize_scales, df) {
   m$start <- min(df$ds)
   m$t.scale <- time_diff(max(df$ds), m$start, "secs")
   for (name in names(m$extra_regressors)) {
+    n.vals <- length(unique(df[[name]]))
+    if (n.vals < 2) {
+      stop('Regressor ', name, ' is constant.')
+    }
     standardize <- m$extra_regressors[[name]]$standardize
     if (standardize == 'auto') {
-      if (all(sort(unique(df[[name]])) == c(0, 1))) {
+      if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
         # Don't standardize binary variables
         standardize <- FALSE
       } else {
@@ -407,9 +411,6 @@ initialize_scales_fn <- function(m, initialize_scales, df) {
     if (standardize) {
       mu <- mean(df[[name]])
       std <- stats::sd(df[[name]])
-      if (std == 0) {
-        std <- mu
-      }
       m$extra_regressors[[name]]$mu <- mu
       m$extra_regressors[[name]]$std <- std
     }

+ 5 - 0
R/tests/testthat/test_prophet.R

@@ -511,6 +511,11 @@ test_that("added_regressors", {
   expect_equal(fcst$seasonal[1],
                fcst$seasonalities[1] + fcst$extra_regressors[1])
   expect_equal(fcst$yhat[1], fcst$trend[1] + fcst$seasonal[1])
+  # Check fails if constant extra regressor
+  df$constant_feature <- 5
+  m <- prophet()
+  m <- add_regressor(m, 'constant_feature')
+  expect_error(fit.prophet(m, df))
 })
 
 test_that("copy", {

文件差异内容过多而无法显示
+ 1 - 1
docs/_docs/seasonality_and_holiday_effects.md


文件差异内容过多而无法显示
+ 1 - 1
notebooks/seasonality_and_holiday_effects.ipynb


+ 3 - 2
python/fbprophet/forecaster.py

@@ -278,6 +278,9 @@ class Prophet(object):
         self.t_scale = df['ds'].max() - self.start
         for name, props in self.extra_regressors.items():
             standardize = props['standardize']
+            n_vals = len(df[name].unique())
+            if n_vals < 2:
+                raise ValueError('Regressor {} is constant.'.format(name))
             if standardize == 'auto':
                 if set(df[name].unique()) == set([1, 0]):
                     # Don't standardize binary variables.
@@ -287,8 +290,6 @@ class Prophet(object):
             if standardize:
                 mu = df[name].mean()
                 std = df[name].std()
-                if std == 0:
-                    std = mu
                 self.extra_regressors[name]['mu'] = mu
                 self.extra_regressors[name]['std'] = std
 

+ 6 - 0
python/fbprophet/tests/test_prophet.py

@@ -548,6 +548,12 @@ class TestProphet(TestCase):
             fcst['yhat'][0],
             fcst['trend'][0] + fcst['seasonal'][0],
         )
+        # Check fails if constant extra regressor
+        df['constant_feature'] = 5
+        m = Prophet()
+        m.add_regressor('constant_feature')
+        with self.assertRaises(ValueError):
+            m.fit(df.copy())
 
     def test_copy(self):
         # These values are created except for its default values