plot.R 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. #' Merge history and forecast for plotting.
  2. #'
  3. #' @param m Prophet object.
  4. #' @param fcst Data frame returned by prophet predict.
  5. #'
  6. #' @importFrom dplyr "%>%"
  7. #' @keywords internal
  8. df_for_plotting <- function(m, fcst) {
  9. # Make sure there is no y in fcst
  10. fcst$y <- NULL
  11. df <- m$history %>%
  12. dplyr::select(ds, y) %>%
  13. dplyr::full_join(fcst, by = "ds") %>%
  14. dplyr::arrange(ds)
  15. return(df)
  16. }
  17. #' Plot the prophet forecast.
  18. #'
  19. #' @param x Prophet object.
  20. #' @param fcst Data frame returned by predict(m, df).
  21. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  22. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  23. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  24. #' figure, if available.
  25. #' @param xlabel Optional label for x-axis
  26. #' @param ylabel Optional label for y-axis
  27. #' @param ... additional arguments
  28. #'
  29. #' @return A ggplot2 plot.
  30. #'
  31. #' @examples
  32. #' \dontrun{
  33. #' history <- data.frame(ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  34. #' y = sin(1:366/200) + rnorm(366)/10)
  35. #' m <- prophet(history)
  36. #' future <- make_future_dataframe(m, periods = 365)
  37. #' forecast <- predict(m, future)
  38. #' plot(m, forecast)
  39. #' }
  40. #'
  41. #' @export
  42. plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
  43. xlabel = 'ds', ylabel = 'y', ...) {
  44. df <- df_for_plotting(x, fcst)
  45. gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
  46. ggplot2::labs(x = xlabel, y = ylabel)
  47. if (exists('cap', where = df) && plot_cap) {
  48. gg <- gg + ggplot2::geom_line(
  49. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  50. }
  51. if (x$logistic.floor && exists('floor', where = df) && plot_cap) {
  52. gg <- gg + ggplot2::geom_line(
  53. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  54. }
  55. if (uncertainty && exists('yhat_lower', where = df)) {
  56. gg <- gg +
  57. ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
  58. alpha = 0.2,
  59. fill = "#0072B2",
  60. na.rm = TRUE)
  61. }
  62. gg <- gg +
  63. ggplot2::geom_point(na.rm=TRUE) +
  64. ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
  65. na.rm = TRUE) +
  66. ggplot2::theme(aspect.ratio = 3 / 5)
  67. return(gg)
  68. }
  69. #' Plot the components of a prophet forecast.
  70. #' Prints a ggplot2 with whichever are available of: trend, holidays, weekly
  71. #' seasonality, yearly seasonality, and additive and multiplicative extra
  72. #' regressors.
  73. #'
  74. #' @param m Prophet object.
  75. #' @param fcst Data frame returned by predict(m, df).
  76. #' @param uncertainty Boolean indicating if the uncertainty interval should be
  77. #' plotted for the trend, from fcst columns trend_lower and trend_upper.
  78. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  79. #' figure, if available.
  80. #' @param weekly_start Integer specifying the start day of the weekly
  81. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  82. #' to Monday, and so on.
  83. #' @param yearly_start Integer specifying the start day of the yearly
  84. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  85. #' to Jan 2, and so on.
  86. #'
  87. #' @return Invisibly return a list containing the plotted ggplot objects
  88. #'
  89. #' @export
  90. #' @importFrom dplyr "%>%"
  91. prophet_plot_components <- function(
  92. m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
  93. yearly_start = 0
  94. ) {
  95. # Plot the trend
  96. panels <- list(
  97. plot_forecast_component(m, fcst, 'trend', uncertainty, plot_cap))
  98. # Plot holiday components, if present.
  99. if (!is.null(m$holidays) && ('holidays' %in% colnames(fcst))) {
  100. panels[[length(panels) + 1]] <- plot_forecast_component(
  101. m, fcst, 'holidays', uncertainty, FALSE)
  102. }
  103. # Plot weekly seasonality, if present
  104. if ("weekly" %in% colnames(fcst)) {
  105. panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
  106. }
  107. # Plot yearly seasonality, if present
  108. if ("yearly" %in% colnames(fcst)) {
  109. panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
  110. }
  111. # Plot other seasonalities
  112. for (name in names(m$seasonalities)) {
  113. if (!(name %in% c('weekly', 'yearly')) &&
  114. (name %in% colnames(fcst))) {
  115. panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
  116. }
  117. }
  118. # Plot extra regressors
  119. regressors <- list(additive = FALSE, multiplicative = FALSE)
  120. for (name in names(m$extra_regressors)) {
  121. regressors[[m$extra_regressors[[name]]$mode]] <- TRUE
  122. }
  123. for (mode in c('additive', 'multiplicative')) {
  124. if ((regressors[[mode]]) &
  125. (paste0('extra_regressors_', mode) %in% colnames(fcst))
  126. ) {
  127. panels[[length(panels) + 1]] <- plot_forecast_component(
  128. m, fcst, paste0('extra_regressors_', mode), uncertainty, FALSE)
  129. }
  130. }
  131. # Make the plot.
  132. grid::grid.newpage()
  133. grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
  134. 1)))
  135. for (i in seq_along(panels)) {
  136. print(panels[[i]], vp = grid::viewport(layout.pos.row = i,
  137. layout.pos.col = 1))
  138. }
  139. return(invisible(panels))
  140. }
  141. #' Plot a particular component of the forecast.
  142. #'
  143. #' @param m Prophet model
  144. #' @param fcst Dataframe output of `predict`.
  145. #' @param name String name of the component to plot (column of fcst).
  146. #' @param uncertainty Boolean to plot uncertainty intervals.
  147. #' @param plot_cap Boolean indicating if the capacity should be shown in the
  148. #' figure, if available.
  149. #'
  150. #' @return A ggplot2 plot.
  151. #'
  152. #' @export
  153. plot_forecast_component <- function(
  154. m, fcst, name, uncertainty = TRUE, plot_cap = FALSE
  155. ) {
  156. gg.comp <- ggplot2::ggplot(
  157. fcst, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  158. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  159. if (exists('cap', where = fcst) && plot_cap) {
  160. gg.comp <- gg.comp + ggplot2::geom_line(
  161. ggplot2::aes(y = cap), linetype = 'dashed', na.rm = TRUE)
  162. }
  163. if (exists('floor', where = fcst) && plot_cap) {
  164. gg.comp <- gg.comp + ggplot2::geom_line(
  165. ggplot2::aes(y = floor), linetype = 'dashed', na.rm = TRUE)
  166. }
  167. if (uncertainty) {
  168. gg.comp <- gg.comp +
  169. ggplot2::geom_ribbon(
  170. ggplot2::aes_string(
  171. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  172. ),
  173. alpha = 0.2,
  174. fill = "#0072B2",
  175. na.rm = TRUE)
  176. }
  177. if (name %in% m$component.modes$multiplicative) {
  178. gg.comp <- gg.comp + ggplot2::scale_y_continuous(labels = scales::percent)
  179. }
  180. return(gg.comp)
  181. }
  182. #' Prepare dataframe for plotting seasonal components.
  183. #'
  184. #' @param m Prophet object.
  185. #' @param ds Array of dates for column ds.
  186. #'
  187. #' @return A dataframe with seasonal components on ds.
  188. #'
  189. #' @keywords internal
  190. seasonality_plot_df <- function(m, ds) {
  191. df_list <- list(ds = ds, cap = 1, floor = 0)
  192. for (name in names(m$extra_regressors)) {
  193. df_list[[name]] <- 0
  194. }
  195. df <- as.data.frame(df_list)
  196. df <- setup_dataframe(m, df)$df
  197. return(df)
  198. }
  199. #' Plot the weekly component of the forecast.
  200. #'
  201. #' @param m Prophet model object
  202. #' @param uncertainty Boolean to plot uncertainty intervals.
  203. #' @param weekly_start Integer specifying the start day of the weekly
  204. #' seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
  205. #' to Monday, and so on.
  206. #'
  207. #' @return A ggplot2 plot.
  208. #'
  209. #' @keywords internal
  210. plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
  211. # Compute weekly seasonality for a Sun-Sat sequence of dates.
  212. days <- seq(set_date('2017-01-01'), by='d', length.out=7) + as.difftime(
  213. weekly_start, units = "days")
  214. df.w <- seasonality_plot_df(m, days)
  215. seas <- predict_seasonal_components(m, df.w)
  216. seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
  217. gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
  218. group = 1)) +
  219. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  220. ggplot2::labs(x = "Day of week")
  221. if (uncertainty) {
  222. gg.weekly <- gg.weekly +
  223. ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
  224. ymax = weekly_upper),
  225. alpha = 0.2,
  226. fill = "#0072B2",
  227. na.rm = TRUE)
  228. }
  229. if (m$seasonalities$weekly$mode == 'multiplicative') {
  230. gg.weekly <- (
  231. gg.weekly + ggplot2::scale_y_continuous(labels = scales::percent)
  232. )
  233. }
  234. return(gg.weekly)
  235. }
  236. #' Plot the yearly component of the forecast.
  237. #'
  238. #' @param m Prophet model object.
  239. #' @param uncertainty Boolean to plot uncertainty intervals.
  240. #' @param yearly_start Integer specifying the start day of the yearly
  241. #' seasonality plot. 0 (default) starts the year on Jan 1. 1 shifts by 1 day
  242. #' to Jan 2, and so on.
  243. #'
  244. #' @return A ggplot2 plot.
  245. #'
  246. #' @keywords internal
  247. plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
  248. # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
  249. days <- seq(set_date('2017-01-01'), by='d', length.out=365) + as.difftime(
  250. yearly_start, units = "days")
  251. df.y <- seasonality_plot_df(m, days)
  252. seas <- predict_seasonal_components(m, df.y)
  253. seas$ds <- df.y$ds
  254. gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
  255. group = 1)) +
  256. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
  257. ggplot2::labs(x = "Day of year") +
  258. ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
  259. if (uncertainty) {
  260. gg.yearly <- gg.yearly +
  261. ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
  262. ymax = yearly_upper),
  263. alpha = 0.2,
  264. fill = "#0072B2",
  265. na.rm = TRUE)
  266. }
  267. if (m$seasonalities$yearly$mode == 'multiplicative') {
  268. gg.yearly <- (
  269. gg.yearly + ggplot2::scale_y_continuous(labels = scales::percent)
  270. )
  271. }
  272. return(gg.yearly)
  273. }
  274. #' Plot a custom seasonal component.
  275. #'
  276. #' @param m Prophet model object.
  277. #' @param name String name of the seasonality.
  278. #' @param uncertainty Boolean to plot uncertainty intervals.
  279. #'
  280. #' @return A ggplot2 plot.
  281. #'
  282. #' @keywords internal
  283. plot_seasonality <- function(m, name, uncertainty = TRUE) {
  284. # Compute seasonality from Jan 1 through a single period.
  285. start <- set_date('2017-01-01')
  286. period <- m$seasonalities[[name]]$period
  287. end <- start + period * 24 * 3600
  288. plot.points <- 200
  289. days <- seq(from=start, to=end, length.out=plot.points)
  290. df.y <- seasonality_plot_df(m, days)
  291. seas <- predict_seasonal_components(m, df.y)
  292. seas$ds <- df.y$ds
  293. gg.s <- ggplot2::ggplot(
  294. seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
  295. ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
  296. if (period <= 2) {
  297. fmt.str <- '%T'
  298. } else if (period < 14) {
  299. fmt.str <- '%m/%d %R'
  300. } else {
  301. fmt.str <- '%m/%d'
  302. }
  303. gg.s <- gg.s +
  304. ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
  305. if (uncertainty) {
  306. gg.s <- gg.s +
  307. ggplot2::geom_ribbon(
  308. ggplot2::aes_string(
  309. ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
  310. ),
  311. alpha = 0.2,
  312. fill = "#0072B2",
  313. na.rm = TRUE)
  314. }
  315. if (m$seasonalities[[name]]$mode == 'multiplicative') {
  316. gg.s <- gg.s + ggplot2::scale_y_continuous(labels = scales::percent)
  317. }
  318. return(gg.s)
  319. }
  320. #' Get layers to overlay significant changepoints on prophet forecast plot.
  321. #'
  322. #' @param m Prophet model object.
  323. #' @param threshold Numeric, changepoints where abs(delta) >= threshold are
  324. #' significant. (Default 0.01)
  325. #' @param cp_color Character, line color. (Default "red")
  326. #' @param cp_linetype Character or integer, line type. (Default "dashed")
  327. #' @param trend Logical, if FALSE, do not draw trend line. (Default TRUE)
  328. #' @param ... Other arguments passed on to layers.
  329. #'
  330. #' @return A list of ggplot2 layers.
  331. #'
  332. #' @examples
  333. #' \dontrun{
  334. #' plot(m, fcst) + add_changepoints_to_plot(m)
  335. #' }
  336. #'
  337. #' @export
  338. add_changepoints_to_plot <- function(m, threshold = 0.01, cp_color = "red",
  339. cp_linetype = "dashed", trend = TRUE, ...) {
  340. layers <- list()
  341. if (trend) {
  342. trend_layer <- ggplot2::geom_line(
  343. ggplot2::aes_string("ds", "trend"), color = cp_color, ...)
  344. layers <- append(layers, trend_layer)
  345. }
  346. signif_changepoints <- m$changepoints[abs(m$params$delta) >= threshold]
  347. cp_layer <- ggplot2::geom_vline(
  348. xintercept = as.integer(signif_changepoints), color = cp_color,
  349. linetype = cp_linetype, ...)
  350. layers <- append(layers, cp_layer)
  351. return(layers)
  352. }
  353. #' Plot the prophet forecast.
  354. #'
  355. #' @param x Prophet object.
  356. #' @param fcst Data frame returned by predict(m, df).
  357. #' @param uncertainty Boolean indicating if the uncertainty interval for yhat
  358. #' should be plotted. Must be present in fcst as yhat_lower and yhat_upper.
  359. #' @param ... additional arguments
  360. #' @importFrom dplyr "%>%"
  361. #' @return A dygraph plot.
  362. #'
  363. #' @examples
  364. #' \dontrun{
  365. #' history <- data.frame(
  366. #' ds = seq(as.Date('2015-01-01'), as.Date('2016-01-01'), by = 'd'),
  367. #' y = sin(1:366/200) + rnorm(366)/10)
  368. #' m <- prophet(history)
  369. #' future <- make_future_dataframe(m, periods = 365)
  370. #' forecast <- predict(m, future)
  371. #' dyplot.prophet(m, forecast)
  372. #' }
  373. #'
  374. #' @export
  375. dyplot.prophet <- function(x, fcst, uncertainty=TRUE,
  376. ...)
  377. {
  378. forecast.label='Predicted'
  379. actual.label='Actual'
  380. # create data.frame for plotting
  381. df <- df_for_plotting(x, fcst)
  382. # build variables to include, or not, the uncertainty data
  383. if(uncertainty && exists("yhat_lower", where = df))
  384. {
  385. colsToKeep <- c('y', 'yhat', 'yhat_lower', 'yhat_upper')
  386. forecastCols <- c('yhat_lower', 'yhat', 'yhat_upper')
  387. } else
  388. {
  389. colsToKeep <- c('y', 'yhat')
  390. forecastCols <- c('yhat')
  391. }
  392. # convert to xts for easier date handling by dygraph
  393. dfTS <- xts::xts(df %>% dplyr::select_(.dots=colsToKeep), order.by = df$ds)
  394. # base plot
  395. dyBase <- dygraphs::dygraph(dfTS)
  396. presAnnotation <- function(dygraph, x, text) {
  397. dygraph %>%
  398. dygraphs::dyAnnotation(x, text, text, attachAtBottom = TRUE)
  399. }
  400. dyBase <- dyBase %>%
  401. # plot actual values
  402. dygraphs::dySeries(
  403. 'y', label=actual.label, color='black', drawPoints=TRUE, strokeWidth=0
  404. ) %>%
  405. # plot forecast and ribbon
  406. dygraphs::dySeries(forecastCols, label=forecast.label, color='blue') %>%
  407. # allow zooming
  408. dygraphs::dyRangeSelector() %>%
  409. # make unzoom button
  410. dygraphs::dyUnzoom()
  411. if (!is.null(x$holidays)) {
  412. for (i in 1:nrow(x$holidays)) {
  413. # make a gray line
  414. dyBase <- dyBase %>% dygraphs::dyEvent(
  415. x$holidays$ds[i],color = "rgb(200,200,200)", strokePattern = "solid")
  416. dyBase <- dyBase %>% dygraphs::dyAnnotation(
  417. x$holidays$ds[i], x$holidays$holiday[i], x$holidays$holiday[i],
  418. attachAtBottom = TRUE)
  419. }
  420. }
  421. return(dyBase)
  422. }
  423. #' Plot a performance metric vs. forecast horizon from cross validation.
  424. #' Cross validation produces a collection of out-of-sample model predictions
  425. #' that can be compared to actual values, at a range of different horizons
  426. #' (distance from the cutoff). This computes a specified performance metric
  427. #' for each prediction, and aggregated over a rolling window with horizon.
  428. #'
  429. #' This uses fbprophet.diagnostics.performance_metrics to compute the metrics.
  430. #' Valid values of metric are 'mse', 'rmse', 'mae', 'mape', and 'coverage'.
  431. #'
  432. #' rolling_window is the proportion of data included in the rolling window of
  433. #' aggregation. The default value of 0.1 means 10% of data are included in the
  434. #' aggregation for computing the metric.
  435. #'
  436. #' As a concrete example, if metric='mse', then this plot will show the
  437. #' squared error for each cross validation prediction, along with the MSE
  438. #' averaged over rolling windows of 10% of the data.
  439. #'
  440. #' @param df_cv The output from fbprophet.diagnostics.cross_validation.
  441. #' @param metric Metric name, one of 'mse', 'rmse', 'mae', 'mape', 'coverage'.
  442. #' @param rolling_window Proportion of data to use for rolling average of
  443. #' metric. In [0, 1]. Defaults to 0.1.
  444. #'
  445. #' @return A ggplot2 plot.
  446. #'
  447. #' @export
  448. plot_cross_validation_metric <- function(df_cv, metric, rolling_window=0.1) {
  449. df_none <- performance_metrics(df_cv, metrics = metric, rolling_window = 0)
  450. df_h <- performance_metrics(
  451. df_cv, metrics = metric, rolling_window = rolling_window
  452. )
  453. # Better plotting of difftime
  454. # Target ~10 ticks
  455. tick_w <- max(as.double(df_none$horizon, units = 'secs')) / 10.
  456. # Find the largest time resolution that has <1 unit per bin
  457. dts <- c('days', 'hours', 'mins', 'secs')
  458. dt_conversions <- c(
  459. 24 * 60 * 60,
  460. 60 * 60,
  461. 60,
  462. 1
  463. )
  464. for (i in seq_along(dts)) {
  465. if (as.difftime(1, units = dts[i]) < as.difftime(tick_w, units = 'secs')) {
  466. break
  467. }
  468. }
  469. df_none$x_plt <- (
  470. as.double(df_none$horizon, units = 'secs') / dt_conversions[i]
  471. )
  472. df_h$x_plt <- as.double(df_h$horizon, units = 'secs') / dt_conversions[i]
  473. gg <- (
  474. ggplot2::ggplot(df_none, ggplot2::aes_string(x = 'x_plt', y = metric)) +
  475. ggplot2::labs(x = paste0('Horizon (', dts[i], ')'), y = metric) +
  476. ggplot2::geom_point(color = 'gray') +
  477. ggplot2::geom_line(
  478. data = df_h, ggplot2::aes_string(x = 'x_plt', y = metric), color = 'blue'
  479. ) +
  480. ggplot2::theme(aspect.ratio = 3 / 5)
  481. )
  482. return(gg)
  483. }