Browse Source

Add custom seasonalities (Py)

bl 8 năm trước cách đây
mục cha
commit
707c885275

Những thai đổi đã bị hủy bỏ vì nó quá lớn
+ 26 - 5
notebooks/non-daily_data.ipynb


+ 76 - 5
python/fbprophet/forecaster.py

@@ -343,6 +343,25 @@ class Prophet(object):
         # This relies pretty importantly on pandas keeping the columns in order.
         return pd.DataFrame(expanded_holidays)
 
+    def add_seasonality(self, name, period, fourier_order):
+        """Add a seasonal component with specified period and number of Fourier
+        components.
+
+        Increasing the number of Fourier components allows the seasonality to
+        change more quickly (at risk of overfitting).
+
+        Parameters
+        ----------
+        name: string name of the seasonality component.
+        period: float number of days in one period.
+        fourier_order: int number of Fourier components to use.
+        """
+        if self.holidays is not None:
+            if name in set(holidays['holiday']):
+                raise ValueError(
+                    'Name "{}" already used for holiday'.format(name))
+        self.seasonalities[name] = (period, fourier_order)
+
     def make_all_seasonality_features(self, df):
         """Dataframe with seasonality features.
 
@@ -1013,11 +1032,11 @@ class Prophet(object):
         A matplotlib figure.
         """
         # Identify components to be plotted
-        components = [('trend', True),
-                      ('holidays', self.holidays is not None),
-                      ('weekly', 'weekly' in fcst),
-                      ('yearly', 'yearly' in fcst)]
-        components = [plot for plot, cond in components if cond]
+        components = ['trend']
+        if self.holidays is not None:
+            components.append('holidays')
+        components.extend([name for name in self.seasonalities
+                           if name in fcst])
         npanel = len(components)
 
         fig, axes = plt.subplots(npanel, 1, facecolor='w',
@@ -1035,6 +1054,9 @@ class Prophet(object):
             elif plot == 'yearly':
                 self.plot_yearly(
                     ax=ax, uncertainty=uncertainty, yearly_start=yearly_start)
+            else:
+                self.plot_seasonality(
+                     name=plot, ax=ax, uncertainty=uncertainty)
 
         fig.tight_layout()
         return fig
@@ -1188,3 +1210,52 @@ class Prophet(object):
         ax.set_xlabel('Day of year')
         ax.set_ylabel('yearly')
         return artists
+
+    def plot_seasonality(self, name, ax=None, uncertainty=True):
+        """Plot a custom seasonal component.
+
+        Parameters
+        ----------
+        ax: Optional matplotlib Axes to plot on. One will be created if
+            this is not provided.
+        uncertainty: Optional boolean to plot uncertainty intervals.
+
+        Returns
+        -------
+        a list of matplotlib artists
+        """
+        artists = []
+        if not ax:
+            fig = plt.figure(facecolor='w', figsize=(10, 6))
+            ax = fig.add_subplot(111)
+        # Compute seasonality from Jan 1 through a single period.
+        start = pd.to_datetime('2017-01-01 0000')
+        period = self.seasonalities[name][0]
+        end = start + pd.Timedelta(days=period)
+        plot_points = 200
+        df_y = pd.DataFrame({
+            'ds': pd.to_datetime(
+                np.linspace(start.value, end.value, plot_points)),
+            'cap': 1.,
+        })
+        df_y = self.setup_dataframe(df_y)
+        seas = self.predict_seasonal_components(df_y)
+        artists += ax.plot(df_y['ds'], seas[name], ls='-',
+                            c='#0072B2')
+        if uncertainty:
+            artists += [ax.fill_between(
+                df_y['ds'].values, seas[name + '_lower'],
+                seas[name + '_upper'], color='#0072B2', alpha=0.2)]
+        ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
+        ax.set_xticks(pd.to_datetime(np.linspace(start.value, end.value, 7)))
+        if period <= 2:
+            fmt_str = '{dt:%T}'
+        elif period < 14:
+            fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
+        else:
+            fmt_str = '{dt:%m}/{dt:%d}'
+        ax.xaxis.set_major_formatter(FuncFormatter(
+            lambda x, pos=None: fmt_str.format(dt=num2date(x))))
+        ax.set_xlabel('ds')
+        ax.set_ylabel(name)
+        return artists

+ 5 - 0
python/fbprophet/tests/test_prophet.py

@@ -336,3 +336,8 @@ class TestProphet(TestCase):
         m.fit(DATA2)
         fcst = m.predict()
         self.assertEqual(sum(fcst['new_years'] == 0), 575)
+
+    def test_custom_seasonality(self):
+        m = Prophet()
+        m.add_seasonality(name='monthly', period=30, fourier_order=5)
+        self.assertEqual(m.seasonalities['monthly'], (30, 5))