Browse Source

Make component plots work if forecast range is less than seasonality period

Ben Letham 8 years ago
parent
commit
208399678c
5 changed files with 46 additions and 46 deletions
  1. 26 26
      R/R/prophet.R
  2. 2 2
      R/man/plot_weekly.Rd
  3. 2 2
      R/man/plot_yearly.Rd
  4. 1 1
      R/man/prophet.Rd
  5. 15 15
      python/fbprophet/forecaster.py

+ 26 - 26
R/R/prophet.R

@@ -964,11 +964,11 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
   }
   # Plot weekly seasonality, if present
   if ("weekly" %in% colnames(df)) {
-    panels[[length(panels) + 1]] <- plot_weekly(df, uncertainty)
+    panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty)
   }
   # Plot yearly seasonality, if present
   if ("yearly" %in% colnames(df)) {
-    panels[[length(panels) + 1]] <- plot_yearly(df, uncertainty)
+    panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty)
   }
   # Make the plot.
   grid::grid.newpage()
@@ -988,9 +988,10 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
 #'
 #' @return A ggplot2 plot.
 plot_trend <- function(df, uncertainty = TRUE) {
-  gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
+  df.t <- df[!is.na(df$trend),]
+  gg.trend <- ggplot2::ggplot(df.t, ggplot2::aes(x = ds, y = trend)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
-  if (exists('cap', where = df)) {
+  if (exists('cap', where = df.t)) {
     gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
                                               linetype = 'dashed',
                                               na.rm = TRUE)
@@ -1021,6 +1022,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
                                                           "_lower"), drop = FALSE]),
                      holidays_upper = rowSums(df[, paste0(holiday.comps,
                                                           "_upper"), drop = FALSE]))
+  df.s <- df.s[!is.na(df.s$holidays),]
   # NOTE the above CI calculation is incorrect if holidays overlap in time.
   # Since it is just for the visualization we will not worry about it now.
   gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
@@ -1038,20 +1040,19 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
 
 #' Plot the weekly component of the forecast.
 #'
-#' @param df Forecast dataframe for plotting.
+#' @param m Prophet model object
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #'
 #' @return A ggplot2 plot.
-plot_weekly <- function(df, uncertainty = TRUE) {
-  # Get weekday names in current locale
-  days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7))
-  df.s <- df %>%
-    dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>%
-    dplyr::group_by(dow) %>%
-    dplyr::slice(1) %>%
-    dplyr::ungroup() %>%
-    dplyr::arrange(dow)
-  gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
+plot_weekly <- function(m, uncertainty = TRUE) {
+  # Compute weekly seasonality for a Sun-Sat sequence of dates.
+  df.w <- data.frame(ds=seq.Date(zoo::as.Date('2017-01-01'), by='d',
+                                 length.out=7))
+  df.w <- setup_dataframe(m, df.w)$df
+  seas <- predict_seasonal_components(m, df.w)
+  seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
+
+  gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
                                                   group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::labs(x = "Day of week")
@@ -1068,20 +1069,19 @@ plot_weekly <- function(df, uncertainty = TRUE) {
 
 #' Plot the yearly component of the forecast.
 #'
-#' @param df Forecast dataframe for plotting.
+#' @param m Prophet model object.
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #'
 #' @return A ggplot2 plot.
-plot_yearly <- function(df, uncertainty = TRUE) {
-  # Drop year from the dates
-  df.s <- df %>%
-    dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
-    dplyr::group_by(doy) %>%
-    dplyr::slice(1) %>%
-    dplyr::ungroup() %>%
-    dplyr::mutate(doy = zoo::as.Date(doy)) %>%
-    dplyr::arrange(doy)
-  gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
+plot_yearly <- function(m, uncertainty = TRUE) {
+  # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
+  df.y <- data.frame(ds=seq.Date(zoo::as.Date('2017-01-01'), by='d',
+                                 length.out=365))
+  df.y <- setup_dataframe(m, df.y)$df
+  seas <- predict_seasonal_components(m, df.y)
+  seas$ds <- df.y$ds
+
+  gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
                                                   group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +

+ 2 - 2
R/man/plot_weekly.Rd

@@ -4,10 +4,10 @@
 \alias{plot_weekly}
 \title{Plot the weekly component of the forecast.}
 \usage{
-plot_weekly(df, uncertainty = TRUE)
+plot_weekly(m, uncertainty = TRUE)
 }
 \arguments{
-\item{df}{Forecast dataframe for plotting.}
+\item{m}{Prophet model object}
 
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 }

+ 2 - 2
R/man/plot_yearly.Rd

@@ -4,10 +4,10 @@
 \alias{plot_yearly}
 \title{Plot the yearly component of the forecast.}
 \usage{
-plot_yearly(df, uncertainty = TRUE)
+plot_yearly(m, uncertainty = TRUE)
 }
 \arguments{
-\item{df}{Forecast dataframe for plotting.}
+\item{m}{Prophet model object.}
 
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 }

+ 1 - 1
R/man/prophet.Rd

@@ -48,7 +48,7 @@ components model.}
 automatic changepoint selection. Large values will allow many changepoints,
 small values will allow few changepoints.}
 
-\item{mcmc.samples}{Integer, if great than 0, will do full Bayesian
+\item{mcmc.samples}{Integer, if greater than 0, will do full Bayesian
 inference with the specified number of MCMC samples. If 0, will do MAP
 estimation.}
 

+ 15 - 15
python/fbprophet/forecaster.py

@@ -1006,18 +1006,17 @@ class Prophet(object):
         if not ax:
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             ax = fig.add_subplot(111)
-        df_s = fcst.copy()
-        df_s['dow'] = df_s['ds'].dt.weekday_name
-        df_s = df_s.groupby('dow').first()
-        days = pd.date_range(start='2017-01-01', periods=7).weekday_name
-        y_weekly = [df_s.loc[d]['weekly'] for d in days]
-        y_weekly_l = [df_s.loc[d]['weekly_lower'] for d in days]
-        y_weekly_u = [df_s.loc[d]['weekly_upper'] for d in days]
-        artists += ax.plot(range(len(days)), y_weekly, ls='-',
+        # Compute weekly seasonality for a Sun-Sat sequence of dates.
+        days = pd.date_range(start='2017-01-01', periods=7)
+        df_w = pd.DataFrame({'ds': days})
+        df_w = self.setup_dataframe(df_w)
+        seas = self.predict_seasonal_components(df_w)
+        days = days.weekday_name
+        artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
                            c='#0072B2')
         if uncertainty:
             artists += [ax.fill_between(range(len(days)),
-                                        y_weekly_l, y_weekly_u,
+                                        seas['weekly_lower'], seas['weekly_upper'],
                                         color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.set_xticks(range(len(days)))
@@ -1044,15 +1043,16 @@ class Prophet(object):
         if not ax:
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             ax = fig.add_subplot(111)
-        df_s = fcst.copy()
-        df_s['doy'] = df_s['ds'].map(lambda x: x.strftime('2000-%m-%d'))
-        df_s = df_s.groupby('doy').first().sort_index()
-        artists += ax.plot(pd.to_datetime(df_s.index), df_s['yearly'], ls='-',
+        # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
+        df_y = pd.DataFrame({'ds': pd.date_range(start='2017-01-01', periods=365)})
+        df_y = self.setup_dataframe(df_y)
+        seas = self.predict_seasonal_components(df_y)
+        artists += ax.plot(df_y['ds'], seas['yearly'], ls='-',
                            c='#0072B2')
         if uncertainty:
             artists += [ax.fill_between(
-                pd.to_datetime(df_s.index), df_s['yearly_lower'],
-                df_s['yearly_upper'], color='#0072B2', alpha=0.2)]
+                df_y['ds'].values, seas['yearly_lower'],
+                seas['yearly_upper'], color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
         ax.xaxis.set_major_formatter(FuncFormatter(