test_diagnostics.R 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. library(prophet)
  2. context("Prophet diagnostics tests")
  3. ## Makes R CMD CHECK happy due to dplyr syntax below
  4. globalVariables(c("y", "yhat"))
  5. DATA_all <- read.csv('data.csv')
  6. DATA_all$ds <- as.Date(DATA_all$ds)
  7. DATA <- head(DATA_all, 100)
  8. test_that("simulated_historical_forecasts", {
  9. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  10. m <- prophet(DATA)
  11. k <- 2
  12. for (p in c(1, 10)) {
  13. for (h in c(1, 3)) {
  14. df.shf <- simulated_historical_forecasts(
  15. m, horizon = h, units = 'days', k = k, period = p)
  16. # All cutoff dates should be less than ds dates
  17. expect_true(all(df.shf$cutoff < df.shf$ds))
  18. # The unique size of output cutoff should be equal to 'k'
  19. expect_equal(length(unique(df.shf$cutoff)), k)
  20. expect_equal(max(df.shf$ds - df.shf$cutoff),
  21. as.difftime(h, units = 'days'))
  22. dc <- diff(df.shf$cutoff)
  23. dc <- min(dc[dc > 0])
  24. expect_true(dc >= as.difftime(p, units = 'days'))
  25. # Each y in df_shf and DATA with same ds should be equal
  26. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  27. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  28. }
  29. }
  30. })
  31. test_that("simulated_historical_forecasts_logistic", {
  32. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  33. df <- DATA
  34. df$cap <- 40
  35. m <- prophet(df, growth='logistic')
  36. df.shf <- simulated_historical_forecasts(
  37. m, horizon = 3, units = 'days', k = 2, period = 3)
  38. # All cutoff dates should be less than ds dates
  39. expect_true(all(df.shf$cutoff < df.shf$ds))
  40. # The unique size of output cutoff should be equal to 'k'
  41. expect_equal(length(unique(df.shf$cutoff)), 2)
  42. # Each y in df_shf and DATA with same ds should be equal
  43. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  44. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  45. })
  46. test_that("simulated_historical_forecasts_extra_regressors", {
  47. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  48. df <- DATA
  49. df$extra <- seq(0, nrow(df) - 1)
  50. m <- prophet()
  51. m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5)
  52. m <- add_regressor(m, 'extra')
  53. m <- fit.prophet(m, df)
  54. df.shf <- simulated_historical_forecasts(
  55. m, horizon = 3, units = 'days', k = 2, period = 3)
  56. # All cutoff dates should be less than ds dates
  57. expect_true(all(df.shf$cutoff < df.shf$ds))
  58. # The unique size of output cutoff should be equal to 'k'
  59. expect_equal(length(unique(df.shf$cutoff)), 2)
  60. # Each y in df_shf and DATA with same ds should be equal
  61. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  62. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  63. })
  64. test_that("simulated_historical_forecasts_default_value_check", {
  65. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  66. m <- prophet(DATA)
  67. df.shf1 <- simulated_historical_forecasts(
  68. m, horizon = 10, units = 'days', k = 1)
  69. df.shf2 <- simulated_historical_forecasts(
  70. m, horizon = 10, units = 'days', k = 1, period = 5)
  71. expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
  72. })
  73. test_that("cross_validation", {
  74. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  75. m <- prophet(DATA)
  76. # Calculate the number of cutoff points
  77. te <- max(DATA$ds)
  78. ts <- min(DATA$ds)
  79. horizon <- as.difftime(4, units = "days")
  80. period <- as.difftime(10, units = "days")
  81. k <- 5
  82. df.cv <- cross_validation(
  83. m, horizon = 4, units = "days", period = 10, initial = 90)
  84. expect_equal(length(unique(df.cv$cutoff)), k)
  85. expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
  86. dc <- diff(df.cv$cutoff)
  87. dc <- min(dc[dc > 0])
  88. expect_true(dc >= period)
  89. })
  90. test_that("cross_validation_default_value_check", {
  91. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  92. m <- prophet(DATA)
  93. df.cv1 <- cross_validation(
  94. m, horizon = 32, units = "days", period = 10)
  95. df.cv2 <- cross_validation(
  96. m, horizon = 32, units = 'days', period = 10, initial = 96)
  97. expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
  98. })
  99. test_that("performance_metrics", {
  100. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  101. m <- prophet(DATA)
  102. df_cv <- cross_validation(
  103. m, horizon = 4, units = "days", period = 10, initial = 90)
  104. # Aggregation level none
  105. df_none <- performance_metrics(df_cv, rolling_window = 0)
  106. expect_true(all(
  107. sort(colnames(df_none))
  108. == sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
  109. ))
  110. expect_equal(nrow(df_none), 14)
  111. # Aggregation level 0.2
  112. df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
  113. expect_equal(length(unique(df_horizon$horizon)), 4)
  114. expect_equal(nrow(df_horizon), 13)
  115. # Aggregation level all
  116. df_all <- performance_metrics(df_cv, rolling_window = 1)
  117. expect_equal(nrow(df_all), 1)
  118. for (metric in c('mse', 'mape', 'mae', 'coverage')) {
  119. expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
  120. }
  121. # Custom list of metrics
  122. df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
  123. expect_true(all(
  124. sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
  125. ))
  126. })
  127. test_that("copy", {
  128. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  129. df <- DATA_all
  130. df$cap <- 200.
  131. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  132. inputs <- list(
  133. growth = c('linear', 'logistic'),
  134. yearly.seasonality = c(TRUE, FALSE),
  135. weekly.seasonality = c(TRUE, FALSE),
  136. daily.seasonality = c(TRUE, FALSE),
  137. holidays = c('null', 'insert_dataframe'),
  138. seasonality.mode = c('additive', 'multiplicative')
  139. )
  140. products <- expand.grid(inputs)
  141. for (i in 1:length(products)) {
  142. if (products$holidays[i] == 'insert_dataframe') {
  143. holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
  144. } else {
  145. holidays <- NULL
  146. }
  147. m1 <- prophet(
  148. growth = as.character(products$growth[i]),
  149. changepoints = NULL,
  150. n.changepoints = 3,
  151. changepoint.range = 0.9,
  152. yearly.seasonality = products$yearly.seasonality[i],
  153. weekly.seasonality = products$weekly.seasonality[i],
  154. daily.seasonality = products$daily.seasonality[i],
  155. holidays = holidays,
  156. seasonality.prior.scale = 1.1,
  157. holidays.prior.scale = 1.1,
  158. changepoints.prior.scale = 0.1,
  159. mcmc.samples = 100,
  160. interval.width = 0.9,
  161. uncertainty.samples = 200,
  162. fit = FALSE
  163. )
  164. out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
  165. m1 <- out$m
  166. m1$history <- out$df
  167. m1 <- prophet:::set_auto_seasonalities(m1)
  168. m2 <- prophet:::prophet_copy(m1)
  169. # Values should be copied correctly
  170. args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
  171. 'seasonality.prior.scale', 'holidays.prior.scale',
  172. 'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
  173. 'uncertainty.samples', 'seasonality.mode', 'changepoint.range')
  174. for (arg in args) {
  175. expect_equal(m1[[arg]], m2[[arg]])
  176. }
  177. expect_equal(FALSE, m2$yearly.seasonality)
  178. expect_equal(FALSE, m2$weekly.seasonality)
  179. expect_equal(FALSE, m2$daily.seasonality)
  180. expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
  181. expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
  182. expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
  183. }
  184. # Check for cutoff and custom seasonality and extra regressors
  185. changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
  186. cutoff <- as.Date('2012-07-25')
  187. m1 <- prophet(changepoints = changepoints)
  188. m1 <- add_seasonality(m1, 'custom', 10, 5)
  189. m1 <- add_regressor(m1, 'binary_feature')
  190. m1 <- fit.prophet(m1, df)
  191. m2 <- prophet:::prophet_copy(m1, cutoff)
  192. changepoints <- changepoints[changepoints <= cutoff]
  193. expect_equal(prophet:::set_date(changepoints), m2$changepoints)
  194. expect_true('custom' %in% names(m2$seasonalities))
  195. expect_true('binary_feature' %in% names(m2$extra_regressors))
  196. })