Browse Source

Refactor R component plotting to match #84

Ben Letham 8 years ago
parent
commit
fcbd957bcc
5 changed files with 204 additions and 83 deletions
  1. 122 83
      R/R/prophet.R
  2. 22 0
      R/man/plot_holidays.Rd
  3. 20 0
      R/man/plot_trend.Rd
  4. 20 0
      R/man/plot_weekly.Rd
  5. 20 0
      R/man/plot_yearly.Rd

+ 122 - 83
R/R/prophet.R

@@ -873,7 +873,6 @@ df_for_plotting <- function(m, fcst) {
 plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds',
                          ylabel = 'y', ...) {
   df <- df_for_plotting(x, fcst)
-  forecast.color <- "#0072B2"
   gg <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = y)) +
     ggplot2::labs(x = xlabel, y = ylabel)
   if (exists('cap', where = df)) {
@@ -884,12 +883,12 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds',
     gg <- gg +
       ggplot2::geom_ribbon(ggplot2::aes(ymin = yhat_lower, ymax = yhat_upper),
                            alpha = 0.2,
-                           fill = forecast.color,
+                           fill = "#0072B2",
                            na.rm = TRUE)
   }
   gg <- gg +
     ggplot2::geom_point(na.rm=TRUE) +
-    ggplot2::geom_line(ggplot2::aes(y = yhat), color = forecast.color,
+    ggplot2::geom_line(ggplot2::aes(y = yhat), color = "#0072B2",
                        na.rm = TRUE) +
     ggplot2::theme(aspect.ratio = 3 / 5)
   return(gg)
@@ -908,95 +907,19 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, xlabel = 'ds',
 #' @importFrom dplyr "%>%"
 prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
   df <- df_for_plotting(m, fcst)
-  forecast.color <- "#0072B2"
   # Plot the trend
-  gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
-    ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
-  if (exists('cap', where = df)) {
-    gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
-                                              linetype = 'dashed',
-                                              na.rm = TRUE)
-  }
-  if (uncertainty) {
-    gg.trend <- gg.trend +
-      ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
-                                        ymax = trend_upper),
-                           alpha = 0.2,
-                           fill = forecast.color,
-                           na.rm = TRUE)
-  }
-  panels <- list(gg.trend)
+  panels <- list(plot_trend(df, uncertainty))
   # Plot holiday components, if present.
   if (!is.null(m$holidays)) {
-    holiday.comps <- unique(m$holidays$holiday) %>% as.character()
-    df.s <- data.frame(ds = df$ds,
-                       holidays = rowSums(df[, holiday.comps, drop = FALSE]),
-                       holidays_lower = rowSums(df[, paste0(holiday.comps,
-                                                            "_lower"), drop = FALSE]),
-                       holidays_upper = rowSums(df[, paste0(holiday.comps,
-                                                            "_upper"), drop = FALSE]))
-    # NOTE the above CI calculation is incorrect if holidays overlap in time.
-    # Since it is just for the visualization we will not worry about it now.
-    gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
-      ggplot2::geom_line(color = forecast.color, na.rm = TRUE)
-    if (uncertainty) {
-      gg.holidays <- gg.holidays +
-      ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
-                                        ymax = holidays_upper),
-                           alpha = 0.2,
-                           fill = forecast.color,
-                           na.rm = TRUE)
-    }
-    panels[[length(panels) + 1]] <- gg.holidays
+    panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
   }
   # Plot weekly seasonality, if present
   if ("weekly" %in% colnames(df)) {
-    # Get weekday names in current locale
-    days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7))
-    df.s <- df %>%
-      dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>%
-      dplyr::group_by(dow) %>%
-      dplyr::slice(1) %>%
-      dplyr::ungroup() %>%
-      dplyr::arrange(dow)
-    gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
-                                                    group = 1)) +
-      ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
-      ggplot2::labs(x = "Day of week")
-    if (uncertainty) {
-      gg.weekly <- gg.weekly +
-      ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
-                                        ymax = weekly_upper),
-                           alpha = 0.2,
-                           fill = forecast.color,
-                           na.rm = TRUE)
-    }
-    panels[[length(panels) + 1]] <- gg.weekly
+    panels[[length(panels) + 1]] <- plot_weekly(df, uncertainty)
   }
   # Plot yearly seasonality, if present
   if ("yearly" %in% colnames(df)) {
-    # Drop year from the dates
-    df.s <- df %>%
-      dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
-      dplyr::group_by(doy) %>%
-      dplyr::slice(1) %>%
-      dplyr::ungroup() %>%
-      dplyr::mutate(doy = zoo::as.Date(doy)) %>%
-      dplyr::arrange(doy)
-    gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
-                                                    group = 1)) +
-      ggplot2::geom_line(color = forecast.color, na.rm = TRUE) +
-      ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
-      ggplot2::labs(x = "Day of year")
-    if (uncertainty) {
-      gg.yearly <- gg.yearly +
-      ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
-                                        ymax = yearly_upper),
-                           alpha = 0.2,
-                           fill = forecast.color,
-                           na.rm = TRUE)
-    }
-    panels[[length(panels) + 1]] = gg.yearly
+    panels[[length(panels) + 1]] <- plot_yearly(df, uncertainty)
   }
   # Make the plot.
   grid::grid.newpage()
@@ -1008,4 +931,120 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
   }
 }
 
+#' Plot the prophet trend.
+#'
+#' @param df Forecast dataframe for plotting.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+plot_trend <- function(df, uncertainty = TRUE) {
+  gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
+  if (exists('cap', where = df)) {
+    gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
+                                              linetype = 'dashed',
+                                              na.rm = TRUE)
+  }
+  if (uncertainty) {
+    gg.trend <- gg.trend +
+      ggplot2::geom_ribbon(ggplot2::aes(ymin = trend_lower,
+                                        ymax = trend_upper),
+                           alpha = 0.2,
+                           fill = "#0072B2",
+                           na.rm = TRUE)
+  }
+  return(gg.trend)
+}
+
+#' Plot the holidays component of the forecast.
+#'
+#' @param m Prophet model
+#' @param df Forecast dataframe for plotting.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+plot_holidays <- function(m, df, uncertainty = TRUE) {
+  holiday.comps <- unique(m$holidays$holiday) %>% as.character()
+  df.s <- data.frame(ds = df$ds,
+                     holidays = rowSums(df[, holiday.comps, drop = FALSE]),
+                     holidays_lower = rowSums(df[, paste0(holiday.comps,
+                                                          "_lower"), drop = FALSE]),
+                     holidays_upper = rowSums(df[, paste0(holiday.comps,
+                                                          "_upper"), drop = FALSE]))
+  # NOTE the above CI calculation is incorrect if holidays overlap in time.
+  # Since it is just for the visualization we will not worry about it now.
+  gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
+  if (uncertainty) {
+    gg.holidays <- gg.holidays +
+    ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
+                                      ymax = holidays_upper),
+                         alpha = 0.2,
+                         fill = "#0072B2",
+                         na.rm = TRUE)
+  }
+  return(gg.holidays)
+}
+
+#' Plot the weekly component of the forecast.
+#'
+#' @param df Forecast dataframe for plotting.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+plot_weekly <- function(df, uncertainty = TRUE) {
+  # Get weekday names in current locale
+  days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7))
+  df.s <- df %>%
+    dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>%
+    dplyr::group_by(dow) %>%
+    dplyr::slice(1) %>%
+    dplyr::ungroup() %>%
+    dplyr::arrange(dow)
+  gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
+                                                  group = 1)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
+    ggplot2::labs(x = "Day of week")
+  if (uncertainty) {
+    gg.weekly <- gg.weekly +
+    ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
+                                      ymax = weekly_upper),
+                         alpha = 0.2,
+                         fill = "#0072B2",
+                         na.rm = TRUE)
+  }
+  return(gg.weekly)
+}
+
+#' Plot the yearly component of the forecast.
+#'
+#' @param df Forecast dataframe for plotting.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+plot_yearly <- function(df, uncertainty = TRUE) {
+  # Drop year from the dates
+  df.s <- df %>%
+    dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
+    dplyr::group_by(doy) %>%
+    dplyr::slice(1) %>%
+    dplyr::ungroup() %>%
+    dplyr::mutate(doy = zoo::as.Date(doy)) %>%
+    dplyr::arrange(doy)
+  gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
+                                                  group = 1)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
+    ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
+    ggplot2::labs(x = "Day of year")
+  if (uncertainty) {
+    gg.yearly <- gg.yearly +
+    ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
+                                      ymax = yearly_upper),
+                         alpha = 0.2,
+                         fill = "#0072B2",
+                         na.rm = TRUE)
+  }
+  return(gg.yearly)
+}
+
 # fb-block 3

+ 22 - 0
R/man/plot_holidays.Rd

@@ -0,0 +1,22 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{plot_holidays}
+\alias{plot_holidays}
+\title{Plot the holidays component of the forecast.}
+\usage{
+plot_holidays(m, df, uncertainty = TRUE)
+}
+\arguments{
+\item{m}{Prophet model}
+
+\item{df}{Forecast dataframe for plotting.}
+
+\item{uncertainty}{Boolean to plot uncertainty intervals.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+Plot the holidays component of the forecast.
+}
+

+ 20 - 0
R/man/plot_trend.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{plot_trend}
+\alias{plot_trend}
+\title{Plot the prophet trend.}
+\usage{
+plot_trend(df, uncertainty = TRUE)
+}
+\arguments{
+\item{df}{Forecast dataframe for plotting.}
+
+\item{uncertainty}{Boolean to plot uncertainty intervals.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+Plot the prophet trend.
+}
+

+ 20 - 0
R/man/plot_weekly.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{plot_weekly}
+\alias{plot_weekly}
+\title{Plot the weekly component of the forecast.}
+\usage{
+plot_weekly(df, uncertainty = TRUE)
+}
+\arguments{
+\item{df}{Forecast dataframe for plotting.}
+
+\item{uncertainty}{Boolean to plot uncertainty intervals.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+Plot the weekly component of the forecast.
+}
+

+ 20 - 0
R/man/plot_yearly.Rd

@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{plot_yearly}
+\alias{plot_yearly}
+\title{Plot the yearly component of the forecast.}
+\usage{
+plot_yearly(df, uncertainty = TRUE)
+}
+\arguments{
+\item{df}{Forecast dataframe for plotting.}
+
+\item{uncertainty}{Boolean to plot uncertainty intervals.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+Plot the yearly component of the forecast.
+}
+