Ben Letham 8 سال پیش
والد
کامیت
4523315ffc
2فایلهای تغییر یافته به همراه41 افزوده شده و 38 حذف شده
  1. 38 35
      python/fbprophet/forecaster.py
  2. 3 3
      python/fbprophet/tests/test_prophet.py

+ 38 - 35
python/fbprophet/forecaster.py

@@ -66,7 +66,7 @@ class Prophet(object):
         fluctuations, smaller values dampen the seasonality. Can be specified
         for individual seasonalities using add_seasonality.
     holidays_prior_scale: Parameter modulating the strength of the holiday
-        components model, unless overriden in the holidays input.
+        components model, unless overridden in the holidays input.
     changepoint_prior_scale: Parameter modulating the flexibility of the
         automatic changepoint selection. Large values will allow many
         changepoints, small values will allow few changepoints.
@@ -115,8 +115,8 @@ class Prophet(object):
         if holidays is not None:
             if not (
                 isinstance(holidays, pd.DataFrame)
-                and 'ds' in holidays
-                and 'holiday' in holidays
+                and 'ds' in holidays  # noqa W503
+                and 'holiday' in holidays  # noqa W503
             ):
                 raise ValueError("holidays must be a DataFrame with 'ds' and "
                                  "'holiday' columns.")
@@ -232,32 +232,7 @@ class Prophet(object):
         df = df.sort_values('ds')
         df.reset_index(inplace=True, drop=True)
 
-        if initialize_scales:
-            if self.growth == 'logistic' and 'floor' in df:
-                self.logistic_floor = True
-                floor = df['floor']
-            else:
-                floor = 0.
-            self.y_scale = (df['y'] - floor).abs().max()
-            if self.y_scale == 0:
-                self.y_scale = 1
-            self.start = df['ds'].min()
-            self.t_scale = df['ds'].max() - self.start
-            for name, props in self.extra_regressors.items():
-                standardize = props['standardize']
-                if standardize == 'auto':
-                    if set(df[name].unique()) == set([1, 0]):
-                        # Don't standardize binary variables.
-                        standardize = False
-                    else:
-                        standardize = True
-                if standardize:
-                    mu = df[name].mean()
-                    std = df[name].std()
-                    if std == 0:
-                        std = mu
-                    self.extra_regressors[name]['mu'] = mu
-                    self.extra_regressors[name]['std'] = std
+        self.initialize_scales(initialize_scales, df)
 
         if self.logistic_floor:
             if 'floor' not in df:
@@ -279,6 +254,35 @@ class Prophet(object):
                 raise ValueError('Found NaN in column ' + name)
         return df
 
+    def initialize_scales(self, initialize_scales, df):
+        if not initialize_scales:
+            return
+        if self.growth == 'logistic' and 'floor' in df:
+            self.logistic_floor = True
+            floor = df['floor']
+        else:
+            floor = 0.
+        self.y_scale = (df['y'] - floor).abs().max()
+        if self.y_scale == 0:
+            self.y_scale = 1
+        self.start = df['ds'].min()
+        self.t_scale = df['ds'].max() - self.start
+        for name, props in self.extra_regressors.items():
+            standardize = props['standardize']
+            if standardize == 'auto':
+                if set(df[name].unique()) == set([1, 0]):
+                    # Don't standardize binary variables.
+                    standardize = False
+                else:
+                    standardize = True
+            if standardize:
+                mu = df[name].mean()
+                std = df[name].std()
+                if std == 0:
+                    std = mu
+                self.extra_regressors[name]['mu'] = mu
+                self.extra_regressors[name]['std'] = std
+
     def set_changepoints(self):
         """Set changepoints
 
@@ -422,7 +426,7 @@ class Prophet(object):
             if ps <= 0:
                 raise ValueError('Prior scale must be > 0')
             prior_scales[row.holiday] = ps
-                
+
             for offset in range(lw, uw + 1):
                 occurrence = dt + timedelta(days=offset)
                 try:
@@ -918,7 +922,7 @@ class Prophet(object):
         for i, t_s in enumerate(changepoint_ts):
             gammas[i] = (
                 (t_s - m - np.sum(gammas))
-                * (1 - k_cum[i] / k_cum[i + 1])
+                * (1 - k_cum[i] / k_cum[i + 1])  # noqa W503
             )
         # Get cumulative rate and offset at each t
         k_t = k * np.ones_like(t)
@@ -997,7 +1001,7 @@ class Prophet(object):
             comp_features = X[:, cols]
             comp = (
                 np.matmul(comp_features, comp_beta.transpose())
-                * self.y_scale
+                * self.y_scale  # noqa W503
             )
             data[component] = np.nanmean(comp, axis=1)
             data[component + '_lower'] = np.nanpercentile(comp, lower_p,
@@ -1025,7 +1029,6 @@ class Prophet(object):
         components = components.append(new_comp)
         return components
 
-
     def sample_posterior_predictive(self, df):
         """Prophet posterior predictive samples.
 
@@ -1237,7 +1240,7 @@ class Prophet(object):
         ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c='#0072B2')
         if 'cap' in fcst and plot_cap:
             ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
-        if self.logistic_floor and 'floor' in fcst and plot_cap :
+        if self.logistic_floor and 'floor' in fcst and plot_cap:
             ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
         if uncertainty:
             ax.fill_between(fcst['ds'].values, fcst['yhat_lower'],
@@ -1333,7 +1336,7 @@ class Prophet(object):
         artists += ax.plot(fcst['ds'].values, fcst[name], ls='-', c='#0072B2')
         if 'cap' in fcst and plot_cap:
             artists += ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
-        if self.logistic_floor and 'floor' in fcst and plot_cap :
+        if self.logistic_floor and 'floor' in fcst and plot_cap:
             ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
         if uncertainty:
             artists += [ax.fill_between(

+ 3 - 3
python/fbprophet/tests/test_prophet.py

@@ -521,15 +521,15 @@ class TestProphet(TestCase):
             fcst['extra_regressors'][0],
             fcst['numeric_feature'][0] + fcst['binary_feature2'][0],
         )
-        self.assertEqual(
+        self.assertAlmostEqual(
             fcst['seasonalities'][0],
             fcst['yearly'][0] + fcst['weekly'][0],
         )
-        self.assertEqual(
+        self.assertAlmostEqual(
             fcst['seasonal'][0],
             fcst['seasonalities'][0] + fcst['extra_regressors'][0],
         )
-        self.assertEqual(
+        self.assertAlmostEqual(
             fcst['yhat'][0],
             fcst['trend'][0] + fcst['seasonal'][0],
         )