ソースを参照

Custom seasonality priors Py

bl 8 年 前
コミット
8d27643339
2 ファイル変更77 行追加25 行削除
  1. 44 13
      python/fbprophet/forecaster.py
  2. 33 12
      python/fbprophet/tests/test_prophet.py

+ 44 - 13
python/fbprophet/forecaster.py

@@ -63,7 +63,8 @@ class Prophet(object):
         that holiday.
         that holiday.
     seasonality_prior_scale: Parameter modulating the strength of the
     seasonality_prior_scale: Parameter modulating the strength of the
         seasonality model. Larger values allow the model to fit larger seasonal
         seasonality model. Larger values allow the model to fit larger seasonal
-        fluctuations, smaller values dampen the seasonality.
+        fluctuations, smaller values dampen the seasonality. Can be specified
+        for individual seasonalities using add_seasonality.
     holidays_prior_scale: Parameter modulating the strength of the holiday
     holidays_prior_scale: Parameter modulating the strength of the holiday
         components model, unless overriden in the holidays input.
         components model, unless overriden in the holidays input.
     changepoint_prior_scale: Parameter modulating the flexibility of the
     changepoint_prior_scale: Parameter modulating the flexibility of the
@@ -406,6 +407,8 @@ class Prophet(object):
                 raise ValueError(
                 raise ValueError(
                     'Holiday {} does not have consistent prior scale '
                     'Holiday {} does not have consistent prior scale '
                     'specification.'.format(row.holiday))
                     'specification.'.format(row.holiday))
+            if ps <= 0:
+                raise ValueError('Prior scale must be > 0')
             prior_scales[row.holiday] = ps
             prior_scales[row.holiday] = ps
                 
                 
             for offset in range(lw, uw + 1):
             for offset in range(lw, uw + 1):
@@ -470,19 +473,25 @@ class Prophet(object):
         }
         }
         return self
         return self
 
 
-    def add_seasonality(self, name, period, fourier_order):
-        """Add a seasonal component with specified period and number of Fourier
-        components.
+    def add_seasonality(self, name, period, fourier_order, prior_scale=None):
+        """Add a seasonal component with specified period, number of Fourier
+        components, and prior scale.
 
 
         Increasing the number of Fourier components allows the seasonality to
         Increasing the number of Fourier components allows the seasonality to
         change more quickly (at risk of overfitting). Default values for yearly
         change more quickly (at risk of overfitting). Default values for yearly
         and weekly seasonalities are 10 and 3 respectively.
         and weekly seasonalities are 10 and 3 respectively.
 
 
+        Increasing prior scale will allow this seasonality component more
+        flexibility, decreasing will dampen it. If not provided, will use the
+        seasonality_prior_scale provided on Prophet initialization (defaults
+        to 10).
+
         Parameters
         Parameters
         ----------
         ----------
         name: string name of the seasonality component.
         name: string name of the seasonality component.
         period: float number of days in one period.
         period: float number of days in one period.
         fourier_order: int number of Fourier components to use.
         fourier_order: int number of Fourier components to use.
+        prior_scale: float prior scale for this component.
 
 
         Returns
         Returns
         -------
         -------
@@ -494,7 +503,17 @@ class Prophet(object):
         if name not in ['daily', 'weekly', 'yearly']:
         if name not in ['daily', 'weekly', 'yearly']:
             # Allow overwriting built-in seasonalities
             # Allow overwriting built-in seasonalities
             self.validate_column_name(name, check_seasonalities=False)
             self.validate_column_name(name, check_seasonalities=False)
-        self.seasonalities[name] = (period, fourier_order)
+        if prior_scale is None:
+            ps = self.seasonality_prior_scale
+        else:
+            ps = float(prior_scale)
+        if ps <= 0:
+            raise ValueError('Prior scale must be > 0')
+        self.seasonalities[name] = {
+            'period': period,
+            'fourier_order': fourier_order,
+            'prior_scale': ps,
+        }
         return self
         return self
 
 
     def make_all_seasonality_features(self, df):
     def make_all_seasonality_features(self, df):
@@ -516,16 +535,16 @@ class Prophet(object):
         prior_scales = []
         prior_scales = []
 
 
         # Seasonality features
         # Seasonality features
-        for name, (period, series_order) in self.seasonalities.items():
+        for name, props in self.seasonalities.items():
             features = self.make_seasonality_features(
             features = self.make_seasonality_features(
                 df['ds'],
                 df['ds'],
-                period,
-                series_order,
+                props['period'],
+                props['fourier_order'],
                 name,
                 name,
             )
             )
             seasonal_features.append(features)
             seasonal_features.append(features)
             prior_scales.extend(
             prior_scales.extend(
-                [self.seasonality_prior_scale] * features.shape[1])
+                [props['prior_scale']] * features.shape[1])
 
 
         # Holiday features
         # Holiday features
         if self.holidays is not None:
         if self.holidays is not None:
@@ -600,7 +619,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
         fourier_order = self.parse_seasonality_args(
             'yearly', self.yearly_seasonality, yearly_disable, 10)
             'yearly', self.yearly_seasonality, yearly_disable, 10)
         if fourier_order > 0:
         if fourier_order > 0:
-            self.seasonalities['yearly'] = (365.25, fourier_order)
+            self.seasonalities['yearly'] = {
+                'period': 365.25,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
 
         # Weekly seasonality
         # Weekly seasonality
         weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
         weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
@@ -608,7 +631,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
         fourier_order = self.parse_seasonality_args(
             'weekly', self.weekly_seasonality, weekly_disable, 3)
             'weekly', self.weekly_seasonality, weekly_disable, 3)
         if fourier_order > 0:
         if fourier_order > 0:
-            self.seasonalities['weekly'] = (7, fourier_order)
+            self.seasonalities['weekly'] = {
+                'period': 7,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
 
         # Daily seasonality
         # Daily seasonality
         daily_disable = ((last - first < pd.Timedelta(days=2)) or
         daily_disable = ((last - first < pd.Timedelta(days=2)) or
@@ -616,7 +643,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
         fourier_order = self.parse_seasonality_args(
             'daily', self.daily_seasonality, daily_disable, 4)
             'daily', self.daily_seasonality, daily_disable, 4)
         if fourier_order > 0:
         if fourier_order > 0:
-            self.seasonalities['daily'] = (1, fourier_order)
+            self.seasonalities['daily'] = {
+                'period': 1,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
 
     @staticmethod
     @staticmethod
     def linear_growth_init(df):
     def linear_growth_init(df):
@@ -1407,7 +1438,7 @@ class Prophet(object):
             ax = fig.add_subplot(111)
             ax = fig.add_subplot(111)
         # Compute seasonality from Jan 1 through a single period.
         # Compute seasonality from Jan 1 through a single period.
         start = pd.to_datetime('2017-01-01 0000')
         start = pd.to_datetime('2017-01-01 0000')
-        period = self.seasonalities[name][0]
+        period = self.seasonalities[name]['period']
         end = start + pd.Timedelta(days=period)
         end = start + pd.Timedelta(days=period)
         plot_points = 200
         plot_points = 200
         days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
         days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))

+ 33 - 12
python/fbprophet/tests/test_prophet.py

@@ -278,7 +278,7 @@ class TestProphet(TestCase):
         })
         })
         holidays2 = pd.concat((holidays, holidays2))
         holidays2 = pd.concat((holidays, holidays2))
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
-        self.assertEqual(sum(priors), 26)
+        self.assertEqual(priors, [8., 8., 5., 5.])
         # Check incompatible priors
         # Check incompatible priors
         holidays = pd.DataFrame({
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
             'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
@@ -327,7 +327,8 @@ class TestProphet(TestCase):
         self.assertEqual(m.weekly_seasonality, 'auto')
         self.assertEqual(m.weekly_seasonality, 'auto')
         m.fit(train)
         m.fit(train)
         self.assertIn('weekly', m.seasonalities)
         self.assertIn('weekly', m.seasonalities)
-        self.assertEqual(m.seasonalities['weekly'], (7, 3))
+        self.assertEqual(m.seasonalities['weekly'],
+                         {'period': 7, 'fourier_order': 3, 'prior_scale': 10.})
         # Should be disabled due to too short history
         # Should be disabled due to too short history
         N = 9
         N = 9
         train = DATA.head(N)
         train = DATA.head(N)
@@ -342,9 +343,10 @@ class TestProphet(TestCase):
         m = Prophet()
         m = Prophet()
         m.fit(train)
         m.fit(train)
         self.assertNotIn('weekly', m.seasonalities)
         self.assertNotIn('weekly', m.seasonalities)
-        m = Prophet(weekly_seasonality=2)
+        m = Prophet(weekly_seasonality=2, seasonality_prior_scale=3.)
         m.fit(DATA)
         m.fit(DATA)
-        self.assertEqual(m.seasonalities['weekly'], (7, 2))
+        self.assertEqual(m.seasonalities['weekly'],
+                         {'period': 7, 'fourier_order': 2, 'prior_scale': 3.})
 
 
     def test_auto_yearly_seasonality(self):
     def test_auto_yearly_seasonality(self):
         # Should be enabled
         # Should be enabled
@@ -352,7 +354,10 @@ class TestProphet(TestCase):
         self.assertEqual(m.yearly_seasonality, 'auto')
         self.assertEqual(m.yearly_seasonality, 'auto')
         m.fit(DATA)
         m.fit(DATA)
         self.assertIn('yearly', m.seasonalities)
         self.assertIn('yearly', m.seasonalities)
-        self.assertEqual(m.seasonalities['yearly'], (365.25, 10))
+        self.assertEqual(
+            m.seasonalities['yearly'],
+            {'period': 365.25, 'fourier_order': 10, 'prior_scale': 10.},
+        )
         # Should be disabled due to too short history
         # Should be disabled due to too short history
         N = 240
         N = 240
         train = DATA.head(N)
         train = DATA.head(N)
@@ -362,9 +367,12 @@ class TestProphet(TestCase):
         m = Prophet(yearly_seasonality=True)
         m = Prophet(yearly_seasonality=True)
         m.fit(train)
         m.fit(train)
         self.assertIn('yearly', m.seasonalities)
         self.assertIn('yearly', m.seasonalities)
-        m = Prophet(yearly_seasonality=7)
+        m = Prophet(yearly_seasonality=7, seasonality_prior_scale=3.)
         m.fit(DATA)
         m.fit(DATA)
-        self.assertEqual(m.seasonalities['yearly'], (365.25, 7))
+        self.assertEqual(
+            m.seasonalities['yearly'],
+            {'period': 365.25, 'fourier_order': 7, 'prior_scale': 3.},
+        )
 
 
     def test_auto_daily_seasonality(self):
     def test_auto_daily_seasonality(self):
         # Should be enabled
         # Should be enabled
@@ -372,7 +380,8 @@ class TestProphet(TestCase):
         self.assertEqual(m.daily_seasonality, 'auto')
         self.assertEqual(m.daily_seasonality, 'auto')
         m.fit(DATA2)
         m.fit(DATA2)
         self.assertIn('daily', m.seasonalities)
         self.assertIn('daily', m.seasonalities)
-        self.assertEqual(m.seasonalities['daily'], (1, 4))
+        self.assertEqual(m.seasonalities['daily'],
+                         {'period': 1, 'fourier_order': 4, 'prior_scale': 10.})
         # Should be disabled due to too short history
         # Should be disabled due to too short history
         N = 430
         N = 430
         train = DATA2.head(N)
         train = DATA2.head(N)
@@ -382,9 +391,10 @@ class TestProphet(TestCase):
         m = Prophet(daily_seasonality=True)
         m = Prophet(daily_seasonality=True)
         m.fit(train)
         m.fit(train)
         self.assertIn('daily', m.seasonalities)
         self.assertIn('daily', m.seasonalities)
-        m = Prophet(daily_seasonality=7)
+        m = Prophet(daily_seasonality=7, seasonality_prior_scale=3.)
         m.fit(DATA2)
         m.fit(DATA2)
-        self.assertEqual(m.seasonalities['daily'], (1, 7))
+        self.assertEqual(m.seasonalities['daily'],
+                         {'period': 1, 'fourier_order': 7, 'prior_scale': 3.})
         m = Prophet()
         m = Prophet()
         m.fit(DATA)
         m.fit(DATA)
         self.assertNotIn('daily', m.seasonalities)
         self.assertNotIn('daily', m.seasonalities)
@@ -403,15 +413,26 @@ class TestProphet(TestCase):
         holidays = pd.DataFrame({
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2017-01-02']),
             'ds': pd.to_datetime(['2017-01-02']),
             'holiday': ['special_day'],
             'holiday': ['special_day'],
+            'prior_scale': [4.],
         })
         })
         m = Prophet(holidays=holidays)
         m = Prophet(holidays=holidays)
-        m.add_seasonality(name='monthly', period=30, fourier_order=5)
-        self.assertEqual(m.seasonalities['monthly'], (30, 5))
+        m.add_seasonality(name='monthly', period=30, fourier_order=5,
+                          prior_scale=2.)
+        self.assertEqual(m.seasonalities['monthly'],
+                         {'period': 30, 'fourier_order': 5, 'prior_scale': 2.})
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             m.add_seasonality(name='special_day', period=30, fourier_order=5)
             m.add_seasonality(name='special_day', period=30, fourier_order=5)
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             m.add_seasonality(name='trend', period=30, fourier_order=5)
             m.add_seasonality(name='trend', period=30, fourier_order=5)
         m.add_seasonality(name='weekly', period=30, fourier_order=5)
         m.add_seasonality(name='weekly', period=30, fourier_order=5)
+        # Test priors
+        m = Prophet(holidays=holidays, yearly_seasonality=False)
+        m.add_seasonality(name='monthly', period=30, fourier_order=5,
+                          prior_scale=2.)
+        m.fit(DATA.copy())
+        seasonal_features, prior_scales = m.make_all_seasonality_features(
+            m.history)
+        self.assertEqual(prior_scales, [2.] * 10 + [10.] * 6 + [4.])
 
 
     def test_added_regressors(self):
     def test_added_regressors(self):
         m = Prophet()
         m = Prophet()