Преглед изворни кода

Move built-in country holidays to a method

Ben Letham пре 6 година
родитељ
комит
92f955d25a
2 измењених фајлова са 104 додато и 54 уклоњено
  1. 90 40
      python/fbprophet/forecaster.py
  2. 14 14
      python/fbprophet/tests/test_prophet.py

+ 90 - 40
python/fbprophet/forecaster.py

@@ -68,7 +68,6 @@ class Prophet(object):
         lower_window=-2 will include 2 days prior to the date as holidays. Also
         optionally can have a column prior_scale specifying the prior scale for
         that holiday.
-    append_holidays: country name or abbreviation; must be string
     seasonality_mode: 'additive' (default) or 'multiplicative'.
     seasonality_prior_scale: Parameter modulating the strength of the
         seasonality model. Larger values allow the model to fit larger seasonal
@@ -101,7 +100,6 @@ class Prophet(object):
             weekly_seasonality='auto',
             daily_seasonality='auto',
             holidays=None,
-            append_holidays=None,
             seasonality_mode='additive',
             seasonality_prior_scale=10.0,
             holidays_prior_scale=10.0,
@@ -136,13 +134,6 @@ class Prophet(object):
             holidays['ds'] = pd.to_datetime(holidays['ds'])
         self.holidays = holidays
 
-        if append_holidays is not None:
-            if not (
-                    isinstance(append_holidays, str)
-            ):
-                raise ValueError("append_holidays must be a string")
-        self.append_holidays = append_holidays
-
         self.seasonality_mode = seasonality_mode
         self.seasonality_prior_scale = float(seasonality_prior_scale)
         self.changepoint_prior_scale = float(changepoint_prior_scale)
@@ -152,7 +143,7 @@ class Prophet(object):
         self.interval_width = interval_width
         self.uncertainty_samples = uncertainty_samples
 
-        # Set during fitting
+        # Set during fitting or by other methods
         self.start = None
         self.y_scale = None
         self.logistic_floor = False
@@ -160,6 +151,7 @@ class Prophet(object):
         self.changepoints_t = None
         self.seasonalities = {}
         self.extra_regressors = OrderedDict({})
+        self.country_holidays = None
         self.stan_fit = None
         self.params = {}
         self.history = None
@@ -224,10 +216,10 @@ class Prophet(object):
                 name in self.holidays['holiday'].unique()):
             raise ValueError(
                 'Name "{}" already used for a holiday.'.format(name))
-        if (check_holidays and self.append_holidays is not None and
-                name in get_holiday_names(self.append_holidays)):
+        if (check_holidays and self.country_holidays is not None and
+                name in get_holiday_names(self.country_holidays)):
             raise ValueError(
-                'Name "{}" is a holiday name in {}.'.format(name, self.append_holidays))
+                'Name "{}" is a holiday name in {}.'.format(name, self.country_holidays))
         if check_seasonalities and name in self.seasonalities:
             raise ValueError(
                 'Name "{}" already used for a seasonality.'.format(name))
@@ -430,45 +422,64 @@ class Prophet(object):
         ]
         return pd.DataFrame(features, columns=columns)
 
-    def make_holiday_features(self, dates):
-        """Construct a dataframe of holiday features.
-
+    def construct_holiday_dataframe(self, dates):
+        """Construct a dataframe of holiday dates.
+        
+        Will combine self.holidays with the built-in country holidays
+        corresponding to input dates, if self.country_holidays is set.
+        
         Parameters
         ----------
         dates: pd.Series containing timestamps used for computing seasonality.
-
+        
         Returns
         -------
-        holiday_features: pd.DataFrame with a column for each holiday.
-        prior_scale_list: List of prior scales for each holiday column.
-        holiday_names: List of names of holidays
         """
-        # Concatenate holidays and append_holidays
-        all_holidays = self.holidays
-        if self.append_holidays is not None:
+        all_holidays = pd.DataFrame()
+        if self.holidays is not None:
+            all_holidays = pd.concat((all_holidays, self.holidays))
+        if self.country_holidays is not None:
             year_list = list({x.year for x in dates})
-            append_holidays_df = make_holidays_df(
-                                    year_list=year_list,
-                                    country=self.append_holidays)
-            all_holidays = pd.concat((all_holidays, append_holidays_df), sort=False)
+            country_holidays_df = make_holidays_df(
+                year_list=year_list, country=self.country_holidays
+            )
+            all_holidays = pd.concat((all_holidays, country_holidays_df), sort=False)
             all_holidays.reset_index(drop=True, inplace=True)
-        # Make fit and predict holidays components match
+        # If the model has already been fit with a certain set of holidays,
+        # make sure we are using those same ones.
         if self.train_holiday_names is not None:
-            train_holidays = self.train_holiday_names
             # Remove holiday names didn't show up in fit
             index_to_drop = all_holidays.index[
-                                np.logical_not(
-                                    all_holidays.holiday.isin(train_holidays))]
+                np.logical_not(
+                    all_holidays.holiday.isin(self.train_holiday_names)
+                )
+            ]
             all_holidays = all_holidays.drop(index_to_drop)
-            # Add holiday names show up in fit but not in predict with ds as NA
-            holidays_to_add = pd.DataFrame(
-                                {'holiday':
-                                    train_holidays[
-                                        np.logical_not(
-                                            train_holidays.isin(all_holidays.holiday))]})
+            # Add holiday names in fit but not in predict with ds as NA
+            holidays_to_add = pd.DataFrame({
+                'holiday': self.train_holiday_names[
+                    np.logical_not(self.train_holiday_names.isin(all_holidays.holiday))
+                ]
+            })
             all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
             all_holidays.reset_index(drop=True, inplace=True)
+        return all_holidays
+        
+    def make_holiday_features(self, dates, holidays):
+        """Construct a dataframe of holiday features.
+
+        Parameters
+        ----------
+        dates: pd.Series containing timestamps used for computing seasonality.
+        holidays: pd.Dataframe containing holidays, as returned by
+            construct_holiday_dataframe.
 
+        Returns
+        -------
+        holiday_features: pd.DataFrame with a column for each holiday.
+        prior_scale_list: List of prior scales for each holiday column.
+        holiday_names: List of names of holidays
+        """
         # Holds columns of our future matrix.
         expanded_holidays = defaultdict(lambda: np.zeros(dates.shape[0]))
         prior_scales = {}
@@ -476,7 +487,7 @@ class Prophet(object):
         # Strip to just dates.
         row_index = pd.DatetimeIndex(dates.apply(lambda x: x.date()))
 
-        for _ix, row in all_holidays.iterrows():
+        for _ix, row in holidays.iterrows():
             dt = row.ds.date()
             try:
                 lw = int(row.get('lower_window', 0))
@@ -635,6 +646,44 @@ class Prophet(object):
         }
         return self
 
+    def add_country_holidays(self, country_name):
+        """Add in built-in holidays for the specified country.
+
+        These holidays will be included in addition to any specified on model
+        initialization.
+
+        Holidays will be calculated for arbitrary date ranges in the history
+        and future. See the online documentation for the list of countries with
+        built-in holidays.
+
+        Built-in country holidays can only be set for a single country.
+
+        Parameters
+        ----------
+        country_name: Name of the country, like 'UnitedStates' or 'US'
+
+        Returns
+        -------
+        The prophet object.
+        """
+        if self.history is not None:
+            raise Exception(
+                "Country holidays must be added prior to model fitting."
+            )
+        # Validate names.
+        for name in get_holiday_names(country_name):
+            # Allow merging with existing holidays
+            self.validate_column_name(name, check_holidays=False)
+        # Set the holidays.
+        if self.country_holidays is not None:
+            logger.warning(
+                'Changing country holidays from {} to {}'.format(
+                    self.country_holidays, country_name
+                )
+            )
+        self.country_holidays = country_name
+        return self
+
     def make_all_seasonality_features(self, df):
         """Dataframe with seasonality features.
 
@@ -672,9 +721,10 @@ class Prophet(object):
             modes[props['mode']].append(name)
 
         # Holiday features
-        if self.holidays is not None or self.append_holidays is not None:
+        holidays = self.construct_holiday_dataframe(df['ds'])
+        if len(holidays) > 0:
             features, holiday_priors, holiday_names = (
-                self.make_holiday_features(df['ds'])
+                self.make_holiday_features(df['ds'], holidays)
             )
             seasonal_features.append(features)
             prior_scales.extend(holiday_priors)

+ 14 - 14
python/fbprophet/tests/test_prophet.py

@@ -281,7 +281,7 @@ class TestProphet(TestCase):
         df = pd.DataFrame({
             'ds': pd.date_range('2016-12-20', '2016-12-31')
         })
-        feats, priors, names = model.make_holiday_features(df['ds'])
+        feats, priors, names = model.make_holiday_features(df['ds'], model.holidays)
         # 11 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 2))
         self.assertEqual((feats.sum(0) - np.array([1.0, 1.0])).sum(), 0)
@@ -295,7 +295,7 @@ class TestProphet(TestCase):
             'upper_window': [10],
         })
         m = Prophet(holidays=holidays)
-        feats, priors, names = m.make_holiday_features(df['ds'])
+        feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
         # 12 columns generated even though only 8 overlap
         self.assertEqual(feats.shape, (df.shape[0], 12))
         self.assertEqual(priors, list(10. * np.ones(12)))
@@ -309,7 +309,7 @@ class TestProphet(TestCase):
             'prior_scale': [5., 5.],
         })
         m = Prophet(holidays=holidays)
-        feats, priors, names = m.make_holiday_features(df['ds'])
+        feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
         self.assertEqual(priors, [5., 5.])
         self.assertEqual(names, ['xmas'])
         # 2 different priors
@@ -322,7 +322,7 @@ class TestProphet(TestCase):
         })
         holidays2 = pd.concat((holidays, holidays2))
         m = Prophet(holidays=holidays2)
-        feats, priors, names = m.make_holiday_features(df['ds'])
+        feats, priors, names = m.make_holiday_features(df['ds'], m.holidays)
         pn = zip(priors, [s.split('_delim_')[0] for s in feats.columns])
         for t in pn:
             self.assertIn(t, [(8., 'seans-bday'), (5., 'xmas')])
@@ -335,7 +335,7 @@ class TestProphet(TestCase):
         holidays2 = pd.concat((holidays, holidays2))
         feats, priors, names = Prophet(
             holidays=holidays2, holidays_prior_scale=4
-        ).make_holiday_features(df['ds'])
+        ).make_holiday_features(df['ds'], holidays2)
         self.assertEqual(set(priors), {4., 5.})
         # Check incompatible priors
         holidays = pd.DataFrame({
@@ -346,7 +346,7 @@ class TestProphet(TestCase):
             'prior_scale': [5., 6.],
         })
         with self.assertRaises(ValueError):
-            Prophet(holidays=holidays).make_holiday_features(df['ds'])
+            Prophet(holidays=holidays).make_holiday_features(df['ds'], holidays)
 
     def test_fit_with_holidays(self):
         holidays = pd.DataFrame({
@@ -358,28 +358,28 @@ class TestProphet(TestCase):
         model = Prophet(holidays=holidays, uncertainty_samples=0)
         model.fit(DATA).predict()
 
-    def test_fit_predict_with_append_holidays(self):
+    def test_fit_predict_with_country_holidays(self):
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
             'holiday': ['seans-bday'] * 2,
             'lower_window': [0] * 2,
             'upper_window': [1] * 2,
         })
-        append_holidays = 'US'
-        # Test with holidays and append_holidays
-        model = Prophet(holidays=holidays,
-                        append_holidays=append_holidays,
-                        uncertainty_samples=0)
+        # Test with holidays and country_holidays
+        model = Prophet(holidays=holidays, uncertainty_samples=0)
+        model.add_country_holidays(country_name='US')
         model.fit(DATA).predict()
         # There are training holidays missing in the test set
         train = DATA.head(154)
         future = DATA.tail(355)
-        model = Prophet(append_holidays=append_holidays, uncertainty_samples=0)
+        model = Prophet(uncertainty_samples=0)
+        model.add_country_holidays(country_name='US')
         model.fit(train).predict(future)
         # There are test holidays missing in the training set
         train = DATA.tail(355)
         future = DATA2
-        model = Prophet(append_holidays=append_holidays, uncertainty_samples=0)
+        model = Prophet(uncertainty_samples=0)
+        model.add_country_holidays(country_name='US')
         model.fit(train).predict(future)
 
     def test_make_future_dataframe(self):