Bladeren bron

Allow constant extra regressors

Ben Letham 7 jaren geleden
bovenliggende
commit
107f74f0f2
4 gewijzigde bestanden met toevoegingen van 12 en 11 verwijderingen
  1. 2 2
      R/R/prophet.R
  2. 5 4
      R/tests/testthat/test_prophet.R
  3. 1 1
      python/fbprophet/forecaster.py
  4. 4 4
      python/fbprophet/tests/test_prophet.py

+ 2 - 2
R/R/prophet.R

@@ -408,11 +408,11 @@ initialize_scales_fn <- function(m, initialize_scales, df) {
   m$start <- min(df$ds)
   m$t.scale <- time_diff(max(df$ds), m$start, "secs")
   for (name in names(m$extra_regressors)) {
+    standardize <- m$extra_regressors[[name]]$standardize
     n.vals <- length(unique(df[[name]]))
     if (n.vals < 2) {
-      stop('Regressor ', name, ' is constant.')
+      standardize <- FALSE
     }
-    standardize <- m$extra_regressors[[name]]$standardize
     if (standardize == 'auto') {
       if (n.vals == 2 && all(sort(unique(df[[name]])) == c(0, 1))) {
         # Don't standardize binary variables

+ 5 - 4
R/tests/testthat/test_prophet.R

@@ -573,11 +573,12 @@ test_that("added_regressors", {
     fcst$yhat[1],
     fcst$trend[1] * (1 + fcst$multiplicative_terms[1]) + fcst$additive_terms[1]
   )
-  # Check fails if constant extra regressor
-  df$constant_feature <- 5
+  # Check works with constant extra regressor of 0
+  df$constant_feature <- 0
   m <- prophet()
-  m <- add_regressor(m, 'constant_feature')
-  expect_error(fit.prophet(m, df))
+  m <- add_regressor(m, 'constant_feature', standardize = TRUE)
+  m <- fit.prophet(m, df)
+  expect_equal(m$extra_regressors$constant_feature$std, 1)
 })
 
 test_that("set_seasonality_mode", {

+ 1 - 1
python/fbprophet/forecaster.py

@@ -305,7 +305,7 @@ class Prophet(object):
             standardize = props['standardize']
             n_vals = len(df[name].unique())
             if n_vals < 2:
-                raise ValueError('Regressor {} is constant.'.format(name))
+                standardize = False
             if standardize == 'auto':
                 if set(df[name].unique()) == set([1, 0]):
                     # Don't standardize binary variables.

+ 4 - 4
python/fbprophet/tests/test_prophet.py

@@ -651,12 +651,12 @@ class TestProphet(TestCase):
             fcst['trend'][0] * (1 + fcst['multiplicative_terms'][0])
                 + fcst['additive_terms'][0],
         )
-        # Check fails if constant extra regressor
-        df['constant_feature'] = 5
+        # Check works if constant extra regressor at 0
+        df['constant_feature'] = 0
         m = Prophet()
         m.add_regressor('constant_feature')
-        with self.assertRaises(ValueError):
-            m.fit(df.copy())
+        m.fit(df)
+        self.assertEqual(m.extra_regressors['constant_feature']['std'], 1)
 
     def test_set_seasonality_mode(self):
         # Setting attribute