浏览代码

Custom seasonality priors Py

bl 8 年之前
父节点
当前提交
8d27643339
共有 2 个文件被更改,包括 77 次插入25 次删除
  1. 44 13
      python/fbprophet/forecaster.py
  2. 33 12
      python/fbprophet/tests/test_prophet.py

+ 44 - 13
python/fbprophet/forecaster.py

@@ -63,7 +63,8 @@ class Prophet(object):
         that holiday.
     seasonality_prior_scale: Parameter modulating the strength of the
         seasonality model. Larger values allow the model to fit larger seasonal
-        fluctuations, smaller values dampen the seasonality.
+        fluctuations, smaller values dampen the seasonality. Can be specified
+        for individual seasonalities using add_seasonality.
     holidays_prior_scale: Parameter modulating the strength of the holiday
         components model, unless overriden in the holidays input.
     changepoint_prior_scale: Parameter modulating the flexibility of the
@@ -406,6 +407,8 @@ class Prophet(object):
                 raise ValueError(
                     'Holiday {} does not have consistent prior scale '
                     'specification.'.format(row.holiday))
+            if ps <= 0:
+                raise ValueError('Prior scale must be > 0')
             prior_scales[row.holiday] = ps
                 
             for offset in range(lw, uw + 1):
@@ -470,19 +473,25 @@ class Prophet(object):
         }
         return self
 
-    def add_seasonality(self, name, period, fourier_order):
-        """Add a seasonal component with specified period and number of Fourier
-        components.
+    def add_seasonality(self, name, period, fourier_order, prior_scale=None):
+        """Add a seasonal component with specified period, number of Fourier
+        components, and prior scale.
 
         Increasing the number of Fourier components allows the seasonality to
         change more quickly (at risk of overfitting). Default values for yearly
         and weekly seasonalities are 10 and 3 respectively.
 
+        Increasing prior scale will allow this seasonality component more
+        flexibility, decreasing will dampen it. If not provided, will use the
+        seasonality_prior_scale provided on Prophet initialization (defaults
+        to 10).
+
         Parameters
         ----------
         name: string name of the seasonality component.
         period: float number of days in one period.
         fourier_order: int number of Fourier components to use.
+        prior_scale: float prior scale for this component.
 
         Returns
         -------
@@ -494,7 +503,17 @@ class Prophet(object):
         if name not in ['daily', 'weekly', 'yearly']:
             # Allow overwriting built-in seasonalities
             self.validate_column_name(name, check_seasonalities=False)
-        self.seasonalities[name] = (period, fourier_order)
+        if prior_scale is None:
+            ps = self.seasonality_prior_scale
+        else:
+            ps = float(prior_scale)
+        if ps <= 0:
+            raise ValueError('Prior scale must be > 0')
+        self.seasonalities[name] = {
+            'period': period,
+            'fourier_order': fourier_order,
+            'prior_scale': ps,
+        }
         return self
 
     def make_all_seasonality_features(self, df):
@@ -516,16 +535,16 @@ class Prophet(object):
         prior_scales = []
 
         # Seasonality features
-        for name, (period, series_order) in self.seasonalities.items():
+        for name, props in self.seasonalities.items():
             features = self.make_seasonality_features(
                 df['ds'],
-                period,
-                series_order,
+                props['period'],
+                props['fourier_order'],
                 name,
             )
             seasonal_features.append(features)
             prior_scales.extend(
-                [self.seasonality_prior_scale] * features.shape[1])
+                [props['prior_scale']] * features.shape[1])
 
         # Holiday features
         if self.holidays is not None:
@@ -600,7 +619,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
             'yearly', self.yearly_seasonality, yearly_disable, 10)
         if fourier_order > 0:
-            self.seasonalities['yearly'] = (365.25, fourier_order)
+            self.seasonalities['yearly'] = {
+                'period': 365.25,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
         # Weekly seasonality
         weekly_disable = ((last - first < pd.Timedelta(weeks=2)) or
@@ -608,7 +631,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
             'weekly', self.weekly_seasonality, weekly_disable, 3)
         if fourier_order > 0:
-            self.seasonalities['weekly'] = (7, fourier_order)
+            self.seasonalities['weekly'] = {
+                'period': 7,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
         # Daily seasonality
         daily_disable = ((last - first < pd.Timedelta(days=2)) or
@@ -616,7 +643,11 @@ class Prophet(object):
         fourier_order = self.parse_seasonality_args(
             'daily', self.daily_seasonality, daily_disable, 4)
         if fourier_order > 0:
-            self.seasonalities['daily'] = (1, fourier_order)
+            self.seasonalities['daily'] = {
+                'period': 1,
+                'fourier_order': fourier_order,
+                'prior_scale': self.seasonality_prior_scale,
+            }
 
     @staticmethod
     def linear_growth_init(df):
@@ -1407,7 +1438,7 @@ class Prophet(object):
             ax = fig.add_subplot(111)
         # Compute seasonality from Jan 1 through a single period.
         start = pd.to_datetime('2017-01-01 0000')
-        period = self.seasonalities[name][0]
+        period = self.seasonalities[name]['period']
         end = start + pd.Timedelta(days=period)
         plot_points = 200
         days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))

+ 33 - 12
python/fbprophet/tests/test_prophet.py

@@ -278,7 +278,7 @@ class TestProphet(TestCase):
         })
         holidays2 = pd.concat((holidays, holidays2))
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
-        self.assertEqual(sum(priors), 26)
+        self.assertEqual(priors, [8., 8., 5., 5.])
         # Check incompatible priors
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
@@ -327,7 +327,8 @@ class TestProphet(TestCase):
         self.assertEqual(m.weekly_seasonality, 'auto')
         m.fit(train)
         self.assertIn('weekly', m.seasonalities)
-        self.assertEqual(m.seasonalities['weekly'], (7, 3))
+        self.assertEqual(m.seasonalities['weekly'],
+                         {'period': 7, 'fourier_order': 3, 'prior_scale': 10.})
         # Should be disabled due to too short history
         N = 9
         train = DATA.head(N)
@@ -342,9 +343,10 @@ class TestProphet(TestCase):
         m = Prophet()
         m.fit(train)
         self.assertNotIn('weekly', m.seasonalities)
-        m = Prophet(weekly_seasonality=2)
+        m = Prophet(weekly_seasonality=2, seasonality_prior_scale=3.)
         m.fit(DATA)
-        self.assertEqual(m.seasonalities['weekly'], (7, 2))
+        self.assertEqual(m.seasonalities['weekly'],
+                         {'period': 7, 'fourier_order': 2, 'prior_scale': 3.})
 
     def test_auto_yearly_seasonality(self):
         # Should be enabled
@@ -352,7 +354,10 @@ class TestProphet(TestCase):
         self.assertEqual(m.yearly_seasonality, 'auto')
         m.fit(DATA)
         self.assertIn('yearly', m.seasonalities)
-        self.assertEqual(m.seasonalities['yearly'], (365.25, 10))
+        self.assertEqual(
+            m.seasonalities['yearly'],
+            {'period': 365.25, 'fourier_order': 10, 'prior_scale': 10.},
+        )
         # Should be disabled due to too short history
         N = 240
         train = DATA.head(N)
@@ -362,9 +367,12 @@ class TestProphet(TestCase):
         m = Prophet(yearly_seasonality=True)
         m.fit(train)
         self.assertIn('yearly', m.seasonalities)
-        m = Prophet(yearly_seasonality=7)
+        m = Prophet(yearly_seasonality=7, seasonality_prior_scale=3.)
         m.fit(DATA)
-        self.assertEqual(m.seasonalities['yearly'], (365.25, 7))
+        self.assertEqual(
+            m.seasonalities['yearly'],
+            {'period': 365.25, 'fourier_order': 7, 'prior_scale': 3.},
+        )
 
     def test_auto_daily_seasonality(self):
         # Should be enabled
@@ -372,7 +380,8 @@ class TestProphet(TestCase):
         self.assertEqual(m.daily_seasonality, 'auto')
         m.fit(DATA2)
         self.assertIn('daily', m.seasonalities)
-        self.assertEqual(m.seasonalities['daily'], (1, 4))
+        self.assertEqual(m.seasonalities['daily'],
+                         {'period': 1, 'fourier_order': 4, 'prior_scale': 10.})
         # Should be disabled due to too short history
         N = 430
         train = DATA2.head(N)
@@ -382,9 +391,10 @@ class TestProphet(TestCase):
         m = Prophet(daily_seasonality=True)
         m.fit(train)
         self.assertIn('daily', m.seasonalities)
-        m = Prophet(daily_seasonality=7)
+        m = Prophet(daily_seasonality=7, seasonality_prior_scale=3.)
         m.fit(DATA2)
-        self.assertEqual(m.seasonalities['daily'], (1, 7))
+        self.assertEqual(m.seasonalities['daily'],
+                         {'period': 1, 'fourier_order': 7, 'prior_scale': 3.})
         m = Prophet()
         m.fit(DATA)
         self.assertNotIn('daily', m.seasonalities)
@@ -403,15 +413,26 @@ class TestProphet(TestCase):
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2017-01-02']),
             'holiday': ['special_day'],
+            'prior_scale': [4.],
         })
         m = Prophet(holidays=holidays)
-        m.add_seasonality(name='monthly', period=30, fourier_order=5)
-        self.assertEqual(m.seasonalities['monthly'], (30, 5))
+        m.add_seasonality(name='monthly', period=30, fourier_order=5,
+                          prior_scale=2.)
+        self.assertEqual(m.seasonalities['monthly'],
+                         {'period': 30, 'fourier_order': 5, 'prior_scale': 2.})
         with self.assertRaises(ValueError):
             m.add_seasonality(name='special_day', period=30, fourier_order=5)
         with self.assertRaises(ValueError):
             m.add_seasonality(name='trend', period=30, fourier_order=5)
         m.add_seasonality(name='weekly', period=30, fourier_order=5)
+        # Test priors
+        m = Prophet(holidays=holidays, yearly_seasonality=False)
+        m.add_seasonality(name='monthly', period=30, fourier_order=5,
+                          prior_scale=2.)
+        m.fit(DATA.copy())
+        seasonal_features, prior_scales = m.make_all_seasonality_features(
+            m.history)
+        self.assertEqual(prior_scales, [2.] * 10 + [10.] * 6 + [4.])
 
     def test_added_regressors(self):
         m = Prophet()