Bläddra i källkod

Custom seasonality prior scales R, handle NAs in holiday priors

Ben Letham 8 år sedan
förälder
incheckning
66ea9444eb

+ 68 - 31
R/R/prophet.R

@@ -42,7 +42,8 @@ globalVariables(c(
 #'  a column prior_scale specifying the prior scale for each holiday.
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #'  seasonality model. Larger values allow the model to fit larger seasonal
-#'  fluctuations, smaller values dampen the seasonality.
+#'  fluctuations, smaller values dampen the seasonality. Can be specified for
+#'  individual seasonalities using add_seasonality.
 #' @param holidays.prior.scale Parameter modulating the strength of the holiday
 #'  components model, unless overridden in the holidays input.
 #' @param changepoint.prior.scale Parameter modulating the flexibility of the
@@ -508,37 +509,44 @@ make_holiday_features <- function(m, dates) {
       } else {
         offsets <- c(0)
       }
-      if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) {
-        ps <- .$prior_scale
-      } else {
-        ps <- m$holidays.prior.scale
-      }
       names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
                      abs(offsets), sep = '')
-      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names,
-                        prior_scale = ps)
+      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
     }) %>%
     dplyr::mutate(x = 1.) %>%
     tidyr::spread(holiday, x, fill = 0)
 
   holiday.features <- data.frame(ds = set_date(dates)) %>%
-    dplyr::left_join(wide, by = 'ds')
-
-  prior.scales.all <- holiday.features$prior_scale
-  prior.scales <- c()
+    dplyr::left_join(wide, by = 'ds') %>%
+    dplyr::select(-ds)
 
-  holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale)
   holiday.features[is.na(holiday.features)] <- 0
 
-  for (name in colnames(holiday.features)) {
-    rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1)
-    ps <- unique(prior.scales.all[rows])
+  # Prior scales
+  if (!('prior_scale' %in% colnames(m$holidays))) {
+    m$holidays$prior_scale <- m$holidays.prior.scale
+  }
+  prior.scales.list <- list()
+  for (name in unique(m$holidays$holiday)) {
+    df.h <- m$holidays[m$holidays$holiday == name, ]
+    ps <- unique(df.h$prior_scale)
     if (length(ps) > 1) {
-      sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
-      stop('Holiday ', sn, ' does not have a consistent prior scale ',
+      stop('Holiday ', name, ' does not have a consistent prior scale ',
            'specification')
     }
-    prior.scales <- c(prior.scales, ps)
+    if (is.na(ps)) {
+      ps <- m$holidays.prior.scale
+    }
+    if (ps <= 0) {
+      stop('Prior scale must be > 0.')
+    }
+    prior.scales.list[[name]] <- ps
+  }
+
+  prior.scales <- c()
+  for (name in colnames(holiday.features)) {
+    sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
+    prior.scales <- c(prior.scales, prior.scales.list[[sn]])
   }
   return(list(holiday.features = holiday.features,
               prior.scales = prior.scales))
@@ -584,23 +592,28 @@ add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){
   return(m)
 }
 
-#' Add a seasonal component with specified period and number of Fourier
-#' components.
+#' Add a seasonal component with specified period, number of Fourier
+#' components, and prior scale.
 #'
 #' Increasing the number of Fourier components allows the seasonality to change
 #' more quickly (at risk of overfitting). Default values for yearly and weekly
 #' seasonalities are 10 and 3 respectively.
 #'
+#' Increasing prior scale will allow this seasonality component more
+#' flexibility, decreasing will dampen it. If not provided, will use the
+#' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
+#'
 #' @param m Prophet object.
 #' @param name String name of the seasonality component.
 #' @param period Float number of days in one period.
 #' @param fourier.order Int number of Fourier components to use.
+#' @param prior.scale Float prior scale for this component.
 #'
 #' @return The prophet model with the seasonality added.
 #'
 #' @importFrom dplyr "%>%"
 #' @export
-add_seasonality <- function(m, name, period, fourier.order) {
+add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) {
   if (!is.null(m$history)) {
     stop("Seasonality must be added prior to model fitting.")
   }
@@ -608,7 +621,19 @@ add_seasonality <- function(m, name, period, fourier.order) {
     # Allow overriding built-in seasonalities
     validate_column_name(m, name, check_seasonalities = FALSE)
   }
-  m$seasonalities[[name]] <- c(period, fourier.order)
+  if (is.null(prior.scale)) {
+    ps <- m$seasonality.prior.scale
+  } else {
+    ps <- prior.scale
+  }
+  if (ps <= 0) {
+    stop('Prior scale must be > 0')
+  }
+  m$seasonalities[[name]] <- list(
+    period = period,
+    fourier.order = fourier.order,
+    prior.scale = ps
+  )
   return(m)
 }
 
@@ -631,12 +656,12 @@ make_all_seasonality_features <- function(m, df) {
 
   # Seasonality features
   for (name in names(m$seasonalities)) {
-    period <- m$seasonalities[[name]][1]
-    series.order <- m$seasonalities[[name]][2]
-    features <- make_seasonality_features(df$ds, period, series.order, name)
+    props <- m$seasonalities[[name]]
+    features <- make_seasonality_features(
+      df$ds, props$period, props$fourier.order, name)
     seasonal.features <- cbind(seasonal.features, features)
     prior.scales <- c(prior.scales,
-                      m$seasonality.prior.scale * rep(1, ncol(features)))
+                      props$prior.scale * rep(1, ncol(features)))
   }
 
   # Holiday features
@@ -751,21 +776,33 @@ set_auto_seasonalities <- function(m) {
   fourier.order <- parse_seasonality_args(
     m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
   if (fourier.order > 0) {
-    m$seasonalities[['yearly']] <- c(365.25, fourier.order)
+    m$seasonalities[['yearly']] <- list(
+      period = 365.25,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
 
   weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
   fourier.order <- parse_seasonality_args(
     m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
   if (fourier.order > 0) {
-    m$seasonalities[['weekly']] <- c(7, fourier.order)
+    m$seasonalities[['weekly']] <- list(
+      period = 7,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
 
   daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
   fourier.order <- parse_seasonality_args(
     m, 'daily', m$daily.seasonality, daily.disable, 4)
   if (fourier.order > 0) {
-    m$seasonalities[['daily']] <- c(1, fourier.order)
+    m$seasonalities[['daily']] <- list(
+      period = 1,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
   return(m)
 }
@@ -1598,7 +1635,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
 plot_seasonality <- function(m, name, uncertainty = TRUE) {
   # Compute seasonality from Jan 1 through a single period.
   start <- set_date('2017-01-01')
-  period <- m$seasonalities[[name]][1]
+  period <- m$seasonalities[[name]]$period
   end <- start + period * 24 * 3600
   plot.points <- 200
   days <- seq(from=start, to=end, length.out=plot.points)

+ 10 - 3
R/man/add_seasonality.Rd

@@ -2,10 +2,10 @@
 % Please edit documentation in R/prophet.R
 \name{add_seasonality}
 \alias{add_seasonality}
-\title{Add a seasonal component with specified period and number of Fourier
-components.}
+\title{Add a seasonal component with specified period, number of Fourier
+components, and prior scale.}
 \usage{
-add_seasonality(m, name, period, fourier.order)
+add_seasonality(m, name, period, fourier.order, prior.scale = NULL)
 }
 \arguments{
 \item{m}{Prophet object.}
@@ -15,6 +15,8 @@ add_seasonality(m, name, period, fourier.order)
 \item{period}{Float number of days in one period.}
 
 \item{fourier.order}{Int number of Fourier components to use.}
+
+\item{prior.scale}{Float prior scale for this component.}
 }
 \value{
 The prophet model with the seasonality added.
@@ -24,3 +26,8 @@ Increasing the number of Fourier components allows the seasonality to change
 more quickly (at risk of overfitting). Default values for yearly and weekly
 seasonalities are 10 and 3 respectively.
 }
+\details{
+Increasing prior scale will allow this seasonality component more
+flexibility, decreasing will dampen it. If not provided, will use the
+seasonality.prior.scale provided on Prophet initialization (defaults to 10).
+}

+ 3 - 1
R/man/make_holiday_features.Rd

@@ -12,7 +12,9 @@ make_holiday_features(m, dates)
 \item{dates}{Vector with dates used for computing seasonality.}
 }
 \value{
-A dataframe with a column for each holiday.
+A list with entries
+ holiday.features: dataframe with a column for each holiday.
+ prior.scales: array of prior scales for each holiday column.
 }
 \description{
 Construct a matrix of holiday features.

+ 5 - 3
R/man/prophet.Rd

@@ -43,14 +43,16 @@ FALSE, or a number of Fourier terms to generate.}
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a
 range of days around the date to be included as holidays. lower_window=-2
-will include 2 days prior to the date as holidays.}
+will include 2 days prior to the date as holidays. Also optionally can have
+a column prior_scale specifying the prior scale for each holiday.}
 
 \item{seasonality.prior.scale}{Parameter modulating the strength of the
 seasonality model. Larger values allow the model to fit larger seasonal
-fluctuations, smaller values dampen the seasonality.}
+fluctuations, smaller values dampen the seasonality. Can be specified for
+individual seasonalities using add_seasonality.}
 
 \item{holidays.prior.scale}{Parameter modulating the strength of the holiday
-components model.}
+components model, unless overridden in the holidays input.}
 
 \item{changepoint.prior.scale}{Parameter modulating the flexibility of the
 automatic changepoint selection. Large values will allow many changepoints,

+ 57 - 16
R/tests/testthat/test_prophet.R

@@ -247,11 +247,22 @@ test_that("holidays", {
     upper_window = c(1, 1),
     prior_scale = c(8, 8)
   )
-  holiday2 <- rbind(holidays, holidays2)
+  holidays2 <- rbind(holidays, holidays2)
   m <- prophet(holidays = holidays2, fit = FALSE)
   out <- prophet:::make_holiday_features(m, df$ds)
   priors <- out$prior.scales
-  expect_true(all(priors == c(8,8, 5, 5)))
+  expect_true(all(priors == c(8, 8, 5, 5)))
+  holidays2 <- data.frame(
+    ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
+    holiday = c('seans-bday', 'seans-bday'),
+    lower_window = c(0, 0),
+    upper_window = c(1, 1)
+  )
+  holidays2 <- dplyr::bind_rows(holidays, holidays2)
+  m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  priors <- out$prior.scales
+  expect_true(all(priors == c(4, 4, 5, 5)))
   # Check incompatible priors
   holidays <- data.frame(
     ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
@@ -296,9 +307,12 @@ test_that("auto_weekly_seasonality", {
   train.w <- DATA[1:N.w, ]
   m <- prophet(train.w, fit = FALSE)
   expect_equal(m$weekly.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, train.w)
+  m <- fit.prophet(m, train.w)
   expect_true('weekly' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['weekly']], c(7, 3))
+  true <- list(period = 7, fourier.order = 3, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$weekly[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   N.w <- 9
   train.w <- DATA[1:N.w, ]
@@ -310,8 +324,11 @@ test_that("auto_weekly_seasonality", {
   train.w <- DATA[seq(1, nrow(DATA), 7), ]
   m <- prophet(train.w)
   expect_false('weekly' %in% names(m$seasonalities))
-  m <- prophet(DATA, weekly.seasonality=2)
-  expect_equal(m$seasonalities[['weekly']], c(7, 2))
+  m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
+  true <- list(period = 7, fourier.order = 2, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$weekly[[name]], true[[name]])
+  }
 })
 
 test_that("auto_yearly_seasonality", {
@@ -319,9 +336,12 @@ test_that("auto_yearly_seasonality", {
   # Should be enabled
   m <- prophet(DATA, fit = FALSE)
   expect_equal(m$yearly.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, DATA)
+  m <- fit.prophet(m, DATA)
   expect_true('yearly' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
+  true <- list(period = 365.25, fourier.order = 10, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$yearly[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   N.w <- 240
   train.y <- DATA[1:N.w, ]
@@ -329,8 +349,11 @@ test_that("auto_yearly_seasonality", {
   expect_false('yearly' %in% names(m$seasonalities))
   m <- prophet(train.y, yearly.seasonality = TRUE)
   expect_true('yearly' %in% names(m$seasonalities))
-  m <- prophet(DATA, yearly.seasonality=7)
-  expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
+  m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
+  true <- list(period = 365.25, fourier.order = 7, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$yearly[[name]], true[[name]])
+  }
 })
 
 test_that("auto_daily_seasonality", {
@@ -338,9 +361,12 @@ test_that("auto_daily_seasonality", {
   # Should be enabled
   m <- prophet(DATA2, fit = FALSE)
   expect_equal(m$daily.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, DATA2)
+  m <- fit.prophet(m, DATA2)
   expect_true('daily' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['daily']], c(1, 4))
+  true <- list(period = 1, fourier.order = 4, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$daily[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   N.d <- 430
   train.y <- DATA2[1:N.d, ]
@@ -348,8 +374,11 @@ test_that("auto_daily_seasonality", {
   expect_false('daily' %in% names(m$seasonalities))
   m <- prophet(train.y, daily.seasonality = TRUE)
   expect_true('daily' %in% names(m$seasonalities))
-  m <- prophet(DATA2, daily.seasonality=7)
-  expect_equal(m$seasonalities[['daily']], c(1, 7))
+  m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
+  true <- list(period = 1, fourier.order = 7, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$daily[[name]], true[[name]])
+  }
   m <- prophet(DATA)
   expect_false('daily' %in% names(m$seasonalities))
 })
@@ -366,10 +395,14 @@ test_that("test_subdaily_holidays", {
 test_that("custom_seasonality", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   holidays <- data.frame(ds = c('2017-01-02'),
-                         holiday = c('special_day'))
+                         holiday = c('special_day'),
+                         prior_scale = c(4))
   m <- prophet(holidays=holidays)
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
-  expect_equal(m$seasonalities[['monthly']], c(30, 5))
+  true <- list(period = 30, fourier.order = 5, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$monthly[[name]], true[[name]])
+  }
   expect_error(
     add_seasonality(m, name='special_day', period=30, fourier_order=5)
   )
@@ -377,6 +410,14 @@ test_that("custom_seasonality", {
     add_seasonality(m, name='trend', period=30, fourier_order=5)
   )
   m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
+  # Test priors
+  m <- prophet(holidays = holidays, yearly.seasonality = FALSE)
+  m <- add_seasonality(
+    m, name='monthly', period=30, fourier.order=5, prior.scale = 2)
+  m <- fit.prophet(m, DATA)
+  prior.scales <- prophet:::make_all_seasonality_features(
+    m, m$history)$prior.scales
+  expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
 })
 
 test_that("added_regressors", {

+ 2 - 3
python/fbprophet/forecaster.py

@@ -413,9 +413,8 @@ class Prophet(object):
             except ValueError:
                 lw = 0
                 uw = 0
-            try:
-                ps = float(row.get('prior_scale', self.holidays_prior_scale))
-            except ValueError:
+            ps = float(row.get('prior_scale', self.holidays_prior_scale))
+            if np.isnan(ps):
                 ps = float(self.holidays_prior_scale)
             if (
                 row.holiday in prior_scales and prior_scales[row.holiday] != ps

+ 11 - 0
python/fbprophet/tests/test_prophet.py

@@ -308,6 +308,17 @@ class TestProphet(TestCase):
         holidays2 = pd.concat((holidays, holidays2))
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
         self.assertEqual(priors, [8., 8., 5., 5.])
+        holidays2 = pd.DataFrame({
+            'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
+            'holiday': ['seans-bday'] * 2,
+            'lower_window': [0] * 2,
+            'upper_window': [1] * 2,
+        })
+        holidays2 = pd.concat((holidays, holidays2))
+        feats, priors = Prophet(
+            holidays=holidays2, holidays_prior_scale=4
+        ).make_holiday_features(df['ds'])
+        self.assertEqual(priors, [4., 4., 5., 5.])
         # Check incompatible priors
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),