Browse Source

Generalize seasonality representation (Python)

bl 8 years ago
parent
commit
b3017c025f
2 changed files with 88 additions and 71 deletions
  1. 69 60
      python/fbprophet/forecaster.py
  2. 19 11
      python/fbprophet/tests/test_prophet.py

+ 69 - 60
python/fbprophet/forecaster.py

@@ -78,7 +78,6 @@ class Prophet(object):
         parameters, which will include uncertainty in seasonality.
         parameters, which will include uncertainty in seasonality.
     uncertainty_samples: Number of simulated draws used to estimate
     uncertainty_samples: Number of simulated draws used to estimate
         uncertainty intervals.
         uncertainty intervals.
-    daily_seasonality: Boolean, fit daily seasonality
     """
     """
 
 
     def __init__(
     def __init__(
@@ -96,7 +95,6 @@ class Prophet(object):
             mcmc_samples=0,
             mcmc_samples=0,
             interval_width=0.80,
             interval_width=0.80,
             uncertainty_samples=1000,
             uncertainty_samples=1000,
-            daily_seasonality=False,
     ):
     ):
         self.growth = growth
         self.growth = growth
 
 
@@ -134,6 +132,7 @@ class Prophet(object):
         self.y_scale = None
         self.y_scale = None
         self.t_scale = None
         self.t_scale = None
         self.changepoints_t = None
         self.changepoints_t = None
+        self.seasonalities = {}
         self.stan_fit = None
         self.stan_fit = None
         self.params = {}
         self.params = {}
         self.history = None
         self.history = None
@@ -358,81 +357,91 @@ class Prophet(object):
             # Add a column of zeros in case no seasonality is used.
             # Add a column of zeros in case no seasonality is used.
             pd.DataFrame({'zeros': np.zeros(df.shape[0])})
             pd.DataFrame({'zeros': np.zeros(df.shape[0])})
         ]
         ]
-
-        # Seasonality features
-        if self.yearly_seasonality > 0:
-            seasonal_features.append(self.make_seasonality_features(
-                df['ds'],
-                365.25,
-                self.yearly_seasonality,
-                'yearly',
-            ))
-
-        if self.weekly_seasonality > 0:
-            seasonal_features.append(self.make_seasonality_features(
-                df['ds'],
-                7,
-                self.weekly_seasonality,
-                'weekly',
-            ))
-
-        if self.daily_seasonality > 0:
+        for name, (period, series_order) in self.seasonalities.items():
             seasonal_features.append(self.make_seasonality_features(
             seasonal_features.append(self.make_seasonality_features(
                 df['ds'],
                 df['ds'],
-                1,
-                self.daily_seasonality,
-                'daily',
+                period,
+                series_order,
+                name,
             ))
             ))
 
 
         if self.holidays is not None:
         if self.holidays is not None:
             seasonal_features.append(self.make_holiday_features(df['ds']))
             seasonal_features.append(self.make_holiday_features(df['ds']))
         return pd.concat(seasonal_features, axis=1)
         return pd.concat(seasonal_features, axis=1)
 
 
+    def parse_seasonality_args(self, name, arg, auto_disable, default_order):
+        """Get number of fourier components for built-in seasonalities.
+        
+        Parameters
+        ----------
+        name: string name of the seasonality component.
+        arg: 'auto', True, False, or number of fourier components as provided.
+        auto_disable: bool if seasonality should be disabled when 'auto'.
+        default_order: int default fourier order
+
+        Returns
+        -------
+        Number of fourier components, or 0 for disabled.
+        """
+        if arg == 'auto':
+            fourier_order = 0
+            if name in self.seasonalities:
+                logger.info(
+                    'Found custom seasonality named "{name}", '
+                    'disabling built-in {name} seasonality.'.format(name=name)
+                )
+            elif auto_disable:
+                logger.info(
+                    'Disabling {name} seasonality. Run prophet with '
+                    '{name}_seasonality=True to override this.'.format(
+                        name=name)
+                )
+            else:
+                fourier_order = default_order
+        elif arg is True:
+            fourier_order = default_order
+        elif arg is False:
+            fourier_order = 0
+        else:
+            fourier_order = int(arg)
+        return fourier_order
+
     def set_auto_seasonalities(self):
     def set_auto_seasonalities(self):
         """Set seasonalities that were left on auto.
         """Set seasonalities that were left on auto.
 
 
         Turns on yearly seasonality if there is >=2 years of history.
         Turns on yearly seasonality if there is >=2 years of history.
         Turns on weekly seasonality if there is >=2 weeks of history, and the
         Turns on weekly seasonality if there is >=2 weeks of history, and the
         spacing between dates in the history is <7 days.
         spacing between dates in the history is <7 days.
+        Turns on daily seasonality if there is >=2 days of history, and the
+        spacing between dates in the history is <1 day.
         """
         """
         first = self.history['ds'].min()
         first = self.history['ds'].min()
         last = self.history['ds'].max()
         last = self.history['ds'].max()
-        if self.yearly_seasonality == 'auto':
-            if last - first < pd.Timedelta(days=730):
-                self.yearly_seasonality = 0
-                logger.info('Disabling yearly seasonality. Run prophet with '
-                      'yearly_seasonality=True to override this.')
-            else:
-                self.yearly_seasonality = 10
-        elif self.yearly_seasonality is True:
-            self.yearly_seasonality = 10
-                
-        if self.weekly_seasonality == 'auto':
-            dt = self.history['ds'].diff()
-            min_dt = dt.iloc[dt.nonzero()[0]].min()
-            if ((last - first < pd.Timedelta(weeks=2)) or
-                    (min_dt >= pd.Timedelta(weeks=1))):
-                self.weekly_seasonality = 0
-                logger.info('Disabling weekly seasonality. Run prophet with '
-                      'weekly_seasonality=True to override this.')
-            else:
-                self.weekly_seasonality = 3
-        elif self.weekly_seasonality is True:
-            self.weekly_seasonality = 3
-                
-        if self.daily_seasonality == 'auto':
-            # disabled by default but if the average difference is <1 day
-            # then we assume there is intra-day modeling
-            dt = self.history['ds'].diff()
-            min_dt = dt.iloc[dt.nonzero()[0]].min()
-            if (min_dt< pd.Timedelta(days=1)):
-                self.daily_seasonality = 4
-                logger.info('Enabling daily seasonality. Run prophet with '
-                      'daily_seasonality=False to override this.')
-            else:
-                self.daily_seasonality = 0
-        elif self.daily_seasonality is True:
-            self.daily_seasonality = 4
+        dt = self.history['ds'].diff()
+        min_dt = dt.iloc[dt.nonzero()[0]].min()
+
+        # Yearly seasonality
+        yearly_disable = last - first < pd.Timedelta(days=730)
+        fourier_order = self.parse_seasonality_args(
+            'yearly', self.yearly_seasonality, yearly_disable, 10)
+        if fourier_order > 0:
+            self.seasonalities['yearly'] = (365.25, fourier_order)
+
+        # Weekly seasonality
+        weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
+            (min_dt >= pd.Timedelta(weeks=1)))
+        fourier_order = self.parse_seasonality_args(
+            'weekly', self.weekly_seasonality, weekly_disable, 3)
+        if fourier_order > 0:
+            self.seasonalities['weekly'] = (7, fourier_order)
+
+        # Daily seasonality
+        daily_disable = ((last - first < pd.Timedelta(days=2)) or
+            (min_dt >= pd.Timedelta(days=1)))
+        fourier_order = self.parse_seasonality_args(
+            'daily', self.daily_seasonality, daily_disable, 4)
+        if fourier_order > 0:
+            self.seasonalities['daily'] = (1, fourier_order)
 
 
     @staticmethod
     @staticmethod
     def linear_growth_init(df):
     def linear_growth_init(df):

+ 19 - 11
python/fbprophet/tests/test_prophet.py

@@ -252,43 +252,51 @@ class TestProphet(TestCase):
             self.assertEqual(future.iloc[i]['ds'], correct[i])
             self.assertEqual(future.iloc[i]['ds'], correct[i])
 
 
     def test_auto_weekly_seasonality(self):
     def test_auto_weekly_seasonality(self):
-        # Should be True
+        # Should be enabled
         N = 15
         N = 15
         train = DATA.head(N)
         train = DATA.head(N)
         m = Prophet()
         m = Prophet()
         self.assertEqual(m.weekly_seasonality, 'auto')
         self.assertEqual(m.weekly_seasonality, 'auto')
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.weekly_seasonality, True)
-        # Should be False due to too short history
+        self.assertIn('weekly', m.seasonalities)
+        self.assertEqual(m.seasonalities['weekly'], (7, 3))
+        # Should be disabled due to too short history
         N = 9
         N = 9
         train = DATA.head(N)
         train = DATA.head(N)
         m = Prophet()
         m = Prophet()
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.weekly_seasonality, False)
+        self.assertNotIn('weekly', m.seasonalities)
         m = Prophet(weekly_seasonality=True)
         m = Prophet(weekly_seasonality=True)
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.weekly_seasonality, True)
+        self.assertIn('weekly', m.seasonalities)
         # Should be False due to weekly spacing
         # Should be False due to weekly spacing
         train = DATA.iloc[::7, :]
         train = DATA.iloc[::7, :]
         m = Prophet()
         m = Prophet()
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.weekly_seasonality, False)
+        self.assertNotIn('weekly', m.seasonalities)
+        m = Prophet(weekly_seasonality=2)
+        m.fit(DATA)
+        self.assertEqual(m.seasonalities['weekly'], (7, 2))
 
 
     def test_auto_yearly_seasonality(self):
     def test_auto_yearly_seasonality(self):
-        # Should be True
+        # Should be enabled
         m = Prophet()
         m = Prophet()
         self.assertEqual(m.yearly_seasonality, 'auto')
         self.assertEqual(m.yearly_seasonality, 'auto')
         m.fit(DATA)
         m.fit(DATA)
-        self.assertEqual(m.yearly_seasonality, True)
-        # Should be False due to too short history
+        self.assertIn('yearly', m.seasonalities)
+        self.assertEqual(m.seasonalities['yearly'], (365.25, 10))
+        # Should be disabled due to too short history
         N = 240
         N = 240
         train = DATA.head(N)
         train = DATA.head(N)
         m = Prophet()
         m = Prophet()
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.yearly_seasonality, False)
+        self.assertNotIn('yearly', m.seasonalities)
         m = Prophet(yearly_seasonality=True)
         m = Prophet(yearly_seasonality=True)
         m.fit(train)
         m.fit(train)
-        self.assertEqual(m.yearly_seasonality, True)
+        self.assertIn('yearly', m.seasonalities)
+        m = Prophet(yearly_seasonality=7)
+        m.fit(DATA)
+        self.assertEqual(m.seasonalities['yearly'], (365.25, 7))
 
 
 
 
 DATA = pd.read_csv(StringIO("""
 DATA = pd.read_csv(StringIO("""