test_diagnostics.R 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. library(prophet)
  2. context("Prophet diagnostics tests")
  3. ## Makes R CMD CHECK happy due to dplyr syntax below
  4. globalVariables(c("y", "yhat"))
  5. DATA <- head(read.csv('data.csv'), 100)
  6. DATA$ds <- as.Date(DATA$ds)
  7. test_that("simulated_historical_forecasts", {
  8. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  9. m <- prophet(DATA)
  10. k <- 2
  11. for (p in c(1, 10)) {
  12. for (h in c(1, 3)) {
  13. df.shf <- simulated_historical_forecasts(
  14. m, horizon = h, units = 'days', k = k, period = p)
  15. # All cutoff dates should be less than ds dates
  16. expect_true(all(df.shf$cutoff < df.shf$ds))
  17. # The unique size of output cutoff should be equal to 'k'
  18. expect_equal(length(unique(df.shf$cutoff)), k)
  19. expect_equal(max(df.shf$ds - df.shf$cutoff),
  20. as.difftime(h, units = 'days'))
  21. dc <- diff(df.shf$cutoff)
  22. dc <- min(dc[dc > 0])
  23. expect_true(dc >= as.difftime(p, units = 'days'))
  24. # Each y in df_shf and DATA with same ds should be equal
  25. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  26. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  27. }
  28. }
  29. })
  30. test_that("simulated_historical_forecasts_logistic", {
  31. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  32. df <- DATA
  33. df$cap <- 40
  34. m <- prophet(df, growth='logistic')
  35. df.shf <- simulated_historical_forecasts(
  36. m, horizon = 3, units = 'days', k = 2, period = 3)
  37. # All cutoff dates should be less than ds dates
  38. expect_true(all(df.shf$cutoff < df.shf$ds))
  39. # The unique size of output cutoff should be equal to 'k'
  40. expect_equal(length(unique(df.shf$cutoff)), 2)
  41. # Each y in df_shf and DATA with same ds should be equal
  42. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  43. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  44. })
  45. test_that("simulated_historical_forecasts_extra_regressors", {
  46. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  47. df <- DATA
  48. df$extra <- seq(0, nrow(df) - 1)
  49. m <- prophet()
  50. m <- add_seasonality(m, name = 'monthly', period = 30.5, fourier.order = 5)
  51. m <- add_regressor(m, 'extra')
  52. m <- fit.prophet(m, df)
  53. df.shf <- simulated_historical_forecasts(
  54. m, horizon = 3, units = 'days', k = 2, period = 3)
  55. # All cutoff dates should be less than ds dates
  56. expect_true(all(df.shf$cutoff < df.shf$ds))
  57. # The unique size of output cutoff should be equal to 'k'
  58. expect_equal(length(unique(df.shf$cutoff)), 2)
  59. # Each y in df_shf and DATA with same ds should be equal
  60. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  61. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  62. })
  63. test_that("simulated_historical_forecasts_default_value_check", {
  64. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  65. m <- prophet(DATA)
  66. df.shf1 <- simulated_historical_forecasts(
  67. m, horizon = 10, units = 'days', k = 1)
  68. df.shf2 <- simulated_historical_forecasts(
  69. m, horizon = 10, units = 'days', k = 1, period = 5)
  70. expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
  71. })
  72. test_that("cross_validation", {
  73. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  74. m <- prophet(DATA)
  75. # Calculate the number of cutoff points
  76. te <- max(DATA$ds)
  77. ts <- min(DATA$ds)
  78. horizon <- as.difftime(4, units = "days")
  79. period <- as.difftime(10, units = "days")
  80. k <- 5
  81. df.cv <- cross_validation(
  82. m, horizon = 4, units = "days", period = 10, initial = 90)
  83. expect_equal(length(unique(df.cv$cutoff)), k)
  84. expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
  85. dc <- diff(df.cv$cutoff)
  86. dc <- min(dc[dc > 0])
  87. expect_true(dc >= period)
  88. })
  89. test_that("cross_validation_default_value_check", {
  90. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  91. m <- prophet(DATA)
  92. df.cv1 <- cross_validation(
  93. m, horizon = 32, units = "days", period = 10)
  94. df.cv2 <- cross_validation(
  95. m, horizon = 32, units = 'days', period = 10, initial = 96)
  96. expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
  97. })
  98. test_that("performance_metrics", {
  99. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  100. m <- prophet(DATA)
  101. df_cv <- cross_validation(
  102. m, horizon = 4, units = "days", period = 10, initial = 90)
  103. # Aggregation level none
  104. df_none <- performance_metrics(df_cv, rolling_window = 0)
  105. expect_true(all(
  106. sort(colnames(df_none))
  107. == sort(c('horizon', 'coverage', 'mae', 'mape', 'mse', 'rmse'))
  108. ))
  109. expect_equal(nrow(df_none), 14)
  110. # Aggregation level 0.2
  111. df_horizon <- performance_metrics(df_cv, rolling_window = 0.2)
  112. expect_equal(length(unique(df_horizon$horizon)), 4)
  113. expect_equal(nrow(df_horizon), 13)
  114. # Aggregation level all
  115. df_all <- performance_metrics(df_cv, rolling_window = 1)
  116. expect_equal(nrow(df_all), 1)
  117. for (metric in c('mse', 'mape', 'mae', 'coverage')) {
  118. expect_equal(df_all[[metric]][1], mean(df_none[[metric]]))
  119. }
  120. # Custom list of metrics
  121. df_horizon <- performance_metrics(df_cv, metrics = c('coverage', 'mse'))
  122. expect_true(all(
  123. sort(colnames(df_horizon)) == sort(c('coverage', 'mse', 'horizon'))
  124. ))
  125. })