Forráskód Böngészése

Custom prior scales R

Ben Letham 8 éve
szülő
commit
ddbb353278
3 módosított fájl, 83 hozzáadás és 19 törlés
  1. 38 15
      R/R/prophet.R
  2. 43 2
      R/tests/testthat/test_prophet.R
  3. 2 2
      python/fbprophet/tests/test_prophet.py

+ 38 - 15
R/R/prophet.R

@@ -38,12 +38,13 @@ globalVariables(c(
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  range of days around the date to be included as holidays. lower_window=-2
-#'  will include 2 days prior to the date as holidays.
+#'  will include 2 days prior to the date as holidays. Also optionally can have
+#'  a column prior_scale specifying the prior scale for each holiday.
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #'  seasonality model. Larger values allow the model to fit larger seasonal
 #'  fluctuations, smaller values dampen the seasonality.
 #' @param holidays.prior.scale Parameter modulating the strength of the holiday
-#'  components model.
+#'  components model, unless overridden in the holidays input.
 #' @param changepoint.prior.scale Parameter modulating the flexibility of the
 #'  automatic changepoint selection. Large values will allow many changepoints,
 #'  small values will allow few changepoints.
@@ -487,7 +488,9 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
 #' @param m Prophet object.
 #' @param dates Vector with dates used for computing seasonality.
 #'
-#' @return A dataframe with a column for each holiday.
+#' @return A list with entries
+#'  holiday.features: dataframe with a column for each holiday.
+#'  prior.scales: array of prior scales for each holiday column.
 #'
 #' @importFrom dplyr "%>%"
 #' @keywords internal
@@ -505,19 +508,40 @@ make_holiday_features <- function(m, dates) {
       } else {
         offsets <- c(0)
       }
-      names <- paste(
-        .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
-      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
+      if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) {
+        ps <- .$prior_scale
+      } else {
+        ps <- m$holidays.prior.scale
+      }
+      names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
+                     abs(offsets), sep = '')
+      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names,
+                        prior_scale = ps)
     }) %>%
     dplyr::mutate(x = 1.) %>%
     tidyr::spread(holiday, x, fill = 0)
 
-  holiday.mat <- data.frame(ds = dates) %>%
-    dplyr::left_join(wide, by = 'ds') %>%
-    dplyr::select(-ds)
+  holiday.features <- data.frame(ds = set_date(dates)) %>%
+    dplyr::left_join(wide, by = 'ds')
+
+  prior.scales.all <- holiday.features$prior_scale
+  prior.scales <- c()
+
+  holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale)
+  holiday.features[is.na(holiday.features)] <- 0
 
-  holiday.mat[is.na(holiday.mat)] <- 0
-  return(holiday.mat)
+  for (name in colnames(holiday.features)) {
+    rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1)
+    ps <- unique(prior.scales.all[rows])
+    if (length(ps) > 1) {
+      sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
+      stop('Holiday ', sn, ' does not have a consistent prior scale ',
+           'specification')
+    }
+    prior.scales <- c(prior.scales, ps)
+  }
+  return(list(holiday.features = holiday.features,
+              prior.scales = prior.scales))
 }
 
 #' Add an additional regressor to be used for fitting and predicting.
@@ -617,10 +641,9 @@ make_all_seasonality_features <- function(m, df) {
 
   # Holiday features
   if (!is.null(m$holidays)) {
-    features <- make_holiday_features(m, df$ds)
-    seasonal.features <- cbind(seasonal.features, features)
-    prior.scales <- c(prior.scales,
-                      m$holidays.prior.scale * rep(1, ncol(features)))
+    hf <- make_holiday_features(m, df$ds)
+    seasonal.features <- cbind(seasonal.features, hf$holiday.features)
+    prior.scales <- c(prior.scales, hf$prior.scales)
   }
 
   # Additional regressors

+ 43 - 2
R/tests/testthat/test_prophet.R

@@ -208,19 +208,60 @@ test_that("holidays", {
     ds = seq(prophet:::set_date('2016-12-20'),
              prophet:::set_date('2016-12-31'), by='d'))
   m <- prophet(train, holidays = holidays, fit = FALSE)
-  feats <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  feats <- out$holiday.features
+  priors <- out$prior.scales
   expect_equal(nrow(feats), nrow(df))
   expect_equal(ncol(feats), 2)
   expect_equal(sum(colSums(feats) - c(1, 1)), 0)
+  expect_true(all(priors == c(10., 10.)))
 
   holidays = data.frame(ds = c('2016-12-25'),
                         holiday = c('xmas'),
                         lower_window = c(-1),
                         upper_window = c(10))
   m <- prophet(train, holidays = holidays, fit = FALSE)
-  feats <- prophet:::make_holiday_features(m, df$ds)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  feats <- out$holiday.features
+  priors <- out$prior.scales
   expect_equal(nrow(feats), nrow(df))
   expect_equal(ncol(feats), 12)
+  expect_true(all(priors == rep(10, 12)))
+  # Check prior specifications
+  holidays <- data.frame(
+    ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
+    holiday = c('xmas', 'xmas'),
+    lower_window = c(-1, -1),
+    upper_window = c(0, 0),
+    prior_scale = c(5., 5.)
+  )
+  m <- prophet(holidays = holidays, fit = FALSE)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  priors <- out$prior.scales
+  expect_true(all(priors == c(5., 5.)))
+  # 2 different priors
+  holidays2 <- data.frame(
+    ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
+    holiday = c('seans-bday', 'seans-bday'),
+    lower_window = c(0, 0),
+    upper_window = c(1, 1),
+    prior_scale = c(8, 8)
+  )
+  holiday2 <- rbind(holidays, holidays2)
+  m <- prophet(holidays = holidays2, fit = FALSE)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  priors <- out$prior.scales
+  expect_true(all(priors == c(8,8, 5, 5)))
+  # Check incompatible priors
+  holidays <- data.frame(
+    ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
+    holiday = c('xmasish', 'xmasish'),
+    lower_window = c(-1, -1),
+    upper_window = c(0, 0),
+    prior_scale = c(5., 6.)
+  )
+  m <- prophet(holidays = holidays, fit = FALSE)
+  expect_error(prophet:::make_holiday_features(m, df$ds))
 })
 
 test_that("fit_with_holidays", {

+ 2 - 2
python/fbprophet/tests/test_prophet.py

@@ -310,8 +310,8 @@ class TestProphet(TestCase):
         self.assertEqual(priors, [8., 8., 5., 5.])
         # Check incompatible priors
         holidays = pd.DataFrame({
-            'ds': pd.to_datetime(['2016-12-25', '2017-12-25']),
-            'holiday': ['xmas', 'xmas'],
+            'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),
+            'holiday': ['xmasish', 'xmasish'],
             'lower_window': [-1, -1],
             'upper_window': [0, 0],
             'prior_scale': [5., 6.],