|
@@ -8,7 +8,7 @@
|
|
|
## Makes R CMD CHECK happy due to dplyr syntax below
|
|
|
globalVariables(c(
|
|
|
"ds", "y", "cap", ".",
|
|
|
- "component", "dow", "doy", "holiday", "holidays", "append.holidays", "holidays_lower",
|
|
|
+ "component", "dow", "doy", "holiday", "holidays", "holidays_lower",
|
|
|
"holidays_upper", "ix", "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
|
|
|
"trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
|
|
|
"x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
|
|
@@ -43,7 +43,6 @@ globalVariables(c(
|
|
|
#' range of days around the date to be included as holidays. lower_window=-2
|
|
|
#' will include 2 days prior to the date as holidays. Also optionally can have
|
|
|
#' a column prior_scale specifying the prior scale for each holiday.
|
|
|
-#' @param append.holidays country name or abbreviation (character).
|
|
|
#' @param seasonality.mode 'additive' (default) or 'multiplicative'.
|
|
|
#' @param seasonality.prior.scale Parameter modulating the strength of the
|
|
|
#' seasonality model. Larger values allow the model to fit larger seasonal
|
|
@@ -88,7 +87,6 @@ prophet <- function(df = NULL,
|
|
|
weekly.seasonality = 'auto',
|
|
|
daily.seasonality = 'auto',
|
|
|
holidays = NULL,
|
|
|
- append.holidays = NULL,
|
|
|
seasonality.mode = 'additive',
|
|
|
seasonality.prior.scale = 10,
|
|
|
holidays.prior.scale = 10,
|
|
@@ -112,7 +110,6 @@ prophet <- function(df = NULL,
|
|
|
weekly.seasonality = weekly.seasonality,
|
|
|
daily.seasonality = daily.seasonality,
|
|
|
holidays = holidays,
|
|
|
- append.holidays = append.holidays,
|
|
|
seasonality.mode = seasonality.mode,
|
|
|
seasonality.prior.scale = seasonality.prior.scale,
|
|
|
changepoint.prior.scale = changepoint.prior.scale,
|
|
@@ -128,6 +125,7 @@ prophet <- function(df = NULL,
|
|
|
changepoints.t = NULL,
|
|
|
seasonalities = list(),
|
|
|
extra_regressors = list(),
|
|
|
+ country_holidays = NULL,
|
|
|
stan.fit = NULL,
|
|
|
params = list(),
|
|
|
history = NULL,
|
|
@@ -181,11 +179,6 @@ validate_inputs <- function(m) {
|
|
|
validate_column_name(m, h, check_holidays = FALSE)
|
|
|
}
|
|
|
}
|
|
|
- if (!is.null(m$append.holidays)) {
|
|
|
- if (!(m$append.holidays %in% generated_holidays$country)){
|
|
|
- stop("Holidays in ", m$append.holidays," are not currently supported!")
|
|
|
- }
|
|
|
- }
|
|
|
if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
|
|
|
stop("seasonality.mode must be 'additive' or 'multiplicative'")
|
|
|
}
|
|
@@ -223,9 +216,9 @@ validate_column_name <- function(
|
|
|
(name %in% unique(m$holidays$holiday))){
|
|
|
stop("Name ", name, " already used for a holiday.")
|
|
|
}
|
|
|
- if(check_holidays & !is.null(m$append.holidays)){
|
|
|
- if(name %in% get_holiday_names(m$append.holidays)){
|
|
|
- stop("Name ", name, " is a holiday name in ", m$append.holidays, ".")
|
|
|
+ if(check_holidays & !is.null(m$country_holidays)){
|
|
|
+ if(name %in% get_holiday_names(m$country_holidays)){
|
|
|
+ stop("Name ", name, " is a holiday name in ", m$country_holidays, ".")
|
|
|
}
|
|
|
}
|
|
|
if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){
|
|
@@ -533,10 +526,46 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|
|
return(data.frame(features))
|
|
|
}
|
|
|
|
|
|
+#' Construct a dataframe of holiday dates.
|
|
|
+#'
|
|
|
+#' @param m Prophet object.
|
|
|
+#' @param dates Vector with dates used for computing seasonality.
|
|
|
+#'
|
|
|
+#' @return A dataframe of holiday dates, in holiday dataframe format used in
|
|
|
+#' initialization.
|
|
|
+#'
|
|
|
+#' @importFrom dplyr "%>%"
|
|
|
+#' @keywords internal
|
|
|
+construct_holiday_dataframe <- function(m, dates) {
|
|
|
+ all.holidays <- data.frame()
|
|
|
+ if (!is.null(m$holidays)){
|
|
|
+ all.holidays <- m$holidays
|
|
|
+ }
|
|
|
+ if (!is.null(m$country_holidays)) {
|
|
|
+ year.list <- as.numeric(unique(format(dates, "%Y")))
|
|
|
+ country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
|
|
|
+ dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
|
|
|
+ all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
|
|
|
+ }
|
|
|
+ # If the model has already been fit with a certain set of holidays,
|
|
|
+ # make sure we are using those same ones.
|
|
|
+ if (!is.null(m$train.holiday.names)) {
|
|
|
+ row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
|
|
|
+ all.holidays <- all.holidays[row.to.keep,]
|
|
|
+ holidays.to.add <- data.frame(
|
|
|
+ holiday=setdiff(m$train.holiday.names, all.holidays$holiday)
|
|
|
+ )
|
|
|
+ all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
|
|
|
+ }
|
|
|
+ return(all.holidays)
|
|
|
+}
|
|
|
+
|
|
|
#' Construct a matrix of holiday features.
|
|
|
#'
|
|
|
#' @param m Prophet object.
|
|
|
#' @param dates Vector with dates used for computing seasonality.
|
|
|
+#' @param holidays Dataframe containing holidays, as returned by
|
|
|
+#' construct_holiday_dataframe.
|
|
|
#'
|
|
|
#' @return A list with entries
|
|
|
#' holiday.features: dataframe with a column for each holiday.
|
|
@@ -545,28 +574,10 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
|
|
|
#'
|
|
|
#' @importFrom dplyr "%>%"
|
|
|
#' @keywords internal
|
|
|
-make_holiday_features <- function(m, dates) {
|
|
|
+make_holiday_features <- function(m, dates, holidays) {
|
|
|
# Strip dates to be just days, for joining on holidays
|
|
|
dates <- set_date(format(dates, "%Y-%m-%d"))
|
|
|
- all.holidays <- m$holidays
|
|
|
- if (!is.null(m$append.holidays)){
|
|
|
- years <- as.numeric(unique(format(dates, "%Y")))
|
|
|
- append.holidays.df <- make_holidays_df(years, m$append.holidays) %>%
|
|
|
- dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
|
|
|
- all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, append.holidays.df))
|
|
|
- }
|
|
|
- # Make fit.prophet and predict.prophet holidays components match
|
|
|
- if (!is.null(m$append.holidays) && !is.null(m$train.holiday.names)){
|
|
|
- row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
|
|
|
- all.holidays <- all.holidays[row.to.keep,]
|
|
|
- holidays.to.add <- data.frame(holiday=setdiff(m$train.holiday.names,
|
|
|
- all.holidays$holiday))
|
|
|
- all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
|
|
|
- }
|
|
|
- if (nrow(all.holidays)==0){
|
|
|
- return(NULL)
|
|
|
- }
|
|
|
- wide <- all.holidays %>%
|
|
|
+ wide <- holidays %>%
|
|
|
dplyr::mutate(ds = set_date(ds)) %>%
|
|
|
dplyr::group_by(holiday, ds) %>%
|
|
|
dplyr::filter(dplyr::row_number() == 1) %>%
|
|
@@ -587,17 +598,17 @@ make_holiday_features <- function(m, dates) {
|
|
|
holiday.features <- data.frame(ds = set_date(dates)) %>%
|
|
|
dplyr::left_join(wide, by = 'ds') %>%
|
|
|
dplyr::select(-ds)
|
|
|
- # Make sure fit.prophet and predict.prophet component.cols perfectly equal
|
|
|
+ # Make sure column order is consistent
|
|
|
holiday.features <- holiday.features %>% dplyr::select(sort(names(.)))
|
|
|
holiday.features[is.na(holiday.features)] <- 0
|
|
|
-
|
|
|
+
|
|
|
# Prior scales
|
|
|
- if (!('prior_scale' %in% colnames(all.holidays))) {
|
|
|
- all.holidays$prior_scale <- m$holidays.prior.scale
|
|
|
+ if (!('prior_scale' %in% colnames(holidays))) {
|
|
|
+ holidays$prior_scale <- m$holidays.prior.scale
|
|
|
}
|
|
|
prior.scales.list <- list()
|
|
|
- for (name in unique(all.holidays$holiday)) {
|
|
|
- df.h <- all.holidays[all.holidays$holiday == name, ]
|
|
|
+ for (name in unique(holidays$holiday)) {
|
|
|
+ df.h <- holidays[holidays$holiday == name, ]
|
|
|
ps <- unique(df.h$prior_scale)
|
|
|
if (length(ps) > 1) {
|
|
|
stop('Holiday ', name, ' does not have a consistent prior scale ',
|
|
@@ -707,7 +718,6 @@ add_regressor <- function(
|
|
|
#'
|
|
|
#' @return The prophet model with the seasonality added.
|
|
|
#'
|
|
|
-#' @importFrom dplyr "%>%"
|
|
|
#' @export
|
|
|
add_seasonality <- function(
|
|
|
m, name, period, fourier.order, prior.scale = NULL, mode = NULL
|
|
@@ -742,6 +752,46 @@ add_seasonality <- function(
|
|
|
return(m)
|
|
|
}
|
|
|
|
|
|
+#' Add in built-in holidays for the specified country.
|
|
|
+#'
|
|
|
+#' These holidays will be included in addition to any specified on model
|
|
|
+#' initialization.
|
|
|
+#'
|
|
|
+#' Holidays will be calculated for arbitrary date ranges in the history
|
|
|
+#' and future. See the online documentation for the list of countries with
|
|
|
+#' built-in holidays.
|
|
|
+#'
|
|
|
+#' Built-in country holidays can only be set for a single country.
|
|
|
+#'
|
|
|
+#' @param m Prophet object.
|
|
|
+#' @param country_name Name of the country, like 'UnitedStates' or 'US'
|
|
|
+#'
|
|
|
+#' @return The prophet model with the holidays country set.
|
|
|
+#'
|
|
|
+#' @export
|
|
|
+add_country_holidays <- function(m, country_name) {
|
|
|
+ if (!is.null(m$history)) {
|
|
|
+ stop("Country holidays must be added prior to model fitting.")
|
|
|
+ }
|
|
|
+ if (!(country_name %in% generated_holidays$country)){
|
|
|
+ stop("Holidays in ", country_name," are not currently supported!")
|
|
|
+ }
|
|
|
+ # Validate names.
|
|
|
+ for (name in get_holiday_names(country_name)) {
|
|
|
+ # Allow merging with existing holidays
|
|
|
+ validate_column_name(m, name, check_holidays = FALSE)
|
|
|
+ }
|
|
|
+ # Set the holidays.
|
|
|
+ if (!is.null(m$country_holidays)) {
|
|
|
+ message(
|
|
|
+ 'Changing country holidays from ', m$country_holidays, ' to ',
|
|
|
+ country_name
|
|
|
+ )
|
|
|
+ }
|
|
|
+ m$country_holidays = country_name
|
|
|
+ return(m)
|
|
|
+}
|
|
|
+
|
|
|
#' Dataframe with seasonality features.
|
|
|
#' Includes seasonality features, holiday features, and added regressors.
|
|
|
#'
|
|
@@ -776,15 +826,15 @@ make_all_seasonality_features <- function(m, df) {
|
|
|
}
|
|
|
|
|
|
# Holiday features
|
|
|
- if (!is.null(m$holidays) || !is.null(m$append.holidays)) {
|
|
|
- out <- make_holiday_features(m, df$ds)
|
|
|
- if (!is.null(out)){
|
|
|
- m <- out$m
|
|
|
- seasonal.features <- cbind(seasonal.features, out$holiday.features)
|
|
|
- prior.scales <- c(prior.scales, out$prior.scales)
|
|
|
- modes[[m$seasonality.mode]] <- c(
|
|
|
- modes[[m$seasonality.mode]], out$holiday.names)
|
|
|
- }
|
|
|
+ holidays <- construct_holiday_dataframe(m, df$ds)
|
|
|
+ if (nrow(holidays) > 0) {
|
|
|
+ out <- make_holiday_features(m, df$ds, holidays)
|
|
|
+ m <- out$m
|
|
|
+ seasonal.features <- cbind(seasonal.features, out$holiday.features)
|
|
|
+ prior.scales <- c(prior.scales, out$prior.scales)
|
|
|
+ modes[[m$seasonality.mode]] <- c(
|
|
|
+ modes[[m$seasonality.mode]], out$holiday.names
|
|
|
+ )
|
|
|
}
|
|
|
|
|
|
# Additional regressors
|