test_diagnostics.R 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. library(prophet)
  2. context("Prophet diagnostics tests")
  3. ## Makes R CMD CHECK happy due to dplyr syntax below
  4. globalVariables(c("y", "yhat"))
  5. DATA_all <- read.csv('data.csv')
  6. DATA_all$ds <- as.Date(DATA_all$ds)
  7. DATA <- head(DATA_all, 100)
  8. test_that("cross_validation", {
  9. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  10. m <- prophet(DATA)
  11. # Calculate the number of cutoff points
  12. te <- max(DATA$ds)
  13. ts <- min(DATA$ds)
  14. horizon <- as.difftime(4, units = "days")
  15. period <- as.difftime(10, units = "days")
  16. initial <- as.difftime(115, units = "days")
  17. df.cv <- cross_validation(
  18. m, horizon = 4, units = "days", period = 10, initial = 115)
  19. expect_equal(length(unique(df.cv$cutoff)), 3)
  20. expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
  21. expect_true(min(df.cv$cutoff) >= ts + initial)
  22. dc <- diff(df.cv$cutoff)
  23. dc <- min(dc[dc > 0])
  24. expect_true(dc >= period)
  25. expect_true(all(df.cv$cutoff < df.cv$ds))
  26. # Each y in df.cv and DATA with same ds should be equal
  27. df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
  28. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  29. df.cv <- cross_validation(
  30. m, horizon = 4, units = "days", period = 10, initial = 135)
  31. expect_equal(length(unique(df.cv$cutoff)), 1)
  32. expect_error(
  33. cross_validation(
  34. m, horizon = 10, units = "days", period = 10, initial = 140)
  35. )
  36. })
  37. test_that("cross_validation_logistic", {
  38. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  39. df <- DATA
  40. df$cap <- 40
  41. m <- prophet(df, growth = 'logistic')
  42. df.cv <- cross_validation(
  43. m, horizon = 1, units = "days", period = 1, initial = 140)
  44. expect_equal(length(unique(df.cv$cutoff)), 2)
  45. expect_true(all(df.cv$cutoff < df.cv$ds))
  46. df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
  47. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  48. })
  49. test_that("cross_validation_extra_regressors", {
  50. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  51. df <- DATA
  52. df$extra <- seq(0, nrow(df) - 1)
  53. m <- prophet()
  54. m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5)
  55. m <- add_regressor(m, 'extra')
  56. m <- fit.prophet(m, df)
  57. df.cv <- cross_validation(
  58. m, horizon = 4, units = "days", period = 4, initial = 135)
  59. expect_equal(length(unique(df.cv$cutoff)), 2)
  60. period <- as.difftime(4, units = "days")
  61. dc <- diff(df.cv$cutoff)
  62. dc <- min(dc[dc > 0])
  63. expect_true(dc >= period)
  64. expect_true(all(df.cv$cutoff < df.cv$ds))
  65. df.merged <- dplyr::left_join(df.cv, m$history, by="ds")
  66. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  67. })
  68. test_that("cross_validation_default_value_check", {
  69. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  70. m <- prophet(DATA)
  71. df.cv1 <- cross_validation(
  72. m, horizon = 32, units = "days", period = 10)
  73. df.cv2 <- cross_validation(
  74. m, horizon = 32, units = 'days', period = 10, initial = 96)
  75. expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
  76. })
  77. test_that("performance_metrics", {
  78. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  79. m <- prophet(DATA)
  80. df_cv <- cross_validation(
  81. m, horizon = 4, units = "days", period = 10, initial = 90)
  82. # Aggregation level none
  83. df_none <- performance_metrics(df_cv, rolling_window = 0)
  84. expect_true(all(
  85. sort(colnames(df_none))
  86. == sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
  87. ))
  88. expect_equal(nrow(df_none), 16)
  89. # Aggregation level 0.2
  90. df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
  91. expect_equal(length(unique(df_horizon$horizon)), 4)
  92. expect_equal(nrow(df_horizon), 14)
  93. # Aggregation level all
  94. df_all <- performance_metrics(df_cv, rolling_window = 1)
  95. expect_equal(nrow(df_all), 1)
  96. for (metric in c('mse', 'mape', 'mae', 'coverage')) {
  97. expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
  98. }
  99. # Custom list of metrics
  100. df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
  101. expect_true(all(
  102. sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
  103. ))
  104. })
  105. test_that("copy", {
  106. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  107. df <- DATA_all
  108. df$cap <- 200.
  109. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  110. inputs <- list(
  111. growth = c('linear', 'logistic'),
  112. yearly.seasonality = c(TRUE, FALSE),
  113. weekly.seasonality = c(TRUE, FALSE),
  114. daily.seasonality = c(TRUE, FALSE),
  115. holidays = c('null', 'insert_dataframe'),
  116. seasonality.mode = c('additive', 'multiplicative')
  117. )
  118. products <- expand.grid(inputs)
  119. for (i in 1:length(products)) {
  120. if (products$holidays[i] == 'insert_dataframe') {
  121. holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
  122. } else {
  123. holidays <- NULL
  124. }
  125. m1 <- prophet(
  126. growth = as.character(products$growth[i]),
  127. changepoints = NULL,
  128. n.changepoints = 3,
  129. changepoint.range = 0.9,
  130. yearly.seasonality = products$yearly.seasonality[i],
  131. weekly.seasonality = products$weekly.seasonality[i],
  132. daily.seasonality = products$daily.seasonality[i],
  133. holidays = holidays,
  134. seasonality.prior.scale = 1.1,
  135. holidays.prior.scale = 1.1,
  136. changepoints.prior.scale = 0.1,
  137. mcmc.samples = 100,
  138. interval.width = 0.9,
  139. uncertainty.samples = 200,
  140. fit = FALSE
  141. )
  142. out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
  143. m1 <- out$m
  144. m1$history <- out$df
  145. m1 <- prophet:::set_auto_seasonalities(m1)
  146. m2 <- prophet:::prophet_copy(m1)
  147. # Values should be copied correctly
  148. args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
  149. 'seasonality.prior.scale', 'holidays.prior.scale',
  150. 'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
  151. 'uncertainty.samples', 'seasonality.mode', 'changepoint.range')
  152. for (arg in args) {
  153. expect_equal(m1[[arg]], m2[[arg]])
  154. }
  155. expect_equal(FALSE, m2$yearly.seasonality)
  156. expect_equal(FALSE, m2$weekly.seasonality)
  157. expect_equal(FALSE, m2$daily.seasonality)
  158. expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
  159. expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
  160. expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
  161. }
  162. # Check for cutoff and custom seasonality and extra regressors
  163. changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
  164. cutoff <- as.Date('2012-07-25')
  165. m1 <- prophet(changepoints = changepoints)
  166. m1 <- add_seasonality(m1, 'custom', 10, 5)
  167. m1 <- add_regressor(m1, 'binary_feature')
  168. m1 <- fit.prophet(m1, df)
  169. m2 <- prophet:::prophet_copy(m1, cutoff)
  170. changepoints <- changepoints[changepoints <= cutoff]
  171. expect_equal(prophet:::set_date(changepoints), m2$changepoints)
  172. expect_true('custom' %in% names(m2$seasonalities))
  173. expect_true('binary_feature' %in% names(m2$extra_regressors))
  174. })