Browse Source

Limit n_changepoints to number of observations.

bletham 8 năm trước cách đây
mục cha
commit
0b4ec4a9b3

+ 9 - 3
R/R/prophet.R

@@ -329,10 +329,16 @@ set_changepoints <- function(m) {
       }
     }
   } else {
+    # Place potential changepoints evenly through the first 80 pcnt of
+    # the history.
+    hist.size <- floor(nrow(m$history) * .8)
+    if (m$n.changepoints + 1 > hist.size) {
+      m$n.changepoints <- hist.size - 1
+      warning('n.changepoints greater than number of observations. Using ',
+              m$n.changepoints)
+    }
     if (m$n.changepoints > 0) {
-      # Place potential changepoints evenly through the first 80 pcnt of
-      # the history.
-      cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
+      cp.indexes <- round(seq.int(1, hist.size,
                           length.out = (m$n.changepoints + 1))[-1])
       m$changepoints <- m$history$ds[cp.indexes]
     } else {

+ 15 - 0
R/tests/testthat/test_prophet.R

@@ -101,6 +101,21 @@ test_that("get_zero_changepoints", {
   expect_equal(ncol(mat), 1)
 })
 
+test_that("override_n_changepoints", {
+  history <- train[1:20,]
+  m <- prophet(history, fit = FALSE)
+
+  out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
+  m <- out$m
+  history <- out$df
+  m$history <- history
+
+  m <- prophet:::set_changepoints(m)
+  expect_equal(m$n.changepoints, 15)
+  cp <- m$changepoints.t
+  expect_equal(length(cp), 15)
+})
+
 test_that("fourier_series_weekly", {
   mat <- prophet:::fourier_series(DATA$ds, 7, 3)
   true.values <- c(0.9165623, 0.3998920, 0.7330519, -0.6801727, -0.3302791,

+ 20 - 12
python/fbprophet/forecaster.py

@@ -277,19 +277,27 @@ class Prophet(object):
                 too_low = min(self.changepoints) < self.history['ds'].min()
                 too_high = max(self.changepoints) > self.history['ds'].max()
                 if too_low or too_high:
-                    raise ValueError('Changepoints must fall within training data.')
-        elif self.n_changepoints > 0:
-            # Place potential changepoints evenly through first 80% of history
-            max_ix = np.floor(self.history.shape[0] * 0.8)
-            cp_indexes = (
-                np.linspace(0, max_ix, self.n_changepoints + 1)
-                .round()
-                .astype(np.int)
-            )
-            self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1)
+                    raise ValueError(
+                        'Changepoints must fall within training data.')
         else:
-            # set empty changepoints
-            self.changepoints = []
+            # Place potential changepoints evenly through first 80% of history
+            hist_size = np.floor(self.history.shape[0] * 0.8)
+            if self.n_changepoints + 1 > hist_size:
+                self.n_changepoints = hist_size - 1
+                logger.info(
+                    'n_changepoints greater than number of observations.'
+                    'Using {}.'.format(self.n_changepoints)
+                )
+            if self.n_changepoints > 0:
+                cp_indexes = (
+                    np.linspace(0, hist_size, self.n_changepoints + 1)
+                    .round()
+                    .astype(np.int)
+                )
+                self.changepoints = self.history.ix[cp_indexes]['ds'].tail(-1)
+            else:
+                # set empty changepoints
+                self.changepoints = []
         if len(self.changepoints) > 0:
             self.changepoints_t = np.sort(np.array(
                 (self.changepoints - self.start) / self.t_scale))

+ 12 - 0
python/fbprophet/tests/test_prophet.py

@@ -130,6 +130,18 @@ class TestProphet(TestCase):
         self.assertEqual(mat.shape[0], N // 2)
         self.assertEqual(mat.shape[1], 1)
 
+    def test_override_n_changepoints(self):
+        m = Prophet()
+        history = DATA.head(20).copy()
+
+        history = m.setup_dataframe(history, initialize_scales=True)
+        m.history = history
+
+        m.set_changepoints()
+        self.assertEqual(m.n_changepoints, 15)
+        cp = m.changepoints_t
+        self.assertEqual(cp.shape[0], 15)
+
     def test_fourier_series_weekly(self):
         mat = Prophet.fourier_series(DATA['ds'], 7, 3)
         # These are from the R forecast package directly.