123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- library(prophet)
- context("Prophet diagnostics tests")
- ## Makes R CMD CHECK happy due to dplyr syntax below
- globalVariables(c("y", "yhat"))
- DATA_all <- read.csv('data.csv')
- DATA_all$ds <- as.Date(DATA_all$ds)
- DATA <- head(DATA_all, 100)
- test_that("cross_validation", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- m <- prophet(DATA)
- # Calculate the number of cutoff points
- te <- max(DATA$ds)
- ts <- min(DATA$ds)
- horizon <- as.difftime(4, units = "days")
- period <- as.difftime(10, units = "days")
- initial <- as.difftime(115, units = "days")
- df.cv <- cross_validation(
- m, horizon = 4, units = "days", period = 10, initial = 115)
- expect_equal(length(unique(df.cv$cutoff)), 3)
- expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
- expect_true(min(df.cv$cutoff) >= ts + initial)
- dc <- diff(df.cv$cutoff)
- dc <- min(dc[dc > 0])
- expect_true(dc >= period)
- expect_true(all(df.cv$cutoff < df.cv$ds))
- # Each y in df.cv and DATA with same ds should be equal
- df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
- expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
- df.cv <- cross_validation(
- m, horizon = 4, units = "days", period = 10, initial = 135)
- expect_equal(length(unique(df.cv$cutoff)), 1)
- expect_error(
- cross_validation(
- m, horizon = 10, units = "days", period = 10, initial = 140)
- )
- })
- test_that("cross_validation_logistic", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- df <- DATA
- df$cap <- 40
- m <- prophet(df, growth = 'logistic')
- df.cv <- cross_validation(
- m, horizon = 1, units = "days", period = 1, initial = 140)
- expect_equal(length(unique(df.cv$cutoff)), 2)
- expect_true(all(df.cv$cutoff < df.cv$ds))
- df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
- expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
- })
- test_that("cross_validation_extra_regressors", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- df <- DATA
- df$extra <- seq(0, nrow(df) - 1)
- m <- prophet()
- m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5)
- m <- add_regressor(m, 'extra')
- m <- fit.prophet(m, df)
- df.cv <- cross_validation(
- m, horizon = 4, units = "days", period = 4, initial = 135)
- expect_equal(length(unique(df.cv$cutoff)), 2)
- period <- as.difftime(4, units = "days")
- dc <- diff(df.cv$cutoff)
- dc <- min(dc[dc > 0])
- expect_true(dc >= period)
- expect_true(all(df.cv$cutoff < df.cv$ds))
- df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
- expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
- })
- test_that("cross_validation_default_value_check", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- m <- prophet(DATA)
- df.cv1 <- cross_validation(
- m, horizon = 32, units = "days", period = 10)
- df.cv2 <- cross_validation(
- m, horizon = 32, units = 'days', period = 10, initial = 96)
- expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
- })
- test_that("performance_metrics", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- m <- prophet(DATA)
- df_cv <- cross_validation(
- m, horizon = 4, units = "days", period = 10, initial = 90)
- # Aggregation level none
- df_none <- performance_metrics(df_cv, rolling_window = 0)
- expect_true(all(
- sort(colnames(df_none))
- == sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
- ))
- expect_equal(nrow(df_none), 16)
- # Aggregation level 0.2
- df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
- expect_equal(length(unique(df_horizon$horizon)), 4)
- expect_equal(nrow(df_horizon), 14)
- # Aggregation level all
- df_all <- performance_metrics(df_cv, rolling_window = 1)
- expect_equal(nrow(df_all), 1)
- for (metric in c('mse', 'mape', 'mae', 'coverage')) {
- expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
- }
- # Custom list of metrics
- df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
- expect_true(all(
- sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
- ))
- })
- test_that("copy", {
- skip_if_not(Sys.getenv('R_ARCH') != '/i386')
- df <- DATA_all
- df$cap <- 200.
- df$binary_feature <- c(rep(0, 255), rep(1, 255))
- inputs <- list(
- growth = c('linear', 'logistic'),
- yearly.seasonality = c(TRUE, FALSE),
- weekly.seasonality = c(TRUE, FALSE),
- daily.seasonality = c(TRUE, FALSE),
- holidays = c('null', 'insert_dataframe'),
- seasonality.mode = c('additive', 'multiplicative')
- )
- products <- expand.grid(inputs)
- for (i in 1:length(products)) {
- if (products$holidays[i] == 'insert_dataframe') {
- holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
- } else {
- holidays <- NULL
- }
- m1 <- prophet(
- growth = as.character(products$growth[i]),
- changepoints = NULL,
- n.changepoints = 3,
- changepoint.range = 0.9,
- yearly.seasonality = products$yearly.seasonality[i],
- weekly.seasonality = products$weekly.seasonality[i],
- daily.seasonality = products$daily.seasonality[i],
- holidays = holidays,
- seasonality.prior.scale = 1.1,
- holidays.prior.scale = 1.1,
- changepoints.prior.scale = 0.1,
- mcmc.samples = 100,
- interval.width = 0.9,
- uncertainty.samples = 200,
- fit = FALSE
- )
- out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
- m1 <- out$m
- m1$history <- out$df
- m1 <- prophet:::set_auto_seasonalities(m1)
- m2 <- prophet:::prophet_copy(m1)
- # Values should be copied correctly
- args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
- 'seasonality.prior.scale', 'holidays.prior.scale',
- 'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
- 'uncertainty.samples', 'seasonality.mode', 'changepoint.range')
- for (arg in args) {
- expect_equal(m1[[arg]], m2[[arg]])
- }
- expect_equal(FALSE, m2$yearly.seasonality)
- expect_equal(FALSE, m2$weekly.seasonality)
- expect_equal(FALSE, m2$daily.seasonality)
- expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
- expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
- expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
- }
- # Check for cutoff and custom seasonality and extra regressors
- changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
- cutoff <- as.Date('2012-07-25')
- m1 <- prophet(changepoints = changepoints)
- m1 <- add_seasonality(m1, 'custom', 10, 5)
- m1 <- add_regressor(m1, 'binary_feature')
- m1 <- fit.prophet(m1, df)
- m2 <- prophet:::prophet_copy(m1, cutoff)
- changepoints <- changepoints[changepoints <= cutoff]
- expect_equal(prophet:::set_date(changepoints), m2$changepoints)
- expect_true('custom' %in% names(m2$seasonalities))
- expect_true('binary_feature' %in% names(m2$extra_regressors))
- })
|