Explorar o código

Fix plotting in pandas 0.21 by using pydatetime instead of numpy

bl %!s(int64=7) %!d(string=hai) anos
pai
achega
d3d1fcd1a1
Modificáronse 1 ficheiros con 21 adicións e 18 borrados
  1. 21 18
      python/fbprophet/forecaster.py

+ 21 - 18
python/fbprophet/forecaster.py

@@ -1250,16 +1250,16 @@ class Prophet(object):
             ax = fig.add_subplot(111)
             ax = fig.add_subplot(111)
         else:
         else:
             fig = ax.get_figure()
             fig = ax.get_figure()
-        ax.plot(self.history['ds'].values, self.history['y'], 'k.')
-        ax.plot(fcst['ds'].values, fcst['yhat'], ls='-', c='#0072B2')
+        fcst_t = fcst['ds'].dt.to_pydatetime()
+        ax.plot(self.history['ds'].dt.to_pydatetime(), self.history['y'], 'k.')
+        ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
         if 'cap' in fcst and plot_cap:
         if 'cap' in fcst and plot_cap:
-            ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
+            ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
         if self.logistic_floor and 'floor' in fcst and plot_cap:
         if self.logistic_floor and 'floor' in fcst and plot_cap:
-            ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
+            ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
         if uncertainty:
         if uncertainty:
-            ax.fill_between(fcst['ds'].values, fcst['yhat_lower'],
-                            fcst['yhat_upper'], color='#0072B2',
-                            alpha=0.2)
+            ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
+                            color='#0072B2', alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.set_xlabel(xlabel)
         ax.set_xlabel(xlabel)
         ax.set_ylabel(ylabel)
         ax.set_ylabel(ylabel)
@@ -1347,15 +1347,16 @@ class Prophet(object):
         if not ax:
         if not ax:
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             ax = fig.add_subplot(111)
             ax = fig.add_subplot(111)
-        artists += ax.plot(fcst['ds'].values, fcst[name], ls='-', c='#0072B2')
+        fcst_t = fcst['ds'].dt.to_pydatetime()
+        artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
         if 'cap' in fcst and plot_cap:
         if 'cap' in fcst and plot_cap:
-            artists += ax.plot(fcst['ds'].values, fcst['cap'], ls='--', c='k')
+            artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
         if self.logistic_floor and 'floor' in fcst and plot_cap:
         if self.logistic_floor and 'floor' in fcst and plot_cap:
-            ax.plot(fcst['ds'].values, fcst['floor'], ls='--', c='k')
+            ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
         if uncertainty:
         if uncertainty:
             artists += [ax.fill_between(
             artists += [ax.fill_between(
-                fcst['ds'].values, fcst[name + '_lower'],
-                fcst[name + '_upper'], color='#0072B2', alpha=0.2)]
+                fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
+                color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.set_xlabel('ds')
         ax.set_xlabel('ds')
         ax.set_ylabel(name)
         ax.set_ylabel(name)
@@ -1443,11 +1444,11 @@ class Prophet(object):
                 pd.Timedelta(days=yearly_start))
                 pd.Timedelta(days=yearly_start))
         df_y = self.seasonality_plot_df(days)
         df_y = self.seasonality_plot_df(days)
         seas = self.predict_seasonal_components(df_y)
         seas = self.predict_seasonal_components(df_y)
-        artists += ax.plot(df_y['ds'], seas['yearly'], ls='-',
-                           c='#0072B2')
+        artists += ax.plot(
+            df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
         if uncertainty:
         if uncertainty:
             artists += [ax.fill_between(
             artists += [ax.fill_between(
-                df_y['ds'].values, seas['yearly_lower'],
+                df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
                 seas['yearly_upper'], color='#0072B2', alpha=0.2)]
                 seas['yearly_upper'], color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
         months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
@@ -1483,14 +1484,16 @@ class Prophet(object):
         days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
         days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
         df_y = self.seasonality_plot_df(days)
         df_y = self.seasonality_plot_df(days)
         seas = self.predict_seasonal_components(df_y)
         seas = self.predict_seasonal_components(df_y)
-        artists += ax.plot(df_y['ds'], seas[name], ls='-',
+        artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
                             c='#0072B2')
                             c='#0072B2')
         if uncertainty:
         if uncertainty:
             artists += [ax.fill_between(
             artists += [ax.fill_between(
-                df_y['ds'].values, seas[name + '_lower'],
+                df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
                 seas[name + '_upper'], color='#0072B2', alpha=0.2)]
                 seas[name + '_upper'], color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
-        ax.set_xticks(pd.to_datetime(np.linspace(start.value, end.value, 7)))
+        xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
+            ).to_pydatetime()
+        ax.set_xticks(xticks)
         if period <= 2:
         if period <= 2:
             fmt_str = '{dt:%T}'
             fmt_str = '{dt:%T}'
         elif period < 14:
         elif period < 14: