Browse Source

Merge pull request #84 from lemonlaug/better_mpl

Refactoring mpl code to address #62, #63
Sean J. Taylor 8 years ago
parent
commit
f287a57cca
1 changed files with 150 additions and 78 deletions
  1. 150 78
      python/fbprophet/forecaster.py

+ 150 - 78
python/fbprophet/forecaster.py

@@ -2,7 +2,7 @@
 # All rights reserved.
 #
 # This source code is licensed under the BSD-style license found in the
-# LICENSE file in the root directory of this source tree. An additional grant 
+# LICENSE file in the root directory of this source tree. An additional grant
 # of patent rights can be found in the PATENTS file in the same directory.
 
 from __future__ import absolute_import
@@ -188,7 +188,6 @@ class Prophet(object):
         else:
             self.changepoints_t = np.array([0])  # dummy changepoint
 
-
     def get_changepoint_matrix(self):
         A = np.zeros((self.history.shape[0], len(self.changepoints_t)))
         for i, t_i in enumerate(self.changepoints_t):
@@ -269,7 +268,6 @@ class Prophet(object):
         # This relies pretty importantly on pandas keeping the columns in order.
         return pd.DataFrame(expanded_holidays)
 
-
     def make_all_seasonality_features(self, df):
         seasonal_features = [
             # Add a column of zeros in case no seasonality is used.
@@ -626,16 +624,16 @@ class Prophet(object):
         -------
         a matplotlib figure.
         """
-        forecast_color = '#0072B2'
         fig = plt.figure(facecolor='w', figsize=(10, 6))
         ax = fig.add_subplot(111)
         ax.plot(self.history['ds'].values, self.history['y'], 'k.')
-        ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c=forecast_color)
+        ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c='#0072B2')
         if 'cap' in fcst:
             ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
         if uncertainty:
             ax.fill_between(fcst['ds'].values, fcst['yhat_lower'],
-                            fcst['yhat_upper'], color=forecast_color, alpha=0.2)
+                            fcst['yhat_upper'], color='#0072B2',
+                            alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.set_xlabel(xlabel)
         ax.set_ylabel(ylabel)
@@ -658,87 +656,161 @@ class Prophet(object):
         a matplotlib figure.
         """
         # Identify components to be plotted
-        plot_trend = True
-        plot_holidays = self.holidays is not None
-        plot_weekly = 'weekly' in fcst
-        plot_yearly = 'yearly' in fcst
-
-        npanel = plot_trend + plot_holidays + plot_weekly + plot_yearly
-        forecast_color = '#0072B2'
-        fig = plt.figure(facecolor='w', figsize=(9, 3 * npanel))
-        panel_num = 1
-        ax = fig.add_subplot(npanel, 1, panel_num)
-        ax.plot(fcst['ds'].values, fcst['trend'], ls='-', c=forecast_color)
+        components = [('plot_trend', True),
+                      ('plot_holidays', self.holidays is not None),
+                      ('plot_weekly', 'weekly' in fcst),
+                      ('plot_yearly', 'yearly' in fcst)]
+        components = [(plot, cond) for plot, cond in components if cond]
+        npanel = len(components)
+
+        fig, axes = plt.subplots(npanel, 1, facecolor='w',
+                                 figsize=(9, 3 * npanel))
+
+        artists = []
+        for ax, plot in zip(axes,
+                            [getattr(self, plot) for plot, _ in components]):
+            artists += plot(fcst, ax=ax, uncertainty=uncertainty)
+
+        fig.tight_layout()
+        return artists
+
+    def plot_trend(self, fcst, ax=None, uncertainty=True):
+        """Plot the trend component of the forecast.
+
+        Parameters
+        ----------
+        fcst: pd.DataFrame output of self.predict.
+        ax: Optional matplotlib Axes to plot on.
+        uncertainty: Optional boolean to plot uncertainty intervals.
+
+        Returns
+        -------
+        a list of matplotlib artists
+        """
+
+        artists = []
+        if not ax:
+            ax = fig.add_subplot(111)
+        artists += ax.plot(fcst['ds'].values, fcst['trend'], ls='-',
+                           c='#0072B2')
         if 'cap' in fcst:
-            ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
+            artists += ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
         if uncertainty:
-            ax.fill_between(
+            artists += [ax.fill_between(
                 fcst['ds'].values, fcst['trend_lower'], fcst['trend_upper'],
-                color=forecast_color, alpha=0.2)
+                color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.xaxis.set_major_locator(MaxNLocator(nbins=7))
         ax.set_xlabel('ds')
         ax.set_ylabel('trend')
+        return artists
 
-        if plot_holidays:
-            panel_num += 1
-            ax = fig.add_subplot(npanel, 1, panel_num)
-            holiday_comps = self.holidays['holiday'].unique()
-            y_holiday = fcst[holiday_comps].sum(1)
-            y_holiday_l = fcst[[h + '_lower' for h in holiday_comps]].sum(1)
-            y_holiday_u = fcst[[h + '_upper' for h in holiday_comps]].sum(1)
-            # NOTE the above CI calculation is incorrect if holidays overlap
-            # in time. Since it is just for the visualization we will not
-            # worry about it now.
-            ax.plot(fcst['ds'].values, y_holiday, ls='-', c=forecast_color)
-            if uncertainty:
-                ax.fill_between(fcst['ds'].values, y_holiday_l, y_holiday_u,
-                                color=forecast_color, alpha=0.2)
-            ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
-            ax.xaxis.set_major_locator(MaxNLocator(nbins=7))
-            ax.set_xlabel('ds')
-            ax.set_ylabel('holidays')
-
-        if plot_weekly:
-            panel_num += 1
-            ax = fig.add_subplot(npanel, 1, panel_num)
-            df_s = fcst.copy()
-            df_s['dow'] = df_s['ds'].dt.weekday_name
-            df_s = df_s.groupby('dow').first()
-            days = pd.date_range(start='2017-01-01', periods=7).weekday_name
-            y_weekly = [df_s.loc[d]['weekly'] for d in days]
-            y_weekly_l = [df_s.loc[d]['weekly_lower'] for d in days]
-            y_weekly_u = [df_s.loc[d]['weekly_upper'] for d in days]
-            ax.plot(range(len(days)), y_weekly, ls='-', c=forecast_color)
-            if uncertainty:
-                ax.fill_between(range(len(days)), y_weekly_l, y_weekly_u,
-                                color=forecast_color, alpha=0.2)
-            ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
-            ax.set_xticks(range(len(days)))
-            ax.set_xticklabels(days)
-            ax.set_xlabel('Day of week')
-            ax.set_ylabel('weekly')
-
-        if plot_yearly:
-            panel_num += 1
+    def plot_holidays(self, fcst, ax=None, uncertainty=True):
+        """Plot the holidays component of the forecast.
+
+        Parameters
+        ----------
+        fcst: pd.DataFrame output of self.predict.
+        ax: Optional matplotlib Axes to plot on. One will be created if this 
+            is not provided.
+        uncertainty: Optional boolean to plot uncertainty intervals.
+
+        Returns
+        -------
+        a list of matplotlib artists
+        """
+        artists = []
+        if not ax:
+            ax = fig.add_subplot(111)
+        holiday_comps = self.holidays['holiday'].unique()
+        y_holiday = fcst[holiday_comps].sum(1)
+        y_holiday_l = fcst[[h + '_lower' for h in holiday_comps]].sum(1)
+        y_holiday_u = fcst[[h + '_upper' for h in holiday_comps]].sum(1)
+        # NOTE the above CI calculation is incorrect if holidays overlap
+        # in time. Since it is just for the visualization we will not
+        # worry about it now.
+        artists += ax.plot(fcst['ds'].values, y_holiday, ls='-',
+                           c='#0072B2')
+        if uncertainty:
+            artists += [ax.fill_between(fcst['ds'].values,
+                                        y_holiday_l, y_holiday_u,
+                                        color='#0072B2', alpha=0.2)]
+        ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
+        ax.xaxis.set_major_locator(MaxNLocator(nbins=7))
+        ax.set_xlabel('ds')
+        ax.set_ylabel('holidays')
+        return artists
+
+    def plot_weekly(self, fcst, ax=None, uncertainty=True):
+        """Plot the weekly component of the forecast.
+
+        Parameters
+        ----------
+        fcst: pd.DataFrame output of self.predict.
+        ax: Optional matplotlib Axes to plot on. One will be created if this 
+            is not provided.
+        uncertainty: Optional boolean to plot uncertainty intervals.
+
+        Returns
+        -------
+        a list of matplotlib artists
+        """
+        artists = []
+        if not ax:
+            ax = fig.add_subplot(111)
+        df_s = fcst.copy()
+        df_s['dow'] = df_s['ds'].dt.weekday_name
+        df_s = df_s.groupby('dow').first()
+        days = pd.date_range(start='2017-01-01', periods=7).weekday_name
+        y_weekly = [df_s.loc[d]['weekly'] for d in days]
+        y_weekly_l = [df_s.loc[d]['weekly_lower'] for d in days]
+        y_weekly_u = [df_s.loc[d]['weekly_upper'] for d in days]
+        artists += ax.plot(range(len(days)), y_weekly, ls='-',
+                           c='#0072B2')
+        if uncertainty:
+            artists += [ax.fill_between(range(len(days)),
+                                        y_weekly_l, y_weekly_u,
+                                        color='#0072B2', alpha=0.2)]
+        ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
+        ax.set_xticks(range(len(days)))
+        ax.set_xticklabels(days)
+        ax.set_xlabel('Day of week')
+        ax.set_ylabel('weekly')
+        return artists
+
+    def plot_yearly(self, fcst, ax=None, uncertainty=True):
+        """Plot the yearly component of the forecast.
+
+        Parameters
+        ----------
+        fcst: pd.DataFrame output of self.predict.
+        ax: Optional matplotlib Axes to plot on. One will be created if 
+            this is not provided.
+        uncertainty: Optional boolean to plot uncertainty intervals.
+
+        Returns
+        -------
+        a list of matplotlib artists
+        """
+        artists = []
+        if not ax:
             ax = fig.add_subplot(npanel, 1, panel_num)
-            df_s = fcst.copy()
-            df_s['doy'] = df_s['ds'].map(lambda x: x.strftime('2000-%m-%d'))
-            df_s = df_s.groupby('doy').first().sort_index()
-            ax.plot(pd.to_datetime(df_s.index), df_s['yearly'], ls='-',
-                    c=forecast_color)
-            if uncertainty:
-                ax.fill_between(
-                    pd.to_datetime(df_s.index), df_s['yearly_lower'],
-                    df_s['yearly_upper'], color=forecast_color, alpha=0.2)
-            ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
-            months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
-            ax.xaxis.set_major_formatter(DateFormatter('%B %-d'))
-            ax.xaxis.set_major_locator(months)
-            ax.set_xlabel('Day of year')
-            ax.set_ylabel('yearly')
+        df_s = fcst.copy()
+        df_s['doy'] = df_s['ds'].map(lambda x: x.strftime('2000-%m-%d'))
+        df_s = df_s.groupby('doy').first().sort_index()
+        artists += ax.plot(pd.to_datetime(df_s.index), df_s['yearly'], ls='-',
+                           c='#0072B2')
+        if uncertainty:
+            artists += [ax.fill_between(
+                pd.to_datetime(df_s.index), df_s['yearly_lower'],
+                df_s['yearly_upper'], color='#0072B2', alpha=0.2)]
+        ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
+        months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
+        ax.xaxis.set_major_formatter(DateFormatter('%B %-d'))
+        ax.xaxis.set_major_locator(months)
+        ax.set_xlabel('Day of year')
+        ax.set_ylabel('yearly')
+        return artists
 
-        fig.tight_layout()
-        return fig
 
 # fb-block 9