ソースを参照

Custom prior scales for holidays Py

bletham 8 年 前
コミット
a620a6c9f9
2 ファイル変更62 行追加11 行削除
  1. 27 9
      python/fbprophet/forecaster.py
  2. 35 2
      python/fbprophet/tests/test_prophet.py

+ 27 - 9
python/fbprophet/forecaster.py

@@ -58,12 +58,14 @@ class Prophet(object):
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
         and optionally columns lower_window and upper_window which specify a
         and optionally columns lower_window and upper_window which specify a
         range of days around the date to be included as holidays.
         range of days around the date to be included as holidays.
-        lower_window=-2 will include 2 days prior to the date as holidays.
+        lower_window=-2 will include 2 days prior to the date as holidays. Also
+        optionally can have a column prior_scale specifying the prior scale for
+        that holiday.
     seasonality_prior_scale: Parameter modulating the strength of the
     seasonality_prior_scale: Parameter modulating the strength of the
         seasonality model. Larger values allow the model to fit larger seasonal
         seasonality model. Larger values allow the model to fit larger seasonal
         fluctuations, smaller values dampen the seasonality.
         fluctuations, smaller values dampen the seasonality.
     holidays_prior_scale: Parameter modulating the strength of the holiday
     holidays_prior_scale: Parameter modulating the strength of the holiday
-        components model.
+        components model, unless overriden in the holidays input.
     changepoint_prior_scale: Parameter modulating the flexibility of the
     changepoint_prior_scale: Parameter modulating the flexibility of the
         automatic changepoint selection. Large values will allow many
         automatic changepoint selection. Large values will allow many
         changepoints, small values will allow few changepoints.
         changepoints, small values will allow few changepoints.
@@ -376,10 +378,12 @@ class Prophet(object):
 
 
         Returns
         Returns
         -------
         -------
-        pd.DataFrame with a column for each holiday.
+        holiday_features: pd.DataFrame with a column for each holiday.
+        prior_scale_list: List of prior scales for each holiday column.
         """
         """
         # Holds columns of our future matrix.
         # Holds columns of our future matrix.
         expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0]))
         expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0]))
+        prior_scales = {}
         # Makes an index so we can perform `get_loc` below.
         # Makes an index so we can perform `get_loc` below.
         # Strip to just dates.
         # Strip to just dates.
         row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date()))
         row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date()))
@@ -392,6 +396,18 @@ class Prophet(object):
             except ValueError:
             except ValueError:
                 lw = 0
                 lw = 0
                 uw = 0
                 uw = 0
+            try:
+                ps = float(row.get('prior_scale', self.holidays_prior_scale))
+            except ValueError:
+                ps = float(self.holidays_prior_scale)
+            if (
+                row.holiday in prior_scales and prior_scales[row.holiday] != ps
+            ):
+                raise ValueError(
+                    'Holiday {} does not have consistent prior scale '
+                    'specification.'.format(row.holiday))
+            prior_scales[row.holiday] = ps
+                
             for offset in range(lw, uw + 1):
             for offset in range(lw, uw + 1):
                 occurrence = dt + timedelta(days=offset)
                 occurrence = dt + timedelta(days=offset)
                 try:
                 try:
@@ -409,9 +425,12 @@ class Prophet(object):
                 else:
                 else:
                     # Access key to generate value
                     # Access key to generate value
                     expanded_holidays[key]
                     expanded_holidays[key]
-
-        # This relies pretty importantly on pandas keeping the columns in order.
-        return pd.DataFrame(expanded_holidays)
+        holiday_features = pd.DataFrame(expanded_holidays)
+        prior_scale_list = [
+            prior_scales[h.split('_delim_')[0]]
+            for h in holiday_features.columns
+        ]
+        return holiday_features, prior_scale_list
 
 
     def add_regressor(self, name, prior_scale=None, standardize='auto'):
     def add_regressor(self, name, prior_scale=None, standardize='auto'):
         """Add an additional regressor to be used for fitting and predicting.
         """Add an additional regressor to be used for fitting and predicting.
@@ -510,10 +529,9 @@ class Prophet(object):
 
 
         # Holiday features
         # Holiday features
         if self.holidays is not None:
         if self.holidays is not None:
-            features = self.make_holiday_features(df['ds'])
+            features, holiday_priors = self.make_holiday_features(df['ds'])
             seasonal_features.append(features)
             seasonal_features.append(features)
-            prior_scales.extend(
-                [self.holidays_prior_scale] * features.shape[1])
+            prior_scales.extend(holiday_priors)
 
 
         # Additional regressors
         # Additional regressors
         for name, props in self.extra_regressors.items():
         for name, props in self.extra_regressors.items():

+ 35 - 2
python/fbprophet/tests/test_prophet.py

@@ -242,10 +242,11 @@ class TestProphet(TestCase):
         df = pd.DataFrame({
         df = pd.DataFrame({
             'ds': pd.date_range('2016-12-20', '2016-12-31')
             'ds': pd.date_range('2016-12-20', '2016-12-31')
         })
         })
-        feats = model.make_holiday_features(df['ds'])
+        feats, priors = model.make_holiday_features(df['ds'])
         # 11 columns generated even though only 8 overlap
         # 11 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 2))
         self.assertEqual(feats.shape, (df.shape[0], 2))
         self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
         self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
+        self.assertEqual(priors, [10., 10.])  # Default prior
 
 
         holidays = pd.DataFrame({
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25']),
             'ds': pd.to_datetime(['2016-12-25']),
@@ -253,9 +254,41 @@ class TestProphet(TestCase):
             'lower_window': [-1],
             'lower_window': [-1],
             'upper_window': [10],
             'upper_window': [10],
         })
         })
-        feats = Prophet(holidays=holidays).make_holiday_features(df['ds'])
+        feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
         # 12 columns generated even though only 8 overlap
         # 12 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 12))
         self.assertEqual(feats.shape, (df.shape[0], 12))
+        self.assertEqual(priors, list(10. * np.ones(12)))
+        # Check prior specifications
+        holidays = pd.DataFrame({
+            'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
+            'holiday': ['xmas', 'xmas'],
+            'lower_window': [-1, -1],
+            'upper_window': [0, 0],
+            'prior_scale': [5., 5.],
+        })
+        feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
+        self.assertEqual(priors, [5., 5.])
+        # 2 different priors
+        holidays2 = pd.DataFrame({
+            'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
+            'holiday': ['seans-bday'] * 2,
+            'lower_window': [0] * 2,
+            'upper_window': [1] * 2,
+            'prior_scale': [8] * 2,
+        })
+        holidays2 = pd.concat((holidays, holidays2))
+        feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
+        self.assertEqual(sum(priors), 26)
+        # Check incompatible priors
+        holidays = pd.DataFrame({
+            'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
+            'holiday': ['xmas', 'xmas'],
+            'lower_window': [-1, -1],
+            'upper_window': [0, 0],
+            'prior_scale': [5., 6.],
+        })
+        with self.assertRaises(ValueError):
+            Prophet(holidays=holidays).make_holiday_features(df['ds'])
 
 
     def test_fit_with_holidays(self):
     def test_fit_with_holidays(self):
         holidays = pd.DataFrame({
         holidays = pd.DataFrame({