Переглянути джерело

Add support for fitting daily seasonality, make holiday features work when daily seasonality is enabled (#246)

* Add support for fitting daily seasonality, make holiday features work
when daily seasonality is enabled

* fix wrong comment in make_future_dataframe()
Qi Wang 8 роки тому
батько
коміт
b0938df109
2 змінених файлів з 125 додано та 43 видалено
  1. 112 30
      R/R/prophet.R
  2. 13 13
      R/tests/testthat/test_prophet.R

+ 112 - 30
R/R/prophet.R

@@ -33,6 +33,8 @@ globalVariables(c(
 #'  FALSE, or a number of Fourier terms to generate.
 #' @param weekly.seasonality Fit weekly seasonality. Can be 'auto', TRUE,
 #'  FALSE, or a number of Fourier terms to generate.
+#' @param daily.seasonality Fit daily seasonality. Can be 'auto', TRUE,
+#' FALSE, or a number of Fourier terms to generate.
 #' @param holidays data frame with columns holiday (character) and ds (date
 #'  type)and optionally columns lower_window and upper_window which specify a
 #'  range of days around the date to be included as holidays. lower_window=-2
@@ -76,6 +78,7 @@ prophet <- function(df = NULL,
                     n.changepoints = 25,
                     yearly.seasonality = 'auto',
                     weekly.seasonality = 'auto',
+                    daily.seasonality = 'auto',
                     holidays = NULL,
                     seasonality.prior.scale = 10,
                     holidays.prior.scale = 10,
@@ -98,6 +101,7 @@ prophet <- function(df = NULL,
     n.changepoints = n.changepoints,
     yearly.seasonality = yearly.seasonality,
     weekly.seasonality = weekly.seasonality,
+    daily.seasonality = daily.seasonality,
     holidays = holidays,
     seasonality.prior.scale = seasonality.prior.scale,
     changepoint.prior.scale = changepoint.prior.scale,
@@ -206,6 +210,47 @@ compile_stan_model <- function(model) {
   return(rstan::stan_model(stanc_ret = stanc, model_name = model.name))
 }
 
+#' Convert date vector
+#' 
+#' Convert the date to POSIXct object 
+#' 
+#' @param ds Date vector, can be consisted of characters
+#' 
+#' @return vector of POSIXct object converted from date
+#' 
+set_date <- function(ds = NULL, tz = "GMT") {
+  if (length(ds) == 0) {
+    return(NULL)
+  } 
+  
+  if (is.factor(ds)) {
+    ds <- as.character(ds)
+  }
+  
+  if (min(nchar(ds)) < 12) {
+    ds <- as.POSIXct(ds, format = "%Y-%m-%d", tz = tz)
+  } else {
+    ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
+  }
+  return(ds)
+}
+
+#' Extract hour
+#' 
+#' Extract hour from a POSIXct object
+#' 
+#' @param ds POSIXct object
+#' 
+#' @return hour of POSIXct object
+#' 
+get_hour <- function(ds) {
+  if (!("POSIXct" %in% is(ds))) {
+    stop("ds must be a POSIXct object, use function set_date() to convert first.")
+  }
+  
+  return(format(ds , "%H"))
+}
+
 #' Prepare dataframe for fitting or predicting.
 #'
 #' Adds a time index and scales y. Creates auxillary columns 't', 't_ix',
@@ -222,9 +267,9 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
   if (exists('y', where=df)) {
     df$y <- as.numeric(df$y)
   }
-  df$ds <- zoo::as.Date(df$ds)
+  df$ds <- set_date(df$ds)
   if (anyNA(df$ds)) {
-    stop('Unable to parse date format in column ds. Convert to date format.')
+    stop('Unable to parse date format in column ds. Convert to date format. Either %Y-%m-%d or %Y-%m-%d %H:%M:%S')
   }
 
   df <- df %>%
@@ -233,10 +278,10 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
   if (initialize_scales) {
     m$y.scale <- max(abs(df$y))
     m$start <- min(df$ds)
-    m$t.scale <- as.numeric(max(df$ds) - m$start)
+    m$t.scale <- as.numeric(difftime(max(df$ds), m$start, units = "secs"))
   }
 
-  df$t <- as.numeric(df$ds - m$start) / m$t.scale
+  df$t <- as.numeric(difftime(df$ds, m$start, units = "secs")) / m$t.scale
   if (exists('y', where=df)) {
     df$y_scaled <- df$y / m$y.scale
   }
@@ -285,8 +330,8 @@ set_changepoints <- function(m) {
     }
   }
   if (length(m$changepoints) > 0) {
-    m$changepoints <- zoo::as.Date(m$changepoints)
-    m$changepoints.t <- sort(as.numeric(m$changepoints - m$start) / m$t.scale)
+    m$changepoints <- set_date(m$changepoints)
+    m$changepoints.t <- sort(as.numeric(difftime(m$changepoints, m$start, units = "secs"))) / m$t.scale
   } else {
     m$changepoints.t <- c(0)  # dummy changepoint
   }
@@ -316,7 +361,7 @@ get_changepoint_matrix <- function(m) {
 #' @return Matrix with seasonality features.
 #'
 fourier_series <- function(dates, period, series.order) {
-  t <- dates - zoo::as.Date('1970-01-01')
+  t <- as.numeric(difftime(dates, set_date('1970-01-01 00:00:00'), units = 'days')) 
   features <- matrix(0, length(t), 2 * series.order)
   for (i in 1:series.order) {
     x <- as.numeric(2 * i * pi * t / period)
@@ -352,7 +397,7 @@ make_seasonality_features <- function(dates, period, series.order, prefix) {
 make_holiday_features <- function(m, dates) {
   scale.ratio <- m$holidays.prior.scale / m$seasonality.prior.scale
   wide <- m$holidays %>%
-    dplyr::mutate(ds = zoo::as.Date(ds)) %>%
+    dplyr::mutate(ds = set_date(ds)) %>%
     dplyr::group_by(holiday, ds) %>%
     dplyr::filter(row_number() == 1) %>%
     dplyr::do({
@@ -364,7 +409,7 @@ make_holiday_features <- function(m, dates) {
       }
       names <- paste(
         .$holiday, '_delim_', ifelse(offsets < 0, '-', '+'), abs(offsets), sep = '')
-      dplyr::data_frame(ds = .$ds + offsets, holiday = names)
+      dplyr::data_frame(ds = .$ds + offsets * 24 * 3600, holiday = names)
     }) %>%
     dplyr::mutate(x = scale.ratio) %>%
     tidyr::spread(holiday, x, fill = 0)
@@ -472,22 +517,29 @@ parse_seasonality_args <- function(m, name, arg, auto.disable, default.order) {
 set_auto_seasonalities <- function(m) {
   first <- min(m$history$ds)
   last <- max(m$history$ds)
-  dt <- diff(m$history$ds)
+  dt <- diff(as.numeric(difftime(m$history$ds, m$start, units = "d")))
   min.dt <- min(dt[dt > 0])
 
-  yearly.disable <- last - first < 730
+  yearly.disable <- as.numeric(difftime(last, first, unit = "days")) < 730
   fourier.order <- parse_seasonality_args(
     m, 'yearly', m$yearly.seasonality, yearly.disable, 10)
   if (fourier.order > 0) {
     m$seasonalities[['yearly']] <- c(365.25, fourier.order)
   }
 
-  weekly.disable <- ((last - first < 14) || (min.dt >= 7))
+  weekly.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 14) || (min.dt >= 7))
   fourier.order <- parse_seasonality_args(
     m, 'weekly', m$weekly.seasonality, weekly.disable, 3)
   if (fourier.order > 0) {
     m$seasonalities[['weekly']] <- c(7, fourier.order)
   }
+
+  daily.disable <- ((as.numeric(difftime(last, first, unit = "days")) < 2)) || (min.dt >= 1)
+  fourier.order <- parse_seasonality_args(
+    m, 'daily', m$daily.seasonality, daily.disable, 4)
+  if (fourier.order > 0) {
+    m$seasonalities[['daily']] <- c(1, fourier.order)
+  }
   return(m)
 }
 
@@ -571,7 +623,7 @@ fit.prophet <- function(m, df, ...) {
   if (any(is.infinite(history$y))) {
     stop("Found infinity in column y.")
   }
-  m$history.dates <- sort(zoo::as.Date(df$ds))
+  m$history.dates <- sort(set_date(df$ds))
 
   out <- setup_dataframe(m, history, initialize_scales = TRUE)
   history <- out$df
@@ -985,7 +1037,7 @@ sample_predictive_trend <- function(model, df, iteration) {
 #'
 #' @param m Prophet model object.
 #' @param periods Int number of periods to forecast forward.
-#' @param freq 'day', 'week', 'month', 'quarter', or 'year'.
+#' @param freq 'day', 'week', 'month', 'quarter', 'year', 1(1 sec), 60(1 minute) or 3600(1 hour).
 #' @param include_history Boolean to include the historical dates in the data
 #'  frame for predictions.
 #'
@@ -993,7 +1045,7 @@ sample_predictive_trend <- function(model, df, iteration) {
 #'  requested number of periods.
 #'
 #' @export
-make_future_dataframe <- function(m, periods, freq = 'd',
+make_future_dataframe <- function(m, periods, freq = 'day',
                                   include_history = TRUE) {
   dates <- seq(max(m$history.dates), length.out = periods + 1, by = freq)
   dates <- dates[2:(periods + 1)]  # Drop the first, which is max(history$ds)
@@ -1091,7 +1143,7 @@ plot.prophet <- function(x, fcst, uncertainty = TRUE, plot_cap = TRUE,
 #' @importFrom dplyr "%>%"
 prophet_plot_components <- function(
     m, fcst, uncertainty = TRUE, plot_cap = TRUE, weekly_start = 0,
-    yearly_start = 0) {
+    yearly_start = 0, daily_start = 0) {
   df <- df_for_plotting(m, fcst)
   # Plot the trend
   panels <- list(plot_trend(df, uncertainty, plot_cap))
@@ -1099,6 +1151,10 @@ prophet_plot_components <- function(
   if (!is.null(m$holidays)) {
     panels[[length(panels) + 1]] <- plot_holidays(m, df, uncertainty)
   }
+  # Plot daily seasonality, if present
+  if ("daily" %in% colnames(df)) {
+    panels[[length(panels) + 1]] <- plot_daily(m, uncertainty, daily_start)
+  }
   # Plot weekly seasonality, if present
   if ("weekly" %in% colnames(df)) {
     panels[[length(panels) + 1]] <- plot_weekly(m, uncertainty, weekly_start)
@@ -1109,7 +1165,7 @@ prophet_plot_components <- function(
   }
   # Plot other seasonalities
   for (name in names(m$seasonalities)) {
-    if (!(name %in% c('weekly', 'yearly')) && (name %in% colnames(df))) {
+    if (!(name %in% c('daily', 'weekly', 'yearly')) && (name %in% colnames(df))) {
       panels[[length(panels) + 1]] <- plot_seasonality(m, name, uncertainty)
     }
   }
@@ -1184,6 +1240,39 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
   return(gg.holidays)
 }
 
+#' Plot the daily component of the forecast.
+#'
+#' @param m Prophet model object
+#' @param uncertainty Boolean to plot uncertainty intervals.
+#' @param daily_start Integer specifying the start day of the daily
+#'  seasonality plot. 0 (default) starts the week on Sunday. 1 shifts by 1 day
+#'  to Monday, and so on.
+#'
+#' @return A ggplot2 plot.
+plot_daily <- function(m, uncertainty = TRUE, daily_start = 0) {
+  # Compute weekly seasonality for a Sun-Sat sequence of dates.
+  df.d <- data.frame(
+    ds=seq(set_date('2017-01-01 00:00:00'), length.out=24, by = "hour") +
+      daily_start, cap=1.)
+  df.d <- setup_dataframe(m, df.d)$df
+  seas <- predict_seasonal_components(m, df.d)
+  seas$hod <- factor(get_hour(df.d$ds), levels=get_hour(df.d$ds))
+  
+  gg.daily <- ggplot2::ggplot(seas, ggplot2::aes(x = hod, y = daily,
+                                                  group = 1)) +
+    ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
+    ggplot2::labs(x = "Hour of day")
+  if (uncertainty) {
+    gg.daily <- gg.daily +
+      ggplot2::geom_ribbon(ggplot2::aes(ymin = daily_lower,
+                                        ymax = daily_upper),
+                           alpha = 0.2,
+                           fill = "#0072B2",
+                           na.rm = TRUE)
+  }
+  return(gg.daily)
+}
+
 #' Plot the weekly component of the forecast.
 #'
 #' @param m Prophet model object
@@ -1196,7 +1285,7 @@ plot_holidays <- function(m, df, uncertainty = TRUE) {
 plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
   # Compute weekly seasonality for a Sun-Sat sequence of dates.
   df.w <- data.frame(
-    ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=7) +
+    ds=seq(set_date('2017-01-01'), by='d', length.out=7) +
     weekly_start, cap=1.)
   df.w <- setup_dataframe(m, df.w)$df
   seas <- predict_seasonal_components(m, df.w)
@@ -1229,7 +1318,7 @@ plot_weekly <- function(m, uncertainty = TRUE, weekly_start = 0) {
 plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
   # Compute yearly seasonality for a Jan 1 - Dec 31 sequence of dates.
   df.y <- data.frame(
-    ds=seq.Date(zoo::as.Date('2017-01-01'), by='d', length.out=365) +
+    ds=seq(set_date('2017-01-01'), by='d', length.out=365) +
     yearly_start, cap=1.)
   df.y <- setup_dataframe(m, df.y)$df
   seas <- predict_seasonal_components(m, df.y)
@@ -1238,7 +1327,6 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
   gg.yearly <- ggplot2::ggplot(seas, ggplot2::aes(x = ds, y = yearly,
                                                   group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE) +
-    ggplot2::scale_x_date(labels = scales::date_format('%B %d')) +
     ggplot2::labs(x = "Day of year")
   if (uncertainty) {
     gg.yearly <- gg.yearly +
@@ -1260,24 +1348,18 @@ plot_yearly <- function(m, uncertainty = TRUE, yearly_start = 0) {
 #' @return A ggplot2 plot.
 plot_seasonality <- function(m, name, uncertainty = TRUE) {
   # Compute seasonality from Jan 1 through a single period.
-  start <- zoo::as.Date('2017-01-01')
+  start <- set_date('2017-01-01')
   period <- m$seasonalities[[name]][1]
-  end <- start + period
-  plot.points <- as.numeric(end - start)
+  end <- start + period * 24 * 3600
+  plot.points <- as.numeric(difftime(end, start))
   df.y <- data.frame(
-    ds=seq.Date(from=start, to=end, length.out=plot.points), cap=1.)
+    ds=seq(from=start, by='d', length.out=plot.points), cap=1.)
   df.y <- setup_dataframe(m, df.y)$df
   seas <- predict_seasonal_components(m, df.y)
   seas$ds <- df.y$ds
   gg.s <- ggplot2::ggplot(
       seas, ggplot2::aes_string(x = 'ds', y = name, group = 1)) +
     ggplot2::geom_line(color = "#0072B2", na.rm = TRUE)
-  if (period < 14) {
-    fmt.str <- '%m/%d %R'
-  } else {
-    fmt.str <- '%m/%d'
-  }
-  gg.s <- gg.s + ggplot2::scale_x_date(labels = scales::date_format(fmt.str))
   if (uncertainty) {
     gg.s <- gg.s +
     ggplot2::geom_ribbon(

+ 13 - 13
R/tests/testthat/test_prophet.R

@@ -2,7 +2,7 @@ library(prophet)
 context("Prophet tests")
 
 DATA <- read.csv('data.csv')
-DATA$ds <- as.Date(DATA$ds)
+DATA$ds <- set_date(DATA$ds)
 N <- nrow(DATA)
 train <- DATA[1:floor(N / 2), ]
 future <- DATA[(ceiling(N/2) + 1):N, ]
@@ -27,9 +27,9 @@ test_that("fit_predict_no_changepoints", {
 
 test_that("fit_predict_changepoint_not_in_history", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds))
-  train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) |
-                                (ds > zoo::as.Date('2014-01-01')))
+  train_t <- dplyr::mutate(DATA, ds=set_date(ds))
+  train_t <- dplyr::filter(train_t, (ds < set_date('2013-01-01')) |
+                                (ds > set_date('2014-01-01')))
   future <- data.frame(ds=DATA$ds)
   m <- prophet(train_t, changepoints=c('2013-06-06'))
   expect_error(predict(m, future), NA)
@@ -170,19 +170,19 @@ test_that("piecewise_logistic", {
 })
 
 test_that("holidays", {
-  holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
+  holidays = data.frame(ds = set_date(c('2016-12-25')),
                         holiday = c('xmas'),
                         lower_window = c(-1),
                         upper_window = c(0))
   df <- data.frame(
-    ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d'))
+    ds = seq(set_date('2016-12-20'), set_date('2016-12-31'), by='d'))
   m <- prophet(train, holidays = holidays, fit = FALSE)
   feats <- prophet:::make_holiday_features(m, df$ds)
   expect_equal(nrow(feats), nrow(df))
   expect_equal(ncol(feats), 2)
   expect_equal(sum(colSums(feats) - c(1, 1)), 0)
 
-  holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
+  holidays = data.frame(ds = set_date(c('2016-12-25')),
                         holiday = c('xmas'),
                         lower_window = c(-1),
                         upper_window = c(10))
@@ -194,7 +194,7 @@ test_that("holidays", {
 
 test_that("fit_with_holidays", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')),
+  holidays <- data.frame(ds = set_date(c('2012-06-06', '2013-06-06')),
                          holiday = c('seans-bday', 'seans-bday'),
                          lower_window = c(0, 0),
                          upper_window = c(1, 1))
@@ -206,14 +206,14 @@ test_that("make_future_dataframe", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   train.t <- DATA[1:234, ]
   m <- prophet(train.t)
-  future <- make_future_dataframe(m, periods = 3, freq = 'd',
+  future <- make_future_dataframe(m, periods = 3, freq = 'day',
                                   include_history = FALSE)
-  correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
+  correct <- set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
   expect_equal(future$ds, correct)
 
-  future <- make_future_dataframe(m, periods = 3, freq = 'm',
+  future <- make_future_dataframe(m, periods = 3, freq = 'month',
                                   include_history = FALSE)
-  correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
+  correct <- set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
   expect_equal(future$ds, correct)
 })
 
@@ -263,7 +263,7 @@ test_that("auto_yearly_seasonality", {
 
 test_that("custom_seasonality", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
-  holidays <- data.frame(ds = zoo::as.Date(c('2017-01-02')),
+  holidays <- data.frame(ds = set_date(c('2017-01-02')),
                          holiday = c('special_day'))
   m <- prophet(holidays=holidays)
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)