test_prophet.R 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. N <- nrow(DATA)
  5. train <- DATA[1:floor(N / 2), ]
  6. future <- DATA[(ceiling(N/2) + 1):N, ]
  7. DATA2 <- read.csv('data2.csv')
  8. DATA$ds <- prophet:::set_date(DATA$ds)
  9. DATA2$ds <- prophet:::set_date(DATA2$ds)
  10. test_that("fit_predict", {
  11. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  12. m <- prophet(train)
  13. expect_error(predict(m, future), NA)
  14. })
  15. test_that("fit_predict_no_seasons", {
  16. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  17. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  18. expect_error(predict(m, future), NA)
  19. })
  20. test_that("fit_predict_no_changepoints", {
  21. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  22. m <- prophet(train, n.changepoints = 0)
  23. expect_error(predict(m, future), NA)
  24. })
  25. test_that("fit_predict_changepoint_not_in_history", {
  26. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  27. train_t <- dplyr::mutate(DATA, ds=prophet:::set_date(ds))
  28. train_t <- dplyr::filter(train_t,
  29. (ds < prophet:::set_date('2013-01-01')) |
  30. (ds > prophet:::set_date('2014-01-01')))
  31. future <- data.frame(ds=DATA$ds)
  32. m <- prophet(train_t, changepoints=c('2013-06-06'))
  33. expect_error(predict(m, future), NA)
  34. })
  35. test_that("fit_predict_duplicates", {
  36. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  37. train2 <- train
  38. train2$y <- train2$y + 10
  39. train_t <- rbind(train, train2)
  40. m <- prophet(train_t)
  41. expect_error(predict(m, future), NA)
  42. })
  43. test_that("fit_predict_constant_history", {
  44. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  45. train2 <- train
  46. train2$y <- 20
  47. m <- prophet(train2)
  48. fcst <- predict(m, future)
  49. expect_equal(tail(fcst$yhat, 1), 20)
  50. train2$y <- 0
  51. m <- prophet(train2)
  52. fcst <- predict(m, future)
  53. expect_equal(tail(fcst$yhat, 1), 0)
  54. })
  55. test_that("setup_dataframe", {
  56. history <- train
  57. m <- prophet(history, fit = FALSE)
  58. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  59. history <- out$df
  60. expect_true('t' %in% colnames(history))
  61. expect_equal(min(history$t), 0)
  62. expect_equal(max(history$t), 1)
  63. expect_true('y_scaled' %in% colnames(history))
  64. expect_equal(max(history$y_scaled), 1)
  65. })
  66. test_that("logistic_floor", {
  67. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  68. m <- prophet(growth = 'logistic')
  69. history <- train
  70. history$floor <- 10.
  71. history$cap <- 80.
  72. future1 <- future
  73. future1$cap <- 80.
  74. future1$floor <- 10.
  75. m <- fit.prophet(m, history, algorithm = 'Newton')
  76. expect_true(m$logistic.floor)
  77. expect_true('floor' %in% colnames(m$history))
  78. expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
  79. fcst1 <- predict(m, future1)
  80. m2 <- prophet(growth = 'logistic')
  81. history2 <- history
  82. history2$y <- history2$y + 10.
  83. history2$floor <- history2$floor + 10.
  84. history2$cap <- history2$cap + 10.
  85. future1$cap <- future1$cap + 10.
  86. future1$floor <- future1$floor + 10.
  87. m2 <- fit.prophet(m2, history2, algorithm = 'Newton')
  88. expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
  89. fcst2 <- predict(m, future1)
  90. fcst2$yhat <- fcst2$yhat - 10.
  91. # Check for approximate shift invariance
  92. expect_true(all(abs(fcst1$yhat - fcst2$yhat) < 1))
  93. })
  94. test_that("get_changepoints", {
  95. history <- train
  96. m <- prophet(history, fit = FALSE)
  97. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  98. history <- out$df
  99. m <- out$m
  100. m$history <- history
  101. m <- prophet:::set_changepoints(m)
  102. cp <- m$changepoints.t
  103. expect_equal(length(cp), m$n.changepoints)
  104. expect_true(min(cp) > 0)
  105. expect_true(max(cp) < 1)
  106. })
  107. test_that("get_zero_changepoints", {
  108. history <- train
  109. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  110. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  111. m <- out$m
  112. history <- out$df
  113. m$history <- history
  114. m <- prophet:::set_changepoints(m)
  115. cp <- m$changepoints.t
  116. expect_equal(length(cp), 1)
  117. expect_equal(cp[1], 0)
  118. })
  119. test_that("override_n_changepoints", {
  120. history <- train[1:20,]
  121. m <- prophet(history, fit = FALSE)
  122. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  123. m <- out$m
  124. history <- out$df
  125. m$history <- history
  126. m <- prophet:::set_changepoints(m)
  127. expect_equal(m$n.changepoints, 15)
  128. cp <- m$changepoints.t
  129. expect_equal(length(cp), 15)
  130. })
  131. test_that("fourier_series_weekly", {
  132. true.values <- c(0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837,
  133. -0.9009689)
  134. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  135. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  136. })
  137. test_that("fourier_series_yearly", {
  138. true.values <- c(0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249,
  139. 0.6874572)
  140. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  141. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  142. })
  143. test_that("growth_init", {
  144. history <- DATA[1:468, ]
  145. history$cap <- max(history$y)
  146. m <- prophet(history, growth = 'logistic', fit = FALSE)
  147. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  148. m <- out$m
  149. history <- out$df
  150. params <- prophet:::linear_growth_init(history)
  151. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  152. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  153. params <- prophet:::logistic_growth_init(history)
  154. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  155. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  156. })
  157. test_that("piecewise_linear", {
  158. t <- seq(0, 10)
  159. m <- 0
  160. k <- 1.0
  161. deltas <- c(0.5)
  162. changepoint.ts <- c(5)
  163. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  164. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  165. expect_equal(y, y.true)
  166. t <- t[8:length(t)]
  167. y.true <- y.true[8:length(y.true)]
  168. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  169. expect_equal(y, y.true)
  170. })
  171. test_that("piecewise_logistic", {
  172. t <- seq(0, 10)
  173. cap <- rep(10, 11)
  174. m <- 0
  175. k <- 1.0
  176. deltas <- c(0.5)
  177. changepoint.ts <- c(5)
  178. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  179. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  180. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  181. expect_equal(y, y.true, tolerance = 1e-6)
  182. t <- t[8:length(t)]
  183. y.true <- y.true[8:length(y.true)]
  184. cap <- cap[8:length(cap)]
  185. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  186. expect_equal(y, y.true, tolerance = 1e-6)
  187. })
  188. test_that("holidays", {
  189. holidays <- data.frame(ds = c('2016-12-25'),
  190. holiday = c('xmas'),
  191. lower_window = c(-1),
  192. upper_window = c(0))
  193. df <- data.frame(
  194. ds = seq(prophet:::set_date('2016-12-20'),
  195. prophet:::set_date('2016-12-31'), by='d'))
  196. m <- prophet(train, holidays = holidays, fit = FALSE)
  197. out <- prophet:::make_holiday_features(m, df$ds)
  198. feats <- out$holiday.features
  199. priors <- out$prior.scales
  200. names <- out$holiday.names
  201. expect_equal(nrow(feats), nrow(df))
  202. expect_equal(ncol(feats), 2)
  203. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  204. expect_true(all(priors == c(10., 10.)))
  205. expect_equal(names, c('xmas'))
  206. holidays <- data.frame(ds = c('2016-12-25'),
  207. holiday = c('xmas'),
  208. lower_window = c(-1),
  209. upper_window = c(10))
  210. m <- prophet(train, holidays = holidays, fit = FALSE)
  211. out <- prophet:::make_holiday_features(m, df$ds)
  212. feats <- out$holiday.features
  213. priors <- out$prior.scales
  214. names <- out$holiday.names
  215. expect_equal(nrow(feats), nrow(df))
  216. expect_equal(ncol(feats), 12)
  217. expect_true(all(priors == rep(10, 12)))
  218. expect_equal(names, c('xmas'))
  219. # Check prior specifications
  220. holidays <- data.frame(
  221. ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
  222. holiday = c('xmas', 'xmas'),
  223. lower_window = c(-1, -1),
  224. upper_window = c(0, 0),
  225. prior_scale = c(5., 5.)
  226. )
  227. m <- prophet(holidays = holidays, fit = FALSE)
  228. out <- prophet:::make_holiday_features(m, df$ds)
  229. priors <- out$prior.scales
  230. names <- out$holiday.names
  231. expect_true(all(priors == c(5., 5.)))
  232. expect_equal(names, c('xmas'))
  233. # 2 different priors
  234. holidays2 <- data.frame(
  235. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  236. holiday = c('seans-bday', 'seans-bday'),
  237. lower_window = c(0, 0),
  238. upper_window = c(1, 1),
  239. prior_scale = c(8, 8)
  240. )
  241. holidays2 <- rbind(holidays, holidays2)
  242. m <- prophet(holidays = holidays2, fit = FALSE)
  243. out <- prophet:::make_holiday_features(m, df$ds)
  244. priors <- out$prior.scales
  245. names <- out$holiday.names
  246. expect_true(all(priors == c(8, 8, 5, 5)))
  247. expect_true(all(sort(names) == c('seans-bday', 'xmas')))
  248. holidays2 <- data.frame(
  249. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  250. holiday = c('seans-bday', 'seans-bday'),
  251. lower_window = c(0, 0),
  252. upper_window = c(1, 1)
  253. )
  254. holidays2 <- dplyr::bind_rows(holidays, holidays2)
  255. m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
  256. out <- prophet:::make_holiday_features(m, df$ds)
  257. priors <- out$prior.scales
  258. expect_true(all(priors == c(4, 4, 5, 5)))
  259. # Check incompatible priors
  260. holidays <- data.frame(
  261. ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
  262. holiday = c('xmasish', 'xmasish'),
  263. lower_window = c(-1, -1),
  264. upper_window = c(0, 0),
  265. prior_scale = c(5., 6.)
  266. )
  267. m <- prophet(holidays = holidays, fit = FALSE)
  268. expect_error(prophet:::make_holiday_features(m, df$ds))
  269. })
  270. test_that("fit_with_holidays", {
  271. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  272. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  273. holiday = c('seans-bday', 'seans-bday'),
  274. lower_window = c(0, 0),
  275. upper_window = c(1, 1))
  276. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  277. expect_error(predict(m), NA)
  278. })
  279. test_that("make_future_dataframe", {
  280. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  281. train.t <- DATA[1:234, ]
  282. m <- prophet(train.t)
  283. future <- make_future_dataframe(m, periods = 3, freq = 'day',
  284. include_history = FALSE)
  285. correct <- prophet:::set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  286. expect_equal(future$ds, correct)
  287. future <- make_future_dataframe(m, periods = 3, freq = 'month',
  288. include_history = FALSE)
  289. correct <- prophet:::set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  290. expect_equal(future$ds, correct)
  291. })
  292. test_that("auto_weekly_seasonality", {
  293. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  294. # Should be enabled
  295. N.w <- 15
  296. train.w <- DATA[1:N.w, ]
  297. m <- prophet(train.w, fit = FALSE)
  298. expect_equal(m$weekly.seasonality, 'auto')
  299. m <- fit.prophet(m, train.w)
  300. expect_true('weekly' %in% names(m$seasonalities))
  301. true <- list(
  302. period = 7, fourier.order = 3, prior.scale = 10, mode = 'additive')
  303. for (name in names(true)) {
  304. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  305. }
  306. # Should be disabled due to too short history
  307. N.w <- 9
  308. train.w <- DATA[1:N.w, ]
  309. m <- prophet(train.w)
  310. expect_false('weekly' %in% names(m$seasonalities))
  311. m <- prophet(train.w, weekly.seasonality = TRUE)
  312. expect_true('weekly' %in% names(m$seasonalities))
  313. # Should be False due to weekly spacing
  314. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  315. m <- prophet(train.w)
  316. expect_false('weekly' %in% names(m$seasonalities))
  317. m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
  318. true <- list(
  319. period = 7, fourier.order = 2, prior.scale = 3, mode = 'additive')
  320. for (name in names(true)) {
  321. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  322. }
  323. })
  324. test_that("auto_yearly_seasonality", {
  325. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  326. # Should be enabled
  327. m <- prophet(DATA, fit = FALSE)
  328. expect_equal(m$yearly.seasonality, 'auto')
  329. m <- fit.prophet(m, DATA)
  330. expect_true('yearly' %in% names(m$seasonalities))
  331. true <- list(
  332. period = 365.25, fourier.order = 10, prior.scale = 10, mode = 'additive')
  333. for (name in names(true)) {
  334. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  335. }
  336. # Should be disabled due to too short history
  337. N.w <- 240
  338. train.y <- DATA[1:N.w, ]
  339. m <- prophet(train.y)
  340. expect_false('yearly' %in% names(m$seasonalities))
  341. m <- prophet(train.y, yearly.seasonality = TRUE)
  342. expect_true('yearly' %in% names(m$seasonalities))
  343. m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
  344. true <- list(
  345. period = 365.25, fourier.order = 7, prior.scale = 3, mode = 'additive')
  346. for (name in names(true)) {
  347. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  348. }
  349. })
  350. test_that("auto_daily_seasonality", {
  351. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  352. # Should be enabled
  353. m <- prophet(DATA2, fit = FALSE)
  354. expect_equal(m$daily.seasonality, 'auto')
  355. m <- fit.prophet(m, DATA2)
  356. expect_true('daily' %in% names(m$seasonalities))
  357. true <- list(
  358. period = 1, fourier.order = 4, prior.scale = 10, mode = 'additive')
  359. for (name in names(true)) {
  360. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  361. }
  362. # Should be disabled due to too short history
  363. N.d <- 430
  364. train.y <- DATA2[1:N.d, ]
  365. m <- prophet(train.y)
  366. expect_false('daily' %in% names(m$seasonalities))
  367. m <- prophet(train.y, daily.seasonality = TRUE)
  368. expect_true('daily' %in% names(m$seasonalities))
  369. m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
  370. true <- list(
  371. period = 1, fourier.order = 7, prior.scale = 3, mode = 'additive')
  372. for (name in names(true)) {
  373. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  374. }
  375. m <- prophet(DATA)
  376. expect_false('daily' %in% names(m$seasonalities))
  377. })
  378. test_that("test_subdaily_holidays", {
  379. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  380. holidays <- data.frame(ds = c('2017-01-02'),
  381. holiday = c('special_day'))
  382. m <- prophet(DATA2, holidays=holidays)
  383. fcst <- predict(m)
  384. expect_equal(sum(fcst$special_day == 0), 575)
  385. })
  386. test_that("custom_seasonality", {
  387. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  388. holidays <- data.frame(ds = c('2017-01-02'),
  389. holiday = c('special_day'),
  390. prior_scale = c(4))
  391. m <- prophet(holidays=holidays)
  392. m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
  393. true <- list(
  394. period = 30, fourier.order = 5, prior.scale = 10, mode = 'additive')
  395. for (name in names(true)) {
  396. expect_equal(m$seasonalities$monthly[[name]], true[[name]])
  397. }
  398. expect_error(
  399. add_seasonality(m, name='special_day', period=30, fourier_order=5)
  400. )
  401. expect_error(
  402. add_seasonality(m, name='trend', period=30, fourier_order=5)
  403. )
  404. m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
  405. # Test priors
  406. m <- prophet(
  407. holidays = holidays, yearly.seasonality = FALSE,
  408. seasonality.mode = 'multiplicative')
  409. m <- add_seasonality(
  410. m, name='monthly', period=30, fourier.order=5, prior.scale = 2,
  411. mode = 'additive')
  412. m <- fit.prophet(m, DATA)
  413. expect_equal(m$seasonalities$monthly$mode, 'additive')
  414. expect_equal(m$seasonalities$weekly$mode, 'multiplicative')
  415. out <- prophet:::make_all_seasonality_features(m, m$history)
  416. prior.scales <- out$prior.scales
  417. component.cols <- out$component.cols
  418. expect_equal(sum(component.cols$monthly), 10)
  419. expect_equal(sum(component.cols$special_day), 1)
  420. expect_equal(sum(component.cols$weekly), 6)
  421. expect_equal(sum(component.cols$additive_terms), 10)
  422. expect_equal(sum(component.cols$multiplicative_terms), 7)
  423. expect_equal(sum(component.cols$monthly[1:11]), 10)
  424. expect_equal(sum(component.cols$weekly[11:17]), 6)
  425. expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
  426. })
  427. test_that("added_regressors", {
  428. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  429. m <- prophet()
  430. m <- add_regressor(m, 'binary_feature', prior.scale=0.2)
  431. m <- add_regressor(m, 'numeric_feature', prior.scale=0.5)
  432. m <- add_regressor(
  433. m, 'numeric_feature2', prior.scale=0.5, mode = 'multiplicative')
  434. m <- add_regressor(m, 'binary_feature2', standardize=TRUE)
  435. df <- DATA
  436. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  437. df$numeric_feature <- 0:509
  438. df$numeric_feature2 <- 0:509
  439. # Require all regressors in df
  440. expect_error(
  441. fit.prophet(m, df)
  442. )
  443. df$binary_feature2 <- c(rep(1, 100), rep(0, 410))
  444. m <- fit.prophet(m, df)
  445. # Check that standardizations are correctly set
  446. true <- list(
  447. prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto', mode = 'additive'
  448. )
  449. for (name in names(true)) {
  450. expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]])
  451. }
  452. true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585)
  453. for (name in names(true)) {
  454. expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]],
  455. tolerance = 1e-5)
  456. }
  457. expect_equal(m$extra_regressors$numeric_feature2$mode, 'multiplicative')
  458. true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183)
  459. for (name in names(true)) {
  460. expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]],
  461. tolerance = 1e-5)
  462. }
  463. # Check that standardization is done correctly
  464. df2 <- prophet:::setup_dataframe(m, df)$df
  465. expect_equal(df2$binary_feature[1], 0)
  466. expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4)
  467. expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4)
  468. # Check that feature matrix and prior scales are correctly constructed
  469. out <- prophet:::make_all_seasonality_features(m, df2)
  470. seasonal.features <- out$seasonal.features
  471. prior.scales <- out$prior.scales
  472. component.cols <- out$component.cols
  473. modes <- out$modes
  474. expect_equal(ncol(seasonal.features), 30)
  475. r_names <- c('binary_feature', 'numeric_feature', 'binary_feature2')
  476. true.priors <- c(0.2, 0.5, 10.)
  477. for (i in seq_along(r_names)) {
  478. name <- r_names[i]
  479. expect_true(name %in% colnames(seasonal.features))
  480. expect_equal(sum(component.cols[[name]]), 1)
  481. expect_equal(sum(prior.scales * component.cols[[name]]), true.priors[i])
  482. }
  483. # Check that forecast components are reasonable
  484. future <- data.frame(
  485. ds = c('2014-06-01'),
  486. binary_feature = c(0),
  487. numeric_feature = c(10),
  488. numeric_feature2 = c(10)
  489. )
  490. expect_error(predict(m, future))
  491. future$binary_feature2 <- 0.
  492. fcst <- predict(m, future)
  493. expect_equal(ncol(fcst), 37)
  494. expect_equal(fcst$binary_feature[1], 0)
  495. expect_equal(fcst$extra_regressors_additive[1],
  496. fcst$numeric_feature[1] + fcst$binary_feature2[1])
  497. expect_equal(fcst$extra_regressors_multiplicative[1],
  498. fcst$numeric_feature2[1])
  499. expect_equal(fcst$additive_terms[1],
  500. fcst$yearly[1] + fcst$weekly[1]
  501. + fcst$extra_regressors_additive[1])
  502. expect_equal(fcst$multiplicative_terms[1],
  503. fcst$extra_regressors_multiplicative[1])
  504. expect_equal(
  505. fcst$yhat[1],
  506. fcst$trend[1] * (1 + fcst$multiplicative_terms[1]) + fcst$additive_terms[1]
  507. )
  508. # Check fails if constant extra regressor
  509. df$constant_feature <- 5
  510. m <- prophet()
  511. m <- add_regressor(m, 'constant_feature')
  512. expect_error(fit.prophet(m, df))
  513. })
  514. test_that("set_seasonality_mode", {
  515. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  516. m <- prophet()
  517. expect_equal(m$seasonality.mode, 'additive')
  518. m <- prophet(seasonality.mode = 'multiplicative')
  519. expect_equal(m$seasonality.mode, 'multiplicative')
  520. expect_error(prophet(seasonality.mode = 'batman'))
  521. })
  522. test_that("seasonality_modes", {
  523. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  524. holidays <- data.frame(ds = c('2016-12-25'),
  525. holiday = c('xmas'),
  526. lower_window = c(-1),
  527. upper_window = c(0))
  528. m <- prophet(seasonality.mode = 'multiplicative', holidays = holidays)
  529. m <- add_seasonality(
  530. m, name = 'monthly', period = 30, fourier.order = 3, mode = 'additive')
  531. m <- add_regressor(m, name = 'binary_feature', mode = 'additive')
  532. m <- add_regressor(m, name = 'numeric_feature')
  533. # Construct seasonal features
  534. df <- DATA
  535. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  536. df$numeric_feature <- 0:509
  537. out <- prophet:::setup_dataframe(m, df, initialize_scales = TRUE)
  538. df <- out$df
  539. m <- out$m
  540. m$history <- df
  541. m <- prophet:::set_auto_seasonalities(m)
  542. out <- prophet:::make_all_seasonality_features(m, df)
  543. component.cols <- out$component.cols
  544. modes <- out$modes
  545. expect_equal(sum(component.cols$additive_terms), 7)
  546. expect_equal(sum(component.cols$multiplicative_terms), 29)
  547. expect_equal(
  548. sort(modes$additive),
  549. c('additive_terms', 'binary_feature', 'extra_regressors_additive',
  550. 'monthly')
  551. )
  552. expect_equal(
  553. sort(modes$multiplicative),
  554. c('extra_regressors_multiplicative', 'holidays', 'multiplicative_terms',
  555. 'numeric_feature', 'weekly', 'xmas', 'yearly')
  556. )
  557. })