فهرست منبع

Enable seasonalities automatically depending on history length / frequency

Ben Letham 8 سال پیش
والد
کامیت
d937f47612
7فایلهای تغییر یافته به همراه255 افزوده شده و 16 حذف شده
  1. 37 4
      R/R/prophet.R
  2. 4 4
      R/man/prophet.Rd
  3. 20 0
      R/man/set_auto_seasonalities.Rd
  4. 42 0
      R/tests/testthat/data.csv
  5. 41 2
      R/tests/testthat/test_prophet.R
  6. 28 4
      python/fbprophet/forecaster.py
  7. 83 2
      python/fbprophet/tests/test_prophet.py

+ 37 - 4
R/R/prophet.R

@@ -27,8 +27,8 @@ globalVariables(c(
 #'  if input `changepoints` is supplied. If `changepoints` is not supplied,
 #'  if input `changepoints` is supplied. If `changepoints` is not supplied,
 #'  then n.changepoints potential changepoints are selected uniformly from the
 #'  then n.changepoints potential changepoints are selected uniformly from the
 #'  first 80 percent of df$ds.
 #'  first 80 percent of df$ds.
-#' @param yearly.seasonality Boolean, fit yearly seasonality.
-#' @param weekly.seasonality Boolean, fit weekly seasonality.
+#' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
+#' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
 #' @param holidays data frame with columns holiday (character) and ds (date
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  range of days around the date to be included as holidays. lower_window=-2
 #'  range of days around the date to be included as holidays. lower_window=-2
@@ -70,8 +70,8 @@ prophet <- function(df = df,
                     growth = 'linear',
                     growth = 'linear',
                     changepoints = NULL,
                     changepoints = NULL,
                     n.changepoints = 25,
                     n.changepoints = 25,
-                    yearly.seasonality = TRUE,
-                    weekly.seasonality = TRUE,
+                    yearly.seasonality = 'auto',
+                    weekly.seasonality = 'auto',
                     holidays = NULL,
                     holidays = NULL,
                     seasonality.prior.scale = 10,
                     seasonality.prior.scale = 10,
                     holidays.prior.scale = 10,
                     holidays.prior.scale = 10,
@@ -401,6 +401,38 @@ make_all_seasonality_features <- function(m, df) {
   return(seasonal.features)
   return(seasonal.features)
 }
 }
 
 
+#' Set seasonalities that were left on auto.
+#'
+#' Turns on yearly seasonality if there is >=2 years of history.
+#' Turns on weekly seasonality if there is >=2 weeks of history, and the
+#' spacing between dates in the history is <7 days.
+#'
+#' @param m Prophet object.
+#'
+#' @return The prophet model with seasonalities set.
+#'
+set_auto_seasonalities <- function(m) {
+  first <- min(m$history$ds)
+  last <- max(m$history$ds)
+  if (m$yearly.seasonality == 'auto') {
+    if (last - first < 730) {
+      m$yearly.seasonality <- FALSE
+    } else {
+      m$yearly.seasonality <- TRUE
+    }
+  }
+  if (m$weekly.seasonality == 'auto') {
+    dt <- diff(m$history$ds)
+    min.dt <- min(dt[dt > 0])
+    if ((last - first < 14) || (min.dt >= 7)) {
+      m$weekly.seasonality <- FALSE
+    } else {
+      m$weekly.seasonality <- TRUE
+    }
+  }
+  return(m)
+}
+
 #' Initialize linear growth.
 #' Initialize linear growth.
 #'
 #'
 #' Provides a strong initialization for linear growth by calculating the
 #' Provides a strong initialization for linear growth by calculating the
@@ -484,6 +516,7 @@ fit.prophet <- function(m, df, ...) {
   history <- out$df
   history <- out$df
   m <- out$m
   m <- out$m
   m$history <- history
   m$history <- history
+  m <- set_auto_seasonalities(m)
   seasonal.features <- make_all_seasonality_features(m, history)
   seasonal.features <- make_all_seasonality_features(m, history)
 
 
   m <- set_changepoints(m)
   m <- set_changepoints(m)

+ 4 - 4
R/man/prophet.Rd

@@ -5,8 +5,8 @@
 \title{Prophet forecaster.}
 \title{Prophet forecaster.}
 \usage{
 \usage{
 prophet(df = df, growth = "linear", changepoints = NULL,
 prophet(df = df, growth = "linear", changepoints = NULL,
-  n.changepoints = 25, yearly.seasonality = TRUE,
-  weekly.seasonality = TRUE, holidays = NULL,
+  n.changepoints = 25, yearly.seasonality = "auto",
+  weekly.seasonality = "auto", holidays = NULL,
   seasonality.prior.scale = 10, holidays.prior.scale = 10,
   seasonality.prior.scale = 10, holidays.prior.scale = 10,
   changepoint.prior.scale = 0.05, mcmc.samples = 0, interval.width = 0.8,
   changepoint.prior.scale = 0.05, mcmc.samples = 0, interval.width = 0.8,
   uncertainty.samples = 1000, fit = TRUE, ...)
   uncertainty.samples = 1000, fit = TRUE, ...)
@@ -28,9 +28,9 @@ if input `changepoints` is supplied. If `changepoints` is not supplied,
 then n.changepoints potential changepoints are selected uniformly from the
 then n.changepoints potential changepoints are selected uniformly from the
 first 80 percent of df$ds.}
 first 80 percent of df$ds.}
 
 
-\item{yearly.seasonality}{Boolean, fit yearly seasonality.}
+\item{yearly.seasonality}{Fit yearly seasonality; 'auto', TRUE, or FALSE.}
 
 
-\item{weekly.seasonality}{Boolean, fit weekly seasonality.}
+\item{weekly.seasonality}{Fit weekly seasonality; 'auto', TRUE, or FALSE.}
 
 
 \item{holidays}{data frame with columns holiday (character) and ds (date
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a
 type)and optionally columns lower_window and upper_window which specify a

+ 20 - 0
R/man/set_auto_seasonalities.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{set_auto_seasonalities}
+\alias{set_auto_seasonalities}
+\title{Set seasonalities that were left on auto.}
+\usage{
+set_auto_seasonalities(m)
+}
+\arguments{
+\item{m}{Prophet object.}
+}
+\value{
+The prophet model with seasonalities set.
+}
+\description{
+Turns on yearly seasonality if there is >=2 years of history.
+Turns on weekly seasonality if there is >=2 weeks of history, and the
+spacing between dates in the history is <7 days.
+}
+

+ 42 - 0
R/tests/testthat/data.csv

@@ -467,3 +467,45 @@ ds,y
 2014-03-27,60.97
 2014-03-27,60.97
 2014-03-28,60.01
 2014-03-28,60.01
 2014-03-31,60.24
 2014-03-31,60.24
+2014-04-01,62.62
+2014-04-02,62.72
+2014-04-03,59.49
+2014-04-04,56.75
+2014-04-07,56.95
+2014-04-08,58.19
+2014-04-09,62.41
+2014-04-10,59.16
+2014-04-11,58.53
+2014-04-14,58.89
+2014-04-15,59.09
+2014-04-16,59.72
+2014-04-17,58.94
+2014-04-21,61.24
+2014-04-22,63.03
+2014-04-23,61.36
+2014-04-24,60.87
+2014-04-25,57.71
+2014-04-28,56.14
+2014-04-29,58.15
+2014-04-30,59.78
+2014-05-01,61.15
+2014-05-02,60.46
+2014-05-05,61.22
+2014-05-06,58.53
+2014-05-07,57.39
+2014-05-08,56.76
+2014-05-09,57.24
+2014-05-12,59.83
+2014-05-13,59.83
+2014-05-14,59.23
+2014-05-15,57.92
+2014-05-16,58.02
+2014-05-19,59.21
+2014-05-20,58.56
+2014-05-21,60.49
+2014-05-22,60.52
+2014-05-23,61.35
+2014-05-27,63.48
+2014-05-28,63.51
+2014-05-29,63.83
+2014-05-30,63.30

+ 41 - 2
R/tests/testthat/test_prophet.R

@@ -119,7 +119,7 @@ test_that("fourier_series_yearly", {
 })
 })
 
 
 test_that("growth_init", {
 test_that("growth_init", {
-  history <- DATA
+  history <- DATA[1:468, ]
   history$cap <- max(history$y)
   history$cap <- max(history$y)
   m <- prophet(history, growth = 'logistic', fit = FALSE)
   m <- prophet(history, growth = 'logistic', fit = FALSE)
 
 
@@ -209,7 +209,8 @@ test_that("fit_with_holidays", {
 
 
 test_that("make_future_dataframe", {
 test_that("make_future_dataframe", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  m <- prophet(train)
+  train.t <- DATA[1:234, ]
+  m <- prophet(train.t)
   future <- make_future_dataframe(m, periods = 3, freq = 'd',
   future <- make_future_dataframe(m, periods = 3, freq = 'd',
                                   include_history = FALSE)
                                   include_history = FALSE)
   correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
   correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
@@ -220,3 +221,41 @@ test_that("make_future_dataframe", {
   correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
   correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
   expect_equal(future$ds, correct)
   expect_equal(future$ds, correct)
 })
 })
+
+test_that("auto_weekly_seasonality", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  # Should be True
+  N.w <- 15
+  train.w <- DATA[1:N.w, ]
+  m <- prophet(train.w, fit = FALSE)
+  expect_equal(m$weekly.seasonality, 'auto')
+  m <- prophet:::fit.prophet(m, train.w)
+  expect_equal(m$weekly.seasonality, TRUE)
+  # Should be False due to too short history
+  N.w <- 9
+  train.w <- DATA[1:N.w, ]
+  m <- prophet(train.w)
+  expect_equal(m$weekly.seasonality, FALSE)
+  m <- prophet(train.w, weekly.seasonality = TRUE)
+  expect_equal(m$weekly.seasonality, TRUE)
+  # Should be False due to weekly spacing
+  train.w <- DATA[seq(1, nrow(DATA), 7), ]
+  m <- prophet(train.w)
+  expect_equal(m$weekly.seasonality, FALSE)
+})
+
+test_that("auto_yearly_seasonality", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  # Should be True
+  m <- prophet(DATA, fit = FALSE)
+  expect_equal(m$yearly.seasonality, 'auto')
+  m <- prophet:::fit.prophet(m, DATA)
+  expect_equal(m$yearly.seasonality, TRUE)
+  # Should be False due to too short history
+  N.w <- 240
+  train.y <- DATA[1:N.w, ]
+  m <- prophet(train.y)
+  expect_equal(m$yearly.seasonality, FALSE)
+  m <- prophet(train.y, yearly.seasonality = TRUE)
+  expect_equal(m$yearly.seasonality, TRUE)
+})

+ 28 - 4
python/fbprophet/forecaster.py

@@ -46,8 +46,8 @@ class Prophet(object):
         if input `changepoints` is supplied. If `changepoints` is not supplied,
         if input `changepoints` is supplied. If `changepoints` is not supplied,
         then n.changepoints potential changepoints are selected uniformly from
         then n.changepoints potential changepoints are selected uniformly from
         the first 80 percent of the history.
         the first 80 percent of the history.
-    yearly_seasonality: Boolean, fit yearly seasonality.
-    weekly_seasonality: Boolean, fit weekly seasonality.
+    yearly_seasonality: Fit yearly seasonality. Can be 'auto', True, or False.
+    weekly_seasonality: Fit weekly seasonality. Can be 'auto', True, or False.
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
         and optionally columns lower_window and upper_window which specify a
         and optionally columns lower_window and upper_window which specify a
         range of days around the date to be included as holidays.
         range of days around the date to be included as holidays.
@@ -77,8 +77,8 @@ class Prophet(object):
             growth='linear',
             growth='linear',
             changepoints=None,
             changepoints=None,
             n_changepoints=25,
             n_changepoints=25,
-            yearly_seasonality=True,
-            weekly_seasonality=True,
+            yearly_seasonality='auto',
+            weekly_seasonality='auto',
             holidays=None,
             holidays=None,
             seasonality_prior_scale=10.0,
             seasonality_prior_scale=10.0,
             holidays_prior_scale=10.0,
             holidays_prior_scale=10.0,
@@ -392,6 +392,29 @@ class Prophet(object):
             seasonal_features.append(self.make_holiday_features(df['ds']))
             seasonal_features.append(self.make_holiday_features(df['ds']))
         return pd.concat(seasonal_features, axis=1)
         return pd.concat(seasonal_features, axis=1)
 
 
+    def set_auto_seasonalities(self):
+        """Set seasonalities that were left on auto.
+
+        Turns on yearly seasonality if there is >=2 years of history.
+        Turns on weekly seasonality if there is >=2 weeks of history, and the
+        spacing between dates in the history is <7 days.
+        """
+        first = self.history['ds'].min()
+        last = self.history['ds'].max()
+        if self.yearly_seasonality == 'auto':
+            if last - first < pd.Timedelta(days=730):
+                self.yearly_seasonality = False
+            else:
+                self.yearly_seasonality = True
+        if self.weekly_seasonality == 'auto':
+            dt = self.history['ds'].diff()
+            min_dt = dt.iloc[dt.nonzero()[0]].min()
+            if ((last - first < pd.Timedelta(weeks=2)) or
+                (min_dt >= pd.Timedelta(weeks=1))):
+                self.weekly_seasonality = False
+            else:
+                self.weekly_seasonality = True
+
     @staticmethod
     @staticmethod
     def linear_growth_init(df):
     def linear_growth_init(df):
         """Initialize linear growth.
         """Initialize linear growth.
@@ -487,6 +510,7 @@ class Prophet(object):
 
 
         history = self.setup_dataframe(history, initialize_scales=True)
         history = self.setup_dataframe(history, initialize_scales=True)
         self.history = history
         self.history = history
+        self.set_auto_seasonalities()
         seasonal_features = self.make_all_seasonality_features(history)
         seasonal_features = self.make_all_seasonality_features(history)
 
 
         self.set_changepoints()
         self.set_changepoints()

+ 83 - 2
python/fbprophet/tests/test_prophet.py

@@ -145,7 +145,7 @@ class TestProphet(TestCase):
 
 
     def test_growth_init(self):
     def test_growth_init(self):
         model = Prophet(growth='logistic')
         model = Prophet(growth='logistic')
-        history = DATA.copy()
+        history = DATA.iloc[:468].copy()
         history['cap'] = history['y'].max()
         history['cap'] = history['y'].max()
 
 
         history = model.setup_dataframe(history, initialize_scales=True)
         history = model.setup_dataframe(history, initialize_scales=True)
@@ -237,7 +237,7 @@ class TestProphet(TestCase):
         model.fit(DATA).predict()
         model.fit(DATA).predict()
 
 
     def test_make_future_dataframe(self):
     def test_make_future_dataframe(self):
-        N = DATA.shape[0]
+        N = 468
         train = DATA.head(N // 2)
         train = DATA.head(N // 2)
         forecaster = Prophet()
         forecaster = Prophet()
         forecaster.fit(train)
         forecaster.fit(train)
@@ -255,6 +255,45 @@ class TestProphet(TestCase):
         for i in range(3):
         for i in range(3):
             self.assertEqual(future.iloc[i]['ds'], correct[i])
             self.assertEqual(future.iloc[i]['ds'], correct[i])
 
 
+    def test_auto_weekly_seasonality(self):
+        # Should be True
+        N = 15
+        train = DATA.head(N)
+        m = Prophet()
+        self.assertEqual(m.weekly_seasonality, 'auto')
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, True)
+        # Should be False due to too short history
+        N = 9
+        train = DATA.head(N)
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, False)
+        m = Prophet(weekly_seasonality=True)
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, True)
+        # Should be False due to weekly spacing
+        train = DATA.iloc[::7, :]
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.weekly_seasonality, False)
+
+    def test_auto_yearly_seasonality(self):
+        # Should be True
+        m = Prophet()
+        self.assertEqual(m.yearly_seasonality, 'auto')
+        m.fit(DATA)
+        self.assertEqual(m.yearly_seasonality, True)
+        # Should be False due to too short history
+        N = 240
+        train = DATA.head(N)
+        m = Prophet()
+        m.fit(train)
+        self.assertEqual(m.yearly_seasonality, False)
+        m = Prophet(yearly_seasonality=True)
+        m.fit(train)
+        self.assertEqual(m.yearly_seasonality, True)
+
 
 
 DATA = pd.read_csv(StringIO("""
 DATA = pd.read_csv(StringIO("""
 ds,y
 ds,y
@@ -726,4 +765,46 @@ ds,y
 2014-03-27,60.97
 2014-03-27,60.97
 2014-03-28,60.01
 2014-03-28,60.01
 2014-03-31,60.24
 2014-03-31,60.24
+2014-04-01,62.62
+2014-04-02,62.72
+2014-04-03,59.49
+2014-04-04,56.75
+2014-04-07,56.95
+2014-04-08,58.19
+2014-04-09,62.41
+2014-04-10,59.16
+2014-04-11,58.53
+2014-04-14,58.89
+2014-04-15,59.09
+2014-04-16,59.72
+2014-04-17,58.94
+2014-04-21,61.24
+2014-04-22,63.03
+2014-04-23,61.36
+2014-04-24,60.87
+2014-04-25,57.71
+2014-04-28,56.14
+2014-04-29,58.15
+2014-04-30,59.78
+2014-05-01,61.15
+2014-05-02,60.46
+2014-05-05,61.22
+2014-05-06,58.53
+2014-05-07,57.39
+2014-05-08,56.76
+2014-05-09,57.24
+2014-05-12,59.83
+2014-05-13,59.83
+2014-05-14,59.23
+2014-05-15,57.92
+2014-05-16,58.02
+2014-05-19,59.21
+2014-05-20,58.56
+2014-05-21,60.49
+2014-05-22,60.52
+2014-05-23,61.35
+2014-05-27,63.48
+2014-05-28,63.51
+2014-05-29,63.83
+2014-05-30,63.30
 """), parse_dates=['ds'])
 """), parse_dates=['ds'])