Browse Source

Custom seasonality prior scales R, handle NAs in holiday priors

Ben Letham 8 years ago
parent
commit
66ea9444eb

+ 68 - 31
R/R/prophet.R

@@ -42,7 +42,8 @@ globalVariables(c(
 #'  a column prior_scale specifying the prior scale for each holiday.
 #'  a column prior_scale specifying the prior scale for each holiday.
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #' @param seasonality.prior.scale Parameter modulating the strength of the
 #'  seasonality model. Larger values allow the model to fit larger seasonal
 #'  seasonality model. Larger values allow the model to fit larger seasonal
-#'  fluctuations, smaller values dampen the seasonality.
+#'  fluctuations, smaller values dampen the seasonality. Can be specified for
+#'  individual seasonalities using add_seasonality.
 #' @param holidays.prior.scale Parameter modulating the strength of the holiday
 #' @param holidays.prior.scale Parameter modulating the strength of the holiday
 #'  components model, unless overridden in the holidays input.
 #'  components model, unless overridden in the holidays input.
 #' @param changepoint.prior.scale Parameter modulating the flexibility of the
 #' @param changepoint.prior.scale Parameter modulating the flexibility of the
@@ -508,37 +509,44 @@ make_holiday_features <- function(m, dates) {
       } else {
       } else {
         offsets <- c(0)
         offsets <- c(0)
       }
       }
-      if (exists('prior_scale', where = .) && !is.na(.$prior_scale)) {
-        ps <- .$prior_scale
-      } else {
-        ps <- m$holidays.prior.scale
-      }
       names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
       names <- paste(.$holiday, '_delim_', ifelse(offsets < 0, '-', '+'),
                      abs(offsets), sep = '')
                      abs(offsets), sep = '')
-      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names,
-                        prior_scale = ps)
+      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
     }) %>%
     }) %>%
     dplyr::mutate(x = 1.) %>%
     dplyr::mutate(x = 1.) %>%
     tidyr::spread(holiday, x, fill = 0)
     tidyr::spread(holiday, x, fill = 0)
 
 
   holiday.features <- data.frame(ds = set_date(dates)) %>%
   holiday.features <- data.frame(ds = set_date(dates)) %>%
-    dplyr::left_join(wide, by = 'ds')
-
-  prior.scales.all <- holiday.features$prior_scale
-  prior.scales <- c()
+    dplyr::left_join(wide, by = 'ds') %>%
+    dplyr::select(-ds)
 
 
-  holiday.features <- dplyr::select(holiday.features, -ds, -prior_scale)
   holiday.features[is.na(holiday.features)] <- 0
   holiday.features[is.na(holiday.features)] <- 0
 
 
-  for (name in colnames(holiday.features)) {
-    rows <- !is.na(holiday.features[[name]]) & (holiday.features[[name]] == 1)
-    ps <- unique(prior.scales.all[rows])
+  # Prior scales
+  if (!('prior_scale' %in% colnames(m$holidays))) {
+    m$holidays$prior_scale <- m$holidays.prior.scale
+  }
+  prior.scales.list <- list()
+  for (name in unique(m$holidays$holiday)) {
+    df.h <- m$holidays[m$holidays$holiday == name, ]
+    ps <- unique(df.h$prior_scale)
     if (length(ps) > 1) {
     if (length(ps) > 1) {
-      sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
-      stop('Holiday ', sn, ' does not have a consistent prior scale ',
+      stop('Holiday ', name, ' does not have a consistent prior scale ',
            'specification')
            'specification')
     }
     }
-    prior.scales <- c(prior.scales, ps)
+    if (is.na(ps)) {
+      ps <- m$holidays.prior.scale
+    }
+    if (ps <= 0) {
+      stop('Prior scale must be > 0.')
+    }
+    prior.scales.list[[name]] <- ps
+  }
+
+  prior.scales <- c()
+  for (name in colnames(holiday.features)) {
+    sn <- strsplit(name, '_delim_', fixed = TRUE)[[1]][1]
+    prior.scales <- c(prior.scales, prior.scales.list[[sn]])
   }
   }
   return(list(holiday.features = holiday.features,
   return(list(holiday.features = holiday.features,
               prior.scales = prior.scales))
               prior.scales = prior.scales))
@@ -584,23 +592,28 @@ add_regressor <- function(m, name, prior.scale = NULL, standardize = 'auto'){
   return(m)
   return(m)
 }
 }
 
 
-#' Add a seasonal component with specified period and number of Fourier
-#' components.
+#' Add a seasonal component with specified period, number of Fourier
+#' components, and prior scale.
 #'
 #'
 #' Increasing the number of Fourier components allows the seasonality to change
 #' Increasing the number of Fourier components allows the seasonality to change
 #' more quickly (at risk of overfitting). Default values for yearly and weekly
 #' more quickly (at risk of overfitting). Default values for yearly and weekly
 #' seasonalities are 10 and 3 respectively.
 #' seasonalities are 10 and 3 respectively.
 #'
 #'
+#' Increasing prior scale will allow this seasonality component more
+#' flexibility, decreasing will dampen it. If not provided, will use the
+#' seasonality.prior.scale provided on Prophet initialization (defaults to 10).
+#'
 #' @param m Prophet object.
 #' @param m Prophet object.
 #' @param name String name of the seasonality component.
 #' @param name String name of the seasonality component.
 #' @param period Float number of days in one period.
 #' @param period Float number of days in one period.
 #' @param fourier.order Int number of Fourier components to use.
 #' @param fourier.order Int number of Fourier components to use.
+#' @param prior.scale Float prior scale for this component.
 #'
 #'
 #' @return The prophet model with the seasonality added.
 #' @return The prophet model with the seasonality added.
 #'
 #'
 #' @importFrom dplyr "%>%"
 #' @importFrom dplyr "%>%"
 #' @export
 #' @export
-add_seasonality <- function(m, name, period, fourier.order) {
+add_seasonality <- function(m, name, period, fourier.order, prior.scale = NULL) {
   if (!is.null(m$history)) {
   if (!is.null(m$history)) {
     stop("Seasonality must be added prior to model fitting.")
     stop("Seasonality must be added prior to model fitting.")
   }
   }
@@ -608,7 +621,19 @@ add_seasonality <- function(m, name, period, fourier.order) {
     # Allow overriding built-in seasonalities
     # Allow overriding built-in seasonalities
     validate_column_name(m, name, check_seasonalities = FALSE)
     validate_column_name(m, name, check_seasonalities = FALSE)
   }
   }
-  m$seasonalities[[name]] <- c(period, fourier.order)
+  if (is.null(prior.scale)) {
+    ps <- m$seasonality.prior.scale
+  } else {
+    ps <- prior.scale
+  }
+  if (ps <= 0) {
+    stop('Prior scale must be > 0')
+  }
+  m$seasonalities[[name]] <- list(
+    period = period,
+    fourier.order = fourier.order,
+    prior.scale = ps
+  )
   return(m)
   return(m)
 }
 }
 
 
@@ -631,12 +656,12 @@ make_all_seasonality_features <- function(m, df) {
 
 
   # Seasonality features
   # Seasonality features
   for (name in names(m$seasonalities)) {
   for (name in names(m$seasonalities)) {
-    period <- m$seasonalities[[name]][1]
-    series.order <- m$seasonalities[[name]][2]
-    features <- make_seasonality_features(df$ds, period, series.order, name)
+    props <- m$seasonalities[[name]]
+    features <- make_seasonality_features(
+      df$ds, props$period, props$fourier.order, name)
     seasonal.features <- cbind(seasonal.features, features)
     seasonal.features <- cbind(seasonal.features, features)
     prior.scales <- c(prior.scales,
     prior.scales <- c(prior.scales,
-                      m$seasonality.prior.scale * rep(1, ncol(features)))
+                      props$prior.scale * rep(1, ncol(features)))
   }
   }
 
 
   # Holiday features
   # Holiday features
@@ -751,21 +776,33 @@ set_auto_seasonalities <- function(m) {
   fourier.order <- parse_seasonality_args(
   fourier.order <- parse_seasonality_args(
     m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
     m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
   if (fourier.order > 0) {
   if (fourier.order > 0) {
-    m$seasonalities[['yearly']] <- c(365.25, fourier.order)
+    m$seasonalities[['yearly']] <- list(
+      period = 365.25,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
   }
 
 
   weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
   weekly.disable <- ((time_diff(last, first) < 14) || (min.dt >= 7))
   fourier.order <- parse_seasonality_args(
   fourier.order <- parse_seasonality_args(
     m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
     m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
   if (fourier.order > 0) {
   if (fourier.order > 0) {
-    m$seasonalities[['weekly']] <- c(7, fourier.order)
+    m$seasonalities[['weekly']] <- list(
+      period = 7,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
   }
 
 
   daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
   daily.disable <- ((time_diff(last, first) < 2) || (min.dt >= 1))
   fourier.order <- parse_seasonality_args(
   fourier.order <- parse_seasonality_args(
     m, 'daily', m$daily.seasonality, daily.disable, 4)
     m, 'daily', m$daily.seasonality, daily.disable, 4)
   if (fourier.order > 0) {
   if (fourier.order > 0) {
-    m$seasonalities[['daily']] <- c(1, fourier.order)
+    m$seasonalities[['daily']] <- list(
+      period = 1,
+      fourier.order = fourier.order,
+      prior.scale = m$seasonality.prior.scale
+    )
   }
   }
   return(m)
   return(m)
 }
 }
@@ -1598,7 +1635,7 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
 plot_seasonality <- function(m, name, uncertainty = TRUE) {
 plot_seasonality <- function(m, name, uncertainty = TRUE) {
   # Compute seasonality from Jan 1 through a single period.
   # Compute seasonality from Jan 1 through a single period.
   start <- set_date('2017-01-01')
   start <- set_date('2017-01-01')
-  period <- m$seasonalities[[name]][1]
+  period <- m$seasonalities[[name]]$period
   end <- start + period * 24 * 3600
   end <- start + period * 24 * 3600
   plot.points <- 200
   plot.points <- 200
   days <- seq(from=start, to=end, length.out=plot.points)
   days <- seq(from=start, to=end, length.out=plot.points)

+ 10 - 3
R/man/add_seasonality.Rd

@@ -2,10 +2,10 @@
 % Please edit documentation in R/prophet.R
 % Please edit documentation in R/prophet.R
 \name{add_seasonality}
 \name{add_seasonality}
 \alias{add_seasonality}
 \alias{add_seasonality}
-\title{Add a seasonal component with specified period and number of Fourier
-components.}
+\title{Add a seasonal component with specified period, number of Fourier
+components, and prior scale.}
 \usage{
 \usage{
-add_seasonality(m, name, period, fourier.order)
+add_seasonality(m, name, period, fourier.order, prior.scale = NULL)
 }
 }
 \arguments{
 \arguments{
 \item{m}{Prophet object.}
 \item{m}{Prophet object.}
@@ -15,6 +15,8 @@ add_seasonality(m, name, period, fourier.order)
 \item{period}{Float number of days in one period.}
 \item{period}{Float number of days in one period.}
 
 
 \item{fourier.order}{Int number of Fourier components to use.}
 \item{fourier.order}{Int number of Fourier components to use.}
+
+\item{prior.scale}{Float prior scale for this component.}
 }
 }
 \value{
 \value{
 The prophet model with the seasonality added.
 The prophet model with the seasonality added.
@@ -24,3 +26,8 @@ Increasing the number of Fourier components allows the seasonality to change
 more quickly (at risk of overfitting). Default values for yearly and weekly
 more quickly (at risk of overfitting). Default values for yearly and weekly
 seasonalities are 10 and 3 respectively.
 seasonalities are 10 and 3 respectively.
 }
 }
+\details{
+Increasing prior scale will allow this seasonality component more
+flexibility, decreasing will dampen it. If not provided, will use the
+seasonality.prior.scale provided on Prophet initialization (defaults to 10).
+}

+ 3 - 1
R/man/make_holiday_features.Rd

@@ -12,7 +12,9 @@ make_holiday_features(m, dates)
 \item{dates}{Vector with dates used for computing seasonality.}
 \item{dates}{Vector with dates used for computing seasonality.}
 }
 }
 \value{
 \value{
-A dataframe with a column for each holiday.
+A list with entries
+ holiday.features: dataframe with a column for each holiday.
+ prior.scales: array of prior scales for each holiday column.
 }
 }
 \description{
 \description{
 Construct a matrix of holiday features.
 Construct a matrix of holiday features.

+ 5 - 3
R/man/prophet.Rd

@@ -43,14 +43,16 @@ FALSE, or a number of Fourier terms to generate.}
 \item{holidays}{data frame with columns holiday (character) and ds (date
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a
 type)and optionally columns lower_window and upper_window which specify a
 range of days around the date to be included as holidays. lower_window=-2
 range of days around the date to be included as holidays. lower_window=-2
-will include 2 days prior to the date as holidays.}
+will include 2 days prior to the date as holidays. Also optionally can have
+a column prior_scale specifying the prior scale for each holiday.}
 
 
 \item{seasonality.prior.scale}{Parameter modulating the strength of the
 \item{seasonality.prior.scale}{Parameter modulating the strength of the
 seasonality model. Larger values allow the model to fit larger seasonal
 seasonality model. Larger values allow the model to fit larger seasonal
-fluctuations, smaller values dampen the seasonality.}
+fluctuations, smaller values dampen the seasonality. Can be specified for
+individual seasonalities using add_seasonality.}
 
 
 \item{holidays.prior.scale}{Parameter modulating the strength of the holiday
 \item{holidays.prior.scale}{Parameter modulating the strength of the holiday
-components model.}
+components model, unless overridden in the holidays input.}
 
 
 \item{changepoint.prior.scale}{Parameter modulating the flexibility of the
 \item{changepoint.prior.scale}{Parameter modulating the flexibility of the
 automatic changepoint selection. Large values will allow many changepoints,
 automatic changepoint selection. Large values will allow many changepoints,

+ 57 - 16
R/tests/testthat/test_prophet.R

@@ -247,11 +247,22 @@ test_that("holidays", {
     upper_window = c(1, 1),
     upper_window = c(1, 1),
     prior_scale = c(8, 8)
     prior_scale = c(8, 8)
   )
   )
-  holiday2 <- rbind(holidays, holidays2)
+  holidays2 <- rbind(holidays, holidays2)
   m <- prophet(holidays = holidays2, fit = FALSE)
   m <- prophet(holidays = holidays2, fit = FALSE)
   out <- prophet:::make_holiday_features(m, df$ds)
   out <- prophet:::make_holiday_features(m, df$ds)
   priors <- out$prior.scales
   priors <- out$prior.scales
-  expect_true(all(priors == c(8,8, 5, 5)))
+  expect_true(all(priors == c(8, 8, 5, 5)))
+  holidays2 <- data.frame(
+    ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
+    holiday = c('seans-bday', 'seans-bday'),
+    lower_window = c(0, 0),
+    upper_window = c(1, 1)
+  )
+  holidays2 <- dplyr::bind_rows(holidays, holidays2)
+  m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
+  out <- prophet:::make_holiday_features(m, df$ds)
+  priors <- out$prior.scales
+  expect_true(all(priors == c(4, 4, 5, 5)))
   # Check incompatible priors
   # Check incompatible priors
   holidays <- data.frame(
   holidays <- data.frame(
     ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
     ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
@@ -296,9 +307,12 @@ test_that("auto_weekly_seasonality", {
   train.w <- DATA[1:N.w, ]
   train.w <- DATA[1:N.w, ]
   m <- prophet(train.w, fit = FALSE)
   m <- prophet(train.w, fit = FALSE)
   expect_equal(m$weekly.seasonality, 'auto')
   expect_equal(m$weekly.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, train.w)
+  m <- fit.prophet(m, train.w)
   expect_true('weekly' %in% names(m$seasonalities))
   expect_true('weekly' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['weekly']], c(7, 3))
+  true <- list(period = 7, fourier.order = 3, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$weekly[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   # Should be disabled due to too short history
   N.w <- 9
   N.w <- 9
   train.w <- DATA[1:N.w, ]
   train.w <- DATA[1:N.w, ]
@@ -310,8 +324,11 @@ test_that("auto_weekly_seasonality", {
   train.w <- DATA[seq(1, nrow(DATA), 7), ]
   train.w <- DATA[seq(1, nrow(DATA), 7), ]
   m <- prophet(train.w)
   m <- prophet(train.w)
   expect_false('weekly' %in% names(m$seasonalities))
   expect_false('weekly' %in% names(m$seasonalities))
-  m <- prophet(DATA, weekly.seasonality=2)
-  expect_equal(m$seasonalities[['weekly']], c(7, 2))
+  m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
+  true <- list(period = 7, fourier.order = 2, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$weekly[[name]], true[[name]])
+  }
 })
 })
 
 
 test_that("auto_yearly_seasonality", {
 test_that("auto_yearly_seasonality", {
@@ -319,9 +336,12 @@ test_that("auto_yearly_seasonality", {
   # Should be enabled
   # Should be enabled
   m <- prophet(DATA, fit = FALSE)
   m <- prophet(DATA, fit = FALSE)
   expect_equal(m$yearly.seasonality, 'auto')
   expect_equal(m$yearly.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, DATA)
+  m <- fit.prophet(m, DATA)
   expect_true('yearly' %in% names(m$seasonalities))
   expect_true('yearly' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
+  true <- list(period = 365.25, fourier.order = 10, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$yearly[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   # Should be disabled due to too short history
   N.w <- 240
   N.w <- 240
   train.y <- DATA[1:N.w, ]
   train.y <- DATA[1:N.w, ]
@@ -329,8 +349,11 @@ test_that("auto_yearly_seasonality", {
   expect_false('yearly' %in% names(m$seasonalities))
   expect_false('yearly' %in% names(m$seasonalities))
   m <- prophet(train.y, yearly.seasonality = TRUE)
   m <- prophet(train.y, yearly.seasonality = TRUE)
   expect_true('yearly' %in% names(m$seasonalities))
   expect_true('yearly' %in% names(m$seasonalities))
-  m <- prophet(DATA, yearly.seasonality=7)
-  expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
+  m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
+  true <- list(period = 365.25, fourier.order = 7, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$yearly[[name]], true[[name]])
+  }
 })
 })
 
 
 test_that("auto_daily_seasonality", {
 test_that("auto_daily_seasonality", {
@@ -338,9 +361,12 @@ test_that("auto_daily_seasonality", {
   # Should be enabled
   # Should be enabled
   m <- prophet(DATA2, fit = FALSE)
   m <- prophet(DATA2, fit = FALSE)
   expect_equal(m$daily.seasonality, 'auto')
   expect_equal(m$daily.seasonality, 'auto')
-  m <- prophet:::fit.prophet(m, DATA2)
+  m <- fit.prophet(m, DATA2)
   expect_true('daily' %in% names(m$seasonalities))
   expect_true('daily' %in% names(m$seasonalities))
-  expect_equal(m$seasonalities[['daily']], c(1, 4))
+  true <- list(period = 1, fourier.order = 4, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$daily[[name]], true[[name]])
+  }
   # Should be disabled due to too short history
   # Should be disabled due to too short history
   N.d <- 430
   N.d <- 430
   train.y <- DATA2[1:N.d, ]
   train.y <- DATA2[1:N.d, ]
@@ -348,8 +374,11 @@ test_that("auto_daily_seasonality", {
   expect_false('daily' %in% names(m$seasonalities))
   expect_false('daily' %in% names(m$seasonalities))
   m <- prophet(train.y, daily.seasonality = TRUE)
   m <- prophet(train.y, daily.seasonality = TRUE)
   expect_true('daily' %in% names(m$seasonalities))
   expect_true('daily' %in% names(m$seasonalities))
-  m <- prophet(DATA2, daily.seasonality=7)
-  expect_equal(m$seasonalities[['daily']], c(1, 7))
+  m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
+  true <- list(period = 1, fourier.order = 7, prior.scale = 3)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$daily[[name]], true[[name]])
+  }
   m <- prophet(DATA)
   m <- prophet(DATA)
   expect_false('daily' %in% names(m$seasonalities))
   expect_false('daily' %in% names(m$seasonalities))
 })
 })
@@ -366,10 +395,14 @@ test_that("test_subdaily_holidays", {
 test_that("custom_seasonality", {
 test_that("custom_seasonality", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   holidays <- data.frame(ds = c('2017-01-02'),
   holidays <- data.frame(ds = c('2017-01-02'),
-                         holiday = c('special_day'))
+                         holiday = c('special_day'),
+                         prior_scale = c(4))
   m <- prophet(holidays=holidays)
   m <- prophet(holidays=holidays)
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
-  expect_equal(m$seasonalities[['monthly']], c(30, 5))
+  true <- list(period = 30, fourier.order = 5, prior.scale = 10)
+  for (name in names(true)) {
+    expect_equal(m$seasonalities$monthly[[name]], true[[name]])
+  }
   expect_error(
   expect_error(
     add_seasonality(m, name='special_day', period=30, fourier_order=5)
     add_seasonality(m, name='special_day', period=30, fourier_order=5)
   )
   )
@@ -377,6 +410,14 @@ test_that("custom_seasonality", {
     add_seasonality(m, name='trend', period=30, fourier_order=5)
     add_seasonality(m, name='trend', period=30, fourier_order=5)
   )
   )
   m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
   m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
+  # Test priors
+  m <- prophet(holidays = holidays, yearly.seasonality = FALSE)
+  m <- add_seasonality(
+    m, name='monthly', period=30, fourier.order=5, prior.scale = 2)
+  m <- fit.prophet(m, DATA)
+  prior.scales <- prophet:::make_all_seasonality_features(
+    m, m$history)$prior.scales
+  expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
 })
 })
 
 
 test_that("added_regressors", {
 test_that("added_regressors", {

+ 2 - 3
python/fbprophet/forecaster.py

@@ -413,9 +413,8 @@ class Prophet(object):
             except ValueError:
             except ValueError:
                 lw = 0
                 lw = 0
                 uw = 0
                 uw = 0
-            try:
-                ps = float(row.get('prior_scale', self.holidays_prior_scale))
-            except ValueError:
+            ps = float(row.get('prior_scale', self.holidays_prior_scale))
+            if np.isnan(ps):
                 ps = float(self.holidays_prior_scale)
                 ps = float(self.holidays_prior_scale)
             if (
             if (
                 row.holiday in prior_scales and prior_scales[row.holiday] != ps
                 row.holiday in prior_scales and prior_scales[row.holiday] != ps

+ 11 - 0
python/fbprophet/tests/test_prophet.py

@@ -308,6 +308,17 @@ class TestProphet(TestCase):
         holidays2 = pd.concat((holidays, holidays2))
         holidays2 = pd.concat((holidays, holidays2))
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
         feats, priors = Prophet(holidays=holidays2).make_holiday_features(df['ds'])
         self.assertEqual(priors, [8., 8., 5., 5.])
         self.assertEqual(priors, [8., 8., 5., 5.])
+        holidays2 = pd.DataFrame({
+            'ds': pd.to_datetime(['2012-06-06', '2013-06-06']),
+            'holiday': ['seans-bday'] * 2,
+            'lower_window': [0] * 2,
+            'upper_window': [1] * 2,
+        })
+        holidays2 = pd.concat((holidays, holidays2))
+        feats, priors = Prophet(
+            holidays=holidays2, holidays_prior_scale=4
+        ).make_holiday_features(df['ds'])
+        self.assertEqual(priors, [4., 4., 5., 5.])
         # Check incompatible priors
         # Check incompatible priors
         holidays = pd.DataFrame({
         holidays = pd.DataFrame({
             'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),
             'ds': pd.to_datetime(['2016-12-25', '2016-12-27']),