浏览代码

Custom seasonalities in R

bl 8 年之前
父节点
当前提交
8be35c2f34

+ 1 - 0
R/NAMESPACE

@@ -2,6 +2,7 @@
 
 S3method(plot,prophet)
 S3method(predict,prophet)
+export(add_seasonality)
 export(fit.prophet)
 export(make_future_dataframe)
 export(predictive_samples)

+ 135 - 34
R/R/prophet.R

@@ -15,9 +15,11 @@ globalVariables(c(
 
 #' Prophet forecaster.
 #'
-#' @param df Dataframe containing the history. Must have columns ds (date type)
-#'  and y, the time series. If growth is logistic, then df must also have a
-#'  column cap that specifies the capacity at each ds.
+#' @param df (optional) Dataframe containing the history. Must have columns ds
+#'  (date type) and y, the time series. If growth is logistic, then df must
+#'  also have a column cap that specifies the capacity at each ds. If not
+#'  provided, then the model object will be instantiated but not fit; use
+#'  fit.prophet(m, df) to fit the model.
 #' @param growth String 'linear' or 'logistic' to specify a linear or logistic
 #'  trend.
 #' @param changepoints Vector of dates at which to include potential
@@ -27,8 +29,10 @@ globalVariables(c(
 #'  if input `changepoints` is supplied. If `changepoints` is not supplied,
 #'  then n.changepoints potential changepoints are selected uniformly from the
 #'  first 80 percent of df$ds.
-#' @param yearly.seasonality Fit yearly seasonality; 'auto', TRUE, or FALSE.
-#' @param weekly.seasonality Fit weekly seasonality; 'auto', TRUE, or FALSE.
+#' @param yearly.seasonality Fit yearly seasonality. Can be 'auto', TRUE,
+#'  FALSE, or a number of Fourier terms to generate.
+#' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
+#'  FALSE, or a number of Fourier terms to generate.
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  range of days around the date to be included as holidays. lower_window=-2
@@ -66,7 +70,7 @@ globalVariables(c(
 #' @export
 #' @importFrom dplyr "%>%"
 #' @import Rcpp
-prophet <- function(df = df,
+prophet <- function(df = NULL,
                     growth = 'linear',
                     changepoints = NULL,
                     n.changepoints = 25,
@@ -105,6 +109,7 @@ prophet <- function(df = df,
     y.scale = NULL,
     t.scale = NULL,
     changepoints.t = NULL,
+    seasonalities = list(),
     stan.fit = NULL,
     params = list(),
     history = NULL,
@@ -112,7 +117,7 @@ prophet <- function(df = df,
   )
   validate_inputs(m)
   class(m) <- append("prophet", class(m))
-  if (fit) {
+  if ((fit) && (!is.null(df))) {
     m <- fit.prophet(m, df, ...)
   }
 
@@ -372,6 +377,31 @@ make_holiday_features <- function(m, dates) {
   return(holiday.mat)
 }
 
+#' Add a seasonal component with specified period and number of Fourier
+#' components.
+#'
+#' Increasing the number of Fourier components allows the seasonality to change
+#' more quickly (at risk of overfitting).
+#'
+#' @param m Prophet object.
+#' @param name String name of the seasonality component.
+#' @param period Float number of days in one period.
+#' @param fourier.order Int number of Fourier components to use.
+#'
+#' @return The prophet model with the seasonality added.
+#'
+#' @importFrom dplyr "%>%"
+#' @export
+add_seasonality <- function(m, name, period, fourier.order) {
+  if (!is.null(m$holidays)) {
+    if (name %in% (unique(m$holidays$holiday) %>% as.character())) {
+      stop('Name "', name, '" already used for holiday')
+    }
+  }
+  m$seasonalities[[name]] <- c(period, fourier.order)
+  return(m)
+}
+
 #' Dataframe with seasonality features.
 #'
 #' @param m Prophet object.
@@ -381,19 +411,14 @@ make_holiday_features <- function(m, dates) {
 #'
 make_all_seasonality_features <- function(m, df) {
   seasonal.features <- data.frame(zeros = rep(0, nrow(df)))
-  if (m$yearly.seasonality) {
-    seasonal.features <- cbind(
-      seasonal.features,
-      make_seasonality_features(df$ds, 365.25, 10, 'yearly'))
-  }
-  if (m$weekly.seasonality) {
+  for (name in names(m$seasonalities)) {
+    period <- m$seasonalities[[name]][1]
+    series.order <- m$seasonalities[[name]][2]
     seasonal.features <- cbind(
       seasonal.features,
-      make_seasonality_features(df$ds, 7, 3, 'weekly'))
+      make_seasonality_features(df$ds, period, series.order, name))
   }
   if(!is.null(m$holidays)) {
-    # A smaller prior scale will shrink holiday estimates more than seasonality
-    scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
     seasonal.features <- cbind(
       seasonal.features,
       make_holiday_features(m, df$ds))
@@ -401,6 +426,39 @@ make_all_seasonality_features <- function(m, df) {
   return(seasonal.features)
 }
 
+#' Get number of Fourier components for built-in seasonalities.
+#'
+#' @param m Prophet object.
+#' @param name String name of the seasonality component.
+#' @param arg 'auto', TRUE, FALSE, or number of Fourier components as
+#'  provided.
+#' @param auto.disable Bool if seasonality should be disabled when 'auto'.
+#' @param default.order Int default Fourier order.
+#'
+#' @return Number of Fourier components, or 0 for disabled.
+#'
+parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
+  if (arg == 'auto') {
+    fourier.order <- 0
+    if (name %in% names(m$seasonalities)) {
+      warning('Found custom seasonality named "', name,
+              '", disabling built-in ', name, ' seasonality.')
+    } else if (auto.disable) {
+      warning('Disabling ', name, ' seasonality. Run prophet with ', name,
+              '.seasonality=TRUE to override this.')
+    } else {
+      fourier.order <- default.order
+    }
+  } else if (arg == TRUE) {
+    fourier.order <- default.order
+  } else if (arg == FALSE) {
+    fourier.order <- 0
+  } else {
+    fourier.order <- arg
+  }
+  return(fourier.order)
+}
+
 #' Set seasonalities that were left on auto.
 #'
 #' Turns on yearly seasonality if there is >=2 years of history.
@@ -414,25 +472,21 @@ make_all_seasonality_features <- function(m, df) {
 set_auto_seasonalities <- function(m) {
   first <- min(m$history$ds)
   last <- max(m$history$ds)
-  if (m$yearly.seasonality == 'auto') {
-    if (last - first < 730) {
-      warning('Disabling yearly seasonality. ',
-              'Run prophet with `yearly.seasonality=TRUE` to override this.')
-      m$yearly.seasonality <- FALSE
-    } else {
-      m$yearly.seasonality <- TRUE
-    }
+  dt <- diff(m$history$ds)
+  min.dt <- min(dt[dt > 0])
+
+  yearly.disable <- last - first < 730
+  fourier.order <- parse_seasonality_args(
+    m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
+  if (fourier.order > 0) {
+    m$seasonalities[['yearly']] <- c(365.25, fourier.order)
   }
-  if (m$weekly.seasonality == 'auto') {
-    dt <- diff(m$history$ds)
-    min.dt <- min(dt[dt > 0])
-    if ((last - first < 14) || (min.dt >= 7)) {
-      warning('Disabling weekly seasonality. ',
-              'Run prophet with `weekly.seasonality=TRUE` to override this.')
-      m$weekly.seasonality <- FALSE
-    } else {
-      m$weekly.seasonality <- TRUE
-    }
+
+  weekly.disable <- ((last - first < 14) || (min.dt >= 7))
+  fourier.order <- parse_seasonality_args(
+    m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
+  if (fourier.order > 0) {
+    m$seasonalities[['weekly']] <- c(7, fourier.order)
   }
   return(m)
 }
@@ -1053,6 +1107,13 @@ prophet_plot_components <- function(
   if ("yearly" %in% colnames(df)) {
     panels[[length(panels) + 1]] <- plot_yearly(m, uncertainty, yearly_start)
   }
+  # Plot other seasonalities
+  for (name in names(m$seasonalities)) {
+    if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
+      panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
+    }
+  }
+
   # Make the plot.
   grid::grid.newpage()
   grid::pushViewport(grid::viewport(layout = grid::grid.layout(length(panels),
@@ -1190,4 +1251,44 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
   return(gg.yearly)
 }
 
+#' Plot a custom seasonal component.
+#'
+#' @param m Prophet model object.
+#' @param name String name of the seasonality.
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#'
+#' @return A ggplot2 plot.
+plot_seasonality <- function(m, name, uncertainty = TRUE) {
+  # Compute seasonality from Jan 1 through a single period.
+  start <- zoo::as.Date('2017-01-01')
+  period <- m$seasonalities[[name]][1]
+  end <- start + period
+  plot.points <- as.numeric(end - start)
+  df.y <- data.frame(
+    ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
+  df.y <- setup_dataframe(m, df.y)$df
+  seas <- predict_seasonal_components(m, df.y)
+  seas$ds <- df.y$ds
+  gg.s <- ggplot2::ggplot(
+      seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
+  if (period < 14) {
+    fmt.str <- '%m/%d %R'
+  } else {
+    fmt.str <- '%m/%d'
+  }
+  gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
+  if (uncertainty) {
+    gg.s <- gg.s +
+    ggplot2::geom_ribbon(
+      ggplot2::aes_string(
+        ymin = paste0(name, '_lower'), ymax = paste0(name, '_upper')
+      ),
+      alpha = 0.2,
+      fill = "#0072B2",
+      na.rm = TRUE)
+  }
+  return(gg.s)
+}
+
 # fb-block 3

+ 25 - 0
R/man/add_seasonality.Rd

@@ -0,0 +1,25 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{add_seasonality}
+\alias{add_seasonality}
+\title{Add a seasonal component with specified period and number of Fourier
+components.}
+\usage{
+add_seasonality(m, name, period, fourier.order)
+}
+\arguments{
+\item{m}{Prophet object.}
+
+\item{name}{String name of the seasonality component.}
+
+\item{period}{Float number of days in one period.}
+
+\item{fourier.order}{Int number of Fourier components to use.}
+}
+\value{
+The prophet model with the seasonality added.
+}
+\description{
+Increasing the number of Fourier components allows the seasonality to change
+more quickly (at risk of overfitting).
+}

+ 26 - 0
R/man/parse_seasonality_args.Rd

@@ -0,0 +1,26 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{parse_seasonality_args}
+\alias{parse_seasonality_args}
+\title{Get number of Fourier components for built-in seasonalities.}
+\usage{
+parse_seasonality_args(m, name, arg, auto.disable, default.order)
+}
+\arguments{
+\item{m}{Prophet object.}
+
+\item{name}{String name of the seasonality component.}
+
+\item{arg}{'auto', TRUE, FALSE, or number of Fourier components as
+provided.}
+
+\item{auto.disable}{Bool if seasonality should be disabled when 'auto'.}
+
+\item{default.order}{Int default Fourier order.}
+}
+\value{
+Number of Fourier components, or 0 for disabled.
+}
+\description{
+Get number of Fourier components for built-in seasonalities.
+}

+ 21 - 0
R/man/plot_seasonality.Rd

@@ -0,0 +1,21 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{plot_seasonality}
+\alias{plot_seasonality}
+\title{Plot a custom seasonal component.}
+\usage{
+plot_seasonality(m, name, uncertainty = TRUE)
+}
+\arguments{
+\item{m}{Prophet model object.}
+
+\item{name}{String name of the seasonality.}
+
+\item{uncertainty}{Boolean to plot uncertainty intervals.}
+}
+\value{
+A ggplot2 plot.
+}
+\description{
+Plot a custom seasonal component.
+}

+ 10 - 6
R/man/prophet.Rd

@@ -4,7 +4,7 @@
 \alias{prophet}
 \title{Prophet forecaster.}
 \usage{
-prophet(df = df, growth = "linear", changepoints = NULL,
+prophet(df = NULL, growth = "linear", changepoints = NULL,
   n.changepoints = 25, yearly.seasonality = "auto",
   weekly.seasonality = "auto", holidays = NULL,
   seasonality.prior.scale = 10, holidays.prior.scale = 10,
@@ -12,9 +12,11 @@ prophet(df = df, growth = "linear", changepoints = NULL,
   uncertainty.samples = 1000, fit = TRUE, ...)
 }
 \arguments{
-\item{df}{Dataframe containing the history. Must have columns ds (date type)
-and y, the time series. If growth is logistic, then df must also have a
-column cap that specifies the capacity at each ds.}
+\item{df}{(optional) Dataframe containing the history. Must have columns ds
+(date type) and y, the time series. If growth is logistic, then df must
+also have a column cap that specifies the capacity at each ds. If not
+provided, then the model object will be instantiated but not fit; use
+fit.prophet(m, df) to fit the model.}
 
 \item{growth}{String 'linear' or 'logistic' to specify a linear or logistic
 trend.}
@@ -28,9 +30,11 @@ if input `changepoints` is supplied. If `changepoints` is not supplied,
 then n.changepoints potential changepoints are selected uniformly from the
 first 80 percent of df$ds.}
 
-\item{yearly.seasonality}{Fit yearly seasonality; 'auto', TRUE, or FALSE.}
+\item{yearly.seasonality}{Fit yearly seasonality. Can be 'auto', TRUE,
+FALSE, or a number of Fourier terms to generate.}
 
-\item{weekly.seasonality}{Fit weekly seasonality; 'auto', TRUE, or FALSE.}
+\item{weekly.seasonality}{Fit weekly seasonality. Can be 'auto', TRUE,
+FALSE, or a number of Fourier terms to generate.}
 
 \item{holidays}{data frame with columns holiday (character) and ds (date
 type)and optionally columns lower_window and upper_window which specify a

+ 24 - 11
R/tests/testthat/test_prophet.R

@@ -219,38 +219,51 @@ test_that("make_future_dataframe", {
 
 test_that("auto_weekly_seasonality", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  # Should be True
+  # Should be enabled
   N.w <- 15
   train.w <- DATA[1:N.w, ]
   m <- prophet(train.w, fit = FALSE)
   expect_equal(m$weekly.seasonality, 'auto')
   m <- prophet:::fit.prophet(m, train.w)
-  expect_equal(m$weekly.seasonality, TRUE)
-  # Should be False due to too short history
+  expect_true('weekly' %in% names(m$seasonalities))
+  expect_equal(m$seasonalities[['weekly']], c(7, 3))
+  # Should be disabled due to too short history
   N.w <- 9
   train.w <- DATA[1:N.w, ]
   m <- prophet(train.w)
-  expect_equal(m$weekly.seasonality, FALSE)
+  expect_false('weekly' %in% names(m$seasonalities))
   m <- prophet(train.w, weekly.seasonality = TRUE)
-  expect_equal(m$weekly.seasonality, TRUE)
+  expect_true('weekly' %in% names(m$seasonalities))
   # Should be False due to weekly spacing
   train.w <- DATA[seq(1, nrow(DATA), 7), ]
   m <- prophet(train.w)
-  expect_equal(m$weekly.seasonality, FALSE)
+  expect_false('weekly' %in% names(m$seasonalities))
+  m <- prophet(DATA, weekly.seasonality=2)
+  expect_equal(m$seasonalities[['weekly']], c(7, 2))
 })
 
 test_that("auto_yearly_seasonality", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  # Should be True
+  # Should be enabled
   m <- prophet(DATA, fit = FALSE)
   expect_equal(m$yearly.seasonality, 'auto')
   m <- prophet:::fit.prophet(m, DATA)
-  expect_equal(m$yearly.seasonality, TRUE)
-  # Should be False due to too short history
+  expect_true('yearly' %in% names(m$seasonalities))
+  expect_equal(m$seasonalities[['yearly']], c(365.25, 10))
+  # Should be disabled due to too short history
   N.w <- 240
   train.y <- DATA[1:N.w, ]
   m <- prophet(train.y)
-  expect_equal(m$yearly.seasonality, FALSE)
+  expect_false('yearly' %in% names(m$seasonalities))
   m <- prophet(train.y, yearly.seasonality = TRUE)
-  expect_equal(m$yearly.seasonality, TRUE)
+  expect_true('yearly' %in% names(m$seasonalities))
+  m <- prophet(DATA, yearly.seasonality=7)
+  expect_equal(m$seasonalities[['yearly']], c(365.25, 7))
+})
+
+test_that("custom_seasonality", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet()
+  m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
+  expect_equal(m$seasonalities[['monthly']], c(30, 5))
 })

+ 1 - 1
docs/_data/nav_docs.yml

@@ -4,7 +4,7 @@
   - id: quick_start
   - id: forecasting_growth
   - id: trend_changepoints
-  - id: holiday_effects
+  - id: seasonality_and_holiday_effects
   - id: uncertainty_intervals
   - id: outliers
   - id: non-daily_data

文件差异内容过多而无法显示
+ 95 - 37
notebooks/holiday_effects.ipynb