Преглед на файлове

Move built-in country holidays to a function (R)

Ben Letham преди 6 години
родител
ревизия
287fb2f6de
променени са 9 файла, в които са добавени 183 реда и са изтрити 83 реда
  1. 1 1
      R/DESCRIPTION
  2. 1 0
      R/NAMESPACE
  3. 98 48
      R/R/prophet.R
  4. 27 0
      R/man/add_country_holidays.Rd
  5. 21 0
      R/man/construct_holiday_dataframe.Rd
  6. 4 1
      R/man/make_holiday_features.Rd
  7. 4 6
      R/man/prophet.Rd
  8. 22 24
      R/tests/testthat/test_prophet.R
  9. 5 3
      python/fbprophet/forecaster.py

+ 1 - 1
R/DESCRIPTION

@@ -32,7 +32,7 @@ Suggests:
     readr
 License: BSD_3_clause + file LICENSE
 LazyData: true
-RoxygenNote: 6.1.0
+RoxygenNote: 6.1.1
 VignetteBuilder: knitr
 SystemRequirements: C++11
 Encoding: UTF-8

+ 1 - 0
R/NAMESPACE

@@ -3,6 +3,7 @@
 S3method(plot,prophet)
 S3method(predict,prophet)
 export(add_changepoints_to_plot)
+export(add_country_holidays)
 export(add_regressor)
 export(add_seasonality)
 export(cross_validation)

+ 98 - 48
R/R/prophet.R

@@ -8,7 +8,7 @@
 ## Makes R CMD CHECK happy due to dplyr syntax below
 globalVariables(c(
   "ds", "y", "cap", ".",
-  "component", "dow", "doy", "holiday", "holidays", "append.holidays", "holidays_lower", 
+  "component", "dow", "doy", "holiday", "holidays", "holidays_lower",
   "holidays_upper", "ix", "lower", "n", "stat", "trend", "row_number", "extra_regressors", "col",
   "trend_lower", "trend_upper", "upper", "value", "weekly", "weekly_lower", "weekly_upper",
   "x", "yearly", "yearly_lower", "yearly_upper", "yhat", "yhat_lower", "yhat_upper"))
@@ -43,7 +43,6 @@ globalVariables(c(
 #'  range of days around the date to be included as holidays. lower_window=-2
 #'  will include 2 days prior to the date as holidays. Also optionally can have
 #'  a column prior_scale specifying the prior scale for each holiday.
-#' @param append.holidays country name or abbreviation (character).
 #' @param seasonality.mode 'additive' (default) or 'multiplicative'.
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #'  seasonality model. Larger values allow the model to fit larger seasonal
@@ -88,7 +87,6 @@ prophet <- function(df = NULL,
                     weekly.seasonality = 'auto',
                     daily.seasonality = 'auto',
                     holidays = NULL,
-                    append.holidays = NULL,
                     seasonality.mode = 'additive',
                     seasonality.prior.scale = 10,
                     holidays.prior.scale = 10,
@@ -112,7 +110,6 @@ prophet <- function(df = NULL,
     weekly.seasonality = weekly.seasonality,
     daily.seasonality = daily.seasonality,
     holidays = holidays,
-    append.holidays = append.holidays,
     seasonality.mode = seasonality.mode,
     seasonality.prior.scale = seasonality.prior.scale,
     changepoint.prior.scale = changepoint.prior.scale,
@@ -128,6 +125,7 @@ prophet <- function(df = NULL,
     changepoints.t = NULL,
     seasonalities = list(),
     extra_regressors = list(),
+    country_holidays = NULL,
     stan.fit = NULL,
     params = list(),
     history = NULL,
@@ -181,11 +179,6 @@ validate_inputs <- function(m) {
       validate_column_name(m, h, check_holidays = FALSE)
     }
   }
-  if (!is.null(m$append.holidays)) {
-    if (!(m$append.holidays %in% generated_holidays$country)){
-      stop("Holidays in ", m$append.holidays," are not currently supported!")
-    }
-  }
   if (!(m$seasonality.mode %in% c('additive', 'multiplicative'))) {
     stop("seasonality.mode must be 'additive' or 'multiplicative'")
   }
@@ -223,9 +216,9 @@ validate_column_name <- function(
      (name %in% unique(m$holidays$holiday))){
     stop("Name ", name, " already used for a holiday.")
   }
-  if(check_holidays & !is.null(m$append.holidays)){
-    if(name %in% get_holiday_names(m$append.holidays)){
-      stop("Name ", name, " is a holiday name in ", m$append.holidays, ".")
+  if(check_holidays & !is.null(m$country_holidays)){
+    if(name %in% get_holiday_names(m$country_holidays)){
+      stop("Name ", name, " is a holiday name in ", m$country_holidays, ".")
     }
   }
   if(check_seasonalities & (!is.null(m$seasonalities[[name]]))){
@@ -533,10 +526,46 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
   return(data.frame(features))
 }
 
+#' Construct a dataframe of holiday dates.
+#'
+#' @param m Prophet object.
+#' @param dates Vector with dates used for computing seasonality.
+#'
+#' @return A dataframe of holiday dates, in holiday dataframe format used in
+#'  initialization.
+#'
+#' @importFrom dplyr "%>%"
+#' @keywords internal
+construct_holiday_dataframe <- function(m, dates) {
+  all.holidays <- data.frame()
+  if (!is.null(m$holidays)){
+    all.holidays <- m$holidays
+  }
+  if (!is.null(m$country_holidays)) {
+    year.list <- as.numeric(unique(format(dates, "%Y")))
+    country.holidays.df <- make_holidays_df(year.list, m$country_holidays) %>%
+      dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
+    all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, country.holidays.df))
+  }
+  # If the model has already been fit with a certain set of holidays,
+  # make sure we are using those same ones.
+  if (!is.null(m$train.holiday.names)) {
+    row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
+    all.holidays <- all.holidays[row.to.keep,]
+    holidays.to.add <- data.frame(
+      holiday=setdiff(m$train.holiday.names, all.holidays$holiday)
+    )
+    all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
+  }
+  return(all.holidays)
+}
+
 #' Construct a matrix of holiday features.
 #'
 #' @param m Prophet object.
 #' @param dates Vector with dates used for computing seasonality.
+#' @param holidays Dataframe containing holidays, as returned by
+#'  construct_holiday_dataframe.
 #'
 #' @return A list with entries
 #'  holiday.features: dataframe with a column for each holiday.
@@ -545,28 +574,10 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
 #'
 #' @importFrom dplyr "%>%"
 #' @keywords internal
-make_holiday_features <- function(m, dates) {
+make_holiday_features <- function(m, dates, holidays) {
   # Strip dates to be just days, for joining on holidays
   dates <- set_date(format(dates, "%Y-%m-%d"))
-  all.holidays <- m$holidays 
-  if (!is.null(m$append.holidays)){
-    years <- as.numeric(unique(format(dates, "%Y")))
-    append.holidays.df <- make_holidays_df(years, m$append.holidays) %>%
-      dplyr::mutate(ds=as.character(ds), holiday=as.character(holiday))
-    all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, append.holidays.df))
-  }
-  # Make fit.prophet and predict.prophet holidays components match
-  if (!is.null(m$append.holidays) && !is.null(m$train.holiday.names)){
-    row.to.keep <- which(all.holidays$holiday %in% m$train.holiday.names)
-    all.holidays <- all.holidays[row.to.keep,]
-    holidays.to.add <- data.frame(holiday=setdiff(m$train.holiday.names,
-                                                    all.holidays$holiday))
-    all.holidays <- suppressWarnings(dplyr::bind_rows(all.holidays, holidays.to.add))
-  }
-  if (nrow(all.holidays)==0){
-    return(NULL)
-  }
-  wide <- all.holidays %>%
+  wide <- holidays %>%
     dplyr::mutate(ds = set_date(ds)) %>%
     dplyr::group_by(holiday, ds) %>%
     dplyr::filter(dplyr::row_number() == 1) %>%
@@ -587,17 +598,17 @@ make_holiday_features <- function(m, dates) {
   holiday.features <- data.frame(ds = set_date(dates)) %>%
     dplyr::left_join(wide, by = 'ds') %>%
     dplyr::select(-ds)
-  # Make sure fit.prophet and predict.prophet component.cols perfectly equal
+  # Make sure column order is consistent
   holiday.features <- holiday.features %>% dplyr::select(sort(names(.)))
   holiday.features[is.na(holiday.features)] <- 0
-  
+
   # Prior scales
-  if (!('prior_scale' %in% colnames(all.holidays))) {
-    all.holidays$prior_scale <- m$holidays.prior.scale
+  if (!('prior_scale' %in% colnames(holidays))) {
+    holidays$prior_scale <- m$holidays.prior.scale
   }
   prior.scales.list <- list()
-  for (name in unique(all.holidays$holiday)) {
-    df.h <- all.holidays[all.holidays$holiday == name, ]
+  for (name in unique(holidays$holiday)) {
+    df.h <- holidays[holidays$holiday == name, ]
     ps <- unique(df.h$prior_scale)
     if (length(ps) > 1) {
       stop('Holiday ', name, ' does not have a consistent prior scale ',
@@ -707,7 +718,6 @@ add_regressor <- function(
 #'
 #' @return The prophet model with the seasonality added.
 #'
-#' @importFrom dplyr "%>%"
 #' @export
 add_seasonality <- function(
   m, name, period, fourier.order, prior.scale = NULL, mode = NULL
@@ -742,6 +752,46 @@ add_seasonality <- function(
   return(m)
 }
 
+#' Add in built-in holidays for the specified country.
+#'
+#' These holidays will be included in addition to any specified on model
+#' initialization.
+#'
+#' Holidays will be calculated for arbitrary date ranges in the history
+#' and future. See the online documentation for the list of countries with
+#' built-in holidays.
+#'
+#' Built-in country holidays can only be set for a single country.
+#'
+#' @param m Prophet object.
+#' @param country_name Name of the country, like 'UnitedStates' or 'US'
+#'
+#' @return The prophet model with the holidays country set.
+#'
+#' @export
+add_country_holidays <- function(m, country_name) {
+  if (!is.null(m$history)) {
+    stop("Country holidays must be added prior to model fitting.")
+  }
+  if (!(country_name %in% generated_holidays$country)){
+      stop("Holidays in ", country_name," are not currently supported!")
+    }
+  # Validate names.
+  for (name in get_holiday_names(country_name)) {
+    # Allow merging with existing holidays
+    validate_column_name(m, name, check_holidays = FALSE)
+  }
+  # Set the holidays.
+  if (!is.null(m$country_holidays)) {
+    message(
+      'Changing country holidays from ', m$country_holidays, ' to ',
+      country_name
+    )
+  }
+  m$country_holidays = country_name
+  return(m)
+}
+
 #' Dataframe with seasonality features.
 #' Includes seasonality features, holiday features, and added regressors.
 #'
@@ -776,15 +826,15 @@ make_all_seasonality_features <- function(m, df) {
   }
 
   # Holiday features
-  if (!is.null(m$holidays) || !is.null(m$append.holidays)) {
-    out <- make_holiday_features(m, df$ds)
-    if (!is.null(out)){
-      m <- out$m
-      seasonal.features <- cbind(seasonal.features, out$holiday.features)
-      prior.scales <- c(prior.scales, out$prior.scales)
-      modes[[m$seasonality.mode]] <- c(
-        modes[[m$seasonality.mode]], out$holiday.names)
-    }
+  holidays <- construct_holiday_dataframe(m, df$ds)
+  if (nrow(holidays) > 0) {
+    out <- make_holiday_features(m, df$ds, holidays)
+    m <- out$m
+    seasonal.features <- cbind(seasonal.features, out$holiday.features)
+    prior.scales <- c(prior.scales, out$prior.scales)
+    modes[[m$seasonality.mode]] <- c(
+      modes[[m$seasonality.mode]], out$holiday.names
+    )
   }
 
   # Additional regressors

+ 27 - 0
R/man/add_country_holidays.Rd

@@ -0,0 +1,27 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{add_country_holidays}
+\alias{add_country_holidays}
+\title{Add in built-in holidays for the specified country.}
+\usage{
+add_country_holidays(m, country_name)
+}
+\arguments{
+\item{m}{Prophet object.}
+
+\item{country_name}{Name of the country, like 'UnitedStates' or 'US'}
+}
+\value{
+The prophet model with the holidays country set.
+}
+\description{
+These holidays will be included in addition to any specified on model
+initialization.
+}
+\details{
+Holidays will be calculated for arbitrary date ranges in the history
+and future. See the online documentation for the list of countries with
+built-in holidays.
+
+Built-in country holidays can only be set for a single country.
+}

+ 21 - 0
R/man/construct_holiday_dataframe.Rd

@@ -0,0 +1,21 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{construct_holiday_dataframe}
+\alias{construct_holiday_dataframe}
+\title{Construct a dataframe of holiday dates.}
+\usage{
+construct_holiday_dataframe(m, dates)
+}
+\arguments{
+\item{m}{Prophet object.}
+
+\item{dates}{Vector with dates used for computing seasonality.}
+}
+\value{
+A dataframe of holiday dates, in holiday dataframe format used in
+ initialization.
+}
+\description{
+Construct a dataframe of holiday dates.
+}
+\keyword{internal}

+ 4 - 1
R/man/make_holiday_features.Rd

@@ -4,12 +4,15 @@
 \alias{make_holiday_features}
 \title{Construct a matrix of holiday features.}
 \usage{
-make_holiday_features(m, dates)
+make_holiday_features(m, dates, holidays)
 }
 \arguments{
 \item{m}{Prophet object.}
 
 \item{dates}{Vector with dates used for computing seasonality.}
+
+\item{holidays}{Dataframe containing holidays, as returned by
+construct_holiday_dataframe.}
 }
 \value{
 A list with entries

+ 4 - 6
R/man/prophet.Rd

@@ -8,10 +8,10 @@ prophet(df = NULL, growth = "linear", changepoints = NULL,
   n.changepoints = 25, changepoint.range = 0.8,
   yearly.seasonality = "auto", weekly.seasonality = "auto",
   daily.seasonality = "auto", holidays = NULL,
-  append.holidays = NULL, seasonality.mode = "additive",
-  seasonality.prior.scale = 10, holidays.prior.scale = 10,
-  changepoint.prior.scale = 0.05, mcmc.samples = 0,
-  interval.width = 0.8, uncertainty.samples = 1000, fit = TRUE, ...)
+  seasonality.mode = "additive", seasonality.prior.scale = 10,
+  holidays.prior.scale = 10, changepoint.prior.scale = 0.05,
+  mcmc.samples = 0, interval.width = 0.8, uncertainty.samples = 1000,
+  fit = TRUE, ...)
 }
 \arguments{
 \item{df}{(optional) Dataframe containing the history. Must have columns ds
@@ -51,8 +51,6 @@ range of days around the date to be included as holidays. lower_window=-2
 will include 2 days prior to the date as holidays. Also optionally can have
 a column prior_scale specifying the prior scale for each holiday.}
 
-\item{append.holidays}{country name or abbreviation (character).}
-
 \item{seasonality.mode}{'additive' (default) or 'multiplicative'.}
 
 \item{seasonality.prior.scale}{Parameter modulating the strength of the

+ 22 - 24
R/tests/testthat/test_prophet.R

@@ -259,7 +259,7 @@ test_that("holidays", {
     ds = seq(prophet:::set_date('2016-12-20'),
              prophet:::set_date('2016-12-31'), by='d'))
   m <- prophet(train, holidays = holidays, fit = FALSE)
-  out <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
   feats <- out$holiday.features
   priors <- out$prior.scales
   names <- out$holiday.names
@@ -274,7 +274,7 @@ test_that("holidays", {
                         lower_window = c(-1),
                         upper_window = c(10))
   m <- prophet(train, holidays = holidays, fit = FALSE)
-  out <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
   feats <- out$holiday.features
   priors <- out$prior.scales
   names <- out$holiday.names
@@ -291,7 +291,7 @@ test_that("holidays", {
     prior_scale = c(5., 5.)
   )
   m <- prophet(holidays = holidays, fit = FALSE)
-  out <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
   priors <- out$prior.scales
   names <- out$holiday.names
   expect_true(all(priors == c(5., 5.)))
@@ -306,7 +306,7 @@ test_that("holidays", {
   )
   holidays2 <- rbind(holidays, holidays2)
   m <- prophet(holidays = holidays2, fit = FALSE)
-  out <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
   priors <- out$prior.scales
   names <- out$holiday.names
   expect_true(all(priors == c(8, 8, 5, 5)))
@@ -324,7 +324,7 @@ test_that("holidays", {
   # manual factorizing to avoid above bind_rows() warning
   holidays2$holiday <- factor(holidays2$holiday)
   m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
-  out <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds, m$holidays)
   priors <- out$prior.scales
   expect_true(all(priors == c(4, 4, 5, 5)))
   # Check incompatible priors
@@ -336,7 +336,7 @@ test_that("holidays", {
     prior_scale = c(5., 6.)
   )
   m <- prophet(holidays = holidays, fit = FALSE)
-  expect_error(prophet:::make_holiday_features(m, df$ds))
+  expect_error(prophet:::make_holiday_features(m, df$ds, m$holidays))
 })
 
 test_that("fit_with_holidays", {
@@ -349,47 +349,45 @@ test_that("fit_with_holidays", {
   expect_error(predict(m), NA)
 })
 
-test_that("fit_with_append_holidays", {
+test_that("fit_with_country_holidays", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
                          holiday = c('seans-bday', 'seans-bday'),
                          lower_window = c(0, 0),
                          upper_window = c(1, 1))
-  append.holidays = 'US'
   # Test with holidays and append_holidays
-  m <- prophet(DATA, 
-               holidays = holidays, 
-               append.holidays = append.holidays, 
-               uncertainty.samples = 0)
+  m <- prophet(holidays = holidays, uncertainty.samples = 0)
+  m <- add_country_holidays(m, 'US')
+  m <- fit.prophet(m, DATA)
   expect_error(predict(m), NA)
   # There are training holidays missing in the test set
   train2 <- DATA %>% head(155)
   future2 <- DATA %>% tail(355)
-  model <- prophet(train2,
-                   append.holidays = append.holidays, 
-                   uncertainty.samples = 0)
+  m <- prophet(uncertainty.samples = 0)
+  m <- add_country_holidays(m, 'US')
+  m <- fit.prophet(m, train2)
   expect_error(predict(m, future2), NA)
   # There are test holidays missing in the training set
   train2 <- DATA %>% tail(355)
   future2 <- DATA2
-  model <- prophet(train2,
-                   append.holidays = append.holidays, 
-                   uncertainty.samples = 0)
+  m <- prophet(uncertainty.samples = 0)
+  m <- add_country_holidays(m, 'US')
+  m <- fit.prophet(m, train2)
   expect_error(predict(m, future2), NA)
   # Append_holidays with non-existing year
   max.year <- generated_holidays %>% 
-    dplyr::filter(country==append.holidays) %>%
+    dplyr::filter(country=='US') %>%
     dplyr::select(year) %>%
     max()
   train2 <- data.frame('ds'=c(paste(max.year+1, "-01-01", sep=''),
                               paste(max.year+1, "-01-02", sep='')),
                        'y'=1)
-  expect_warning(prophet(train2, 
-                         append.holidays = append.holidays))
+  m <- prophet()
+  m <- add_country_holidays(m, 'US')
+  expect_warning(m <- fit.prophet(m, train2))
   # Append_holidays with non-existing country
-  append.holidays = 'Utopia'
-  expect_error(prophet(DATA, 
-                       append.holidays = append.holidays))
+  m <- prophet()
+  expect_error(add_country_holidays(m, 'Utopia'))
 })
 
 test_that("make_future_dataframe", {

+ 5 - 3
python/fbprophet/forecaster.py

@@ -434,10 +434,12 @@ class Prophet(object):
         
         Returns
         -------
+        dataframe of holiday dates, in holiday dataframe format used in
+        initialization.
         """
         all_holidays = pd.DataFrame()
         if self.holidays is not None:
-            all_holidays = pd.concat((all_holidays, self.holidays))
+            all_holidays = self.holidays.copy()
         if self.country_holidays is not None:
             year_list = list({x.year for x in dates})
             country_holidays_df = make_holidays_df(
@@ -464,7 +466,7 @@ class Prophet(object):
             all_holidays = pd.concat((all_holidays, holidays_to_add), sort=False)
             all_holidays.reset_index(drop=True, inplace=True)
         return all_holidays
-        
+
     def make_holiday_features(self, dates, holidays):
         """Construct a dataframe of holiday features.
 
@@ -526,7 +528,7 @@ class Prophet(object):
                     # Access key to generate value
                     expanded_holidays[key]
         holiday_features = pd.DataFrame(expanded_holidays)
-        # Make sure fit and predict component_cols perfectly equal
+        # Make sure column order is consistent
         holiday_features = holiday_features[sorted(holiday_features.columns.tolist())]
         prior_scale_list = [
             prior_scales[h.split('_delim_')[0]]