plot.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import logging
  12. import numpy as np
  13. import pandas as pd
  14. from fbprophet.diagnostics import performance_metrics
  15. logging.basicConfig()
  16. logger = logging.getLogger(__name__)
  17. try:
  18. from matplotlib import pyplot as plt
  19. from matplotlib.dates import MonthLocator, num2date
  20. from matplotlib.ticker import FuncFormatter
  21. except ImportError:
  22. logger.error('Importing matplotlib failed. Plotting will not work.')
  23. def plot(
  24. m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
  25. figsize=(10, 6)
  26. ):
  27. """Plot the Prophet forecast.
  28. Parameters
  29. ----------
  30. m: Prophet model.
  31. fcst: pd.DataFrame output of m.predict.
  32. ax: Optional matplotlib axes on which to plot.
  33. uncertainty: Optional boolean to plot uncertainty intervals.
  34. plot_cap: Optional boolean indicating if the capacity should be shown
  35. in the figure, if available.
  36. xlabel: Optional label name on X-axis
  37. ylabel: Optional label name on Y-axis
  38. figsize: Optional tuple width, height in inches.
  39. Returns
  40. -------
  41. A matplotlib figure.
  42. """
  43. if ax is None:
  44. fig = plt.figure(facecolor='w', figsize=figsize)
  45. ax = fig.add_subplot(111)
  46. else:
  47. fig = ax.get_figure()
  48. fcst_t = fcst['ds'].dt.to_pydatetime()
  49. ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
  50. ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
  51. if 'cap' in fcst and plot_cap:
  52. ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  53. if m.logistic_floor and 'floor' in fcst and plot_cap:
  54. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  55. if uncertainty:
  56. ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
  57. color='#0072B2', alpha=0.2)
  58. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  59. ax.set_xlabel(xlabel)
  60. ax.set_ylabel(ylabel)
  61. fig.tight_layout()
  62. return fig
  63. def plot_components(
  64. m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0,
  65. figsize=None
  66. ):
  67. """Plot the Prophet forecast components.
  68. Will plot whichever are available of: trend, holidays, weekly
  69. seasonality, yearly seasonality, and additive and multiplicative extra
  70. regressors.
  71. Parameters
  72. ----------
  73. m: Prophet model.
  74. fcst: pd.DataFrame output of m.predict.
  75. uncertainty: Optional boolean to plot uncertainty intervals.
  76. plot_cap: Optional boolean indicating if the capacity should be shown
  77. in the figure, if available.
  78. weekly_start: Optional int specifying the start day of the weekly
  79. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  80. by 1 day to Monday, and so on.
  81. yearly_start: Optional int specifying the start day of the yearly
  82. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  83. by 1 day to Jan 2, and so on.
  84. figsize: Optional tuple width, height in inches.
  85. Returns
  86. -------
  87. A matplotlib figure.
  88. """
  89. # Identify components to be plotted
  90. components = ['trend']
  91. if m.holidays is not None and 'holidays' in fcst:
  92. components.append('holidays')
  93. components.extend([name for name in m.seasonalities
  94. if name in fcst])
  95. regressors = {'additive': False, 'multiplicative': False}
  96. for name, props in m.extra_regressors.items():
  97. regressors[props['mode']] = True
  98. for mode in ['additive', 'multiplicative']:
  99. if regressors[mode] and 'extra_regressors_{}'.format(mode) in fcst:
  100. components.append('extra_regressors_{}'.format(mode))
  101. npanel = len(components)
  102. figsize = figsize if figsize else (9, 3 * npanel)
  103. fig, axes = plt.subplots(npanel, 1, facecolor='w', figsize=figsize)
  104. if npanel == 1:
  105. axes = [axes]
  106. multiplicative_axes = []
  107. for ax, plot_name in zip(axes, components):
  108. if plot_name == 'trend':
  109. plot_forecast_component(
  110. m=m, fcst=fcst, name='trend', ax=ax, uncertainty=uncertainty,
  111. plot_cap=plot_cap,
  112. )
  113. elif plot_name == 'weekly':
  114. plot_weekly(
  115. m=m, ax=ax, uncertainty=uncertainty, weekly_start=weekly_start,
  116. )
  117. elif plot_name == 'yearly':
  118. plot_yearly(
  119. m=m, ax=ax, uncertainty=uncertainty, yearly_start=yearly_start,
  120. )
  121. elif plot_name in [
  122. 'holidays',
  123. 'extra_regressors_additive',
  124. 'extra_regressors_multiplicative',
  125. ]:
  126. plot_forecast_component(
  127. m=m, fcst=fcst, name=plot_name, ax=ax, uncertainty=uncertainty,
  128. plot_cap=False,
  129. )
  130. else:
  131. plot_seasonality(
  132. m=m, name=plot_name, ax=ax, uncertainty=uncertainty,
  133. )
  134. if plot_name in m.component_modes['multiplicative']:
  135. multiplicative_axes.append(ax)
  136. fig.tight_layout()
  137. # Reset multiplicative axes labels after tight_layout adjustment
  138. for ax in multiplicative_axes:
  139. ax = set_y_as_percent(ax)
  140. return fig
  141. def plot_forecast_component(
  142. m, fcst, name, ax=None, uncertainty=True, plot_cap=False, figsize=(10, 6)
  143. ):
  144. """Plot a particular component of the forecast.
  145. Parameters
  146. ----------
  147. m: Prophet model.
  148. fcst: pd.DataFrame output of m.predict.
  149. name: Name of the component to plot.
  150. ax: Optional matplotlib Axes to plot on.
  151. uncertainty: Optional boolean to plot uncertainty intervals.
  152. plot_cap: Optional boolean indicating if the capacity should be shown
  153. in the figure, if available.
  154. figsize: Optional tuple width, height in inches.
  155. Returns
  156. -------
  157. a list of matplotlib artists
  158. """
  159. artists = []
  160. if not ax:
  161. fig = plt.figure(facecolor='w', figsize=figsize)
  162. ax = fig.add_subplot(111)
  163. fcst_t = fcst['ds'].dt.to_pydatetime()
  164. artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
  165. if 'cap' in fcst and plot_cap:
  166. artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  167. if m.logistic_floor and 'floor' in fcst and plot_cap:
  168. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  169. if uncertainty:
  170. artists += [ax.fill_between(
  171. fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
  172. color='#0072B2', alpha=0.2)]
  173. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  174. ax.set_xlabel('ds')
  175. ax.set_ylabel(name)
  176. if name in m.component_modes['multiplicative']:
  177. ax = set_y_as_percent(ax)
  178. return artists
  179. def seasonality_plot_df(m, ds):
  180. """Prepare dataframe for plotting seasonal components.
  181. Parameters
  182. ----------
  183. m: Prophet model.
  184. ds: List of dates for column ds.
  185. Returns
  186. -------
  187. A dataframe with seasonal components on ds.
  188. """
  189. df_dict = {'ds': ds, 'cap': 1., 'floor': 0.}
  190. for name in m.extra_regressors:
  191. df_dict[name] = 0.
  192. df = pd.DataFrame(df_dict)
  193. df = m.setup_dataframe(df)
  194. return df
  195. def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0, figsize=(10, 6)):
  196. """Plot the weekly component of the forecast.
  197. Parameters
  198. ----------
  199. m: Prophet model.
  200. ax: Optional matplotlib Axes to plot on. One will be created if this
  201. is not provided.
  202. uncertainty: Optional boolean to plot uncertainty intervals.
  203. weekly_start: Optional int specifying the start day of the weekly
  204. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  205. by 1 day to Monday, and so on.
  206. figsize: Optional tuple width, height in inches.
  207. Returns
  208. -------
  209. a list of matplotlib artists
  210. """
  211. artists = []
  212. if not ax:
  213. fig = plt.figure(facecolor='w', figsize=figsize)
  214. ax = fig.add_subplot(111)
  215. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  216. days = (pd.date_range(start='2017-01-01', periods=7) +
  217. pd.Timedelta(days=weekly_start))
  218. df_w = seasonality_plot_df(m, days)
  219. seas = m.predict_seasonal_components(df_w)
  220. days = days.weekday_name
  221. artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
  222. c='#0072B2')
  223. if uncertainty:
  224. artists += [ax.fill_between(range(len(days)),
  225. seas['weekly_lower'], seas['weekly_upper'],
  226. color='#0072B2', alpha=0.2)]
  227. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  228. ax.set_xticks(range(len(days)))
  229. ax.set_xticklabels(days)
  230. ax.set_xlabel('Day of week')
  231. ax.set_ylabel('weekly')
  232. if m.seasonalities['weekly']['mode'] == 'multiplicative':
  233. ax = set_y_as_percent(ax)
  234. return artists
  235. def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0, figsize=(10, 6)):
  236. """Plot the yearly component of the forecast.
  237. Parameters
  238. ----------
  239. m: Prophet model.
  240. ax: Optional matplotlib Axes to plot on. One will be created if
  241. this is not provided.
  242. uncertainty: Optional boolean to plot uncertainty intervals.
  243. yearly_start: Optional int specifying the start day of the yearly
  244. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  245. by 1 day to Jan 2, and so on.
  246. figsize: Optional tuple width, height in inches.
  247. Returns
  248. -------
  249. a list of matplotlib artists
  250. """
  251. artists = []
  252. if not ax:
  253. fig = plt.figure(facecolor='w', figsize=figsize)
  254. ax = fig.add_subplot(111)
  255. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  256. days = (pd.date_range(start='2017-01-01', periods=365) +
  257. pd.Timedelta(days=yearly_start))
  258. df_y = seasonality_plot_df(m, days)
  259. seas = m.predict_seasonal_components(df_y)
  260. artists += ax.plot(
  261. df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
  262. if uncertainty:
  263. artists += [ax.fill_between(
  264. df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
  265. seas['yearly_upper'], color='#0072B2', alpha=0.2)]
  266. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  267. months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
  268. ax.xaxis.set_major_formatter(FuncFormatter(
  269. lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
  270. ax.xaxis.set_major_locator(months)
  271. ax.set_xlabel('Day of year')
  272. ax.set_ylabel('yearly')
  273. if m.seasonalities['yearly']['mode'] == 'multiplicative':
  274. ax = set_y_as_percent(ax)
  275. return artists
  276. def plot_seasonality(m, name, ax=None, uncertainty=True, figsize=(10, 6)):
  277. """Plot a custom seasonal component.
  278. Parameters
  279. ----------
  280. m: Prophet model.
  281. name: Seasonality name, like 'daily', 'weekly'.
  282. ax: Optional matplotlib Axes to plot on. One will be created if
  283. this is not provided.
  284. uncertainty: Optional boolean to plot uncertainty intervals.
  285. figsize: Optional tuple width, height in inches.
  286. Returns
  287. -------
  288. a list of matplotlib artists
  289. """
  290. artists = []
  291. if not ax:
  292. fig = plt.figure(facecolor='w', figsize=figsize)
  293. ax = fig.add_subplot(111)
  294. # Compute seasonality from Jan 1 through a single period.
  295. start = pd.to_datetime('2017-01-01 0000')
  296. period = m.seasonalities[name]['period']
  297. end = start + pd.Timedelta(days=period)
  298. plot_points = 200
  299. days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
  300. df_y = seasonality_plot_df(m, days)
  301. seas = m.predict_seasonal_components(df_y)
  302. artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
  303. c='#0072B2')
  304. if uncertainty:
  305. artists += [ax.fill_between(
  306. df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
  307. seas[name + '_upper'], color='#0072B2', alpha=0.2)]
  308. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  309. xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
  310. ).to_pydatetime()
  311. ax.set_xticks(xticks)
  312. if period <= 2:
  313. fmt_str = '{dt:%T}'
  314. elif period < 14:
  315. fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
  316. else:
  317. fmt_str = '{dt:%m}/{dt:%d}'
  318. ax.xaxis.set_major_formatter(FuncFormatter(
  319. lambda x, pos=None: fmt_str.format(dt=num2date(x))))
  320. ax.set_xlabel('ds')
  321. ax.set_ylabel(name)
  322. if m.seasonalities[name]['mode'] == 'multiplicative':
  323. ax = set_y_as_percent(ax)
  324. return artists
  325. def set_y_as_percent(ax):
  326. yticks = 100 * ax.get_yticks()
  327. yticklabels = ['{0:.4g}%'.format(y) for y in yticks]
  328. ax.set_yticklabels(yticklabels)
  329. return ax
  330. def add_changepoints_to_plot(
  331. ax, m, fcst, threshold=0.01, cp_color='r', cp_linestyle='--', trend=True,
  332. ):
  333. """Add markers for significant changepoints to prophet forecast plot.
  334. Example:
  335. fig = m.plot(forecast)
  336. add_changepoints_to_plot(fig.gca(), m, forecast)
  337. Parameters
  338. ----------
  339. ax: axis on which to overlay changepoint markers.
  340. m: Prophet model.
  341. fcst: Forecast output from m.predict.
  342. threshold: Threshold on trend change magnitude for significance.
  343. cp_color: Color of changepoint markers.
  344. cp_linestyle: Linestyle for changepoint markers.
  345. trend: If True, will also overlay the trend.
  346. Returns
  347. -------
  348. a list of matplotlib artists
  349. """
  350. artists = []
  351. if trend:
  352. artists.append(ax.plot(fcst['ds'], fcst['trend'], c=cp_color))
  353. signif_changepoints = m.changepoints[
  354. np.abs(np.nanmean(m.params['delta'], axis=0)) >= threshold
  355. ]
  356. for cp in signif_changepoints:
  357. artists.append(ax.axvline(x=cp, c=cp_color, ls=cp_linestyle))
  358. return artists
  359. def plot_cross_validation_metric(
  360. df_cv, metric, rolling_window=0.1, ax=None, figsize=(10, 6)
  361. ):
  362. """Plot a performance metric vs. forecast horizon from cross validation.
  363. Cross validation produces a collection of out-of-sample model predictions
  364. that can be compared to actual values, at a range of different horizons
  365. (distance from the cutoff). This computes a specified performance metric
  366. for each prediction, and aggregated over a rolling window with horizon.
  367. This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
  368. Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
  369. rolling_window is the proportion of data included in the rolling window of
  370. aggregation. The default value of 0.1 means 10% of data are included in the
  371. aggregation for computing the metric.
  372. As a concrete example, if metric='mse', then this plot will show the
  373. squared error for each cross validation prediction, along with the MSE
  374. averaged over rolling windows of 10% of the data.
  375. Parameters
  376. ----------
  377. df_cv: The output from fbprophet.diagnostics.cross_validation.
  378. metric: Metric name, one of ['mse', 'rmse', 'mae', 'mape', 'coverage'].
  379. rolling_window: Proportion of data to use for rolling average of metric.
  380. In [0, 1]. Defaults to 0.1.
  381. ax: Optional matplotlib axis on which to plot. If not given, a new figure
  382. will be created.
  383. Returns
  384. -------
  385. a matplotlib figure.
  386. """
  387. if ax is None:
  388. fig = plt.figure(facecolor='w', figsize=figsize)
  389. ax = fig.add_subplot(111)
  390. else:
  391. fig = ax.get_figure()
  392. # Get the metric at the level of individual predictions, and with the rolling window.
  393. df_none = performance_metrics(df_cv, metrics=[metric], rolling_window=0)
  394. df_h = performance_metrics(df_cv, metrics=[metric], rolling_window=rolling_window)
  395. # Some work because matplotlib does not handle timedelta
  396. # Target ~10 ticks.
  397. tick_w = max(df_none['horizon'].astype('timedelta64[ns]')) / 10.
  398. # Find the largest time resolution that has <1 unit per bin.
  399. dts = ['D', 'h', 'm', 's', 'ms', 'us', 'ns']
  400. dt_names = [
  401. 'days', 'hours', 'minutes', 'seconds', 'milliseconds', 'microseconds',
  402. 'nanoseconds'
  403. ]
  404. dt_conversions = [
  405. 24 * 60 * 60 * 10 ** 9,
  406. 60 * 60 * 10 ** 9,
  407. 60 * 10 ** 9,
  408. 10 ** 9,
  409. 10 ** 6,
  410. 10 ** 3,
  411. 1.,
  412. ]
  413. for i, dt in enumerate(dts):
  414. if np.timedelta64(1, dt) < np.timedelta64(tick_w, 'ns'):
  415. break
  416. x_plt = df_none['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
  417. x_plt_h = df_h['horizon'].astype('timedelta64[ns]').astype(np.int64) / float(dt_conversions[i])
  418. ax.plot(x_plt, df_none[metric], '.', alpha=0.5, c='gray')
  419. ax.plot(x_plt_h, df_h[metric], '-', c='b')
  420. ax.grid(True)
  421. ax.set_xlabel('Horizon ({})'.format(dt_names[i]))
  422. ax.set_ylabel(metric)
  423. return fig