Pārlūkot izejas kodu

Handle constant y in history

bletham 8 gadi atpakaļ
vecāks
revīzija
e4ec600da4

+ 9 - 1
R/R/prophet.R

@@ -286,6 +286,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
 
   if (initialize_scales) {
     m$y.scale <- max(abs(df$y))
+    if (m$y.scale == 0) {
+      m$y.scale <- 1
+    }
     m$start <- min(df$ds)
     m$t.scale <- time_diff(max(df$ds), m$start, "secs")
   }
@@ -703,7 +706,12 @@ fit.prophet <- function(m, df, ...) {
     )
   }
 
-  if (m$mcmc.samples > 0) {
+  if (min(history$y) == max(history$y)) {
+    # Nothing to fit.
+    m$params <- stan_init()
+    m$params$sigma_obs <- 0.
+    n.iteration <- 1.
+  } else if (m$mcmc.samples > 0) {
     stan.fit <- rstan::sampling(
       model,
       data = dat,

+ 13 - 0
R/tests/testthat/test_prophet.R

@@ -46,6 +46,19 @@ test_that("fit_predict_duplicates", {
   expect_error(predict(m, future), NA)
 })
 
+test_that("fit_predict_constant_history", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  train2 <- train
+  train2$y <- 20
+  m <- prophet(train2)
+  fcst <- predict(m, future)
+  expect_equal(tail(fcst$yhat, 1), 20)
+  train2$y <- 0
+  m <- prophet(train2)
+  fcst <- predict(m, future)
+  expect_equal(tail(fcst$yhat, 1), 0)
+})
+
 test_that("setup_dataframe", {
   history <- train
   m <- prophet(history, fit = FALSE)

+ 9 - 1
python/fbprophet/forecaster.py

@@ -227,6 +227,8 @@ class Prophet(object):
 
         if initialize_scales:
             self.y_scale = df['y'].abs().max()
+            if self.y_scale == 0:
+                self.y_scale = 1
             self.start = df['ds'].min()
             self.t_scale = df['ds'].max() - self.start
             for name, props in self.extra_regressors.items():
@@ -726,7 +728,13 @@ class Prophet(object):
                 'sigma_obs': 1,
             }
 
-        if self.mcmc_samples > 0:
+        if history['y'].min() == history['y'].max():
+            # Nothing to fit.
+            self.params = stan_init()
+            self.params['sigma_obs'] = 0.
+            for par in self.params:
+                self.params[par] = np.array([self.params[par]])
+        elif self.mcmc_samples > 0:
             stan_fit = model.sampling(
                 dat,
                 init=stan_init,

+ 16 - 0
python/fbprophet/tests/test_prophet.py

@@ -79,6 +79,22 @@ class TestProphet(TestCase):
         forecaster.fit(train)
         forecaster.predict(future)
 
+    def test_fit_predict_constant_history(self):
+        N = DATA.shape[0]
+        train = DATA.head(N // 2).copy()
+        train['y'] = 20
+        future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)})
+        m = Prophet()
+        m.fit(train)
+        fcst = m.predict(future)
+        self.assertEqual(fcst['yhat'].values[-1], 20)
+        train['y'] = 0
+        future = pd.DataFrame({'ds': DATA['ds'].tail(N // 2)})
+        m = Prophet()
+        m.fit(train)
+        fcst = m.predict(future)
+        self.assertEqual(fcst['yhat'].values[-1], 0)
+
     def test_setup_dataframe(self):
         m = Prophet()
         N = DATA.shape[0]