test_diagnostics.R 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. library(prophet)
  2. context("Prophet diagnostics tests")
  3. ## Makes R CMD CHECK happy due to dplyr syntax below
  4. globalVariables(c("y", "yhat"))
  5. DATA <- head(read.csv('data.csv'), 100)
  6. DATA$ds <- as.Date(DATA$ds)
  7. test_that("simulated_historical_forecasts", {
  8. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  9. m <- prophet(DATA)
  10. k <- 2
  11. for (p in c(1, 10)) {
  12. for (h in c(1, 3)) {
  13. df.shf <- simulated_historical_forecasts(
  14. m, horizon = h, units = 'days', k = k, period = p)
  15. # All cutoff dates should be less than ds dates
  16. expect_true(all(df.shf$cutoff < df.shf$ds))
  17. # The unique size of output cutoff should be equal to 'k'
  18. expect_equal(length(unique(df.shf$cutoff)), k)
  19. expect_equal(max(df.shf$ds - df.shf$cutoff),
  20. as.difftime(h, units = 'days'))
  21. dc <- diff(df.shf$cutoff)
  22. dc <- min(dc[dc > 0])
  23. expect_true(dc >= as.difftime(p, units = 'days'))
  24. # Each y in df_shf and DATA with same ds should be equal
  25. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  26. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  27. }
  28. }
  29. })
  30. test_that("simulated_historical_forecasts_logistic", {
  31. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  32. df <- DATA
  33. df$cap <- 40
  34. m <- prophet(df, growth='logistic')
  35. df.shf <- simulated_historical_forecasts(
  36. m, horizon = 3, units = 'days', k = 2, period = 3)
  37. # All cutoff dates should be less than ds dates
  38. expect_true(all(df.shf$cutoff < df.shf$ds))
  39. # The unique size of output cutoff should be equal to 'k'
  40. expect_equal(length(unique(df.shf$cutoff)), 2)
  41. # Each y in df_shf and DATA with same ds should be equal
  42. df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
  43. expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
  44. })
  45. test_that("simulated_historical_forecasts_default_value_check", {
  46. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  47. m <- prophet(DATA)
  48. df.shf1 <- simulated_historical_forecasts(
  49. m, horizon = 10, units = 'days', k = 1)
  50. df.shf2 <- simulated_historical_forecasts(
  51. m, horizon = 10, units = 'days', k = 1, period = 5)
  52. expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
  53. })
  54. test_that("cross_validation", {
  55. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  56. m <- prophet(DATA)
  57. # Calculate the number of cutoff points
  58. te <- max(DATA$ds)
  59. ts <- min(DATA$ds)
  60. horizon <- as.difftime(4, units = "days")
  61. period <- as.difftime(10, units = "days")
  62. k <- 5
  63. df.cv <- cross_validation(
  64. m, horizon = 4, units = "days", period = 10, initial = 90)
  65. expect_equal(length(unique(df.cv$cutoff)), k)
  66. expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
  67. dc <- diff(df.cv$cutoff)
  68. dc <- min(dc[dc > 0])
  69. expect_true(dc >= period)
  70. })
  71. test_that("cross_validation_default_value_check", {
  72. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  73. m <- prophet(DATA)
  74. df.cv1 <- cross_validation(
  75. m, horizon = 32, units = "days", period = 10)
  76. df.cv2 <- cross_validation(
  77. m, horizon = 32, units = 'days', period = 10, initial = 96)
  78. expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
  79. })