Browse Source

Enable seasonalities automatically depending on history length / frequency

Ben Letham 8 years ago
parent
commit
d937f47612

+ 37 - 4
R/R/prophet.R

@@ -27,8 +27,8 @@ globalVariables(c(
 #'  if input `changepoints` is supplied. If `changepoints` is not supplied,
 #'  then n.changepoints potential changepoints are selected uniformly from the
 #'  first 80 percent of df$ds.
-#' @param yearly.seasonality Boolean, fit yearly seasonality.
-#' @param weekly.seasonality Boolean, fit weekly seasonality.
+#' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
+#' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  range of days around the date to be included as holidays. lower_window=-2
@@ -70,8 +70,8 @@ prophet <- function(df = df,
                     growth = 'linear',
                     changepoints = NULL,
                     n.changepoints = 25,
-                    yearly.seasonality = TRUE,
-                    weekly.seasonality = TRUE,
+                    yearly.seasonality = 'auto',
+                    weekly.seasonality = 'auto',
                     holidays = NULL,
                     seasonality.prior.scale = 10,
                     holidays.prior.scale = 10,
@@ -401,6 +401,38 @@ make_all_seasonality_features <- function(m, df) {
   return(seasonal.features)
 }
 
+#' Set seasonalities that were left on auto.
+#'
+#' Turns on yearly seasonality if there is >=2 years of history.
+#' Turns on weekly seasonality if there is >=2 weeks of history, and the
+#' spacing between dates in the history is <7 days.
+#'
+#' @param m Prophet object.
+#'
+#' @return The prophet model with seasonalities set.
+#'
+set_auto_seasonalities <- function(m) {
+  first <- min(m$history$ds)
+  last <- max(m$history$ds)
+  if (m$yearly.seasonality == 'auto') {
+    if (last - first < 730) {
+      m$yearly.seasonality <- FALSE
+    } else {
+      m$yearly.seasonality <- TRUE
+    }
+  }
+  if (m$weekly.seasonality == 'auto') {
+    dt <- diff(m$history$ds)
+    min.dt <- min(dt[dt > 0])
+    if ((last - first < 14) || (min.dt >= 7)) {
+      m$weekly.seasonality <- FALSE
+    } else {
+      m$weekly.seasonality <- TRUE
+    }
+  }
+  return(m)
+}
+
 #' Initialize linear growth.
 #'
 #' Provides a strong initialization for linear growth by calculating the
@@ -484,6 +516,7 @@ fit.prophet <- function(m, df, ...) {
   history <- out$df
   m <- out$m
   m$history <- history
+  m <- set_auto_seasonalities(m)
   seasonal.features <- make_all_seasonality_features(m, history)
 
   m <- set_changepoints(m)

+ 4 - 4
R/man/prophet.Rd

@@ -5,8 +5,8 @@
 \title{Prophet forecaster.}
 \usage{
 prophet(df = df, growth = "linear", changepoints = NULL,
-  n.changepoints = 25, yearly.seasonality = TRUE,
-  weekly.seasonality = TRUE, holidays = NULL,
+  n.changepoints = 25, yearly.seasonality = "auto",
+  weekly.seasonality = "auto", holidays = NULL,
   seasonality.prior.scale = 10, holidays.prior.scale = 10,
   changepoint.prior.scale = 0.05, mcmc.samples = 0, interval.width = 0.8,
   uncertainty.samples = 1000, fit = TRUE, ...)
@@ -28,9 +28,9 @@ if input `changepoints` is supplied. If `changepoints` is not supplied,
 then n.changepoints potential changepoints are selected uniformly from the
 first 80 percent of df$ds.}
 
-\item{yearly.seasonality}{Boolean, fit yearly seasonality.}
+\item{yearly.seasonality}{Fit yearly seasonality; 'auto', TRUE, or FALSE.}
 
-\item{weekly.seasonality}{Boolean, fit weekly seasonality.}
+\item{weekly.seasonality}{Fit weekly seasonality; 'auto', TRUE, or FALSE.}
 
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a

+ 20 - 0
R/man/set_auto_seasonalities.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{set_auto_seasonalities}
+\alias{set_auto_seasonalities}
+\title{Set seasonalities that were left on auto.}
+\usage{
+set_auto_seasonalities(m)
+}
+\arguments{
+\item{m}{Prophet object.}
+}
+\value{
+The prophet model with seasonalities set.
+}
+\description{
+Turns on yearly seasonality if there is >=2 years of history.
+Turns on weekly seasonality if there is >=2 weeks of history, and the
+spacing between dates in the history is <7 days.
+}
+

+ 42 - 0
R/tests/testthat/data.csv

@@ -467,3 +467,45 @@ ds,y
 2014-03-27,60.97
 2014-03-28,60.01
 2014-03-31,60.24
+2014-04-01,62.62
+2014-04-02,62.72
+2014-04-03,59.49
+2014-04-04,56.75
+2014-04-07,56.95
+2014-04-08,58.19
+2014-04-09,62.41
+2014-04-10,59.16
+2014-04-11,58.53
+2014-04-14,58.89
+2014-04-15,59.09
+2014-04-16,59.72
+2014-04-17,58.94
+2014-04-21,61.24
+2014-04-22,63.03
+2014-04-23,61.36
+2014-04-24,60.87
+2014-04-25,57.71
+2014-04-28,56.14
+2014-04-29,58.15
+2014-04-30,59.78
+2014-05-01,61.15
+2014-05-02,60.46
+2014-05-05,61.22
+2014-05-06,58.53
+2014-05-07,57.39
+2014-05-08,56.76
+2014-05-09,57.24
+2014-05-12,59.83
+2014-05-13,59.83
+2014-05-14,59.23
+2014-05-15,57.92
+2014-05-16,58.02
+2014-05-19,59.21
+2014-05-20,58.56
+2014-05-21,60.49
+2014-05-22,60.52
+2014-05-23,61.35
+2014-05-27,63.48
+2014-05-28,63.51
+2014-05-29,63.83
+2014-05-30,63.30

+ 41 - 2
R/tests/testthat/test_prophet.R

@@ -119,7 +119,7 @@ test_that("fourier_series_yearly", {
 })
 
 test_that("growth_init", {
-  history <- DATA
+  history <- DATA[1:468, ]
   history$cap <- max(history$y)
   m <- prophet(history, growth = 'logistic', fit = FALSE)
 
@@ -209,7 +209,8 @@ test_that("fit_with_holidays", {
 
 test_that("make_future_dataframe", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  m <- prophet(train)
+  train.t <- DATA[1:234, ]
+  m <- prophet(train.t)
   future <- make_future_dataframe(m, periods = 3, freq = 'd',
                                   include_history = FALSE)
   correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
@@ -220,3 +221,41 @@ test_that("make_future_dataframe", {
   correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
   expect_equal(future$ds, correct)
 })
+
+test_that("auto_weekly_seasonality", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  # Should be True
+  N.w <- 15
+  train.w <- DATA[1:N.w, ]
+  m <- prophet(train.w, fit = FALSE)
+  expect_equal(m$weekly.seasonality, 'auto')
+  m <- prophet:::fit.prophet(m, train.w)
+  expect_equal(m$weekly.seasonality, TRUE)
+  # Should be False due to too short history
+  N.w <- 9
+  train.w <- DATA[1:N.w, ]
+  m <- prophet(train.w)
+  expect_equal(m$weekly.seasonality, FALSE)
+  m <- prophet(train.w, weekly.seasonality = TRUE)
+  expect_equal(m$weekly.seasonality, TRUE)
+  # Should be False due to weekly spacing
+  train.w <- DATA[seq(1, nrow(DATA), 7), ]
+  m <- prophet(train.w)
+  expect_equal(m$weekly.seasonality, FALSE)
+})
+
+test_that("auto_yearly_seasonality", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  # Should be True
+  m <- prophet(DATA, fit = FALSE)
+  expect_equal(m$yearly.seasonality, 'auto')
+  m <- prophet:::fit.prophet(m, DATA)
+  expect_equal(m$yearly.seasonality, TRUE)
+  # Should be False due to too short history
+  N.w <- 240
+  train.y <- DATA[1:N.w, ]
+  m <- prophet(train.y)
+  expect_equal(m$yearly.seasonality, FALSE)
+  m <- prophet(train.y, yearly.seasonality = TRUE)
+  expect_equal(m$yearly.seasonality, TRUE)
+})

+ 28 - 4
python/fbprophet/forecaster.py

@@ -46,8 +46,8 @@ class Prophet(object):
         if input `changepoints` is supplied. If `changepoints` is not supplied,
         then n.changepoints potential changepoints are selected uniformly from
         the first 80 percent of the history.
-    yearly_seasonality: Boolean, fit yearly seasonality.
-    weekly_seasonality: Boolean, fit weekly seasonality.
+    yearly_seasonality: Fit yearly seasonality. Can be 'auto', True, or False.
+    weekly_seasonality: Fit weekly seasonality. Can be 'auto', True, or False.
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
         and optionally columns lower_window and upper_window which specify a
         range of days around the date to be included as holidays.
@@ -77,8 +77,8 @@ class Prophet(object):
             growth='linear',
             changepoints=None,
             n_changepoints=25,
-            yearly_seasonality=True,
-            weekly_seasonality=True,
+            yearly_seasonality='auto',
+            weekly_seasonality='auto',
             holidays=None,
             seasonality_prior_scale=10.0,
             holidays_prior_scale=10.0,
@@ -392,6 +392,29 @@ class Prophet(object):
             seasonal_features.append(self.make_holiday_features(df['ds']))
         return pd.concat(seasonal_features, axis=1)
 
+    def set_auto_seasonalities(self):
+        """Set seasonalities that were left on auto.
+
+        Turns on yearly seasonality if there is >=2 years of history.
+        Turns on weekly seasonality if there is >=2 weeks of history, and the
+        spacing between dates in the history is <7 days.
+        """
+        first = self.history['ds'].min()
+        last = self.history['ds'].max()
+        if self.yearly_seasonality == 'auto':
+            if last - first < pd.Timedelta(days=730):
+                self.yearly_seasonality = False
+            else:
+                self.yearly_seasonality = True
+        if self.weekly_seasonality == 'auto':
+            dt = self.history['ds'].diff()
+            min_dt = dt.iloc[dt.nonzero()[0]].min()
+            if ((last - first < pd.Timedelta(weeks=2)) or
+                (min_dt >= pd.Timedelta(weeks=1))):
+                self.weekly_seasonality = False
+            else:
+                self.weekly_seasonality = True
+
     @staticmethod
     def linear_growth_init(df):
         """Initialize linear growth.
@@ -487,6 +510,7 @@ class Prophet(object):
 
         history = self.setup_dataframe(history, initialize_scales=True)
         self.history = history
+        self.set_auto_seasonalities()
         seasonal_features = self.make_all_seasonality_features(history)
 
         self.set_changepoints()

+ 83 - 2
python/fbprophet/tests/test_prophet.py

@@ -145,7 +145,7 @@ class TestProphet(TestCase):
 
     def test_growth_init(self):
         model = Prophet(growth='logistic')
-        history = DATA.copy()
+        history = DATA.iloc[:468].copy()
         history['cap'] = history['y'].max()
 
         history = model.setup_dataframe(history, initialize_scales=True)
@@ -237,7 +237,7 @@ class TestProphet(TestCase):
         model.fit(DATA).predict()
 
     def test_make_future_dataframe(self):
-        N = DATA.shape[0]
+        N = 468
         train = DATA.head(N // 2)
         forecaster = Prophet()
         forecaster.fit(train)
@@ -255,6 +255,45 @@ class TestProphet(TestCase):
         for i in range(3):
             self.assertEqual(future.iloc[i]['ds'], correct[i])
 
+    def test_auto_weekly_seasonality(self):
+        # Should be True
+        N = 15
+        train = DATA.head(N)
+        m = Prophet()
+        self.assertEqual(m.weekly_seasonality, 'auto')
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, True)
+        # Should be False due to too short history
+        N = 9
+        train = DATA.head(N)
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, False)
+        m = Prophet(weekly_seasonality=True)
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, True)
+        # Should be False due to weekly spacing
+        train = DATA.iloc[::7, :]
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, False)
+
+    def test_auto_yearly_seasonality(self):
+        # Should be True
+        m = Prophet()
+        self.assertEqual(m.yearly_seasonality, 'auto')
+        m.fit(DATA)
+        self.assertEqual(m.yearly_seasonality, True)
+        # Should be False due to too short history
+        N = 240
+        train = DATA.head(N)
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.yearly_seasonality, False)
+        m = Prophet(yearly_seasonality=True)
+        m.fit(train)
+        self.assertEqual(m.yearly_seasonality, True)
+
 
 DATA = pd.read_csv(StringIO("""
 ds,y
@@ -726,4 +765,46 @@ ds,y
 2014-03-27,60.97
 2014-03-28,60.01
 2014-03-31,60.24
+2014-04-01,62.62
+2014-04-02,62.72
+2014-04-03,59.49
+2014-04-04,56.75
+2014-04-07,56.95
+2014-04-08,58.19
+2014-04-09,62.41
+2014-04-10,59.16
+2014-04-11,58.53
+2014-04-14,58.89
+2014-04-15,59.09
+2014-04-16,59.72
+2014-04-17,58.94
+2014-04-21,61.24
+2014-04-22,63.03
+2014-04-23,61.36
+2014-04-24,60.87
+2014-04-25,57.71
+2014-04-28,56.14
+2014-04-29,58.15
+2014-04-30,59.78
+2014-05-01,61.15
+2014-05-02,60.46
+2014-05-05,61.22
+2014-05-06,58.53
+2014-05-07,57.39
+2014-05-08,56.76
+2014-05-09,57.24
+2014-05-12,59.83
+2014-05-13,59.83
+2014-05-14,59.23
+2014-05-15,57.92
+2014-05-16,58.02
+2014-05-19,59.21
+2014-05-20,58.56
+2014-05-21,60.49
+2014-05-22,60.52
+2014-05-23,61.35
+2014-05-27,63.48
+2014-05-28,63.51
+2014-05-29,63.83
+2014-05-30,63.30
 """), parse_dates=['ds'])