plot.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import logging
  12. import numpy as np
  13. import pandas as pd
  14. from fbprophet.diagnostics import performance_metrics
  15. logging.basicConfig()
  16. logger = logging.getLogger(__name__)
  17. try:
  18. from matplotlib import pyplot as plt
  19. from matplotlib.dates import MonthLocator, num2date
  20. from matplotlib.ticker import FuncFormatter
  21. except ImportError:
  22. logger.error('Importing matplotlib failed. Plotting will not work.')
  23. def plot(
  24. m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
  25. ):
  26. """Plot the Prophet forecast.
  27. Parameters
  28. ----------
  29. m: Prophet model.
  30. fcst: pd.DataFrame output of m.predict.
  31. ax: Optional matplotlib axes on which to plot.
  32. uncertainty: Optional boolean to plot uncertainty intervals.
  33. plot_cap: Optional boolean indicating if the capacity should be shown
  34. in the figure, if available.
  35. xlabel: Optional label name on X-axis
  36. ylabel: Optional label name on Y-axis
  37. Returns
  38. -------
  39. A matplotlib figure.
  40. """
  41. if ax is None:
  42. fig = plt.figure(facecolor='w', figsize=(10, 6))
  43. ax = fig.add_subplot(111)
  44. else:
  45. fig = ax.get_figure()
  46. fcst_t = fcst['ds'].dt.to_pydatetime()
  47. ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
  48. ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
  49. if 'cap' in fcst and plot_cap:
  50. ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  51. if m.logistic_floor and 'floor' in fcst and plot_cap:
  52. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  53. if uncertainty:
  54. ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
  55. color='#0072B2', alpha=0.2)
  56. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  57. ax.set_xlabel(xlabel)
  58. ax.set_ylabel(ylabel)
  59. fig.tight_layout()
  60. return fig
  61. def plot_components(
  62. m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0,
  63. ):
  64. """Plot the Prophet forecast components.
  65. Will plot whichever are available of: trend, holidays, weekly
  66. seasonality, yearly seasonality, and additive and multiplicative extra
  67. regressors.
  68. Parameters
  69. ----------
  70. m: Prophet model.
  71. fcst: pd.DataFrame output of m.predict.
  72. uncertainty: Optional boolean to plot uncertainty intervals.
  73. plot_cap: Optional boolean indicating if the capacity should be shown
  74. in the figure, if available.
  75. weekly_start: Optional int specifying the start day of the weekly
  76. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  77. by 1 day to Monday, and so on.
  78. yearly_start: Optional int specifying the start day of the yearly
  79. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  80. by 1 day to Jan 2, and so on.
  81. Returns
  82. -------
  83. A matplotlib figure.
  84. """
  85. # Identify components to be plotted
  86. components = ['trend']
  87. if m.holidays is not None and 'holidays' in fcst:
  88. components.append('holidays')
  89. components.extend([name for name in m.seasonalities
  90. if name in fcst])
  91. regressors = {'additive': False, 'multiplicative': False}
  92. for name, props in m.extra_regressors.items():
  93. regressors[props['mode']] = True
  94. for mode in ['additive', 'multiplicative']:
  95. if regressors[mode] and 'extra_regressors_{}'.format(mode) in fcst:
  96. components.append('extra_regressors_{}'.format(mode))
  97. npanel = len(components)
  98. fig, axes = plt.subplots(npanel, 1, facecolor='w',
  99. figsize=(9, 3 * npanel))
  100. if npanel == 1:
  101. axes = [axes]
  102. multiplicative_axes = []
  103. for ax, plot_name in zip(axes, components):
  104. if plot_name == 'trend':
  105. plot_forecast_component(
  106. m=m, fcst=fcst, name='trend', ax=ax, uncertainty=uncertainty,
  107. plot_cap=plot_cap,
  108. )
  109. elif plot_name == 'weekly':
  110. plot_weekly(
  111. m=m, ax=ax, uncertainty=uncertainty, weekly_start=weekly_start,
  112. )
  113. elif plot_name == 'yearly':
  114. plot_yearly(
  115. m=m, ax=ax, uncertainty=uncertainty, yearly_start=yearly_start,
  116. )
  117. elif plot_name in [
  118. 'holidays',
  119. 'extra_regressors_additive',
  120. 'extra_regressors_multiplicative',
  121. ]:
  122. plot_forecast_component(
  123. m=m, fcst=fcst, name=plot_name, ax=ax, uncertainty=uncertainty,
  124. plot_cap=False,
  125. )
  126. else:
  127. plot_seasonality(
  128. m=m, name=plot_name, ax=ax, uncertainty=uncertainty,
  129. )
  130. if plot_name in m.component_modes['multiplicative']:
  131. multiplicative_axes.append(ax)
  132. fig.tight_layout()
  133. # Reset multiplicative axes labels after tight_layout adjustment
  134. for ax in multiplicative_axes:
  135. ax = set_y_as_percent(ax)
  136. return fig
  137. def plot_forecast_component(
  138. m, fcst, name, ax=None, uncertainty=True, plot_cap=False,
  139. ):
  140. """Plot a particular component of the forecast.
  141. Parameters
  142. ----------
  143. m: Prophet model.
  144. fcst: pd.DataFrame output of m.predict.
  145. name: Name of the component to plot.
  146. ax: Optional matplotlib Axes to plot on.
  147. uncertainty: Optional boolean to plot uncertainty intervals.
  148. plot_cap: Optional boolean indicating if the capacity should be shown
  149. in the figure, if available.
  150. Returns
  151. -------
  152. a list of matplotlib artists
  153. """
  154. artists = []
  155. if not ax:
  156. fig = plt.figure(facecolor='w', figsize=(10, 6))
  157. ax = fig.add_subplot(111)
  158. fcst_t = fcst['ds'].dt.to_pydatetime()
  159. artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
  160. if 'cap' in fcst and plot_cap:
  161. artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  162. if m.logistic_floor and 'floor' in fcst and plot_cap:
  163. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  164. if uncertainty:
  165. artists += [ax.fill_between(
  166. fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
  167. color='#0072B2', alpha=0.2)]
  168. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  169. ax.set_xlabel('ds')
  170. ax.set_ylabel(name)
  171. if name in m.component_modes['multiplicative']:
  172. ax = set_y_as_percent(ax)
  173. return artists
  174. def seasonality_plot_df(m, ds):
  175. """Prepare dataframe for plotting seasonal components.
  176. Parameters
  177. ----------
  178. m: Prophet model.
  179. ds: List of dates for column ds.
  180. Returns
  181. -------
  182. A dataframe with seasonal components on ds.
  183. """
  184. df_dict = {'ds': ds, 'cap': 1., 'floor': 0.}
  185. for name in m.extra_regressors:
  186. df_dict[name] = 0.
  187. df = pd.DataFrame(df_dict)
  188. df = m.setup_dataframe(df)
  189. return df
  190. def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
  191. """Plot the weekly component of the forecast.
  192. Parameters
  193. ----------
  194. m: Prophet model.
  195. ax: Optional matplotlib Axes to plot on. One will be created if this
  196. is not provided.
  197. uncertainty: Optional boolean to plot uncertainty intervals.
  198. weekly_start: Optional int specifying the start day of the weekly
  199. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  200. by 1 day to Monday, and so on.
  201. Returns
  202. -------
  203. a list of matplotlib artists
  204. """
  205. artists = []
  206. if not ax:
  207. fig = plt.figure(facecolor='w', figsize=(10, 6))
  208. ax = fig.add_subplot(111)
  209. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  210. days = (pd.date_range(start='2017-01-01', periods=7) +
  211. pd.Timedelta(days=weekly_start))
  212. df_w = seasonality_plot_df(m, days)
  213. seas = m.predict_seasonal_components(df_w)
  214. days = days.weekday_name
  215. artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
  216. c='#0072B2')
  217. if uncertainty:
  218. artists += [ax.fill_between(range(len(days)),
  219. seas['weekly_lower'], seas['weekly_upper'],
  220. color='#0072B2', alpha=0.2)]
  221. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  222. ax.set_xticks(range(len(days)))
  223. ax.set_xticklabels(days)
  224. ax.set_xlabel('Day of week')
  225. ax.set_ylabel('weekly')
  226. if m.seasonalities['weekly']['mode'] == 'multiplicative':
  227. ax = set_y_as_percent(ax)
  228. return artists
  229. def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
  230. """Plot the yearly component of the forecast.
  231. Parameters
  232. ----------
  233. m: Prophet model.
  234. ax: Optional matplotlib Axes to plot on. One will be created if
  235. this is not provided.
  236. uncertainty: Optional boolean to plot uncertainty intervals.
  237. yearly_start: Optional int specifying the start day of the yearly
  238. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  239. by 1 day to Jan 2, and so on.
  240. Returns
  241. -------
  242. a list of matplotlib artists
  243. """
  244. artists = []
  245. if not ax:
  246. fig = plt.figure(facecolor='w', figsize=(10, 6))
  247. ax = fig.add_subplot(111)
  248. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  249. days = (pd.date_range(start='2017-01-01', periods=365) +
  250. pd.Timedelta(days=yearly_start))
  251. df_y = seasonality_plot_df(m, days)
  252. seas = m.predict_seasonal_components(df_y)
  253. artists += ax.plot(
  254. df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
  255. if uncertainty:
  256. artists += [ax.fill_between(
  257. df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
  258. seas['yearly_upper'], color='#0072B2', alpha=0.2)]
  259. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  260. months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
  261. ax.xaxis.set_major_formatter(FuncFormatter(
  262. lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
  263. ax.xaxis.set_major_locator(months)
  264. ax.set_xlabel('Day of year')
  265. ax.set_ylabel('yearly')
  266. if m.seasonalities['yearly']['mode'] == 'multiplicative':
  267. ax = set_y_as_percent(ax)
  268. return artists
  269. def plot_seasonality(m, name, ax=None, uncertainty=True):
  270. """Plot a custom seasonal component.
  271. Parameters
  272. ----------
  273. m: Prophet model.
  274. name: Seasonality name, like 'daily', 'weekly'.
  275. ax: Optional matplotlib Axes to plot on. One will be created if
  276. this is not provided.
  277. uncertainty: Optional boolean to plot uncertainty intervals.
  278. Returns
  279. -------
  280. a list of matplotlib artists
  281. """
  282. artists = []
  283. if not ax:
  284. fig = plt.figure(facecolor='w', figsize=(10, 6))
  285. ax = fig.add_subplot(111)
  286. # Compute seasonality from Jan 1 through a single period.
  287. start = pd.to_datetime('2017-01-01 0000')
  288. period = m.seasonalities[name]['period']
  289. end = start + pd.Timedelta(days=period)
  290. plot_points = 200
  291. days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
  292. df_y = seasonality_plot_df(m, days)
  293. seas = m.predict_seasonal_components(df_y)
  294. artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
  295. c='#0072B2')
  296. if uncertainty:
  297. artists += [ax.fill_between(
  298. df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
  299. seas[name + '_upper'], color='#0072B2', alpha=0.2)]
  300. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  301. xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
  302. ).to_pydatetime()
  303. ax.set_xticks(xticks)
  304. if period <= 2:
  305. fmt_str = '{dt:%T}'
  306. elif period < 14:
  307. fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
  308. else:
  309. fmt_str = '{dt:%m}/{dt:%d}'
  310. ax.xaxis.set_major_formatter(FuncFormatter(
  311. lambda x, pos=None: fmt_str.format(dt=num2date(x))))
  312. ax.set_xlabel('ds')
  313. ax.set_ylabel(name)
  314. if m.seasonalities[name]['mode'] == 'multiplicative':
  315. ax = set_y_as_percent(ax)
  316. return artists
  317. def set_y_as_percent(ax):
  318. yticks = 100 * ax.get_yticks()
  319. yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
  320. ax.set_yticklabels(yticklabels)
  321. return ax
  322. def add_changepoints_to_plot(
  323. ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
  324. ):
  325. """Add markers for significant changepoints to prophet forecast plot.
  326. Example:
  327. fig = m.plot(forecast)
  328. add_changepoints_to_plot(fig.gca(), m, forecast)
  329. Parameters
  330. ----------
  331. ax: axis on which to overlay changepoint markers.
  332. m: Prophet model.
  333. fcst: Forecast output from m.predict.
  334. threshold: Threshold on trend change magnitude for significance.
  335. cp_color: Color of changepoint markers.
  336. cp_linestyle: Linestyle for changepoint markers.
  337. trend: If True, will also overlay the trend.
  338. Returns
  339. -------
  340. a list of matplotlib artists
  341. """
  342. artists = []
  343. if trend:
  344. artists.append(ax.plot(fcst['ds'], fcst['trend'], c=cp_color))
  345. signif_changepoints = m.changepoints[
  346. np.abs(np.nanmean(m.params['delta'], axis=0)) >= threshold
  347. ]
  348. for cp in signif_changepoints:
  349. artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
  350. return artists
  351. def plot_cross_validation_metric(df_cv, metric, rolling_window=0.1, ax=None):
  352. """Plot a performance metric vs. forecast horizon from cross validation.
  353. Cross validation produces a collection of out-of-sample model predictions
  354. that can be compared to actual values, at a range of different horizons
  355. (distance from the cutoff). This computes a specified performance metric
  356. for each prediction, and aggregated over a rolling window with horizon.
  357. This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
  358. Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
  359. rolling_window is the proportion of data included in the rolling window of
  360. aggregation. The default value of 0.1 means 10% of data are included in the
  361. aggregation for computing the metric.
  362. As a concrete example, if metric='mse', then this plot will show the
  363. squared error for each cross validation prediction, along with the MSE
  364. averaged over rolling windows of 10% of the data.
  365. Parameters
  366. ----------
  367. df_cv: The output from fbprophet.diagnostics.cross_validation.
  368. metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
  369. rolling_window: Proportion of data to use for rolling average of metric.
  370. In [0, 1]. Defaults to 0.1.
  371. ax: Optional matplotlib axis on which to plot. If not given, a new figure
  372. will be created.
  373. Returns
  374. -------
  375. a matplotlib figure.
  376. """
  377. if ax is None:
  378. fig = plt.figure(facecolor='w', figsize=(10, 6))
  379. ax = fig.add_subplot(111)
  380. else:
  381. fig = ax.get_figure()
  382. # Get the metric at the level of individual predictions, and with the rolling window.
  383. df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
  384. df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
  385. # Some work because matplotlib does not handle timedelta
  386. # Target ~10 ticks.
  387. tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
  388. # Find the largest time resolution that has <1 unit per bin.
  389. dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
  390. dt_names = [
  391. 'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
  392. 'nanoseconds'
  393. ]
  394. dt_conversions = [
  395. 24 * 60 * 60 * 10 ** 9,
  396. 60 * 60 * 10 ** 9,
  397. 60 * 10 ** 9,
  398. 10 ** 9,
  399. 10 ** 6,
  400. 10 ** 3,
  401. 1.,
  402. ]
  403. for i, dt in enumerate(dts):
  404. if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
  405. break
  406. x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
  407. x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
  408. ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
  409. ax.plot(x_plt_h, df_h[metric], '-', c='b')
  410. ax.grid(True)
  411. ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
  412. ax.set_ylabel(metric)
  413. return fig