Browse Source

External regressors v2 (#283)

Add regressors in R
Simon Kim 8 years ago
parent
commit
17efc9aecd

+ 297 - 80
R/R/prophet.R

@@ -9,7 +9,7 @@
 globalVariables(c(
   "ds", "y", "cap", ".",
   "component", "dow", "doy", "holiday", "holidays", "holidays_lower", "holidays_upper", "ix",
-  "lower", "n", "stat", "trend", "row_number",
+  "lower", "n", "stat", "trend", "row_number", "extra_regressors",
   "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
   "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
 
@@ -80,6 +80,7 @@ prophet <- function(df = NULL,
                     weekly.seasonality = 'auto',
                     daily.seasonality = 'auto',
                     holidays = NULL,
+                    extra_regressors = NULL,  #new
                     seasonality.prior.scale = 10,
                     holidays.prior.scale = 10,
                     changepoint.prior.scale = 0.05,
@@ -103,6 +104,7 @@ prophet <- function(df = NULL,
     weekly.seasonality = weekly.seasonality,
     daily.seasonality = daily.seasonality,
     holidays = holidays,
+    extra_regressors = extra_regressors,
     seasonality.prior.scale = seasonality.prior.scale,
     changepoint.prior.scale = changepoint.prior.scale,
     holidays.prior.scale = holidays.prior.scale,
@@ -130,6 +132,48 @@ prophet <- function(df = NULL,
   return(m)
 }
 
+#' Validates the name of a seasonality, holiday, or regressor
+#'
+#' @param m Prophet object.
+#' @param name string
+#' @param check_holidays bool check if name already used for holiday
+#' @param check_seasonalities bool check if name already used for seasonality
+#' @param check_regressors  bool check if name already used for regressor
+#'
+validate_column_name <- function(m, name, check_holidays = TRUE,
+                                 check_seasonalities = TRUE, check_regressors = TRUE) {
+
+  if (grepl("_delim_", name)) {
+    stop('Holiday name cannot contain "_delim_"')
+  }
+
+  reserved_names = c('trend', 'seasonal', 'seasonalities', 'daily', 'weekly', 'yearly',
+                     'holidays', 'zeros', 'extra_regressors', 'yhat')
+
+  rn_l = paste(reserved_names,"_lower",sep="")
+  rn_u = paste(reserved_names,"_upper",sep="")
+  reserved_names = c(reserved_names, rn_l, rn_u, c("ds","y"));
+
+  if(name %in% reserved_names){
+    error_message = paste("Name ", name, " is reserved.");
+    stop(error_message)
+  }
+
+  if(check_holidays & !is.null(m$holidays) & (name %in% unique(m$holidays$holiday))){
+    error_message = paste("Name ", name, " already used for a holiday.");
+    stop(error_message)
+  }
+  #m$yearly.seasonality
+  if(check_seasonalities & (name %in% m$seasonalities[[name]])){
+    error_message = paste("Name ", name, " already used for a seasonality.");
+    stop(error_message)
+  }
+  if(check_regressors & (name %in% m$extra_regressors[[name]])){
+    error_message = paste("Name ", name, " already used for an added regressor.");
+    stop(error_message)
+  }
+}
+
 #' Validates the inputs to Prophet.
 #'
 #' @param m Prophet object.
@@ -161,13 +205,7 @@ validate_inputs <- function(m) {
       }
     }
     for (h in unique(m$holidays$holiday)) {
-      if (grepl("_delim_", h)) {
-        stop('Holiday name cannot contain "_delim_"')
-      }
-      if (h %in% c('zeros', 'yearly', 'weekly', 'daily', 'yhat', 'seasonal',
-                   'trend')) {
-        stop(paste0('Holiday name "', h, '" reserved.'))
-      }
+      validate_column_name(m,h, check_holidays=FALSE)
     }
   }
 }
@@ -219,7 +257,7 @@ compile_stan_model <- function(model) {
 
 #' Convert date vector
 #'
-#' Convert the date to POSIXct object 
+#' Convert the date to POSIXct object
 #'
 #' @param ds Date vector, can be consisted of characters
 #' @param tz string time zone
@@ -230,18 +268,17 @@ compile_stan_model <- function(model) {
 set_date <- function(ds = NULL, tz = "GMT") {
   if (length(ds) == 0) {
     return(NULL)
-  } 
-  
+  }
+
   if (is.factor(ds)) {
     ds <- as.character(ds)
   }
-  
+
   if (min(nchar(ds)) < 12) {
     ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
   } else {
     ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
   }
-  attr(ds, "tzone") <- tz
   return(ds)
 }
 
@@ -267,7 +304,8 @@ time_diff <- function(ds1, ds2, units = "days") {
 #' and predicting.
 #'
 #' @param m Prophet object.
-#' @param df Data frame with columns ds, y, and cap if logistic growth.
+#' @param df Data frame with columns ds, y, and cap if logistic growth.Any
+#'           specified additional regressors must also be present.
 #' @param initialize_scales Boolean set scaling factors in m from df.
 #'
 #' @return list with items 'df' and 'm'.
@@ -283,6 +321,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
                'format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S'))
   }
 
+  #names(m$extra_regressors)
+
+
   df <- df %>%
     dplyr::arrange(ds)
 
@@ -343,8 +384,10 @@ set_changepoints <- function(m) {
               m$n.changepoints)
     }
     if (m$n.changepoints > 0) {
-      cp.indexes <- round(seq.int(1, hist.size,
-                          length.out = (m$n.changepoints + 1))[-1])
+      # Place potential changepoints evenly through the first 80 pcnt of
+      # the history.
+      cp.indexes <- round(seq.int(1, floor(nrow(m$history) * .8),
+                                  length.out = (m$n.changepoints + 1))[-1])
       m$changepoints <- m$history$ds[cp.indexes]
     } else {
       m$changepoints <- c()
@@ -422,8 +465,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
 make_holiday_features <- function(m, dates) {
   scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
   # Strip dates to be just days, for joining on holidays
-  dates <- set_date(format(dates, "%Y-%m-%d"))
-
+  dates <- set_date(format(dates))
   wide <- m$holidays %>%
     dplyr::mutate(ds = set_date(ds)) %>%
     dplyr::group_by(holiday, ds) %>%
@@ -450,6 +492,44 @@ make_holiday_features <- function(m, dates) {
   return(holiday.mat)
 }
 
+#'Add an additional regressor to be used for fitting and predicting.
+#'
+#'The dataframe passed to `fit` and `predict` will have a column with the
+#'specified name to be used as a regressor. When standardize='auto', the
+#'regressor will be standardized unless it is binary. The regression
+#'coefficient is given a prior with the specified scale parameter.
+#'Decreasing the prior scale will add additional regularization. If no
+#'prior scale is provided, self.holidays_prior_scale will be used.
+#'
+#' @param m
+#' @param  name string name of the regressor
+#' @param  prior_scale optional float scale for the normal prior. If not
+#'                    provided, self.holidays_prior_scale will be used.
+#' @param  standardize optional, specify whether this regressor will be
+#'                     standardized prior to fitting. Can be 'auto' (standardize if not
+#'                     binary), True, or False.
+#' @return  The prophet model with the regressor added.
+#' @export
+add_regressor <- function(m, prior_scale=0.0, standardize='auto'){
+  if(!is.null(m$history)){
+    stop('Regressors must be added prior to model fitting.')
+  }
+  validate_column_name(m,check_regressors=FALSE);
+  if(prior_scale == 0){
+    prior_scale = m$holidays.prior.scale
+  }
+
+  if(prior_scale < 0){
+    stop("prior_scale is less than 0");
+  }
+  m$extra_regressors = list(name = list(prior_scale = prior_scale,
+                                        standardize=standardize,
+                                        mu=0,
+                                        std=1.0))
+
+  return(m)
+}
+
 #' Add a seasonal component with specified period and number of Fourier
 #' components.
 #'
@@ -468,37 +548,93 @@ make_holiday_features <- function(m, dates) {
 #' @export
 add_seasonality <- function(m, name, period, fourier.order) {
   if (!is.null(m$holidays)) {
-    if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
-      stop('Name "', name, '" already used for holiday')
-    }
+    stop("Seasonality must be added prior to model fitting.")
+  }
+
+  if (!(name %in% c('daily', 'weekly', 'yearly'))) {
+    validate_column_name(name,check_seasonalities=FALSE)
   }
   m$seasonalities[[name]] <- c(period, fourier.order)
   return(m)
 }
 
 #' Dataframe with seasonality features.
+#' Includes seasonality features, holiday features, and added regressors.
 #'
 #' @param m Prophet object.
 #' @param df Dataframe with dates for computing seasonality features.
 #'
-#' @return Dataframe with seasonality.
+#' @return Dataframe with regressor features,
+#'         list of prior scales for each colum of the features and any added regressors
 #'
 #' @keywords internal
 make_all_seasonality_features <- function(m, df) {
-  seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
+  #seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
+  seasonal.features <- c();
+  prior_scales <- c();
+
   for (name in names(m$seasonalities)) {
     period <- m$seasonalities[[name]][1]
     series.order <- m$seasonalities[[name]][2]
-    seasonal.features <- cbind(
-      seasonal.features,
-      make_seasonality_features(df$ds, period, series.order, name))
+    features = make_seasonality_features(df$ds, period, series.order, name);
+    if(is.null(seasonal.features)){
+      seasonal.features <- features;
+    }
+    seasonal.features <- cbind(seasonal.features, features) #test append 와 문제가 없는지 확인
+    prior_scales = c(prior_scales, m$seasonality.prior.scale * dim(features)[2]);
   }
   if(!is.null(m$holidays)) {
-    seasonal.features <- cbind(
-      seasonal.features,
-      make_holiday_features(m, df$ds))
+    features = make_holiday_features(m, df$ds);
+    seasonal.features <- cbind(seasonal.features, features) #test
+    prior_scales <- c(prior_scales, m$holiday_prior_scale * dim(features)[2]);
+  }
+
+  # Additional regressors
+  for(name in names(m$extra_regressors)){
+    seasonal.features = cbind(seasonal.features, df[name]); #test
+    prior_scales = cbind(prior_scales, m$extra_regressors[[name]][[prior_scale]])
+  }
+
+  if(length(df) == 0){
+    seasonal.features =cbind(seasonal.features,data.frame(zeros = rep(0, nrow(df))));
+    prior_scales = c(prior_scales,0.1)
+  }
+
+  return(list(seasonal.features=seasonal.features, prior_scales=prior_scales))
+}
+
+#' Get number of Fourier components for built-in seasonalities.
+#'
+#' @param m Prophet object.
+#' @param name String name of the seasonality component.
+#' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
+#'  provided.
+#' @param auto.disable Bool if seasonality should be disabled when 'auto'.
+#' @param default.order Int default Fourier order.
+#'
+#' @return Number of Fourier components, or 0 for disabled.
+#'
+#' @keywords internal
+parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
+  if (arg == 'auto') {
+    fourier.order <- 0
+    if (name %in% names(m$seasonalities)) {
+      warning('Found custom seasonality named "', name,
+              '", disabling built-in ', name, ' seasonality.')
+    } else if (auto.disable) {
+      warning('Disabling ', name, ' seasonality. Run prophet with ', name,
+              '.seasonality=TRUE to override this.')
+    } else {
+      fourier.order <- default.order
+    }
+  } else if (arg == TRUE) {
+    fourier.order <- default.order
+  } else if (arg == FALSE) {
+    fourier.order <- 0
+  } else {
+    fourier.order <- arg
   }
-  return(seasonal.features)
+  return(fourier.order)
 }
 
 #' Get number of Fourier components for built-in seasonalities.
@@ -666,7 +802,8 @@ fit.prophet <- function(m, df, ...) {
   m <- out$m
   m$history <- history
   m <- set_auto_seasonalities(m)
-  seasonal.features <- make_all_seasonality_features(m, history)
+  seasonal.features <- make_all_seasonality_features(m, history)[[1]]
+  prior_scales <- make_all_seasonality_features(m, history)[[2]]
 
   m <- set_changepoints(m)
   A <- get_changepoint_matrix(m)
@@ -681,7 +818,7 @@ fit.prophet <- function(m, df, ...) {
     A = A,
     t_change = array(m$changepoints.t),
     X = as.matrix(seasonal.features),
-    sigma = m$seasonality.prior.scale,
+    sigma = prior_scales,
     tau = m$changepoint.prior.scale
   )
 
@@ -882,7 +1019,7 @@ predict_trend <- function(model, df) {
 #'
 #' @keywords internal
 predict_seasonal_components <- function(m, df) {
-  seasonal.features <- make_all_seasonality_features(m, df)
+  seasonal.features <- make_all_seasonality_features(m, df)[[1]]
   lower.p <- (1 - m$interval.width)/2
   upper.p <- (1 + m$interval.width)/2
 
@@ -893,32 +1030,65 @@ predict_seasonal_components <- function(m, df) {
                     extra = "merge", fill = "right") %>%
     dplyr::filter(component != 'zeros')
 
-  if (nrow(components) > 0) {
-    component.predictions <- components %>%
-      dplyr::group_by(component) %>% dplyr::do({
-        comp <- (as.matrix(seasonal.features[, .$col])
-                 %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
-        dplyr::data_frame(ix = 1:nrow(seasonal.features),
-                          mean = rowMeans(comp, na.rm = TRUE),
-                          lower = apply(comp, 1, stats::quantile, lower.p,
-                                        na.rm = TRUE),
-                          upper = apply(comp, 1, stats::quantile, upper.p,
-                                        na.rm = TRUE))
-      }) %>%
-      tidyr::gather(stat, value, mean, lower, upper) %>%
-      dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
-      tidyr::unite(component, component, stat, sep="") %>%
-      tidyr::spread(component, value) %>%
-      dplyr::select(-ix)
-
-    component.predictions$seasonal <- rowSums(
-      component.predictions[unique(components$component)])
-  } else {
-    component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
+  #components <-
+  components <- rbind(components[,c(1,3)], data.frame("component"=rep("seasonal"),
+                                                      "col"=c(1:dim(seasonal.features)[2])));
+
+  components <- add_group_component(m,components, 'seasonalities', names(m$seasonalities));
+
+  if(!is.null(m$holidays)){
+    components <- add_group_component(m,components, 'holidays', unique(m$holidays$holiday));
   }
+
+  components <- add_group_component(m,components, 'extra_regressors', names(m$extra_regressors));
+  # I am stuck on here: I am little confused that do I need to set
+  #                     components as list or dataframe  ??
+  #
+  #if (nrow(components) > 0) {
+  component.predictions <- components %>%
+    dplyr::group_by(component) %>% dplyr::do({
+      comp <- (as.matrix(seasonal.features[, .$col])
+               %*% t(m$params$beta[, .$col, drop = FALSE])) * m$y.scale
+      dplyr::data_frame(ix = 1:nrow(seasonal.features),
+                        mean = rowMeans(comp, na.rm = TRUE),
+                        lower = apply(comp, 1, stats::quantile, lower.p,
+                                      na.rm = TRUE),
+                        upper = apply(comp, 1, stats::quantile, upper.p,
+                                      na.rm = TRUE))
+    }) %>%
+    tidyr::gather(stat, value, mean, lower, upper) %>%
+    dplyr::mutate(stat = ifelse(stat == 'mean', '', paste0('_', stat))) %>%
+    tidyr::unite(component, component, stat, sep="") %>%
+    tidyr::spread(component, value) %>%
+    dplyr::select(-ix)
+
+  component.predictions$seasonal <- rowSums(
+    component.predictions[unique(components$component)])
+  #  } else {
+  #    component.predictions <- data.frame(seasonal = rep(0, nrow(df)))
+  #  }
   return(component.predictions)
 }
 
+#' Adds a component with given name that contains all of the components
+#' in group.
+#'
+#' @param m Prophet object.
+#' @param components Dataframe with components.
+#' @param name Name of new group component.
+#' @param group  List of components that form the group.
+#'
+#' @return Dataframe with components.
+#'
+#' @keywords internal
+add_group_component <- function(m, components, name, group) {
+
+  loc = (components$component %in% group);
+  new_comp = components[loc,];
+  new_comp$component = name;
+  components= rbind(components, new_comp);
+  return(components);
+}
 #' Prophet posterior predictive samples.
 #'
 #' @param m Prophet object.
@@ -933,7 +1103,7 @@ sample_posterior_predictive <- function(m, df) {
   samp.per.iter <- max(1, ceiling(m$uncertainty.samples / n.iterations))
   nsamp <- n.iterations * samp.per.iter  # The actual number of samples
 
-  seasonal.features <- make_all_seasonality_features(m, df)
+  seasonal.features <- make_all_seasonality_features(m, df)[[1]]
   sim.values <- list("trend" = matrix(, nrow = nrow(df), ncol = nsamp),
                      "seasonal" = matrix(, nrow = nrow(df), ncol = nsamp),
                      "yhat" = matrix(, nrow = nrow(df), ncol = nsamp))
@@ -950,7 +1120,7 @@ sample_posterior_predictive <- function(m, df) {
     }
   }
   return(sim.values)
-}  
+}
 
 #' Sample from the posterior predictive distribution.
 #'
@@ -959,13 +1129,14 @@ sample_posterior_predictive <- function(m, df) {
 #'  (column cap) if logistic growth.
 #'
 #' @return A list with items "trend", "seasonal", and "yhat" containing
-#'  posterior predictive samples for that component.
+#'  posterior predictive samples for that component. "seasonal" is the sum
+#'  of seasonalities, holidays, and added regressors.
 #'
 #' @export
 predictive_samples <- function(m, df) {
-    df <- setup_dataframe(m, df)$df
-    sim.values <- sample_posterior_predictive(m, df)
-    return(sim.values)
+  df <- setup_dataframe(m, df)$df
+  sim.values <- sample_posterior_predictive(m, df)
+  return(sim.values)
 }
 
 #' Prophet uncertainty intervals.
@@ -1197,8 +1368,8 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
 #' @export
 #' @importFrom dplyr "%>%"
 prophet_plot_components <- function(
-    m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
-    yearly_start = 0) {
+  m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
+  yearly_start = 0) {
   df <- df_for_plotting(m, fcst)
   # Plot the trend
   panels <- list(plot_trend(df, uncertainty, plot_cap))
@@ -1287,11 +1458,11 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
   if (uncertainty) {
     gg.holidays <- gg.holidays +
-    ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
-                                      ymax = holidays_upper),
-                         alpha = 0.2,
-                         fill = "#0072B2",
-                         na.rm = TRUE)
+      ggplot2::geom_ribbon(ggplot2::aes(ymin = holidays_lower,
+                                        ymax = holidays_upper),
+                           alpha = 0.2,
+                           fill = "#0072B2",
+                           na.rm = TRUE)
   }
   return(gg.holidays)
 }
@@ -1311,9 +1482,10 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
   # Compute weekly seasonality for a Sun-Sat sequence of dates.
   df.w <- data.frame(
     ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
-    weekly_start, cap=1.)
+      weekly_start, cap=1.)
   df.w <- setup_dataframe(m, df.w)$df
   seas <- predict_seasonal_components(m, df.w)
+  print(seas)
   seas$dow <- factor(weekdays(df.w$ds), levels=weekdays(df.w$ds))
 
   gg.weekly <- ggplot2::ggplot(seas, ggplot2::aes(x = dow, y = weekly,
@@ -1322,11 +1494,11 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
     ggplot2::labs(x = "Day of week")
   if (uncertainty) {
     gg.weekly <- gg.weekly +
-    ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
-                                      ymax = weekly_upper),
-                         alpha = 0.2,
-                         fill = "#0072B2",
-                         na.rm = TRUE)
+      ggplot2::geom_ribbon(ggplot2::aes(ymin = weekly_lower,
+                                        ymax = weekly_upper),
+                           alpha = 0.2,
+                           fill = "#0072B2",
+                           na.rm = TRUE)
   }
   return(gg.weekly)
 }
@@ -1346,7 +1518,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
   # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
   df.y <- data.frame(
     ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
-    yearly_start, cap=1.)
+      yearly_start, cap=1.)
   df.y <- setup_dataframe(m, df.y)$df
   seas <- predict_seasonal_components(m, df.y)
   seas$ds <- df.y$ds
@@ -1358,11 +1530,11 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
     ggplot2::scale_x_datetime(labels = scales::date_format('%B %d'))
   if (uncertainty) {
     gg.yearly <- gg.yearly +
-    ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
-                                      ymax = yearly_upper),
-                         alpha = 0.2,
-                         fill = "#0072B2",
-                         na.rm = TRUE)
+      ggplot2::geom_ribbon(ggplot2::aes(ymin = yearly_lower,
+                                        ymax = yearly_upper),
+                           alpha = 0.2,
+                           fill = "#0072B2",
+                           na.rm = TRUE)
   }
   return(gg.yearly)
 }
@@ -1449,6 +1621,51 @@ prophet_copy <- function(m, cutoff = NULL) {
     uncertainty.samples = m$uncertainty.samples,
     fit = FALSE,
   ))
+
+#' Sample from the posterior predictive distribution.
+#'
+#' @param m Prophet model object.
+#' @param name String name of the seasonality.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+#'
+#' @keywords internal
+plot_seasonality <- function(m, name, uncertainty = TRUE) {
+  # Compute seasonality from Jan 1 through a single period.
+  start <- set_date('2017-01-01')
+  period <- m$seasonalities[[name]][1]
+  end <- start + period * 24 * 3600
+  plot.points <- 200
+  df.y <- data.frame(
+    ds=seq(from=start, to=end, length.out=plot.points), cap=1.)
+  df.y <- setup_dataframe(m, df.y)$df
+  seas <- predict_seasonal_components(m, df.y)
+  seas$ds <- df.y$ds
+  gg.s <- ggplot2::ggplot(
+    seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
+  if (period <= 2) {
+    fmt.str <- '%T'
+  }
+  else if (period < 14) {
+    fmt.str <- '%m/%d %R'
+  } else {
+    fmt.str <- '%m/%d'
+  }
+  gg.s <- gg.s +
+    ggplot2::scale_x_datetime(labels = scales::date_format(fmt.str))
+  if (uncertainty) {
+    gg.s <- gg.s +
+      ggplot2::geom_ribbon(
+        ggplot2::aes_string(
+          ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
+        ),
+        alpha = 0.2,
+        fill = "#0072B2",
+        na.rm = TRUE)
+  }
+  return(gg.s)
 }
 
 # fb-block 3

+ 2 - 2
R/inst/stan/prophet_linear_growth.stan

@@ -7,7 +7,7 @@ data {
   matrix[T, S] A;                   // Split indicators
   real t_change[S];                 // Index of changepoints
   matrix[T,K] X;                // season vectors
-  real<lower=0> sigma;              // scale on seasonality prior
+  vector[K] sigmas;              // scale on seasonality prior
   real<lower=0> tau;                  // scale on changepoints prior
 }
 
@@ -33,7 +33,7 @@ model {
   m ~ normal(0, 5);
   delta ~ double_exponential(0, tau);
   sigma_obs ~ normal(0, 0.5);
-  beta ~ normal(0, sigma);
+  beta ~ normal(0, sigmas);
 
   // Likelihood
   y ~ normal((k + A * delta) .* t + (m + A * gamma) + X * beta, sigma_obs);

+ 2 - 2
R/inst/stan/prophet_logistic_growth.stan

@@ -8,7 +8,7 @@ data {
   matrix[T, S] A;                   // Split indicators
   real t_change[S];                 // Index of changepoints
   matrix[T,K] X;                    // season vectors
-  real<lower=0> sigma;              // scale on seasonality prior
+  vector[K] sigmas;               // scale on seasonality prior
   real<lower=0> tau;                  // scale on changepoints prior
 }
 
@@ -45,7 +45,7 @@ model {
   m ~ normal(0, 5);
   delta ~ double_exponential(0, tau);
   sigma_obs ~ normal(0, 0.1);
-  beta ~ normal(0, sigma);
+  beta ~ normal(0, sigmas);
 
   // Likelihood
   y ~ normal(cap ./ (1 + exp(-(k + A * delta) .* (t - (m + A * gamma)))) + X * beta, sigma_obs);

+ 2 - 2
R/man/plot.prophet.Rd

@@ -4,8 +4,8 @@
 \alias{plot.prophet}
 \title{Plot the prophet forecast.}
 \usage{
-\method{plot}{prophet}(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
-  xlabel = "ds", ylabel = "y", ...)
+plot.prophet(x, fcst, uncertainty = TRUE, plot_cap = TRUE, xlabel = "ds",
+  ylabel = "y", ...)
 }
 \arguments{
 \item{x}{Prophet object.}

+ 1 - 1
docs/_data/nav.yml

@@ -1,5 +1,5 @@
 - title: Docs
-  href: /docs/
+  href: docs/
   category: docs
 
 - title: GitHub