Ver código fonte

Make plotting interfaces consistent (return figs)

Ben Letham 8 anos atrás
pai
commit
26ca2f7af7
1 arquivos alterados com 16 adições e 21 exclusões
  1. 16 21
      python/fbprophet/forecaster.py

+ 16 - 21
python/fbprophet/forecaster.py

@@ -852,13 +852,14 @@ class Prophet(object):
 
         return pd.DataFrame({'ds': dates})
 
-    def plot(self, fcst, uncertainty=True, plot_cap=True, xlabel='ds',
+    def plot(self, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds',
              ylabel='y'):
         """Plot the Prophet forecast.
 
         Parameters
         ----------
         fcst: pd.DataFrame output of self.predict.
+        ax: Optional matplotlib axes on which to plot.
         uncertainty: Optional boolean to plot uncertainty intervals.
         plot_cap: Optional boolean indicating if the capacity should be shown
             in the figure, if available.
@@ -867,10 +868,13 @@ class Prophet(object):
 
         Returns
         -------
-        a matplotlib figure.
+        A matplotlib figure.
         """
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
-        ax = fig.add_subplot(111)
+        if ax is None:
+            fig = plt.figure(facecolor='w', figsize=(10, 6))
+            ax = fig.add_subplot(111)
+        else:
+            fig = ax.get_figure()
         ax.plot(self.history['ds'].values, self.history['y'], 'k.')
         ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c='#0072B2')
         if 'cap' in fcst and plot_cap:
@@ -892,13 +896,6 @@ class Prophet(object):
         Will plot whichever are available of: trend, holidays, weekly
         seasonality, and yearly seasonality.
 
-        This method returns a list of matplotlib artists. To show the plot,
-        >>> m.plot_components(fcst)
-        >>> from matplotlib import pyplot as plt
-        >>> plt.show()
-
-        To save the figure, replace plt.show with plt.savefig.
-
         Parameters
         ----------
         fcst: pd.DataFrame output of self.predict.
@@ -914,7 +911,7 @@ class Prophet(object):
 
         Returns
         -------
-        A list of matplotlib artists.
+        A matplotlib figure.
         """
         # Identify components to be plotted
         components = [('trend', True),
@@ -927,23 +924,21 @@ class Prophet(object):
         fig, axes = plt.subplots(npanel, 1, facecolor='w',
                                  figsize=(9, 3 * npanel))
 
-        artists = []
         for ax, plot in zip(axes, components):
             if plot == 'trend':
-                artists += self.plot_trend(
+                self.plot_trend(
                     fcst, ax=ax, uncertainty=uncertainty, plot_cap=plot_cap)
             elif plot == 'holidays':
-                artists += self.plot_holidays(fcst, ax=ax,
-                                              uncertainty=uncertainty)
+                self.plot_holidays(fcst, ax=ax, uncertainty=uncertainty)
             elif plot == 'weekly':
-                artists += self.plot_weekly(ax=ax, uncertainty=uncertainty,
-                                            weekly_start=weekly_start)
+                self.plot_weekly(
+                    ax=ax, uncertainty=uncertainty, weekly_start=weekly_start)
             elif plot == 'yearly':
-                artists += self.plot_yearly(ax=ax, uncertainty=uncertainty,
-                                            yearly_start=yearly_start)
+                self.plot_yearly(
+                    ax=ax, uncertainty=uncertainty, yearly_start=yearly_start)
 
         fig.tight_layout()
-        return artists
+        return fig
 
     def plot_trend(self, fcst, ax=None, uncertainty=True, plot_cap=True):
         """Plot the trend component of the forecast.