test_prophet.R 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. library(prophet)
  2. context("Prophet tests")
  3. DATA <- read.csv('data.csv')
  4. N <- nrow(DATA)
  5. train <- DATA[1:floor(N / 2), ]
  6. future <- DATA[(ceiling(N/2) + 1):N, ]
  7. DATA2 <- read.csv('data2.csv')
  8. DATA$ds <- prophet:::set_date(DATA$ds)
  9. DATA2$ds <- prophet:::set_date(DATA2$ds)
  10. test_that("fit_predict", {
  11. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  12. m <- prophet(train)
  13. expect_error(predict(m, future), NA)
  14. })
  15. test_that("fit_predict_no_seasons", {
  16. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  17. m <- prophet(train, weekly.seasonality = FALSE, yearly.seasonality = FALSE)
  18. expect_error(predict(m, future), NA)
  19. })
  20. test_that("fit_predict_no_changepoints", {
  21. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  22. expect_warning({
  23. # warning from prophet(), error from predict()
  24. m <- prophet(train, n.changepoints = 0)
  25. expect_error(predict(m, future), NA)
  26. })
  27. })
  28. test_that("fit_predict_changepoint_not_in_history", {
  29. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  30. train_t <- dplyr::mutate(DATA, ds=prophet:::set_date(ds))
  31. train_t <- dplyr::filter(train_t,
  32. (ds < prophet:::set_date('2013-01-01')) |
  33. (ds > prophet:::set_date('2014-01-01')))
  34. future <- data.frame(ds=DATA$ds)
  35. expect_warning({
  36. # warning from prophet(), error from predict()
  37. m <- prophet(train_t, changepoints=c('2013-06-06'))
  38. expect_error(predict(m, future), NA)
  39. })
  40. })
  41. test_that("fit_predict_duplicates", {
  42. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  43. train2 <- train
  44. train2$y <- train2$y + 10
  45. train_t <- rbind(train, train2)
  46. m <- prophet(train_t)
  47. expect_error(predict(m, future), NA)
  48. })
  49. test_that("fit_predict_constant_history", {
  50. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  51. train2 <- train
  52. train2$y <- 20
  53. m <- prophet(train2)
  54. fcst <- predict(m, future)
  55. expect_equal(tail(fcst$yhat, 1), 20)
  56. train2$y <- 0
  57. m <- prophet(train2)
  58. fcst <- predict(m, future)
  59. expect_equal(tail(fcst$yhat, 1), 0)
  60. })
  61. test_that("setup_dataframe", {
  62. history <- train
  63. m <- prophet(history, fit = FALSE)
  64. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  65. history <- out$df
  66. expect_true('t' %in% colnames(history))
  67. expect_equal(min(history$t), 0)
  68. expect_equal(max(history$t), 1)
  69. expect_true('y_scaled' %in% colnames(history))
  70. expect_equal(max(history$y_scaled), 1)
  71. })
  72. test_that("logistic_floor", {
  73. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  74. m <- prophet(growth = 'logistic')
  75. history <- train
  76. history$floor <- 10.
  77. history$cap <- 80.
  78. future1 <- future
  79. future1$cap <- 80.
  80. future1$floor <- 10.
  81. m <- fit.prophet(m, history, algorithm = 'Newton')
  82. expect_true(m$logistic.floor)
  83. expect_true('floor' %in% colnames(m$history))
  84. expect_equal(m$history$y_scaled[1], 1., tolerance = 1e-6)
  85. fcst1 <- predict(m, future1)
  86. m2 <- prophet(growth = 'logistic')
  87. history2 <- history
  88. history2$y <- history2$y + 10.
  89. history2$floor <- history2$floor + 10.
  90. history2$cap <- history2$cap + 10.
  91. future1$cap <- future1$cap + 10.
  92. future1$floor <- future1$floor + 10.
  93. m2 <- fit.prophet(m2, history2, algorithm = 'Newton')
  94. expect_equal(m2$history$y_scaled[1], 1., tolerance = 1e-6)
  95. fcst2 <- predict(m, future1)
  96. fcst2$yhat <- fcst2$yhat - 10.
  97. # Check for approximate shift invariance
  98. expect_true(all(abs(fcst1$yhat - fcst2$yhat) < 1))
  99. })
  100. test_that("get_changepoints", {
  101. history <- train
  102. m <- prophet(history, fit = FALSE)
  103. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  104. history <- out$df
  105. m <- out$m
  106. m$history <- history
  107. m <- prophet:::set_changepoints(m)
  108. cp <- m$changepoints.t
  109. expect_equal(length(cp), m$n.changepoints)
  110. expect_true(min(cp) > 0)
  111. expect_true(max(cp) < N)
  112. mat <- prophet:::get_changepoint_matrix(m)
  113. expect_equal(nrow(mat), floor(N / 2))
  114. expect_equal(ncol(mat), m$n.changepoints)
  115. })
  116. test_that("get_zero_changepoints", {
  117. history <- train
  118. m <- prophet(history, n.changepoints = 0, fit = FALSE)
  119. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  120. m <- out$m
  121. history <- out$df
  122. m$history <- history
  123. m <- prophet:::set_changepoints(m)
  124. cp <- m$changepoints.t
  125. expect_equal(length(cp), 1)
  126. expect_equal(cp[1], 0)
  127. mat <- prophet:::get_changepoint_matrix(m)
  128. expect_equal(nrow(mat), floor(N / 2))
  129. expect_equal(ncol(mat), 1)
  130. })
  131. test_that("override_n_changepoints", {
  132. history <- train[1:20,]
  133. m <- prophet(history, fit = FALSE)
  134. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  135. m <- out$m
  136. history <- out$df
  137. m$history <- history
  138. m <- prophet:::set_changepoints(m)
  139. expect_equal(m$n.changepoints, 15)
  140. cp <- m$changepoints.t
  141. expect_equal(length(cp), 15)
  142. })
  143. test_that("fourier_series_weekly", {
  144. true.values <- c(0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837,
  145. -0.9009689)
  146. mat <- prophet:::fourier_series(DATA$ds, 7, 3)
  147. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  148. })
  149. test_that("fourier_series_yearly", {
  150. true.values <- c(0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249,
  151. 0.6874572)
  152. mat <- prophet:::fourier_series(DATA$ds, 365.25, 3)
  153. expect_equal(true.values, mat[1, ], tolerance = 1e-6)
  154. })
  155. test_that("growth_init", {
  156. history <- DATA[1:468, ]
  157. history$cap <- max(history$y)
  158. m <- prophet(history, growth = 'logistic', fit = FALSE)
  159. out <- prophet:::setup_dataframe(m, history, initialize_scales = TRUE)
  160. m <- out$m
  161. history <- out$df
  162. params <- prophet:::linear_growth_init(history)
  163. expect_equal(params[1], 0.3055671, tolerance = 1e-6)
  164. expect_equal(params[2], 0.5307511, tolerance = 1e-6)
  165. params <- prophet:::logistic_growth_init(history)
  166. expect_equal(params[1], 1.507925, tolerance = 1e-6)
  167. expect_equal(params[2], -0.08167497, tolerance = 1e-6)
  168. })
  169. test_that("piecewise_linear", {
  170. t <- seq(0, 10)
  171. m <- 0
  172. k <- 1.0
  173. deltas <- c(0.5)
  174. changepoint.ts <- c(5)
  175. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  176. y.true <- c(0, 1, 2, 3, 4, 5, 6.5, 8, 9.5, 11, 12.5)
  177. expect_equal(y, y.true)
  178. t <- t[8:length(t)]
  179. y.true <- y.true[8:length(y.true)]
  180. y <- prophet:::piecewise_linear(t, deltas, k, m, changepoint.ts)
  181. expect_equal(y, y.true)
  182. })
  183. test_that("piecewise_logistic", {
  184. t <- seq(0, 10)
  185. cap <- rep(10, 11)
  186. m <- 0
  187. k <- 1.0
  188. deltas <- c(0.5)
  189. changepoint.ts <- c(5)
  190. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  191. y.true <- c(5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071,
  192. 9.984988, 9.996646, 9.999252, 9.999833, 9.999963)
  193. expect_equal(y, y.true, tolerance = 1e-6)
  194. t <- t[8:length(t)]
  195. y.true <- y.true[8:length(y.true)]
  196. cap <- cap[8:length(cap)]
  197. y <- prophet:::piecewise_logistic(t, cap, deltas, k, m, changepoint.ts)
  198. expect_equal(y, y.true, tolerance = 1e-6)
  199. })
  200. test_that("holidays", {
  201. holidays = data.frame(ds = c('2016-12-25'),
  202. holiday = c('xmas'),
  203. lower_window = c(-1),
  204. upper_window = c(0))
  205. df <- data.frame(
  206. ds = seq(prophet:::set_date('2016-12-20'),
  207. prophet:::set_date('2016-12-31'), by='d'))
  208. m <- prophet(train, holidays = holidays, fit = FALSE)
  209. out <- prophet:::make_holiday_features(m, df$ds)
  210. feats <- out$holiday.features
  211. priors <- out$prior.scales
  212. expect_equal(nrow(feats), nrow(df))
  213. expect_equal(ncol(feats), 2)
  214. expect_equal(sum(colSums(feats) - c(1, 1)), 0)
  215. expect_true(all(priors == c(10., 10.)))
  216. holidays = data.frame(ds = c('2016-12-25'),
  217. holiday = c('xmas'),
  218. lower_window = c(-1),
  219. upper_window = c(10))
  220. m <- prophet(train, holidays = holidays, fit = FALSE)
  221. out <- prophet:::make_holiday_features(m, df$ds)
  222. feats <- out$holiday.features
  223. priors <- out$prior.scales
  224. expect_equal(nrow(feats), nrow(df))
  225. expect_equal(ncol(feats), 12)
  226. expect_true(all(priors == rep(10, 12)))
  227. # Check prior specifications
  228. holidays <- data.frame(
  229. ds = prophet:::set_date(c('2016-12-25', '2017-12-25')),
  230. holiday = c('xmas', 'xmas'),
  231. lower_window = c(-1, -1),
  232. upper_window = c(0, 0),
  233. prior_scale = c(5., 5.)
  234. )
  235. m <- prophet(holidays = holidays, fit = FALSE)
  236. out <- prophet:::make_holiday_features(m, df$ds)
  237. priors <- out$prior.scales
  238. expect_true(all(priors == c(5., 5.)))
  239. # 2 different priors
  240. holidays2 <- data.frame(
  241. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  242. holiday = c('seans-bday', 'seans-bday'),
  243. lower_window = c(0, 0),
  244. upper_window = c(1, 1),
  245. prior_scale = c(8, 8)
  246. )
  247. holidays2 <- rbind(holidays, holidays2)
  248. m <- prophet(holidays = holidays2, fit = FALSE)
  249. out <- prophet:::make_holiday_features(m, df$ds)
  250. priors <- out$prior.scales
  251. expect_true(all(priors == c(8, 8, 5, 5)))
  252. holidays2 <- data.frame(
  253. ds = prophet:::set_date(c('2012-06-06', '2013-06-06')),
  254. holiday = c('seans-bday', 'seans-bday'),
  255. lower_window = c(0, 0),
  256. upper_window = c(1, 1)
  257. )
  258. # manual coercions to avoid below bind_rows() warning
  259. holidays$holiday <- as.character(holidays$holiday)
  260. holidays2$holiday <- as.character(holidays2$holiday)
  261. holidays2 <- dplyr::bind_rows(holidays, holidays2)
  262. # manual factorizing to avoid above bind_rows() warning
  263. holidays2$holiday <- factor(holidays2$holiday)
  264. m <- prophet(holidays = holidays2, fit = FALSE, holidays.prior.scale = 4)
  265. out <- prophet:::make_holiday_features(m, df$ds)
  266. priors <- out$prior.scales
  267. expect_true(all(priors == c(4, 4, 5, 5)))
  268. # Check incompatible priors
  269. holidays <- data.frame(
  270. ds = prophet:::set_date(c('2016-12-25', '2016-12-27')),
  271. holiday = c('xmasish', 'xmasish'),
  272. lower_window = c(-1, -1),
  273. upper_window = c(0, 0),
  274. prior_scale = c(5., 6.)
  275. )
  276. m <- prophet(holidays = holidays, fit = FALSE)
  277. expect_error(prophet:::make_holiday_features(m, df$ds))
  278. })
  279. test_that("fit_with_holidays", {
  280. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  281. holidays <- data.frame(ds = c('2012-06-06', '2013-06-06'),
  282. holiday = c('seans-bday', 'seans-bday'),
  283. lower_window = c(0, 0),
  284. upper_window = c(1, 1))
  285. m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
  286. expect_error(predict(m), NA)
  287. })
  288. test_that("make_future_dataframe", {
  289. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  290. train.t <- DATA[1:234, ]
  291. m <- prophet(train.t)
  292. future <- make_future_dataframe(m, periods = 3, freq = 'day',
  293. include_history = FALSE)
  294. correct <- prophet:::set_date(c('2013-04-26', '2013-04-27', '2013-04-28'))
  295. expect_equal(future$ds, correct)
  296. future <- make_future_dataframe(m, periods = 3, freq = 'month',
  297. include_history = FALSE)
  298. correct <- prophet:::set_date(c('2013-05-25', '2013-06-25', '2013-07-25'))
  299. expect_equal(future$ds, correct)
  300. })
  301. test_that("auto_weekly_seasonality", {
  302. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  303. # Should be enabled
  304. N.w <- 15
  305. train.w <- DATA[1:N.w, ]
  306. m <- prophet(train.w, fit = FALSE)
  307. expect_equal(m$weekly.seasonality, 'auto')
  308. m <- fit.prophet(m, train.w)
  309. expect_true('weekly' %in% names(m$seasonalities))
  310. true <- list(period = 7, fourier.order = 3, prior.scale = 10)
  311. for (name in names(true)) {
  312. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  313. }
  314. # Should be disabled due to too short history
  315. N.w <- 9
  316. train.w <- DATA[1:N.w, ]
  317. m <- prophet(train.w)
  318. expect_false('weekly' %in% names(m$seasonalities))
  319. expect_warning({
  320. # prophet warning: non-zero return code in optimizing
  321. m <- prophet(train.w, weekly.seasonality = TRUE)
  322. expect_true('weekly' %in% names(m$seasonalities))
  323. })
  324. # Should be False due to weekly spacing
  325. train.w <- DATA[seq(1, nrow(DATA), 7), ]
  326. m <- prophet(train.w)
  327. expect_false('weekly' %in% names(m$seasonalities))
  328. m <- prophet(DATA, weekly.seasonality = 2, seasonality.prior.scale = 3)
  329. true <- list(period = 7, fourier.order = 2, prior.scale = 3)
  330. for (name in names(true)) {
  331. expect_equal(m$seasonalities$weekly[[name]], true[[name]])
  332. }
  333. })
  334. test_that("auto_yearly_seasonality", {
  335. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  336. # Should be enabled
  337. m <- prophet(DATA, fit = FALSE)
  338. expect_equal(m$yearly.seasonality, 'auto')
  339. m <- fit.prophet(m, DATA)
  340. expect_true('yearly' %in% names(m$seasonalities))
  341. true <- list(period = 365.25, fourier.order = 10, prior.scale = 10)
  342. for (name in names(true)) {
  343. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  344. }
  345. # Should be disabled due to too short history
  346. N.w <- 240
  347. train.y <- DATA[1:N.w, ]
  348. m <- prophet(train.y)
  349. expect_false('yearly' %in% names(m$seasonalities))
  350. m <- prophet(train.y, yearly.seasonality = TRUE)
  351. expect_true('yearly' %in% names(m$seasonalities))
  352. m <- prophet(DATA, yearly.seasonality = 7, seasonality.prior.scale = 3)
  353. true <- list(period = 365.25, fourier.order = 7, prior.scale = 3)
  354. for (name in names(true)) {
  355. expect_equal(m$seasonalities$yearly[[name]], true[[name]])
  356. }
  357. })
  358. test_that("auto_daily_seasonality", {
  359. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  360. # Should be enabled
  361. m <- prophet(DATA2, fit = FALSE)
  362. expect_equal(m$daily.seasonality, 'auto')
  363. m <- fit.prophet(m, DATA2)
  364. expect_true('daily' %in% names(m$seasonalities))
  365. true <- list(period = 1, fourier.order = 4, prior.scale = 10)
  366. for (name in names(true)) {
  367. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  368. }
  369. # Should be disabled due to too short history
  370. N.d <- 430
  371. train.y <- DATA2[1:N.d, ]
  372. m <- prophet(train.y)
  373. expect_false('daily' %in% names(m$seasonalities))
  374. m <- prophet(train.y, daily.seasonality = TRUE)
  375. expect_true('daily' %in% names(m$seasonalities))
  376. m <- prophet(DATA2, daily.seasonality = 7, seasonality.prior.scale = 3)
  377. true <- list(period = 1, fourier.order = 7, prior.scale = 3)
  378. for (name in names(true)) {
  379. expect_equal(m$seasonalities$daily[[name]], true[[name]])
  380. }
  381. m <- prophet(DATA)
  382. expect_false('daily' %in% names(m$seasonalities))
  383. })
  384. test_that("test_subdaily_holidays", {
  385. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  386. holidays <- data.frame(ds = c('2017-01-02'),
  387. holiday = c('special_day'))
  388. m <- prophet(DATA2, holidays=holidays)
  389. fcst <- predict(m)
  390. expect_equal(sum(fcst$special_day == 0), 575)
  391. })
  392. test_that("custom_seasonality", {
  393. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  394. holidays <- data.frame(ds = c('2017-01-02'),
  395. holiday = c('special_day'),
  396. prior_scale = c(4))
  397. m <- prophet(holidays=holidays)
  398. m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
  399. true <- list(period = 30, fourier.order = 5, prior.scale = 10)
  400. for (name in names(true)) {
  401. expect_equal(m$seasonalities$monthly[[name]], true[[name]])
  402. }
  403. expect_error(
  404. add_seasonality(m, name='special_day', period=30, fourier_order=5)
  405. )
  406. expect_error(
  407. add_seasonality(m, name='trend', period=30, fourier_order=5)
  408. )
  409. m <- add_seasonality(m, name='weekly', period=30, fourier.order=5)
  410. # Test priors
  411. m <- prophet(holidays = holidays, yearly.seasonality = FALSE)
  412. m <- add_seasonality(
  413. m, name='monthly', period=30, fourier.order=5, prior.scale = 2)
  414. m <- fit.prophet(m, DATA)
  415. prior.scales <- prophet:::make_all_seasonality_features(
  416. m, m$history)$prior.scales
  417. expect_true(all(prior.scales == c(rep(2, 10), rep(10, 6), 4)))
  418. })
  419. test_that("added_regressors", {
  420. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  421. m <- prophet()
  422. m <- add_regressor(m, 'binary_feature', prior.scale=0.2)
  423. m <- add_regressor(m, 'numeric_feature', prior.scale=0.5)
  424. m <- add_regressor(m, 'binary_feature2', standardize=TRUE)
  425. df <- DATA
  426. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  427. df$numeric_feature <- 0:509
  428. # Require all regressors in df
  429. expect_error(
  430. fit.prophet(m, df)
  431. )
  432. df$binary_feature2 <- c(rep(1, 100), rep(0, 410))
  433. m <- fit.prophet(m, df)
  434. # Check that standardizations are correctly set
  435. true <- list(prior.scale = 0.2, mu = 0, std = 1, standardize = 'auto')
  436. for (name in names(true)) {
  437. expect_equal(true[[name]], m$extra_regressors$binary_feature[[name]])
  438. }
  439. true <- list(prior.scale = 0.5, mu = 254.5, std = 147.368585)
  440. for (name in names(true)) {
  441. expect_equal(true[[name]], m$extra_regressors$numeric_feature[[name]],
  442. tolerance = 1e-5)
  443. }
  444. true <- list(prior.scale = 10., mu = 0.1960784, std = 0.3974183)
  445. for (name in names(true)) {
  446. expect_equal(true[[name]], m$extra_regressors$binary_feature2[[name]],
  447. tolerance = 1e-5)
  448. }
  449. # Check that standardization is done correctly
  450. df2 <- prophet:::setup_dataframe(m, df)$df
  451. expect_equal(df2$binary_feature[1], 0)
  452. expect_equal(df2$numeric_feature[1], -1.726962, tolerance = 1e-4)
  453. expect_equal(df2$binary_feature2[1], 2.022859, tolerance = 1e-4)
  454. # Check that feature matrix and prior scales are correctly constructed
  455. out <- prophet:::make_all_seasonality_features(m, df2)
  456. seasonal.features <- out$seasonal.features
  457. prior.scales <- out$prior.scales
  458. expect_true('binary_feature' %in% colnames(seasonal.features))
  459. expect_true('numeric_feature' %in% colnames(seasonal.features))
  460. expect_true('binary_feature2' %in% colnames(seasonal.features))
  461. expect_equal(ncol(seasonal.features), 29)
  462. expect_true(all(sort(prior.scales[27:29]) == c(0.2, 0.5, 10.)))
  463. # Check that forecast components are reasonable
  464. future <- data.frame(
  465. ds = c('2014-06-01'), binary_feature = c(0), numeric_feature = c(10))
  466. expect_error(predict(m, future))
  467. future$binary_feature2 <- 0.
  468. fcst <- predict(m, future)
  469. expect_equal(ncol(fcst), 31)
  470. expect_equal(fcst$binary_feature[1], 0)
  471. expect_equal(fcst$extra_regressors[1],
  472. fcst$numeric_feature[1] + fcst$binary_feature2[1])
  473. expect_equal(fcst$seasonalities[1], fcst$yearly[1] + fcst$weekly[1])
  474. expect_equal(fcst$seasonal[1],
  475. fcst$seasonalities[1] + fcst$extra_regressors[1])
  476. expect_equal(fcst$yhat[1], fcst$trend[1] + fcst$seasonal[1])
  477. # Check fails if constant extra regressor
  478. df$constant_feature <- 5
  479. m <- prophet()
  480. m <- add_regressor(m, 'constant_feature')
  481. expect_error(fit.prophet(m, df))
  482. })
  483. test_that("copy", {
  484. skip_if_not(Sys.getenv('R_ARCH') != '/i386')
  485. df <- DATA
  486. df$cap <- 200.
  487. df$binary_feature <- c(rep(0, 255), rep(1, 255))
  488. inputs <- list(
  489. growth = c('linear', 'logistic'),
  490. yearly.seasonality = c(TRUE, FALSE),
  491. weekly.seasonality = c(TRUE, FALSE),
  492. daily.seasonality = c(TRUE, FALSE),
  493. holidays = c('null', 'insert_dataframe')
  494. )
  495. products <- expand.grid(inputs)
  496. for (i in 1:length(products)) {
  497. if (products$holidays[i] == 'insert_dataframe') {
  498. holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
  499. } else {
  500. holidays <- NULL
  501. }
  502. m1 <- prophet(
  503. growth = as.character(products$growth[i]),
  504. changepoints = NULL,
  505. n.changepoints = 3,
  506. yearly.seasonality = products$yearly.seasonality[i],
  507. weekly.seasonality = products$weekly.seasonality[i],
  508. daily.seasonality = products$daily.seasonality[i],
  509. holidays = holidays,
  510. seasonality.prior.scale = 1.1,
  511. holidays.prior.scale = 1.1,
  512. changepoints.prior.scale = 0.1,
  513. mcmc.samples = 100,
  514. interval.width = 0.9,
  515. uncertainty.samples = 200,
  516. fit = FALSE
  517. )
  518. out <- prophet:::setup_dataframe(m1, df, initialize_scales = TRUE)
  519. m1 <- out$m
  520. m1$history <- out$df
  521. m1 <- prophet:::set_auto_seasonalities(m1)
  522. m2 <- prophet:::prophet_copy(m1)
  523. # Values should be copied correctly
  524. args <- c('growth', 'changepoints', 'n.changepoints', 'holidays',
  525. 'seasonality.prior.scale', 'holidays.prior.scale',
  526. 'changepoints.prior.scale', 'mcmc.samples', 'interval.width',
  527. 'uncertainty.samples')
  528. for (arg in args) {
  529. expect_equal(m1[[arg]], m2[[arg]])
  530. }
  531. expect_equal(FALSE, m2$yearly.seasonality)
  532. expect_equal(FALSE, m2$weekly.seasonality)
  533. expect_equal(FALSE, m2$daily.seasonality)
  534. expect_equal(m1$yearly.seasonality, 'yearly' %in% names(m2$seasonalities))
  535. expect_equal(m1$weekly.seasonality, 'weekly' %in% names(m2$seasonalities))
  536. expect_equal(m1$daily.seasonality, 'daily' %in% names(m2$seasonalities))
  537. }
  538. # Check for cutoff and custom seasonality and extra regressors
  539. changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
  540. cutoff <- as.Date('2012-07-25')
  541. m1 <- prophet(changepoints = changepoints)
  542. m1 <- add_seasonality(m1, 'custom', 10, 5)
  543. m1 <- add_regressor(m1, 'binary_feature')
  544. m1 <- fit.prophet(m1, df)
  545. m2 <- prophet:::prophet_copy(m1, cutoff)
  546. changepoints <- changepoints[changepoints <= cutoff]
  547. expect_equal(prophet:::set_date(changepoints), m2$changepoints)
  548. expect_true('custom' %in% names(m2$seasonalities))
  549. expect_true('binary_feature' %in% names(m2$extra_regressors))
  550. })