Prechádzať zdrojové kódy

Fix copy with extra seasonalities / regressors Py

bl 8 rokov pred
rodič
commit
5dbffbaa18

+ 14 - 5
python/fbprophet/forecaster.py

@@ -11,6 +11,7 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from collections import defaultdict
+from copy import deepcopy
 from datetime import timedelta
 import logging
 
@@ -1515,6 +1516,9 @@ class Prophet(object):
         -------
         Prophet class object with the same parameter with model variable
         """
+        if self.history is None:
+            raise Exception('This is for copying a fitted Prophet object.')
+
         if self.specified_changepoints:
             changepoints = self.changepoints
             if cutoff is not None:
@@ -1523,18 +1527,23 @@ class Prophet(object):
         else:
             changepoints = None
 
-        return Prophet(
+        # Auto seasonalities are set to False because they are already set in
+        # self.seasonalities.
+        m = Prophet(
             growth=self.growth,
             n_changepoints=self.n_changepoints,
             changepoints=changepoints,
-            yearly_seasonality=self.yearly_seasonality,
-            weekly_seasonality=self.weekly_seasonality,
-            daily_seasonality=self.daily_seasonality,
+            yearly_seasonality=False,
+            weekly_seasonality=False,
+            daily_seasonality=False,
             holidays=self.holidays,
             seasonality_prior_scale=self.seasonality_prior_scale,
             changepoint_prior_scale=self.changepoint_prior_scale,
             holidays_prior_scale=self.holidays_prior_scale,
             mcmc_samples=self.mcmc_samples,
             interval_width=self.interval_width,
-            uncertainty_samples=self.uncertainty_samples
+            uncertainty_samples=self.uncertainty_samples,
         )
+        m.extra_regressors = deepcopy(self.extra_regressors)
+        m.seasonalities = deepcopy(self.seasonalities)
+        return m

+ 21 - 5
python/fbprophet/tests/test_prophet.py

@@ -555,6 +555,9 @@ class TestProphet(TestCase):
             m.fit(df.copy())
 
     def test_copy(self):
+        df = DATA.copy()
+        df['cap'] = 200.
+        df['binary_feature'] = [0] * 255 + [1] * 255
         # These values are created except for its default values
         holiday = pd.DataFrame(
             {'ds': pd.to_datetime(['2016-12-25']), 'holiday': ['x']})
@@ -576,13 +579,22 @@ class TestProphet(TestCase):
         # Values should be copied correctly
         for product in products:
             m1 = Prophet(*product)
+            m1.history = m1.setup_dataframe(
+                df.copy(), initialize_scales=True)
+            m1.set_auto_seasonalities()
             m2 = m1.copy()
             self.assertEqual(m1.growth, m2.growth)
             self.assertEqual(m1.n_changepoints, m2.n_changepoints)
             self.assertEqual(m1.changepoints, m2.changepoints)
-            self.assertEqual(m1.yearly_seasonality, m2.yearly_seasonality)
-            self.assertEqual(m1.weekly_seasonality, m2.weekly_seasonality)
-            self.assertEqual(m1.daily_seasonality, m2.daily_seasonality)
+            self.assertEqual(False, m2.yearly_seasonality)
+            self.assertEqual(False, m2.weekly_seasonality)
+            self.assertEqual(False, m2.daily_seasonality)
+            self.assertEqual(
+                m1.yearly_seasonality, 'yearly' in m2.seasonalities)
+            self.assertEqual(
+                m1.weekly_seasonality, 'weekly' in m2.seasonalities)
+            self.assertEqual(
+                m1.daily_seasonality, 'daily' in m2.seasonalities)
             if m1.holidays is None:
                 self.assertEqual(m1.holidays, m2.holidays)
             else:
@@ -594,11 +606,15 @@ class TestProphet(TestCase):
             self.assertEqual(m1.interval_width, m2.interval_width)
             self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
 
-        # Check for cutoff
+        # Check for cutoff and custom seasonality and extra regressors
         changepoints = pd.date_range('2012-06-15', '2012-09-15')
         cutoff = pd.Timestamp('2012-07-25')
         m1 = Prophet(changepoints=changepoints)
-        m1.fit(DATA)
+        m1.add_seasonality('custom', 10, 5)
+        m1.add_regressor('binary_feature')
+        m1.fit(df)
         m2 = m1.copy(cutoff=cutoff)
         changepoints = changepoints[changepoints <= cutoff]
         self.assertTrue((changepoints == m2.changepoints).all())
+        self.assertTrue('custom' in m2.seasonalities)
+        self.assertTrue('binary_feature' in m2.extra_regressors)