plot.R 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509
  1. #' Merge history and forecast for plotting.
  2. #'
  3. #' @param m Prophet object.
  4. #' @param fcst Data frame returned by prophet predict.
  5. #'
  6. #' @importFrom dplyr "%>%"
  7. #' @keywords internal
  8. df_for_plotting <- function(m, fcst) {
  9. # Make sure there is no y in fcst
  10. fcst$y <- NULL
  11. df <- m$history %>%
  12. dplyr::select(ds, y) %>%
  13. dplyr::full_join(fcst, by = "ds") %>%
  14. dplyr::arrange(ds)
  15. return(df)
  16. }
  17. #' Plot the prophet forecast.
  18. #'
  19. #' @param x Prophet object.
  20. #' @param fcst Data frame returned by predict(m, df).
  21. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  22. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  23. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  24. #' figure, if available.
  25. #' @param xlabel Optional label for x-axis
  26. #' @param ylabel Optional label for y-axis
  27. #' @param ... additional arguments
  28. #'
  29. #' @return A ggplot2 plot.
  30. #'
  31. #' @examples
  32. #' \dontrun{
  33. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  34. #' y = sin(1:366/200) + rnorm(366)/10)
  35. #' m <- prophet(history)
  36. #' future <- make_future_dataframe(m, periods = 365)
  37. #' forecast <- predict(m, future)
  38. #' plot(m, forecast)
  39. #' }
  40. #'
  41. #' @export
  42. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  43. xlabel = 'ds', ylabel = 'y', ...) {
  44. df <- df_for_plotting(x, fcst)
  45. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  46. ggplot2::labs(x = xlabel, y = ylabel)
  47. if (exists('cap', where = df) && plot_cap) {
  48. gg <- gg + ggplot2::geom_line(
  49. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  50. }
  51. if (x$logistic.floor && exists('floor', where = df) && plot_cap) {
  52. gg <- gg + ggplot2::geom_line(
  53. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  54. }
  55. if (uncertainty && exists('yhat_lower', where = df)) {
  56. gg <- gg +
  57. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  58. alpha = 0.2,
  59. fill = "#0072B2",
  60. na.rm = TRUE)
  61. }
  62. gg <- gg +
  63. ggplot2::geom_point(na.rm=TRUE) +
  64. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  65. na.rm = TRUE) +
  66. ggplot2::theme(aspect.ratio = 3 / 5)
  67. return(gg)
  68. }
  69. #' Plot the components of a prophet forecast.
  70. #' Prints a ggplot2 with whichever are available of: trend, holidays, weekly
  71. #' seasonality, yearly seasonality, and additive and multiplicative extra
  72. #' regressors.
  73. #'
  74. #' @param m Prophet object.
  75. #' @param fcst Data frame returned by predict(m, df).
  76. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  77. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  78. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  79. #' figure, if available.
  80. #' @param weekly_start Integer specifying the start day of the weekly
  81. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  82. #' to Monday, and so on.
  83. #' @param yearly_start Integer specifying the start day of the yearly
  84. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  85. #' to Jan 2, and so on.
  86. #' @param render_plot Boolean indicating if the plots should be rendered.
  87. #' Set to FALSE if you want the function to only return the list of panels.
  88. #'
  89. #' @return Invisibly return a list containing the plotted ggplot objects
  90. #'
  91. #' @export
  92. #' @importFrom dplyr "%>%"
  93. prophet_plot_components <- function(
  94. m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
  95. yearly_start = 0, render_plot = TRUE
  96. ) {
  97. # Plot the trend
  98. panels <- list(
  99. plot_forecast_component(m, fcst, 'trend', uncertainty, plot_cap))
  100. # Plot holiday components, if present.
  101. if (!is.null(m$holidays) && ('holidays' %in% colnames(fcst))) {
  102. panels[[length(panels) + 1]] <- plot_forecast_component(
  103. m, fcst, 'holidays', uncertainty, FALSE)
  104. }
  105. # Plot weekly seasonality, if present
  106. if ("weekly" %in% colnames(fcst)) {
  107. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
  108. }
  109. # Plot yearly seasonality, if present
  110. if ("yearly" %in% colnames(fcst)) {
  111. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
  112. }
  113. # Plot other seasonalities
  114. for (name in names(m$seasonalities)) {
  115. if (!(name %in% c('weekly', 'yearly')) &&
  116. (name %in% colnames(fcst))) {
  117. panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
  118. }
  119. }
  120. # Plot extra regressors
  121. regressors <- list(additive = FALSE, multiplicative = FALSE)
  122. for (name in names(m$extra_regressors)) {
  123. regressors[[m$extra_regressors[[name]]$mode]] <- TRUE
  124. }
  125. for (mode in c('additive', 'multiplicative')) {
  126. if ((regressors[[mode]]) &
  127. (paste0('extra_regressors_', mode) %in% colnames(fcst))
  128. ) {
  129. panels[[length(panels) + 1]] <- plot_forecast_component(
  130. m, fcst, paste0('extra_regressors_', mode), uncertainty, FALSE)
  131. }
  132. }
  133. if (render_plot) {
  134. # Make the plot.
  135. grid::grid.newpage()
  136. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  137. 1)))
  138. for (i in seq_along(panels)) {
  139. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  140. layout.pos.col = 1))
  141. }
  142. }
  143. return(invisible(panels))
  144. }
  145. #' Plot a particular component of the forecast.
  146. #'
  147. #' @param m Prophet model
  148. #' @param fcst Dataframe output of `predict`.
  149. #' @param name String name of the component to plot (column of fcst).
  150. #' @param uncertainty Boolean to plot uncertainty intervals.
  151. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  152. #' figure, if available.
  153. #'
  154. #' @return A ggplot2 plot.
  155. #'
  156. #' @export
  157. plot_forecast_component <- function(
  158. m, fcst, name, uncertainty = TRUE, plot_cap = FALSE
  159. ) {
  160. gg.comp <- ggplot2::ggplot(
  161. fcst, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  162. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  163. if (exists('cap', where = fcst) && plot_cap) {
  164. gg.comp <- gg.comp + ggplot2::geom_line(
  165. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  166. }
  167. if (exists('floor', where = fcst) && plot_cap) {
  168. gg.comp <- gg.comp + ggplot2::geom_line(
  169. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  170. }
  171. if (uncertainty) {
  172. gg.comp <- gg.comp +
  173. ggplot2::geom_ribbon(
  174. ggplot2::aes_string(
  175. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  176. ),
  177. alpha = 0.2,
  178. fill = "#0072B2",
  179. na.rm = TRUE)
  180. }
  181. if (name %in% m$component.modes$multiplicative) {
  182. gg.comp <- gg.comp + ggplot2::scale_y_continuous(labels = scales::percent)
  183. }
  184. return(gg.comp)
  185. }
  186. #' Prepare dataframe for plotting seasonal components.
  187. #'
  188. #' @param m Prophet object.
  189. #' @param ds Array of dates for column ds.
  190. #'
  191. #' @return A dataframe with seasonal components on ds.
  192. #'
  193. #' @keywords internal
  194. seasonality_plot_df <- function(m, ds) {
  195. df_list <- list(ds = ds, cap = 1, floor = 0)
  196. for (name in names(m$extra_regressors)) {
  197. df_list[[name]] <- 0
  198. }
  199. df <- as.data.frame(df_list)
  200. df <- setup_dataframe(m, df)$df
  201. return(df)
  202. }
  203. #' Plot the weekly component of the forecast.
  204. #'
  205. #' @param m Prophet model object
  206. #' @param uncertainty Boolean to plot uncertainty intervals.
  207. #' @param weekly_start Integer specifying the start day of the weekly
  208. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  209. #' to Monday, and so on.
  210. #'
  211. #' @return A ggplot2 plot.
  212. #'
  213. #' @keywords internal
  214. plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
  215. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  216. days <- seq(set_date('2017-01-01'), by='d', length.out=7) + as.difftime(
  217. weekly_start, units = "days")
  218. df.w <- seasonality_plot_df(m, days)
  219. seas <- predict_seasonal_components(m, df.w)
  220. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  221. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  222. group = 1)) +
  223. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  224. ggplot2::labs(x = "Day of week")
  225. if (uncertainty) {
  226. gg.weekly <- gg.weekly +
  227. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  228. ymax = weekly_upper),
  229. alpha = 0.2,
  230. fill = "#0072B2",
  231. na.rm = TRUE)
  232. }
  233. if (m$seasonalities$weekly$mode == 'multiplicative') {
  234. gg.weekly <- (
  235. gg.weekly + ggplot2::scale_y_continuous(labels = scales::percent)
  236. )
  237. }
  238. return(gg.weekly)
  239. }
  240. #' Plot the yearly component of the forecast.
  241. #'
  242. #' @param m Prophet model object.
  243. #' @param uncertainty Boolean to plot uncertainty intervals.
  244. #' @param yearly_start Integer specifying the start day of the yearly
  245. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  246. #' to Jan 2, and so on.
  247. #'
  248. #' @return A ggplot2 plot.
  249. #'
  250. #' @keywords internal
  251. plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
  252. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  253. days <- seq(set_date('2017-01-01'), by='d', length.out=365) + as.difftime(
  254. yearly_start, units = "days")
  255. df.y <- seasonality_plot_df(m, days)
  256. seas <- predict_seasonal_components(m, df.y)
  257. seas$ds <- df.y$ds
  258. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  259. group = 1)) +
  260. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  261. ggplot2::labs(x = "Day of year") +
  262. ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
  263. if (uncertainty) {
  264. gg.yearly <- gg.yearly +
  265. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  266. ymax = yearly_upper),
  267. alpha = 0.2,
  268. fill = "#0072B2",
  269. na.rm = TRUE)
  270. }
  271. if (m$seasonalities$yearly$mode == 'multiplicative') {
  272. gg.yearly <- (
  273. gg.yearly + ggplot2::scale_y_continuous(labels = scales::percent)
  274. )
  275. }
  276. return(gg.yearly)
  277. }
  278. #' Plot a custom seasonal component.
  279. #'
  280. #' @param m Prophet model object.
  281. #' @param name String name of the seasonality.
  282. #' @param uncertainty Boolean to plot uncertainty intervals.
  283. #'
  284. #' @return A ggplot2 plot.
  285. #'
  286. #' @keywords internal
  287. plot_seasonality <- function(m, name, uncertainty = TRUE) {
  288. # Compute seasonality from Jan 1 through a single period.
  289. start <- set_date('2017-01-01')
  290. period <- m$seasonalities[[name]]$period
  291. end <- start + period * 24 * 3600
  292. plot.points <- 200
  293. days <- seq(from=start, to=end, length.out=plot.points)
  294. df.y <- seasonality_plot_df(m, days)
  295. seas <- predict_seasonal_components(m, df.y)
  296. seas$ds <- df.y$ds
  297. gg.s <- ggplot2::ggplot(
  298. seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  299. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  300. if (period <= 2) {
  301. fmt.str <- '%T'
  302. } else if (period < 14) {
  303. fmt.str <- '%m/%d %R'
  304. } else {
  305. fmt.str <- '%m/%d'
  306. }
  307. gg.s <- gg.s +
  308. ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
  309. if (uncertainty) {
  310. gg.s <- gg.s +
  311. ggplot2::geom_ribbon(
  312. ggplot2::aes_string(
  313. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  314. ),
  315. alpha = 0.2,
  316. fill = "#0072B2",
  317. na.rm = TRUE)
  318. }
  319. if (m$seasonalities[[name]]$mode == 'multiplicative') {
  320. gg.s <- gg.s + ggplot2::scale_y_continuous(labels = scales::percent)
  321. }
  322. return(gg.s)
  323. }
  324. #' Get layers to overlay significant changepoints on prophet forecast plot.
  325. #'
  326. #' @param m Prophet model object.
  327. #' @param threshold Numeric, changepoints where abs(delta) >= threshold are
  328. #' significant. (Default 0.01)
  329. #' @param cp_color Character, line color. (Default "red")
  330. #' @param cp_linetype Character or integer, line type. (Default "dashed")
  331. #' @param trend Logical, if FALSE, do not draw trend line. (Default TRUE)
  332. #' @param ... Other arguments passed on to layers.
  333. #'
  334. #' @return A list of ggplot2 layers.
  335. #'
  336. #' @examples
  337. #' \dontrun{
  338. #' plot(m, fcst) + add_changepoints_to_plot(m)
  339. #' }
  340. #'
  341. #' @export
  342. add_changepoints_to_plot <- function(m, threshold = 0.01, cp_color = "red",
  343. cp_linetype = "dashed", trend = TRUE, ...) {
  344. layers <- list()
  345. if (trend) {
  346. trend_layer <- ggplot2::geom_line(
  347. ggplot2::aes_string("ds", "trend"), color = cp_color, ...)
  348. layers <- append(layers, trend_layer)
  349. }
  350. signif_changepoints <- m$changepoints[abs(m$params$delta) >= threshold]
  351. cp_layer <- ggplot2::geom_vline(
  352. xintercept = as.integer(signif_changepoints), color = cp_color,
  353. linetype = cp_linetype, ...)
  354. layers <- append(layers, cp_layer)
  355. return(layers)
  356. }
  357. #' Plot the prophet forecast.
  358. #'
  359. #' @param x Prophet object.
  360. #' @param fcst Data frame returned by predict(m, df).
  361. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  362. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  363. #' @param ... additional arguments
  364. #' @importFrom dplyr "%>%"
  365. #' @return A dygraph plot.
  366. #'
  367. #' @examples
  368. #' \dontrun{
  369. #' history <- data.frame(
  370. #' ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  371. #' y = sin(1:366/200) + rnorm(366)/10)
  372. #' m <- prophet(history)
  373. #' future <- make_future_dataframe(m, periods = 365)
  374. #' forecast <- predict(m, future)
  375. #' dyplot.prophet(m, forecast)
  376. #' }
  377. #'
  378. #' @export
  379. dyplot.prophet <- function(x, fcst, uncertainty=TRUE,
  380. ...)
  381. {
  382. forecast.label='Predicted'
  383. actual.label='Actual'
  384. # create data.frame for plotting
  385. df <- df_for_plotting(x, fcst)
  386. # build variables to include, or not, the uncertainty data
  387. if(uncertainty && exists("yhat_lower", where = df))
  388. {
  389. colsToKeep <- c('y', 'yhat', 'yhat_lower', 'yhat_upper')
  390. forecastCols <- c('yhat_lower', 'yhat', 'yhat_upper')
  391. } else
  392. {
  393. colsToKeep <- c('y', 'yhat')
  394. forecastCols <- c('yhat')
  395. }
  396. # convert to xts for easier date handling by dygraph
  397. dfTS <- xts::xts(df %>% dplyr::select_(.dots=colsToKeep), order.by = df$ds)
  398. # base plot
  399. dyBase <- dygraphs::dygraph(dfTS)
  400. presAnnotation <- function(dygraph, x, text) {
  401. dygraph %>%
  402. dygraphs::dyAnnotation(x, text, text, attachAtBottom = TRUE)
  403. }
  404. dyBase <- dyBase %>%
  405. # plot actual values
  406. dygraphs::dySeries(
  407. 'y', label=actual.label, color='black', drawPoints=TRUE, strokeWidth=0
  408. ) %>%
  409. # plot forecast and ribbon
  410. dygraphs::dySeries(forecastCols, label=forecast.label, color='blue') %>%
  411. # allow zooming
  412. dygraphs::dyRangeSelector() %>%
  413. # make unzoom button
  414. dygraphs::dyUnzoom()
  415. if (!is.null(x$holidays)) {
  416. for (i in 1:nrow(x$holidays)) {
  417. # make a gray line
  418. dyBase <- dyBase %>% dygraphs::dyEvent(
  419. x$holidays$ds[i],color = "rgb(200,200,200)", strokePattern = "solid")
  420. dyBase <- dyBase %>% dygraphs::dyAnnotation(
  421. x$holidays$ds[i], x$holidays$holiday[i], x$holidays$holiday[i],
  422. attachAtBottom = TRUE)
  423. }
  424. }
  425. return(dyBase)
  426. }
  427. #' Plot a performance metric vs. forecast horizon from cross validation.
  428. #' Cross validation produces a collection of out-of-sample model predictions
  429. #' that can be compared to actual values, at a range of different horizons
  430. #' (distance from the cutoff). This computes a specified performance metric
  431. #' for each prediction, and aggregated over a rolling window with horizon.
  432. #'
  433. #' This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
  434. #' Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
  435. #'
  436. #' rolling_window is the proportion of data included in the rolling window of
  437. #' aggregation. The default value of 0.1 means 10% of data are included in the
  438. #' aggregation for computing the metric.
  439. #'
  440. #' As a concrete example, if metric='mse', then this plot will show the
  441. #' squared error for each cross validation prediction, along with the MSE
  442. #' averaged over rolling windows of 10% of the data.
  443. #'
  444. #' @param df_cv The output from fbprophet.diagnostics.cross_validation.
  445. #' @param metric Metric name, one of 'mse', 'rmse', 'mae', 'mape', 'coverage'.
  446. #' @param rolling_window Proportion of data to use for rolling average of
  447. #' metric. In [0, 1]. Defaults to 0.1.
  448. #'
  449. #' @return A ggplot2 plot.
  450. #'
  451. #' @export
  452. plot_cross_validation_metric <- function(df_cv, metric, rolling_window=0.1) {
  453. df_none <- performance_metrics(df_cv, metrics = metric, rolling_window = 0)
  454. df_h <- performance_metrics(
  455. df_cv, metrics = metric, rolling_window = rolling_window
  456. )
  457. # Better plotting of difftime
  458. # Target ~10 ticks
  459. tick_w <- max(as.double(df_none$horizon, units = 'secs')) / 10.
  460. # Find the largest time resolution that has <1 unit per bin
  461. dts <- c('days', 'hours', 'mins', 'secs')
  462. dt_conversions <- c(
  463. 24 * 60 * 60,
  464. 60 * 60,
  465. 60,
  466. 1
  467. )
  468. for (i in seq_along(dts)) {
  469. if (as.difftime(1, units = dts[i]) < as.difftime(tick_w, units = 'secs')) {
  470. break
  471. }
  472. }
  473. df_none$x_plt <- (
  474. as.double(df_none$horizon, units = 'secs') / dt_conversions[i]
  475. )
  476. df_h$x_plt <- as.double(df_h$horizon, units = 'secs') / dt_conversions[i]
  477. gg <- (
  478. ggplot2::ggplot(df_none, ggplot2::aes_string(x = 'x_plt', y = metric)) +
  479. ggplot2::labs(x = paste0('Horizon (', dts[i], ')'), y = metric) +
  480. ggplot2::geom_point(color = 'gray') +
  481. ggplot2::geom_line(
  482. data = df_h, ggplot2::aes_string(x = 'x_plt', y = metric), color = 'blue'
  483. ) +
  484. ggplot2::theme(aspect.ratio = 3 / 5)
  485. )
  486. return(gg)
  487. }