Просмотр исходного кода

Custom prior scales for holidays Py

bletham 8 лет назад
Родитель
Сommit
a620a6c9f9
2 измененных файлов с 62 добавлено и 11 удалено
  1. 27 9
      python/fbprophet/forecaster.py
  2. 35 2
      python/fbprophet/tests/test_prophet.py

+ 27 - 9
python/fbprophet/forecaster.py

@@ -58,12 +58,14 @@ class Prophet(object):
     holidays: pd.DataFrame with columns holiday (string) and ds (date type)
         and optionally columns lower_window and upper_window which specify a
         range of days around the date to be included as holidays.
-        lower_window=-2 will include 2 days prior to the date as holidays.
+        lower_window=-2 will include 2 days prior to the date as holidays. Also
+        optionally can have a column prior_scale specifying the prior scale for
+        that holiday.
     seasonality_prior_scale: Parameter modulating the strength of the
         seasonality model. Larger values allow the model to fit larger seasonal
         fluctuations, smaller values dampen the seasonality.
     holidays_prior_scale: Parameter modulating the strength of the holiday
-        components model.
+        components model, unless overriden in the holidays input.
     changepoint_prior_scale: Parameter modulating the flexibility of the
         automatic changepoint selection. Large values will allow many
         changepoints, small values will allow few changepoints.
@@ -376,10 +378,12 @@ class Prophet(object):
 
         Returns
         -------
-        pd.DataFrame with a column for each holiday.
+        holiday_features: pd.DataFrame with a column for each holiday.
+        prior_scale_list: List of prior scales for each holiday column.
         """
         # Holds columns of our future matrix.
         expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0]))
+        prior_scales = {}
         # Makes an index so we can perform `get_loc` below.
         # Strip to just dates.
         row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date()))
@@ -392,6 +396,18 @@ class Prophet(object):
             except ValueError:
                 lw = 0
                 uw = 0
+            try:
+                ps = float(row.get('prior_scale', self.holidays_prior_scale))
+            except ValueError:
+                ps = float(self.holidays_prior_scale)
+            if (
+                row.holiday in prior_scales and prior_scales[row.holiday] != ps
+            ):
+                raise ValueError(
+                    'Holiday {} does not have consistent prior scale '
+                    'specification.'.format(row.holiday))
+            prior_scales[row.holiday] = ps
+                
             for offset in range(lw, uw + 1):
                 occurrence = dt + timedelta(days=offset)
                 try:
@@ -409,9 +425,12 @@ class Prophet(object):
                 else:
                     # Access key to generate value
                     expanded_holidays[key]
-
-        # This relies pretty importantly on pandas keeping the columns in order.
-        return pd.DataFrame(expanded_holidays)
+        holiday_features = pd.DataFrame(expanded_holidays)
+        prior_scale_list = [
+            prior_scales[h.split('_delim_')[0]]
+            for h in holiday_features.columns
+        ]
+        return holiday_features, prior_scale_list
 
     def add_regressor(self, name, prior_scale=None, standardize='auto'):
         """Add an additional regressor to be used for fitting and predicting.
@@ -510,10 +529,9 @@ class Prophet(object):
 
         # Holiday features
         if self.holidays is not None:
-            features = self.make_holiday_features(df['ds'])
+            features, holiday_priors = self.make_holiday_features(df['ds'])
             seasonal_features.append(features)
-            prior_scales.extend(
-                [self.holidays_prior_scale] * features.shape[1])
+            prior_scales.extend(holiday_priors)
 
         # Additional regressors
         for name, props in self.extra_regressors.items():

+ 35 - 2
python/fbprophet/tests/test_prophet.py

@@ -242,10 +242,11 @@ class TestProphet(TestCase):
         df = pd.DataFrame({
             'ds': pd.date_range('2016-12-20', '2016-12-31')
         })
-        feats = model.make_holiday_features(df['ds'])
+        feats, priors = model.make_holiday_features(df['ds'])
         # 11 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 2))
         self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
+        self.assertEqual(priors, [10., 10.])  # Default prior
 
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25']),
@@ -253,9 +254,41 @@ class TestProphet(TestCase):
             'lower_window': [-1],
             'upper_window': [10],
         })
-        feats = Prophet(holidays=holidays).make_holiday_features(df['ds'])
+        feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
         # 12 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 12))
+        self.assertEqual(priors, list(10. * np.ones(12)))
+        # Check prior specifications
+        holidays = pd.DataFrame({
+            'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
+            'holiday': ['xmas', 'xmas'],
+            'lower_window': [-1, -1],
+            'upper_window': [0, 0],
+            'prior_scale': [5., 5.],
+        })
+        feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
+        self.assertEqual(priors, [5., 5.])
+        # 2 different priors
+        holidays2 = pd.DataFrame({
+            'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
+            'holiday': ['seans-bday'] * 2,
+            'lower_window': [0] * 2,
+            'upper_window': [1] * 2,
+            'prior_scale': [8] * 2,
+        })
+        holidays2 = pd.concat((holidays, holidays2))
+        feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
+        self.assertEqual(sum(priors), 26)
+        # Check incompatible priors
+        holidays = pd.DataFrame({
+            'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
+            'holiday': ['xmas', 'xmas'],
+            'lower_window': [-1, -1],
+            'upper_window': [0, 0],
+            'prior_scale': [5., 6.],
+        })
+        with self.assertRaises(ValueError):
+            Prophet(holidays=holidays).make_holiday_features(df['ds'])
 
     def test_fit_with_holidays(self):
         holidays = pd.DataFrame({