Sfoglia il codice sorgente

Fix copy with extra seasonalities / regressors R

bletham 8 anni fa
parent
commit
0addabcad7
2 ha cambiato i file con 50 aggiunte e 27 eliminazioni
  1. 15 6
      R/R/prophet.R
  2. 35 21
      R/tests/testthat/test_prophet.R

+ 15 - 6
R/R/prophet.R

@@ -1701,6 +1701,10 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
 #'
 #' @keywords internal
 prophet_copy <- function(m, cutoff = NULL) {
+  if (is.null(m$history)) {
+    stop("This is for copying a fitted Prophet object.")
+  }
+
   if (m$specified.changepoints) {
     changepoints <- m$changepoints
     if (!is.null(cutoff)) {
@@ -1710,13 +1714,15 @@ prophet_copy <- function(m, cutoff = NULL) {
   } else {
     changepoints <- NULL
   }
-  return(prophet(
+  # Auto seasonalities are set to FALSE because they are already set in
+  # m$seasonalities.
+  m2 <- prophet(
     growth = m$growth,
     changepoints = changepoints,
     n.changepoints = m$n.changepoints,
-    yearly.seasonality = m$yearly.seasonality,
-    weekly.seasonality = m$weekly.seasonality,
-    daily.seasonality = m$daily.seasonality,
+    yearly.seasonality = FALSE,
+    weekly.seasonality = FALSE,
+    daily.seasonality = FALSE,
     holidays = m$holidays,
     seasonality.prior.scale = m$seasonality.prior.scale,
     changepoint.prior.scale = m$changepoint.prior.scale,
@@ -1724,8 +1730,11 @@ prophet_copy <- function(m, cutoff = NULL) {
     mcmc.samples = m$mcmc.samples,
     interval.width = m$interval.width,
     uncertainty.samples = m$uncertainty.samples,
-    fit = FALSE,
-  ))
+    fit = FALSE
+  )
+  m2$extra_regressors <- m$extra_regressors
+  m2$seasonalities <- m$seasonalities
+  return(m2)
 }
 
 # fb-block 3

+ 35 - 21
R/tests/testthat/test_prophet.R

@@ -520,20 +520,15 @@ test_that("added_regressors", {
 
 test_that("copy", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  df <- DATA
+  df$cap <- 200.
+  df$binary_feature <- c(rep(0, 255), rep(1, 255))
   inputs <- list(
     growth = c('linear', 'logistic'),
-    changepoints = c(NULL, c('2016-12-25')),
-    n.changepoints = c(3),
     yearly.seasonality = c(TRUE, FALSE),
     weekly.seasonality = c(TRUE, FALSE),
     daily.seasonality = c(TRUE, FALSE),
-    holidays = c(NULL, 'insert_dataframe'),
-    seasonality.prior.scale = c(1.1),
-    holidays.prior.scale = c(1.1),
-    changepoints.prior.scale = c(0.1),
-    mcmc.samples = c(100),
-    interval.width = c(0.9),
-    uncertainty.samples = c(200)
+    holidays = c('null', 'insert_dataframe')
   )
   products <- expand.grid(inputs)
   for (i in 1:length(products)) {
@@ -543,32 +538,51 @@ test_that("copy", {
       holidays <- NULL
     }
     m1 <- prophet(
-      growth = products$growth[i],
-      changepoints = products$changepoints[i],
-      n.changepoints = products$n.changepoints[i],
+      growth = as.character(products$growth[i]),
+      changepoints = NULL,
+      n.changepoints = 3,
       yearly.seasonality = products$yearly.seasonality[i],
       weekly.seasonality = products$weekly.seasonality[i],
       daily.seasonality = products$daily.seasonality[i],
       holidays = holidays,
-      seasonality.prior.scale = products$seasonality.prior.scale[i],
-      holidays.prior.scale = products$holidays.prior.scale[i],
-      changepoints.prior.scale = products$changepoints.prior.scale[i],
-      mcmc.samples = products$mcmc.samples[i],
-      interval.width = products$interval.width[i],
-      uncertainty.samples = products$uncertainty.samples[i],
+      seasonality.prior.scale = 1.1,
+      holidays.prior.scale = 1.1,
+      changepoints.prior.scale = 0.1,
+      mcmc.samples = 100,
+      interval.width = 0.9,
+      uncertainty.samples = 200,
       fit = FALSE
     )
+    out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
+    m1 <- out$m
+    m1$history <- out$df
+    m1 <- prophet:::set_auto_seasonalities(m1)
     m2 <- prophet:::prophet_copy(m1)
     # Values should be copied correctly
-    for (arg in names(inputs)) {
+    args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
+              'seasonality.prior.scale', 'holidays.prior.scale',
+              'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
+              'uncertainty.samples')
+    for (arg in args) {
       expect_equal(m1[[arg]], m2[[arg]])
     }
+    expect_equal(FALSE, m2$yearly.seasonality)
+    expect_equal(FALSE, m2$weekly.seasonality)
+    expect_equal(FALSE, m2$daily.seasonality)
+    expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
+    expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
+    expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
   }
-  # Check for cutoff
+  # Check for cutoff and custom seasonality and extra regressors
   changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
   cutoff <- as.Date('2012-07-25')
-  m1 <- prophet(DATA, changepoints = changepoints)
+  m1 <- prophet(changepoints = changepoints)
+  m1 <- add_seasonality(m1, 'custom', 10, 5)
+  m1 <- add_regressor(m1, 'binary_feature')
+  m1 <- fit.prophet(m1, df)
   m2 <- prophet:::prophet_copy(m1, cutoff)
   changepoints <- changepoints[changepoints <= cutoff]
   expect_equal(prophet:::set_date(changepoints), m2$changepoints)
+  expect_true('custom' %in% names(m2$seasonalities))
+  expect_true('binary_feature' %in% names(m2$extra_regressors))
 })