plot.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the BSD-style license found in the
  5. # LICENSE file in the root directory of this source tree. An additional grant
  6. # of patent rights can be found in the PATENTS file in the same directory.
  7. from __future__ import absolute_import
  8. from __future__ import division
  9. from __future__ import print_function
  10. from __future__ import unicode_literals
  11. import logging
  12. import numpy as np
  13. import pandas as pd
  14. logging.basicConfig()
  15. logger = logging.getLogger(__name__)
  16. try:
  17. from matplotlib import pyplot as plt
  18. from matplotlib.dates import MonthLocator, num2date
  19. from matplotlib.ticker import FuncFormatter
  20. except ImportError:
  21. logger.error('Importing matplotlib failed. Plotting will not work.')
  22. def plot(
  23. m, fcst, ax=None, uncertainty=True, plot_cap=True, xlabel='ds', ylabel='y',
  24. ):
  25. """Plot the Prophet forecast.
  26. Parameters
  27. ----------
  28. m: Prophet model.
  29. fcst: pd.DataFrame output of m.predict.
  30. ax: Optional matplotlib axes on which to plot.
  31. uncertainty: Optional boolean to plot uncertainty intervals.
  32. plot_cap: Optional boolean indicating if the capacity should be shown
  33. in the figure, if available.
  34. xlabel: Optional label name on X-axis
  35. ylabel: Optional label name on Y-axis
  36. Returns
  37. -------
  38. A matplotlib figure.
  39. """
  40. if ax is None:
  41. fig = plt.figure(facecolor='w', figsize=(10, 6))
  42. ax = fig.add_subplot(111)
  43. else:
  44. fig = ax.get_figure()
  45. fcst_t = fcst['ds'].dt.to_pydatetime()
  46. ax.plot(m.history['ds'].dt.to_pydatetime(), m.history['y'], 'k.')
  47. ax.plot(fcst_t, fcst['yhat'], ls='-', c='#0072B2')
  48. if 'cap' in fcst and plot_cap:
  49. ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  50. if m.logistic_floor and 'floor' in fcst and plot_cap:
  51. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  52. if uncertainty:
  53. ax.fill_between(fcst_t, fcst['yhat_lower'], fcst['yhat_upper'],
  54. color='#0072B2', alpha=0.2)
  55. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  56. ax.set_xlabel(xlabel)
  57. ax.set_ylabel(ylabel)
  58. fig.tight_layout()
  59. return fig
  60. def plot_components(
  61. m, fcst, uncertainty=True, plot_cap=True, weekly_start=0, yearly_start=0,
  62. ):
  63. """Plot the Prophet forecast components.
  64. Will plot whichever are available of: trend, holidays, weekly
  65. seasonality, and yearly seasonality.
  66. Parameters
  67. ----------
  68. m: Prophet model.
  69. fcst: pd.DataFrame output of m.predict.
  70. uncertainty: Optional boolean to plot uncertainty intervals.
  71. plot_cap: Optional boolean indicating if the capacity should be shown
  72. in the figure, if available.
  73. weekly_start: Optional int specifying the start day of the weekly
  74. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  75. by 1 day to Monday, and so on.
  76. yearly_start: Optional int specifying the start day of the yearly
  77. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  78. by 1 day to Jan 2, and so on.
  79. Returns
  80. -------
  81. A matplotlib figure.
  82. """
  83. # Identify components to be plotted
  84. components = ['trend']
  85. if m.holidays is not None and 'holidays' in fcst:
  86. components.append('holidays')
  87. components.extend([name for name in m.seasonalities
  88. if name in fcst])
  89. if len(m.extra_regressors) > 0 and 'extra_regressors' in fcst:
  90. components.append('extra_regressors')
  91. npanel = len(components)
  92. fig, axes = plt.subplots(npanel, 1, facecolor='w',
  93. figsize=(9, 3 * npanel))
  94. if npanel == 1:
  95. axes = [axes]
  96. for ax, plot_name in zip(axes, components):
  97. if plot_name == 'trend':
  98. plot_forecast_component(
  99. m=m, fcst=fcst, name='trend', ax=ax, uncertainty=uncertainty,
  100. plot_cap=plot_cap,
  101. )
  102. elif plot_name == 'holidays':
  103. plot_forecast_component(
  104. m=m, fcst=fcst, name='holidays', ax=ax,
  105. uncertainty=uncertainty, plot_cap=False,
  106. )
  107. elif plot_name == 'weekly':
  108. plot_weekly(
  109. m=m, ax=ax, uncertainty=uncertainty, weekly_start=weekly_start,
  110. )
  111. elif plot_name == 'yearly':
  112. plot_yearly(
  113. m=m, ax=ax, uncertainty=uncertainty, yearly_start=yearly_start,
  114. )
  115. elif plot_name == 'extra_regressors':
  116. plot_forecast_component(
  117. m=m, fcst=fcst, name='extra_regressors', ax=ax,
  118. uncertainty=uncertainty, plot_cap=False,
  119. )
  120. else:
  121. plot_seasonality(
  122. m=m, name=plot_name, ax=ax, uncertainty=uncertainty,
  123. )
  124. fig.tight_layout()
  125. return fig
  126. def plot_forecast_component(
  127. m, fcst, name, ax=None, uncertainty=True, plot_cap=False,
  128. ):
  129. """Plot a particular component of the forecast.
  130. Parameters
  131. ----------
  132. m: Prophet model.
  133. fcst: pd.DataFrame output of m.predict.
  134. name: Name of the component to plot.
  135. ax: Optional matplotlib Axes to plot on.
  136. uncertainty: Optional boolean to plot uncertainty intervals.
  137. plot_cap: Optional boolean indicating if the capacity should be shown
  138. in the figure, if available.
  139. Returns
  140. -------
  141. a list of matplotlib artists
  142. """
  143. artists = []
  144. if not ax:
  145. fig = plt.figure(facecolor='w', figsize=(10, 6))
  146. ax = fig.add_subplot(111)
  147. fcst_t = fcst['ds'].dt.to_pydatetime()
  148. artists += ax.plot(fcst_t, fcst[name], ls='-', c='#0072B2')
  149. if 'cap' in fcst and plot_cap:
  150. artists += ax.plot(fcst_t, fcst['cap'], ls='--', c='k')
  151. if m.logistic_floor and 'floor' in fcst and plot_cap:
  152. ax.plot(fcst_t, fcst['floor'], ls='--', c='k')
  153. if uncertainty:
  154. artists += [ax.fill_between(
  155. fcst_t, fcst[name + '_lower'], fcst[name + '_upper'],
  156. color='#0072B2', alpha=0.2)]
  157. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  158. ax.set_xlabel('ds')
  159. ax.set_ylabel(name)
  160. return artists
  161. def seasonality_plot_df(m, ds):
  162. """Prepare dataframe for plotting seasonal components.
  163. Parameters
  164. ----------
  165. m: Prophet model.
  166. ds: List of dates for column ds.
  167. Returns
  168. -------
  169. A dataframe with seasonal components on ds.
  170. """
  171. df_dict = {'ds': ds, 'cap': 1., 'floor': 0.}
  172. for name in m.extra_regressors:
  173. df_dict[name] = 0.
  174. df = pd.DataFrame(df_dict)
  175. df = m.setup_dataframe(df)
  176. return df
  177. def plot_weekly(m, ax=None, uncertainty=True, weekly_start=0):
  178. """Plot the weekly component of the forecast.
  179. Parameters
  180. ----------
  181. m: Prophet model.
  182. ax: Optional matplotlib Axes to plot on. One will be created if this
  183. is not provided.
  184. uncertainty: Optional boolean to plot uncertainty intervals.
  185. weekly_start: Optional int specifying the start day of the weekly
  186. seasonality plot. 0 (default) starts the week on Sunday. 1 shifts
  187. by 1 day to Monday, and so on.
  188. Returns
  189. -------
  190. a list of matplotlib artists
  191. """
  192. artists = []
  193. if not ax:
  194. fig = plt.figure(facecolor='w', figsize=(10, 6))
  195. ax = fig.add_subplot(111)
  196. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  197. days = (pd.date_range(start='2017-01-01', periods=7) +
  198. pd.Timedelta(days=weekly_start))
  199. df_w = seasonality_plot_df(m, days)
  200. seas = m.predict_seasonal_components(df_w)
  201. days = days.weekday_name
  202. artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
  203. c='#0072B2')
  204. if uncertainty:
  205. artists += [ax.fill_between(range(len(days)),
  206. seas['weekly_lower'], seas['weekly_upper'],
  207. color='#0072B2', alpha=0.2)]
  208. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  209. ax.set_xticks(range(len(days)))
  210. ax.set_xticklabels(days)
  211. ax.set_xlabel('Day of week')
  212. ax.set_ylabel('weekly')
  213. return artists
  214. def plot_yearly(m, ax=None, uncertainty=True, yearly_start=0):
  215. """Plot the yearly component of the forecast.
  216. Parameters
  217. ----------
  218. m: Prophet model.
  219. ax: Optional matplotlib Axes to plot on. One will be created if
  220. this is not provided.
  221. uncertainty: Optional boolean to plot uncertainty intervals.
  222. yearly_start: Optional int specifying the start day of the yearly
  223. seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts
  224. by 1 day to Jan 2, and so on.
  225. Returns
  226. -------
  227. a list of matplotlib artists
  228. """
  229. artists = []
  230. if not ax:
  231. fig = plt.figure(facecolor='w', figsize=(10, 6))
  232. ax = fig.add_subplot(111)
  233. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  234. days = (pd.date_range(start='2017-01-01', periods=365) +
  235. pd.Timedelta(days=yearly_start))
  236. df_y = seasonality_plot_df(m, days)
  237. seas = m.predict_seasonal_components(df_y)
  238. artists += ax.plot(
  239. df_y['ds'].dt.to_pydatetime(), seas['yearly'], ls='-', c='#0072B2')
  240. if uncertainty:
  241. artists += [ax.fill_between(
  242. df_y['ds'].dt.to_pydatetime(), seas['yearly_lower'],
  243. seas['yearly_upper'], color='#0072B2', alpha=0.2)]
  244. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  245. months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
  246. ax.xaxis.set_major_formatter(FuncFormatter(
  247. lambda x, pos=None: '{dt:%B} {dt.day}'.format(dt=num2date(x))))
  248. ax.xaxis.set_major_locator(months)
  249. ax.set_xlabel('Day of year')
  250. ax.set_ylabel('yearly')
  251. return artists
  252. def plot_seasonality(m, name, ax=None, uncertainty=True):
  253. """Plot a custom seasonal component.
  254. Parameters
  255. ----------
  256. m: Prophet model.
  257. name: Seasonality name, like 'daily', 'weekly'.
  258. ax: Optional matplotlib Axes to plot on. One will be created if
  259. this is not provided.
  260. uncertainty: Optional boolean to plot uncertainty intervals.
  261. Returns
  262. -------
  263. a list of matplotlib artists
  264. """
  265. artists = []
  266. if not ax:
  267. fig = plt.figure(facecolor='w', figsize=(10, 6))
  268. ax = fig.add_subplot(111)
  269. # Compute seasonality from Jan 1 through a single period.
  270. start = pd.to_datetime('2017-01-01 0000')
  271. period = m.seasonalities[name]['period']
  272. end = start + pd.Timedelta(days=period)
  273. plot_points = 200
  274. days = pd.to_datetime(np.linspace(start.value, end.value, plot_points))
  275. df_y = seasonality_plot_df(m, days)
  276. seas = m.predict_seasonal_components(df_y)
  277. artists += ax.plot(df_y['ds'].dt.to_pydatetime(), seas[name], ls='-',
  278. c='#0072B2')
  279. if uncertainty:
  280. artists += [ax.fill_between(
  281. df_y['ds'].dt.to_pydatetime(), seas[name + '_lower'],
  282. seas[name + '_upper'], color='#0072B2', alpha=0.2)]
  283. ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
  284. xticks = pd.to_datetime(np.linspace(start.value, end.value, 7)
  285. ).to_pydatetime()
  286. ax.set_xticks(xticks)
  287. if period <= 2:
  288. fmt_str = '{dt:%T}'
  289. elif period < 14:
  290. fmt_str = '{dt:%m}/{dt:%d} {dt:%R}'
  291. else:
  292. fmt_str = '{dt:%m}/{dt:%d}'
  293. ax.xaxis.set_major_formatter(FuncFormatter(
  294. lambda x, pos=None: fmt_str.format(dt=num2date(x))))
  295. ax.set_xlabel('ds')
  296. ax.set_ylabel(name)
  297. return artists