Browse Source

Validation and tests for setting range for automatic changepoints

Ben Letham 7 years ago
parent
commit
cb0b47994b

+ 1 - 0
R/R/diagnostics.R

@@ -175,6 +175,7 @@ prophet_copy <- function(m, cutoff = NULL) {
     growth = m$growth,
     changepoints = changepoints,
     n.changepoints = m$n.changepoints,
+    changepoint.range = m$changepoint.range,
     yearly.seasonality = FALSE,
     weekly.seasonality = FALSE,
     daily.seasonality = FALSE,

+ 12 - 13
R/R/prophet.R

@@ -28,9 +28,10 @@ utils::globalVariables(c(
 #' @param n.changepoints Number of potential changepoints to include. Not used
 #'  if input `changepoints` is supplied. If `changepoints` is not supplied,
 #'  then n.changepoints potential changepoints are selected uniformly from the
-#'  first `changepoint.threshold` percent of df$ds.
-#' @param changepoint.threshold Parameter controling where to select the changepoints.
-#'  Not used if input `changepoints` is supplied.
+#'  first `changepoint.range` proportion of df$ds.
+#' @param changepoint.range Proportion of history in which trend changepoints
+#'  will be estimated. Defaults to 0.8 for the first 80%. Not used if
+#'  `changepoints` is specified.
 #' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
 #'  FALSE, or a number of Fourier terms to generate.
 #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
@@ -81,7 +82,7 @@ prophet <- function(df = NULL,
                     growth = 'linear',
                     changepoints = NULL,
                     n.changepoints = 25,
-                    changepoint.threshold = 0.8,
+                    changepoint.range = 0.8,
                     yearly.seasonality = 'auto',
                     weekly.seasonality = 'auto',
                     daily.seasonality = 'auto',
@@ -106,7 +107,7 @@ prophet <- function(df = NULL,
     growth = growth,
     changepoints = changepoints,
     n.changepoints = n.changepoints,
-    changepoint.threshold = changepoint.threshold,
+    changepoint.range = changepoint.range,
     yearly.seasonality = yearly.seasonality,
     weekly.seasonality = weekly.seasonality,
     daily.seasonality = daily.seasonality,
@@ -152,6 +153,9 @@ validate_inputs <- function(m) {
   if (!(m$growth %in% c('linear', 'logistic'))) {
     stop("Parameter 'growth' should be 'linear' or 'logistic'.")
   }
+  if ((m$changepoint.range < 0) | (m$changepoint.range > 1)) {
+    stop("Parameter 'changepoint.range' must be in [0, 1]")
+  }
   if (!is.null(m$holidays)) {
     if (!(exists('holiday', where = m$holidays))) {
       stop('Holidays dataframe must have holiday field.')
@@ -455,14 +459,9 @@ set_changepoints <- function(m) {
       }
     }
   } else {
-    # Place potential changepoints evenly through the first changepoint.threshold pcnt of
-    # the history.
-    if (m$changepoint.threshold > 1 || m$changepoint.threshold <= 0){
-        m$changepoint.threshold <- .8
-        message('changepoint.threshold greater than 1 or less than equal to 0. Using ',
-                m$changepoint.threshold)
-    }
-    hist.size <- floor(nrow(m$history) * m$changepoint.threshold)
+    # Place potential changepoints evenly through the first changepoint.range
+    # proportion of the history.
+    hist.size <- floor(nrow(m$history) * m$changepoint.range)
     if (m$n.changepoints + 1 > hist.size) {
       m$n.changepoints <- hist.size - 1
       message('n.changepoints greater than number of observations. Using ',

+ 2 - 1
R/tests/testthat/test_diagnostics.R

@@ -158,6 +158,7 @@ test_that("copy", {
       growth = as.character(products$growth[i]),
       changepoints = NULL,
       n.changepoints = 3,
+      changepoint.range = 0.9,
       yearly.seasonality = products$yearly.seasonality[i],
       weekly.seasonality = products$weekly.seasonality[i],
       daily.seasonality = products$daily.seasonality[i],
@@ -179,7 +180,7 @@ test_that("copy", {
     args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
               'seasonality.prior.scale', 'holidays.prior.scale',
               'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
-              'uncertainty.samples', 'seasonality.mode')
+              'uncertainty.samples', 'seasonality.mode', 'changepoint.range')
     for (arg in args) {
       expect_equal(m1[[arg]], m2[[arg]])
     }

+ 18 - 1
R/tests/testthat/test_prophet.R

@@ -121,7 +121,24 @@ test_that("get_changepoints", {
   cp <- m$changepoints.t
   expect_equal(length(cp), m$n.changepoints)
   expect_true(min(cp) > 0)
-  expect_true(max(cp) < 1)
+  expect_true(max(cp) <= history$t[ceiling(0.8 * length(history$t))])
+})
+
+test_that("set_changepoint_range", {
+  history <- train
+  m <- prophet(history, fit = FALSE, changepoint.range = 0.4)
+
+  out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
+  history <- out$df
+  m <- out$m
+  m$history <- history
+
+  m <- prophet:::set_changepoints(m)
+
+  cp <- m$changepoints.t
+  expect_equal(length(cp), m$n.changepoints)
+  expect_true(min(cp) > 0)
+  expect_true(max(cp) <= history$t[ceiling(0.4 * length(history$t))])
 })
 
 test_that("get_zero_changepoints", {

+ 1 - 0
python/fbprophet/diagnostics.py

@@ -179,6 +179,7 @@ def prophet_copy(m, cutoff=None):
     m2 = m.__class__(
         growth=m.growth,
         n_changepoints=m.n_changepoints,
+        changepoint_range=m.changepoint_range,
         changepoints=changepoints,
         yearly_seasonality=False,
         weekly_seasonality=False,

+ 12 - 13
python/fbprophet/forecaster.py

@@ -57,8 +57,10 @@ class Prophet(object):
     n_changepoints: Number of potential changepoints to include. Not used
         if input `changepoints` is supplied. If `changepoints` is not supplied,
         then n_changepoints potential changepoints are selected uniformly from
-        the first `changepoint_threshold` percent of the history.
-    changepoint_threshold: Parameter controling where to select the changepoints.
+        the first `changepoint_range` proportion of the history.
+    changepoint_range: Proportion of history in which trend changepoints will
+        be estimated. Defaults to 0.8 for the first 80%. Not used if
+        `changepoints` is specified.
     Not used if input `changepoints` is supplied.
     yearly_seasonality: Fit yearly seasonality.
         Can be 'auto', True, False, or a number of Fourier terms to generate.
@@ -99,7 +101,7 @@ class Prophet(object):
             growth='linear',
             changepoints=None,
             n_changepoints=25,
-            changepoint_threshold=0.8,
+            changepoint_range=0.8,
             yearly_seasonality='auto',
             weekly_seasonality='auto',
             daily_seasonality='auto',
@@ -122,7 +124,7 @@ class Prophet(object):
             self.n_changepoints = n_changepoints
             self.specified_changepoints = False
 
-        self.changepoint_threshold = changepoint_threshold
+        self.changepoint_range = changepoint_range
         self.yearly_seasonality = yearly_seasonality
         self.weekly_seasonality = weekly_seasonality
         self.daily_seasonality = daily_seasonality
@@ -168,6 +170,8 @@ class Prophet(object):
         if self.growth not in ('linear', 'logistic'):
             raise ValueError(
                 "Parameter 'growth' should be 'linear' or 'logistic'.")
+        if ((self.changepoint_range < 0) or (self.changepoint_range > 1)):
+            raise ValueError("Parameter 'changepoint_range' must be in [0, 1]")
         if self.holidays is not None:
             has_lower = 'lower_window' in self.holidays
             has_upper = 'upper_window' in self.holidays
@@ -336,15 +340,10 @@ class Prophet(object):
                     raise ValueError(
                         'Changepoints must fall within training data.')
         else:
-            # Place potential changepoints evenly through first changepoint_threshold
-            # of history
-            if (self.changepoint_threshold > 1 or self.changepoint_threshold <= 0):
-                self.changepoint_threshold = 0.8
-                logger.info(
-                    'changepoint_threshold greater than 1 or less than equal to 0.'
-                    'Using {}.'.format(self.changepoint_threshold)
-                )
-            hist_size = np.floor(self.history.shape[0] * self.changepoint_threshold)
+            # Place potential changepoints evenly through first
+            # changepoint_range proportion of the history
+            hist_size = np.floor(
+                self.history.shape[0] * self.changepoint_range)
             if self.n_changepoints + 1 > hist_size:
                 self.n_changepoints = hist_size - 1
                 logger.info(

+ 2 - 0
python/fbprophet/tests/test_diagnostics.py

@@ -180,6 +180,7 @@ class TestDiagnostics(TestCase):
             ['linear', 'logistic'],  # growth
             [None, pd.to_datetime(['2016-12-25'])],  # changepoints
             [3],  # n_changepoints
+            [0.9],  # changepoint_range
             [True, False],  # yearly_seasonality
             [True, False],  # weekly_seasonality
             [True, False],  # daily_seasonality
@@ -201,6 +202,7 @@ class TestDiagnostics(TestCase):
             m2 = diagnostics.prophet_copy(m1)
             self.assertEqual(m1.growth, m2.growth)
             self.assertEqual(m1.n_changepoints, m2.n_changepoints)
+            self.assertEqual(m1.changepoint_range, m2.changepoint_range)
             self.assertEqual(m1.changepoints, m2.changepoints)
             self.assertEqual(False, m2.yearly_seasonality)
             self.assertEqual(False, m2.weekly_seasonality)

+ 19 - 1
python/fbprophet/tests/test_prophet.py

@@ -150,7 +150,25 @@ class TestProphet(TestCase):
         self.assertEqual(cp.shape[0], m.n_changepoints)
         self.assertEqual(len(cp.shape), 1)
         self.assertTrue(cp.min() > 0)
-        self.assertTrue(cp.max() < 1)
+        cp_indx = int(np.ceil(0.8 * history.shape[0]))
+        self.assertTrue(cp.max() <= history['t'].values[cp_indx])
+
+    def test_set_changepoint_range(self):
+        m = Prophet(changepoint_range=0.4)
+        N = DATA.shape[0]
+        history = DATA.head(N // 2).copy()
+
+        history = m.setup_dataframe(history, initialize_scales=True)
+        m.history = history
+
+        m.set_changepoints()
+
+        cp = m.changepoints_t
+        self.assertEqual(cp.shape[0], m.n_changepoints)
+        self.assertEqual(len(cp.shape), 1)
+        self.assertTrue(cp.min() > 0)
+        cp_indx = int(np.ceil(0.4 * history.shape[0]))
+        self.assertTrue(cp.max() <= history['t'].values[cp_indx])
 
     def test_get_zero_changepoints(self):
         m = Prophet(n_changepoints=0)