|
@@ -242,10 +242,11 @@ class TestProphet(TestCase):
|
|
|
df = pd.DataFrame({
|
|
|
'ds': pd.date_range('2016-12-20', '2016-12-31')
|
|
|
})
|
|
|
- feats = model.make_holiday_features(df['ds'])
|
|
|
+ feats, priors = model.make_holiday_features(df['ds'])
|
|
|
# 11 columns generated even though only 8 overlap
|
|
|
self.assertEqual(feats.shape, (df.shape[0], 2))
|
|
|
self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
|
|
|
+ self.assertEqual(priors, [10., 10.]) # Default prior
|
|
|
|
|
|
holidays = pd.DataFrame({
|
|
|
'ds': pd.to_datetime(['2016-12-25']),
|
|
@@ -253,9 +254,41 @@ class TestProphet(TestCase):
|
|
|
'lower_window': [-1],
|
|
|
'upper_window': [10],
|
|
|
})
|
|
|
- feats = Prophet(holidays=holidays).make_holiday_features(df['ds'])
|
|
|
+ feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
|
|
|
# 12 columns generated even though only 8 overlap
|
|
|
self.assertEqual(feats.shape, (df.shape[0], 12))
|
|
|
+ self.assertEqual(priors, list(10. * np.ones(12)))
|
|
|
+ # Check prior specifications
|
|
|
+ holidays = pd.DataFrame({
|
|
|
+ 'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
|
|
|
+ 'holiday': ['xmas', 'xmas'],
|
|
|
+ 'lower_window': [-1, -1],
|
|
|
+ 'upper_window': [0, 0],
|
|
|
+ 'prior_scale': [5., 5.],
|
|
|
+ })
|
|
|
+ feats, priors = Prophet(holidays=holidays).make_holiday_features(df['ds'])
|
|
|
+ self.assertEqual(priors, [5., 5.])
|
|
|
+ # 2 different priors
|
|
|
+ holidays2 = pd.DataFrame({
|
|
|
+ 'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
|
|
|
+ 'holiday': ['seans-bday'] * 2,
|
|
|
+ 'lower_window': [0] * 2,
|
|
|
+ 'upper_window': [1] * 2,
|
|
|
+ 'prior_scale': [8] * 2,
|
|
|
+ })
|
|
|
+ holidays2 = pd.concat((holidays, holidays2))
|
|
|
+ feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
|
|
|
+ self.assertEqual(sum(priors), 26)
|
|
|
+ # Check incompatible priors
|
|
|
+ holidays = pd.DataFrame({
|
|
|
+ 'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
|
|
|
+ 'holiday': ['xmas', 'xmas'],
|
|
|
+ 'lower_window': [-1, -1],
|
|
|
+ 'upper_window': [0, 0],
|
|
|
+ 'prior_scale': [5., 6.],
|
|
|
+ })
|
|
|
+ with self.assertRaises(ValueError):
|
|
|
+ Prophet(holidays=holidays).make_holiday_features(df['ds'])
|
|
|
|
|
|
def test_fit_with_holidays(self):
|
|
|
holidays = pd.DataFrame({
|