test_prophet.R 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. DATA$ds <- as.Date(DATA$ds)
  5. N <- nrow(DATA)
  6. train <- DATA[1:floor(N / 2), ]
  7. future <- DATA[(ceiling(N/2) + 1):N, ]
  8. test_that("fit_predict", {
  9. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  10. m <- prophet(train)
  11. expect_error(predict(m, future), NA)
  12. })
  13. test_that("fit_predict_no_seasons", {
  14. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  15. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  16. expect_error(predict(m, future), NA)
  17. })
  18. test_that("fit_predict_no_changepoints", {
  19. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  20. m <- prophet(train, n.changepoints = 0)
  21. expect_error(predict(m, future), NA)
  22. })
  23. test_that("fit_predict_changepoint_not_in_history", {
  24. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  25. train_t <- dplyr::mutate(DATA, ds=zoo::as.Date(ds))
  26. train_t <- dplyr::filter(train_t, (ds < zoo::as.Date('2013-01-01')) |
  27. (ds > zoo::as.Date('2014-01-01')))
  28. future <- data.frame(ds=DATA$ds)
  29. m <- prophet(train_t, changepoints=c('2013-06-06'))
  30. expect_error(predict(m, future), NA)
  31. })
  32. test_that("fit_predict_duplicates", {
  33. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  34. train2 <- train
  35. train2$y <- train2$y + 10
  36. train_t <- rbind(train, train2)
  37. m <- prophet(train_t)
  38. expect_error(predict(m, future), NA)
  39. })
  40. test_that("setup_dataframe", {
  41. history <- train
  42. m <- prophet(history, fit = FALSE)
  43. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  44. history <- out$df
  45. expect_true('t' %in% colnames(history))
  46. expect_equal(min(history$t), 0)
  47. expect_equal(max(history$t), 1)
  48. expect_true('y_scaled' %in% colnames(history))
  49. expect_equal(max(history$y_scaled), 1)
  50. })
  51. test_that("get_changepoints", {
  52. history <- train
  53. m <- prophet(history, fit = FALSE)
  54. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  55. history <- out$df
  56. m <- out$m
  57. m$history <- history
  58. m <- prophet:::set_changepoints(m)
  59. cp <- m$changepoints.t
  60. expect_equal(length(cp), m$n.changepoints)
  61. expect_true(min(cp) > 0)
  62. expect_true(max(cp) < N)
  63. mat <- prophet:::get_changepoint_matrix(m)
  64. expect_equal(nrow(mat), floor(N / 2))
  65. expect_equal(ncol(mat), m$n.changepoints)
  66. })
  67. test_that("get_zero_changepoints", {
  68. history <- train
  69. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  70. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  71. m <- out$m
  72. history <- out$df
  73. m$history <- history
  74. m <- prophet:::set_changepoints(m)
  75. cp <- m$changepoints.t
  76. expect_equal(length(cp), 1)
  77. expect_equal(cp[1], 0)
  78. mat <- prophet:::get_changepoint_matrix(m)
  79. expect_equal(nrow(mat), floor(N / 2))
  80. expect_equal(ncol(mat), 1)
  81. })
  82. test_that("fourier_series_weekly", {
  83. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  84. true.values <- c(0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837,
  85. -0.9009689)
  86. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  87. })
  88. test_that("fourier_series_yearly", {
  89. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  90. true.values <- c(0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249,
  91. 0.6874572)
  92. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  93. })
  94. test_that("growth_init", {
  95. history <- DATA[1:468, ]
  96. history$cap <- max(history$y)
  97. m <- prophet(history, growth = 'logistic', fit = FALSE)
  98. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  99. m <- out$m
  100. history <- out$df
  101. params <- prophet:::linear_growth_init(history)
  102. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  103. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  104. params <- prophet:::logistic_growth_init(history)
  105. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  106. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  107. })
  108. test_that("piecewise_linear", {
  109. t <- seq(0, 10)
  110. m <- 0
  111. k <- 1.0
  112. deltas <- c(0.5)
  113. changepoint.ts <- c(5)
  114. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  115. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  116. expect_equal(y, y.true)
  117. t <- t[8:length(t)]
  118. y.true <- y.true[8:length(y.true)]
  119. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  120. expect_equal(y, y.true)
  121. })
  122. test_that("piecewise_logistic", {
  123. t <- seq(0, 10)
  124. cap <- rep(10, 11)
  125. m <- 0
  126. k <- 1.0
  127. deltas <- c(0.5)
  128. changepoint.ts <- c(5)
  129. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  130. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  131. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  132. expect_equal(y, y.true, tolerance = 1e-6)
  133. t <- t[8:length(t)]
  134. y.true <- y.true[8:length(y.true)]
  135. cap <- cap[8:length(cap)]
  136. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  137. expect_equal(y, y.true, tolerance = 1e-6)
  138. })
  139. test_that("holidays", {
  140. holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
  141. holiday = c('xmas'),
  142. lower_window = c(-1),
  143. upper_window = c(0))
  144. df <- data.frame(
  145. ds = seq(zoo::as.Date('2016-12-20'), zoo::as.Date('2016-12-31'), by='d'))
  146. m <- prophet(train, holidays = holidays, fit = FALSE)
  147. feats <- prophet:::make_holiday_features(m, df$ds)
  148. expect_equal(nrow(feats), nrow(df))
  149. expect_equal(ncol(feats), 2)
  150. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  151. holidays = data.frame(ds = zoo::as.Date(c('2016-12-25')),
  152. holiday = c('xmas'),
  153. lower_window = c(-1),
  154. upper_window = c(10))
  155. m <- prophet(train, holidays = holidays, fit = FALSE)
  156. feats <- prophet:::make_holiday_features(m, df$ds)
  157. expect_equal(nrow(feats), nrow(df))
  158. expect_equal(ncol(feats), 12)
  159. })
  160. test_that("fit_with_holidays", {
  161. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  162. holidays <- data.frame(ds = zoo::as.Date(c('2012-06-06', '2013-06-06')),
  163. holiday = c('seans-bday', 'seans-bday'),
  164. lower_window = c(0, 0),
  165. upper_window = c(1, 1))
  166. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  167. expect_error(predict(m), NA)
  168. })
  169. test_that("make_future_dataframe", {
  170. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  171. train.t <- DATA[1:234, ]
  172. m <- prophet(train.t)
  173. future <- make_future_dataframe(m, periods = 3, freq = 'd',
  174. include_history = FALSE)
  175. correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  176. expect_equal(future$ds, correct)
  177. future <- make_future_dataframe(m, periods = 3, freq = 'm',
  178. include_history = FALSE)
  179. correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  180. expect_equal(future$ds, correct)
  181. })
  182. test_that("auto_weekly_seasonality", {
  183. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  184. # Should be True
  185. N.w <- 15
  186. train.w <- DATA[1:N.w, ]
  187. m <- prophet(train.w, fit = FALSE)
  188. expect_equal(m$weekly.seasonality, 'auto')
  189. m <- prophet:::fit.prophet(m, train.w)
  190. expect_equal(m$weekly.seasonality, TRUE)
  191. # Should be False due to too short history
  192. N.w <- 9
  193. train.w <- DATA[1:N.w, ]
  194. m <- prophet(train.w)
  195. expect_equal(m$weekly.seasonality, FALSE)
  196. m <- prophet(train.w, weekly.seasonality = TRUE)
  197. expect_equal(m$weekly.seasonality, TRUE)
  198. # Should be False due to weekly spacing
  199. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  200. m <- prophet(train.w)
  201. expect_equal(m$weekly.seasonality, FALSE)
  202. })
  203. test_that("auto_yearly_seasonality", {
  204. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  205. # Should be True
  206. m <- prophet(DATA, fit = FALSE)
  207. expect_equal(m$yearly.seasonality, 'auto')
  208. m <- prophet:::fit.prophet(m, DATA)
  209. expect_equal(m$yearly.seasonality, TRUE)
  210. # Should be False due to too short history
  211. N.w <- 240
  212. train.y <- DATA[1:N.w, ]
  213. m <- prophet(train.y)
  214. expect_equal(m$yearly.seasonality, FALSE)
  215. m <- prophet(train.y, yearly.seasonality = TRUE)
  216. expect_equal(m$yearly.seasonality, TRUE)
  217. })