Selaa lähdekoodia

Figsize argument (#706)

* Add 'figsize' argument to all plotting methods

* Add docstrings for 'figsize' arguments
Florian Schäfer 6 vuotta sitten
vanhempi
commit
55ca33891e
1 muutettua tiedostoa jossa 26 lisäystä ja 16 poistoa
  1. 26 16
      python/fbprophet/plot.py

+ 26 - 16
python/fbprophet/plot.py

@@ -31,6 +31,7 @@ except ImportError:
 
 def plot(
     m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
+    figsize=(10, 6)
 ):
     """Plot the Prophet forecast.
 
@@ -44,13 +45,14 @@ def plot(
         in the figure, if available.
     xlabel: Optional label name on X-axis
     ylabel: Optional label name on Y-axis
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
     A matplotlib figure.
     """
     if ax is None:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize)
         ax = fig.add_subplot(111)
     else:
         fig = ax.get_figure()
@@ -73,6 +75,7 @@ def plot(
 
 def plot_components(
     m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0,
+    figsize=None
 ):
     """Plot the Prophet forecast components.
 
@@ -93,6 +96,7 @@ def plot_components(
     yearly_start: Optional int specifying the start day of the yearly
         seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
         by 1 day to Jan 2, and so on.
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
@@ -112,8 +116,8 @@ def plot_components(
             components.append('extra_regressors_{}'.format(mode))
     npanel = len(components)
 
-    fig, axes = plt.subplots(npanel, 1, facecolor='w',
-                            figsize=(9, 3 * npanel))
+    figsize = figsize if figsize else (9, 3 * npanel)
+    fig, axes = plt.subplots(npanel, 1, facecolor='w', figsize=figsize)
 
     if npanel == 1:
         axes = [axes]
@@ -158,7 +162,7 @@ def plot_components(
 
 
 def plot_forecast_component(
-    m, fcst, name, ax=None, uncertainty=True, plot_cap=False,
+    m, fcst, name, ax=None, uncertainty=True, plot_cap=False, figsize=(10, 6)
 ):
     """Plot a particular component of the forecast.
 
@@ -171,6 +175,7 @@ def plot_forecast_component(
     uncertainty: Optional boolean to plot uncertainty intervals.
     plot_cap: Optional boolean indicating if the capacity should be shown
         in the figure, if available.
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
@@ -178,7 +183,7 @@ def plot_forecast_component(
     """
     artists = []
     if not ax:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize=figsize)
         ax = fig.add_subplot(111)
     fcst_t = fcst['ds'].dt.to_pydatetime()
     artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
@@ -218,7 +223,7 @@ def seasonality_plot_df(m, ds):
     return df
 
 
-def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
+def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0, figsize=(10, 6)):
     """Plot the weekly component of the forecast.
 
     Parameters
@@ -230,6 +235,7 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
     weekly_start: Optional int specifying the start day of the weekly
         seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
         by 1 day to Monday, and so on.
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
@@ -237,7 +243,7 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
     """
     artists = []
     if not ax:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize=figsize)
         ax = fig.add_subplot(111)
     # Compute weekly seasonality for a Sun-Sat sequence of dates.
     days = (pd.date_range(start='2017-01-01', periods=7) +
@@ -261,7 +267,7 @@ def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
     return artists
 
 
-def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
+def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0, figsize=(10, 6)):
     """Plot the yearly component of the forecast.
 
     Parameters
@@ -273,6 +279,7 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
     yearly_start: Optional int specifying the start day of the yearly
         seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
         by 1 day to Jan 2, and so on.
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
@@ -280,7 +287,7 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
     """
     artists = []
     if not ax:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize=figsize)
         ax = fig.add_subplot(111)
     # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
     days = (pd.date_range(start='2017-01-01', periods=365) +
@@ -305,7 +312,7 @@ def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
     return artists
 
 
-def plot_seasonality(m, name, ax=None, uncertainty=True):
+def plot_seasonality(m, name, ax=None, uncertainty=True, figsize=(10, 6)):
     """Plot a custom seasonal component.
 
     Parameters
@@ -315,6 +322,7 @@ def plot_seasonality(m, name, ax=None, uncertainty=True):
     ax: Optional matplotlib Axes to plot on. One will be created if
         this is not provided.
     uncertainty: Optional boolean to plot uncertainty intervals.
+    figsize: Optional tuple width, height in inches.
 
     Returns
     -------
@@ -322,7 +330,7 @@ def plot_seasonality(m, name, ax=None, uncertainty=True):
     """
     artists = []
     if not ax:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize=figsize)
         ax = fig.add_subplot(111)
     # Compute seasonality from Jan 1 through a single period.
     start = pd.to_datetime('2017-01-01 0000')
@@ -368,11 +376,11 @@ def add_changepoints_to_plot(
     ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
 ):
     """Add markers for significant changepoints to prophet forecast plot.
-    
+
     Example:
     fig = m.plot(forecast)
     add_changepoints_to_plot(fig.gca(), m, forecast)
-    
+
     Parameters
     ----------
     ax: axis on which to overlay changepoint markers.
@@ -382,7 +390,7 @@ def add_changepoints_to_plot(
     cp_color: Color of changepoint markers.
     cp_linestyle: Linestyle for changepoint markers.
     trend: If True, will also overlay the trend.
-    
+
     Returns
     -------
     a list of matplotlib artists
@@ -398,7 +406,9 @@ def add_changepoints_to_plot(
     return artists
 
 
-def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
+def plot_cross_validation_metric(
+    df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6)
+):
     """Plot a performance metric vs. forecast horizon from cross validation.
 
     Cross validation produces a collection of out-of-sample model predictions
@@ -431,7 +441,7 @@ def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
     a matplotlib figure.
     """
     if ax is None:
-        fig = plt.figure(facecolor='w', figsize=(10, 6))
+        fig = plt.figure(facecolor='w', figsize=figsize)
         ax = fig.add_subplot(111)
     else:
         fig = ax.get_figure()