Procházet zdrojové kódy

Make component plots work if forecast range is less than seasonality period

Ben Letham před 8 roky
rodič
revize
208399678c
5 změnil soubory, kde provedl 46 přidání a 46 odebrání
  1. 26 26
      R/R/prophet.R
  2. 2 2
      R/man/plot_weekly.Rd
  3. 2 2
      R/man/plot_yearly.Rd
  4. 1 1
      R/man/prophet.Rd
  5. 15 15
      python/fbprophet/forecaster.py

+ 26 - 26
R/R/prophet.R

@@ -964,11 +964,11 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
   }
   }
   # Plot weekly seasonality, if present
   # Plot weekly seasonality, if present
   if ("weekly" %in% colnames(df)) {
   if ("weekly" %in% colnames(df)) {
-    panels[[length(panels) + 1]] <- plot_weekly(df, uncertainty)
+    panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty)
   }
   }
   # Plot yearly seasonality, if present
   # Plot yearly seasonality, if present
   if ("yearly" %in% colnames(df)) {
   if ("yearly" %in% colnames(df)) {
-    panels[[length(panels) + 1]] <- plot_yearly(df, uncertainty)
+    panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty)
   }
   }
   # Make the plot.
   # Make the plot.
   grid::grid.newpage()
   grid::grid.newpage()
@@ -988,9 +988,10 @@ prophet_plot_components <- function(m, fcst, uncertainty = TRUE) {
 #'
 #'
 #' @return A ggplot2 plot.
 #' @return A ggplot2 plot.
 plot_trend <- function(df, uncertainty = TRUE) {
 plot_trend <- function(df, uncertainty = TRUE) {
-  gg.trend <- ggplot2::ggplot(df, ggplot2::aes(x = ds, y = trend)) +
+  df.t <- df[!is.na(df$trend),]
+  gg.trend <- ggplot2::ggplot(df.t, ggplot2::aes(x = ds, y = trend)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
-  if (exists('cap', where = df)) {
+  if (exists('cap', where = df.t)) {
     gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
     gg.trend <- gg.trend + ggplot2::geom_line(ggplot2::aes(y = cap),
                                               linetype = 'dashed',
                                               linetype = 'dashed',
                                               na.rm = TRUE)
                                               na.rm = TRUE)
@@ -1021,6 +1022,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
                                                           "_lower"), drop = FALSE]),
                                                           "_lower"), drop = FALSE]),
                      holidays_upper = rowSums(df[, paste0(holiday.comps,
                      holidays_upper = rowSums(df[, paste0(holiday.comps,
                                                           "_upper"), drop = FALSE]))
                                                           "_upper"), drop = FALSE]))
+  df.s <- df.s[!is.na(df.s$holidays),]
   # NOTE the above CI calculation is incorrect if holidays overlap in time.
   # NOTE the above CI calculation is incorrect if holidays overlap in time.
   # Since it is just for the visualization we will not worry about it now.
   # Since it is just for the visualization we will not worry about it now.
   gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
   gg.holidays <- ggplot2::ggplot(df.s, ggplot2::aes(x = ds, y = holidays)) +
@@ -1038,20 +1040,19 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
 
 
 #' Plot the weekly component of the forecast.
 #' Plot the weekly component of the forecast.
 #'
 #'
-#' @param df Forecast dataframe for plotting.
+#' @param m Prophet model object
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #'
 #'
 #' @return A ggplot2 plot.
 #' @return A ggplot2 plot.
-plot_weekly <- function(df, uncertainty = TRUE) {
-  # Get weekday names in current locale
-  days <- weekdays(seq.Date(as.Date('2017-01-01'), by='d', length.out=7))
-  df.s <- df %>%
-    dplyr::mutate(dow = factor(weekdays(ds), levels = days)) %>%
-    dplyr::group_by(dow) %>%
-    dplyr::slice(1) %>%
-    dplyr::ungroup() %>%
-    dplyr::arrange(dow)
-  gg.weekly <- ggplot2::ggplot(df.s, ggplot2::aes(x = dow, y = weekly,
+plot_weekly <- function(m, uncertainty = TRUE) {
+  # Compute weekly seasonality for a Sun-Sat sequence of dates.
+  df.w <- data.frame(ds=seq.Date(zoo::as.Date('2017-01-01'), by='d',
+                                 length.out=7))
+  df.w <- setup_dataframe(m, df.w)$df
+  seas <- predict_seasonal_components(m, df.w)
+  seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
+
+  gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
                                                   group = 1)) +
                                                   group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::labs(x = "Day of week")
     ggplot2::labs(x = "Day of week")
@@ -1068,20 +1069,19 @@ plot_weekly <- function(df, uncertainty = TRUE) {
 
 
 #' Plot the yearly component of the forecast.
 #' Plot the yearly component of the forecast.
 #'
 #'
-#' @param df Forecast dataframe for plotting.
+#' @param m Prophet model object.
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #' @param uncertainty Boolean to plot uncertainty intervals.
 #'
 #'
 #' @return A ggplot2 plot.
 #' @return A ggplot2 plot.
-plot_yearly <- function(df, uncertainty = TRUE) {
-  # Drop year from the dates
-  df.s <- df %>%
-    dplyr::mutate(doy = strftime(ds, format = "2000-%m-%d")) %>%
-    dplyr::group_by(doy) %>%
-    dplyr::slice(1) %>%
-    dplyr::ungroup() %>%
-    dplyr::mutate(doy = zoo::as.Date(doy)) %>%
-    dplyr::arrange(doy)
-  gg.yearly <- ggplot2::ggplot(df.s, ggplot2::aes(x = doy, y = yearly,
+plot_yearly <- function(m, uncertainty = TRUE) {
+  # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
+  df.y <- data.frame(ds=seq.Date(zoo::as.Date('2017-01-01'), by='d',
+                                 length.out=365))
+  df.y <- setup_dataframe(m, df.y)$df
+  seas <- predict_seasonal_components(m, df.y)
+  seas$ds <- df.y$ds
+
+  gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
                                                   group = 1)) +
                                                   group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
     ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
     ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +

+ 2 - 2
R/man/plot_weekly.Rd

@@ -4,10 +4,10 @@
 \alias{plot_weekly}
 \alias{plot_weekly}
 \title{Plot the weekly component of the forecast.}
 \title{Plot the weekly component of the forecast.}
 \usage{
 \usage{
-plot_weekly(df, uncertainty = TRUE)
+plot_weekly(m, uncertainty = TRUE)
 }
 }
 \arguments{
 \arguments{
-\item{df}{Forecast dataframe for plotting.}
+\item{m}{Prophet model object}
 
 
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 }
 }

+ 2 - 2
R/man/plot_yearly.Rd

@@ -4,10 +4,10 @@
 \alias{plot_yearly}
 \alias{plot_yearly}
 \title{Plot the yearly component of the forecast.}
 \title{Plot the yearly component of the forecast.}
 \usage{
 \usage{
-plot_yearly(df, uncertainty = TRUE)
+plot_yearly(m, uncertainty = TRUE)
 }
 }
 \arguments{
 \arguments{
-\item{df}{Forecast dataframe for plotting.}
+\item{m}{Prophet model object.}
 
 
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 \item{uncertainty}{Boolean to plot uncertainty intervals.}
 }
 }

+ 1 - 1
R/man/prophet.Rd

@@ -48,7 +48,7 @@ components model.}
 automatic changepoint selection. Large values will allow many changepoints,
 automatic changepoint selection. Large values will allow many changepoints,
 small values will allow few changepoints.}
 small values will allow few changepoints.}
 
 
-\item{mcmc.samples}{Integer, if great than 0, will do full Bayesian
+\item{mcmc.samples}{Integer, if greater than 0, will do full Bayesian
 inference with the specified number of MCMC samples. If 0, will do MAP
 inference with the specified number of MCMC samples. If 0, will do MAP
 estimation.}
 estimation.}
 
 

+ 15 - 15
python/fbprophet/forecaster.py

@@ -1006,18 +1006,17 @@ class Prophet(object):
         if not ax:
         if not ax:
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             ax = fig.add_subplot(111)
             ax = fig.add_subplot(111)
-        df_s = fcst.copy()
-        df_s['dow'] = df_s['ds'].dt.weekday_name
-        df_s = df_s.groupby('dow').first()
-        days = pd.date_range(start='2017-01-01', periods=7).weekday_name
-        y_weekly = [df_s.loc[d]['weekly'] for d in days]
-        y_weekly_l = [df_s.loc[d]['weekly_lower'] for d in days]
-        y_weekly_u = [df_s.loc[d]['weekly_upper'] for d in days]
-        artists += ax.plot(range(len(days)), y_weekly, ls='-',
+        # Compute weekly seasonality for a Sun-Sat sequence of dates.
+        days = pd.date_range(start='2017-01-01', periods=7)
+        df_w = pd.DataFrame({'ds': days})
+        df_w = self.setup_dataframe(df_w)
+        seas = self.predict_seasonal_components(df_w)
+        days = days.weekday_name
+        artists += ax.plot(range(len(days)), seas['weekly'], ls='-',
                            c='#0072B2')
                            c='#0072B2')
         if uncertainty:
         if uncertainty:
             artists += [ax.fill_between(range(len(days)),
             artists += [ax.fill_between(range(len(days)),
-                                        y_weekly_l, y_weekly_u,
+                                        seas['weekly_lower'], seas['weekly_upper'],
                                         color='#0072B2', alpha=0.2)]
                                         color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.set_xticks(range(len(days)))
         ax.set_xticks(range(len(days)))
@@ -1044,15 +1043,16 @@ class Prophet(object):
         if not ax:
         if not ax:
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             fig = plt.figure(facecolor='w', figsize=(10, 6))
             ax = fig.add_subplot(111)
             ax = fig.add_subplot(111)
-        df_s = fcst.copy()
-        df_s['doy'] = df_s['ds'].map(lambda x: x.strftime('2000-%m-%d'))
-        df_s = df_s.groupby('doy').first().sort_index()
-        artists += ax.plot(pd.to_datetime(df_s.index), df_s['yearly'], ls='-',
+        # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
+        df_y = pd.DataFrame({'ds': pd.date_range(start='2017-01-01', periods=365)})
+        df_y = self.setup_dataframe(df_y)
+        seas = self.predict_seasonal_components(df_y)
+        artists += ax.plot(df_y['ds'], seas['yearly'], ls='-',
                            c='#0072B2')
                            c='#0072B2')
         if uncertainty:
         if uncertainty:
             artists += [ax.fill_between(
             artists += [ax.fill_between(
-                pd.to_datetime(df_s.index), df_s['yearly_lower'],
-                df_s['yearly_upper'], color='#0072B2', alpha=0.2)]
+                df_y['ds'].values, seas['yearly_lower'],
+                seas['yearly_upper'], color='#0072B2', alpha=0.2)]
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         ax.grid(True, which='major', c='gray', ls='-', lw=1, alpha=0.2)
         months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
         months = MonthLocator(range(1, 13), bymonthday=1, interval=2)
         ax.xaxis.set_major_formatter(FuncFormatter(
         ax.xaxis.set_major_formatter(FuncFormatter(