فهرست منبع

make_future_dataframe return correct number of periods

Ben Letham 8 سال پیش
والد
کامیت
1d37f7f6fe
4فایلهای تغییر یافته به همراه40 افزوده شده و 4 حذف شده
  1. 2 1
      R/R/prophet.R
  2. 14 0
      R/tests/testthat/test_prophet.R
  3. 4 3
      python/fbprophet/forecaster.py
  4. 20 0
      python/fbprophet/tests/test_prophet.py

+ 2 - 1
R/R/prophet.R

@@ -824,7 +824,8 @@ sample_predictive_trend <- function(model, df, iteration) {
 #' @export
 make_future_dataframe <- function(m, periods, freq = 'd',
                                   include_history = TRUE) {
-  dates <- seq(max(m$history$ds), length.out = periods, by = freq)[2:periods]
+  dates <- seq(max(m$history$ds), length.out = periods + 1, by = freq)
+  dates <- dates[2:(periods + 1)]  # Drop the first, which is max(history$ds)
   if (include_history) {
     dates <- c(m$history$ds, dates)
   }

+ 14 - 0
R/tests/testthat/test_prophet.R

@@ -206,3 +206,17 @@ test_that("fit_with_holidays", {
   m <- prophet(DATA, holidays = holidays, uncertainty.samples = 0)
   expect_error(predict(m), NA)
 })
+
+test_that("make_future_dataframe", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(train)
+  future <- make_future_dataframe(m, periods = 3, freq = 'd',
+                                  include_history = FALSE)
+  correct <- as.Date(c('2013-04-26', '2013-04-27', '2013-04-28'))
+  expect_equal(future$ds, correct)
+
+  future <- make_future_dataframe(m, periods = 3, freq = 'm',
+                                  include_history = FALSE)
+  correct <- as.Date(c('2013-05-25', '2013-06-25', '2013-07-25'))
+  expect_equal(future$ds, correct)
+})

+ 4 - 3
python/fbprophet/forecaster.py

@@ -601,9 +601,10 @@ class Prophet(object):
         last_date = self.history['ds'].max()
         dates = pd.date_range(
             start=last_date,
-            periods=periods + 1,  # closed='right' removes a period
-            freq=freq,
-            closed='right')  # omits the start date
+            periods=periods + 1,  # An extra in case we include start
+            freq=freq)
+        dates = dates[dates > last_date]  # Drop start if equals last_date
+        dates = dates[:periods]  # Return correct number of periods
 
         if include_history:
             dates = np.concatenate((np.array(self.history['ds']), dates))

+ 20 - 0
python/fbprophet/tests/test_prophet.py

@@ -236,6 +236,26 @@ class TestProphet(TestCase):
         model = Prophet(holidays=holidays, uncertainty_samples=0)
         model.fit(DATA).predict()
 
+    def test_make_future_dataframe(self):
+        N = DATA.shape[0]
+        train = DATA.head(N // 2)
+        forecaster = Prophet()
+        forecaster.fit(train)
+        future = forecaster.make_future_dataframe(periods=3, freq='D',
+                                                  include_history=False)
+        correct = pd.DatetimeIndex(['2013-04-26', '2013-04-27', '2013-04-28'])
+        self.assertEqual(len(future), 3)
+        for i in range(3):
+            self.assertEqual(future.iloc[i]['ds'], correct[i])
+
+        future = forecaster.make_future_dataframe(periods=3, freq='M',
+                                                  include_history=False)
+        correct = pd.DatetimeIndex(['2013-04-30', '2013-05-31', '2013-06-30'])
+        self.assertEqual(len(future), 3)
+        for i in range(3):
+            self.assertEqual(future.iloc[i]['ds'], correct[i])
+
+
 DATA = pd.read_csv(StringIO("""
 ds,y
 2012-05-18,38.23