Jelajahi Sumber

Add cross-validation functions in R

bletham 8 tahun lalu
induk
melakukan
3c09448018

+ 2 - 0
R/NAMESPACE

@@ -3,10 +3,12 @@
 S3method(plot,prophet)
 S3method(plot,prophet)
 S3method(predict,prophet)
 S3method(predict,prophet)
 export(add_seasonality)
 export(add_seasonality)
+export(cross_validation)
 export(fit.prophet)
 export(fit.prophet)
 export(make_future_dataframe)
 export(make_future_dataframe)
 export(predictive_samples)
 export(predictive_samples)
 export(prophet)
 export(prophet)
 export(prophet_plot_components)
 export(prophet_plot_components)
+export(simulated_historical_forecasts)
 import(Rcpp)
 import(Rcpp)
 importFrom(dplyr,"%>%")
 importFrom(dplyr,"%>%")

+ 132 - 0
R/R/diagnostics.R

@@ -0,0 +1,132 @@
+## Copyright (c) 2017-present, Facebook, Inc.
+## All rights reserved.
+
+## This source code is licensed under the BSD-style license found in the
+## LICENSE file in the root directory of this source tree. An additional grant
+## of patent rights can be found in the PATENTS file in the same directory.
+
+## Makes R CMD CHECK happy due to dplyr syntax below
+globalVariables(c(
+  "ds", "y", "cap", "yhat", "yhat_lower", "yhat_upper"))
+
+#' Generate cutoff dates
+#'
+#' @param df Dataframe with historical data
+#' @param horizon timediff forecast horizon
+#' @param k integer number of forecast points
+#' @param period timediff Simulated forecasts are done with this period.
+#'
+#' @return Array of datetimes
+#'
+#' @keywords internal
+generate_cutoffs <- function(df, horizon, k, period) {
+  # Last cutoff is (latest date in data) - (horizon).
+  cutoff <- max(df$ds) - horizon
+  if (cutoff < min(df$ds)) {
+    stop('Less data than horizon.')
+  }
+  tzone <- attr(cutoff, "tzone")  # Timezone is wiped by putting in array
+  result <- c(cutoff)
+  for (i in 2:k) {
+    cutoff <- cutoff - period
+    # If data does not exist in data range (cutoff, cutoff + horizon]
+    if (!any((df$ds > cutoff) & (df$ds <= cutoff + horizon))) {
+      # Next cutoff point is 'closest date before cutoff in data - horizon'
+      closest.date <- max(df$ds[df$ds <= cutoff])
+      cutoff <- closest.date - horizon
+    }
+    if (cutoff < min(df$ds)) {
+      warning('Not enough data for requested number of cutoffs! Using ', i)
+      break
+    }
+    result <- c(result, cutoff)
+  }
+  # Reset timezones
+  attr(result, "tzone") <- tzone
+  return(rev(result))
+}
+
+#' Simulated historical forecasts.
+#' Make forecasts from k historical cutoff dates, and compare forecast values
+#' to actual values.
+#'
+#' @param model Fitted Prophet model.
+#' @param horizon Integer size of the horizon
+#' @param units String unit of the horizon, e.g., "days", "secs".
+#' @param k integer number of forecast points
+#' @param period Integer amount of time between cutoff dates. Same units as
+#'  horizon. If not provided, will use 0.5 * horizon.
+#'
+#' @return A dataframe with the forecast, actual value, and cutoff date.
+#'
+#' @export
+simulated_historical_forecasts <- function(model, horizon, units, k,
+                                           period = NULL) {
+  df <- model$history
+  horizon <- as.difftime(horizon, units = units)
+  if (is.null(period)) {
+    period <- horizon / 2
+  } else {
+    period <- as.difftime(period, units = units)
+  }
+  cutoffs <- generate_cutoffs(df, horizon, k, period)
+  predicts <- data.frame()
+  for (i in 1:length(cutoffs)) {
+    cutoff <- cutoffs[i]
+    # Copy the model
+    m <- prophet_copy(model, cutoff)
+    # Train model
+    history.c <- dplyr::filter(df, ds <= cutoff)
+    m <- fit.prophet(m, history.c)
+    # Calculate yhat
+    df.predict <- dplyr::filter(df, ds > cutoff, ds <= cutoff + horizon)
+    if (m$growth == 'logistic') {
+      future <- dplyr::select(df.predict, ds, cap)
+    } else{
+      future <- dplyr::select(df.predict, ds)
+    }
+    yhat <- stats::predict(m, future)
+    # Merge yhat, y, and cutoff.
+    df.c <- dplyr::inner_join(df.predict, yhat, by = "ds")
+    df.c <- dplyr::select(df.c, ds, y, yhat, yhat_lower, yhat_upper)
+    df.c$cutoff <- cutoff
+    predicts <- rbind(predicts, df.c)
+  }
+  return(predicts)
+}
+
+#' Cross-validation for time series.
+#' Computes forecast error with cutoffs at the specified period. When the
+#' period is the time interval of the data, is the procedure described in
+#' https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
+#' a cutoff every "period" amount of time, going back to "initial".
+#'
+#' @param model Fitted Prophet model.
+#' @param horizon Integer size of the horizon
+#' @param units String unit of the horizon, e.g., "days", "secs".
+#' @param period Integer amount of time between cutoff dates. Same units as
+#'  horizon.
+#' @param initial Integer size of the first training period. If not provided,
+#'  3 * horizon is used. Same units as horizon.
+#'
+#' @return A dataframe with the forecast, actual value, and cutoff date.
+#'
+#' @export
+cross_validation <- function(model, horizon, units, period, initial = NULL) {
+  te <- max(model$history$ds)
+  ts <- min(model$history$ds)
+  if (is.null(initial)) {
+    initial <- 3 * horizon
+  }
+  horizon.dt <- as.difftime(horizon, units = units)
+  initial.dt <- as.difftime(initial, units = units)
+  period.dt <- as.difftime(period, units = units)
+  k <- ceiling(
+    as.double((te - horizon.dt) - (ts + initial.dt), units='secs') /
+    as.double(period.dt, units = 'secs')
+  )
+  if (k < 1) {
+    stop('Not enough data for specified horizon and initial.')
+  }
+  return(simulated_historical_forecasts(model, horizon, units, k, period))
+}

+ 40 - 0
R/R/prophet.R

@@ -109,6 +109,7 @@ prophet <- function(df = NULL,
     mcmc.samples = mcmc.samples,
     mcmc.samples = mcmc.samples,
     interval.width = interval.width,
     interval.width = interval.width,
     uncertainty.samples = uncertainty.samples,
     uncertainty.samples = uncertainty.samples,
+    specified.changepoints = !is.null(changepoints),
     start = NULL,  # This and following attributes are set during fitting
     start = NULL,  # This and following attributes are set during fitting
     y.scale = NULL,
     y.scale = NULL,
     t.scale = NULL,
     t.scale = NULL,
@@ -240,6 +241,7 @@ set_date <- function(ds = NULL, tz = "GMT") {
   } else {
   } else {
     ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
     ds <- as.POSIXct(ds, format = "%Y-%m-%d %H:%M:%S", tz = tz)
   }
   }
+  attr(ds, "tzone") <- tz
   return(ds)
   return(ds)
 }
 }
 
 
@@ -1411,4 +1413,42 @@ plot_seasonality <- function(m, name, uncertainty = TRUE) {
   return(gg.s)
   return(gg.s)
 }
 }
 
 
+#' Copy Prophet object.
+#'
+#' @param m Prophet model object.
+#' @param cutoff Date, possibly as string. Changepoints are only retained if
+#'  changepoints <= cutoff.
+#'
+#' @return An unfitted Prophet model object with the same parameters as the
+#'  input model.
+#'
+#' @keywords internal
+prophet_copy <- function(m, cutoff = NULL) {
+  if (m$specified.changepoints) {
+    changepoints <- m$changepoints
+    if (!is.null(cutoff)) {
+      cutoff <- set_date(cutoff)
+      changepoints <- changepoints[changepoints <= cutoff]
+    }
+  } else {
+    changepoints <- NULL
+  }
+  return(prophet(
+    growth = m$growth,
+    changepoints = changepoints,
+    n.changepoints = m$n.changepoints,
+    yearly.seasonality = m$yearly.seasonality,
+    weekly.seasonality = m$weekly.seasonality,
+    daily.seasonality = m$daily.seasonality,
+    holidays = m$holidays,
+    seasonality.prior.scale = m$seasonality.prior.scale,
+    changepoint.prior.scale = m$changepoint.prior.scale,
+    holidays.prior.scale = m$holidays.prior.scale,
+    mcmc.samples = m$mcmc.samples,
+    interval.width = m$interval.width,
+    uncertainty.samples = m$uncertainty.samples,
+    fit = FALSE,
+  ))
+}
+
 # fb-block 3
 # fb-block 3

+ 2 - 1
R/man/add_seasonality.Rd

@@ -21,5 +21,6 @@ The prophet model with the seasonality added.
 }
 }
 \description{
 \description{
 Increasing the number of Fourier components allows the seasonality to change
 Increasing the number of Fourier components allows the seasonality to change
-more quickly (at risk of overfitting).
+more quickly (at risk of overfitting). Default values for yearly and weekly
+seasonalities are 10 and 3 respectively.
 }
 }

+ 35 - 0
R/man/cross_validation.Rd

@@ -0,0 +1,35 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{cross_validation}
+\alias{cross_validation}
+\title{Cross-validation for time series.
+Computes forecast error with cutoffs at the specified period. When the
+period is the time interval of the data, is the procedure described in
+https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
+a cutoff every "period" amount of time, going back to "initial".}
+\usage{
+cross_validation(model, horizon, units, period, initial = NULL)
+}
+\arguments{
+\item{model}{Fitted Prophet model.}
+
+\item{horizon}{Integer size of the horizon}
+
+\item{units}{String unit of the horizon, e.g., "days", "secs".}
+
+\item{period}{Integer amount of time between cutoff dates. Same units as
+horizon.}
+
+\item{initial}{Integer size of the first training period. If not provided,
+3 * horizon is used. Same units as horizon.}
+}
+\value{
+A dataframe with the forecast, actual value, and cutoff date.
+}
+\description{
+Cross-validation for time series.
+Computes forecast error with cutoffs at the specified period. When the
+period is the time interval of the data, is the procedure described in
+https://robjhyndman.com/hyndsight/tscv/. Beginning from end-horizon, makes
+a cutoff every "period" amount of time, going back to "initial".
+}

+ 24 - 0
R/man/generate_cutoffs.Rd

@@ -0,0 +1,24 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{generate_cutoffs}
+\alias{generate_cutoffs}
+\title{Generate cutoff dates}
+\usage{
+generate_cutoffs(df, horizon, k, period)
+}
+\arguments{
+\item{df}{Dataframe with historical data}
+
+\item{horizon}{timediff forecast horizon}
+
+\item{k}{integer number of forecast points}
+
+\item{period}{timediff Simulated forecasts are done with this period.}
+}
+\value{
+Array of datetimes
+}
+\description{
+Generate cutoff dates
+}
+\keyword{internal}

+ 22 - 0
R/man/prophet_copy.Rd

@@ -0,0 +1,22 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{prophet_copy}
+\alias{prophet_copy}
+\title{Copy Prophet object.}
+\usage{
+prophet_copy(m, cutoff = NULL)
+}
+\arguments{
+\item{m}{Prophet model object.}
+
+\item{cutoff}{Date, possibly as string. Changepoints are only retained if
+changepoints <= cutoff.}
+}
+\value{
+An unfitted Prophet model object with the same parameters as the
+ input model.
+}
+\description{
+Copy Prophet object.
+}
+\keyword{internal}

+ 30 - 0
R/man/simulated_historical_forecasts.Rd

@@ -0,0 +1,30 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/diagnostics.R
+\name{simulated_historical_forecasts}
+\alias{simulated_historical_forecasts}
+\title{Simulated historical forecasts.
+Make forecasts from k historical cutoff dates, and compare forecast values
+to actual values.}
+\usage{
+simulated_historical_forecasts(model, horizon, units, k, period = NULL)
+}
+\arguments{
+\item{model}{Fitted Prophet model.}
+
+\item{horizon}{Integer size of the horizon}
+
+\item{units}{String unit of the horizon, e.g., "days", "secs".}
+
+\item{k}{integer number of forecast points}
+
+\item{period}{Integer amount of time between cutoff dates. Same units as
+horizon. If not provided, will use 0.5 * horizon.}
+}
+\value{
+A dataframe with the forecast, actual value, and cutoff date.
+}
+\description{
+Simulated historical forecasts.
+Make forecasts from k historical cutoff dates, and compare forecast values
+to actual values.
+}

+ 86 - 0
R/tests/testthat/test_diagnostics.R

@@ -0,0 +1,86 @@
+library(prophet)
+context("Prophet diagnostics tests")
+
+## Makes R CMD CHECK happy due to dplyr syntax below
+globalVariables(c("y", "yhat"))
+
+DATA <- head(read.csv('data.csv'), 100)
+DATA$ds <- as.Date(DATA$ds)
+
+test_that("simulated_historical_forecasts", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(DATA)
+  k <- 2
+  for (p in c(1, 10)) {
+    for (h in c(1, 3)) {
+      df.shf <- simulated_historical_forecasts(
+        m, horizon = h, units = 'days', k = k, period = p)
+      # All cutoff dates should be less than ds dates
+      expect_true(all(df.shf$cutoff < df.shf$ds))
+      # The unique size of output cutoff should be equal to 'k'
+      expect_equal(length(unique(df.shf$cutoff)), k)
+      expect_equal(max(df.shf$ds - df.shf$cutoff),
+                   as.difftime(h, units = 'days'))
+      dc <- diff(df.shf$cutoff)
+      dc <- min(dc[dc > 0])
+      expect_true(dc >= as.difftime(p, units = 'days'))
+      # Each y in df_shf and DATA with same ds should be equal
+      df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
+      expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
+    }
+  }
+})
+
+test_that("simulated_historical_forecasts_logistic", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  df <- DATA
+  df$cap <- 40
+  m <- prophet(df, growth='logistic')
+  df.shf <- simulated_historical_forecasts(
+    m, horizon = 3, units = 'days', k = 2, period = 3)
+  # All cutoff dates should be less than ds dates
+  expect_true(all(df.shf$cutoff < df.shf$ds))
+  # The unique size of output cutoff should be equal to 'k'
+  expect_equal(length(unique(df.shf$cutoff)), 2)
+  # Each y in df_shf and DATA with same ds should be equal
+  df.merged <- dplyr::left_join(df.shf, m$history, by="ds")
+  expect_equal(sum((df.merged$y.x - df.merged$y.y) ** 2), 0)
+})
+
+test_that("simulated_historical_forecasts_default_value_check", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(DATA)
+  df.shf1 <- simulated_historical_forecasts(
+    m, horizon = 10, units = 'days', k = 1)
+  df.shf2 <- simulated_historical_forecasts(
+    m, horizon = 10, units = 'days', k = 1, period = 5)
+  expect_equal(sum(dplyr::select(df.shf1 - df.shf2, y, yhat)), 0)
+})
+
+test_that("cross_validation", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(DATA)
+  # Calculate the number of cutoff points
+  te <- max(DATA$ds)
+  ts <- min(DATA$ds)
+  horizon <- as.difftime(4, units = "days")
+  period <- as.difftime(10, units = "days")
+  k <- 5
+  df.cv <- cross_validation(
+    m, horizon = 4, units = "days", period = 10, initial = 90)
+  expect_equal(length(unique(df.cv$cutoff)), k)
+  expect_equal(max(df.cv$ds - df.cv$cutoff), horizon)
+  dc <- diff(df.cv$cutoff)
+  dc <- min(dc[dc > 0])
+  expect_true(dc >= period)
+})
+
+test_that("cross_validation_default_value_check", {
+  skip_if_not(Sys.getenv('R_ARCH') != '/i386')
+  m <- prophet(DATA)
+  df.cv1 <- cross_validation(
+    m, horizon = 32, units = "days", period = 10)
+  df.cv2 <- cross_validation(
+    m, horizon = 32, units = 'days', period = 10, initial = 96)
+  expect_equal(sum(dplyr::select(df.cv1 - df.cv2, y, yhat)), 0)
+})

+ 54 - 0
R/tests/testthat/test_prophet.R

@@ -330,3 +330,57 @@ test_that("custom_seasonality", {
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
   m <- add_seasonality(m, name='monthly', period=30, fourier.order=5)
   expect_equal(m$seasonalities[['monthly']], c(30, 5))
   expect_equal(m$seasonalities[['monthly']], c(30, 5))
 })
 })
+
+test_that("copy", {
+  inputs <- list(
+    growth = c('linear', 'logistic'),
+    changepoints = c(NULL, c('2016-12-25')),
+    n.changepoints = c(3),
+    yearly.seasonality = c(TRUE, FALSE),
+    weekly.seasonality = c(TRUE, FALSE),
+    daily.seasonality = c(TRUE, FALSE),
+    holidays = c(NULL, 'insert_dataframe'),
+    seasonality.prior.scale = c(1.1),
+    holidays.prior.scale = c(1.1),
+    changepoints.prior.scale = c(0.1),
+    mcmc.samples = c(100),
+    interval.width = c(0.9),
+    uncertainty.samples = c(200)
+  )
+  products <- expand.grid(inputs)
+  for (i in 1:length(products)) {
+    if (products$holidays[i] == 'insert_dataframe') {
+      holidays <- data.frame(ds=c('2016-12-25'), holiday=c('x'))
+    } else {
+      holidays <- NULL
+    }
+    m1 <- prophet(
+      growth = products$growth[i],
+      changepoints = products$changepoints[i],
+      n.changepoints = products$n.changepoints[i],
+      yearly.seasonality = products$yearly.seasonality[i],
+      weekly.seasonality = products$weekly.seasonality[i],
+      daily.seasonality = products$daily.seasonality[i],
+      holidays = holidays,
+      seasonality.prior.scale = products$seasonality.prior.scale[i],
+      holidays.prior.scale = products$holidays.prior.scale[i],
+      changepoints.prior.scale = products$changepoints.prior.scale[i],
+      mcmc.samples = products$mcmc.samples[i],
+      interval.width = products$interval.width[i],
+      uncertainty.samples = products$uncertainty.samples[i],
+      fit = FALSE
+    )
+    m2 <- prophet:::prophet_copy(m1)
+    # Values should be copied correctly
+    for (arg in names(inputs)) {
+      expect_equal(m1[[arg]], m2[[arg]])
+    }
+  }
+  # Check for cutoff
+  changepoints <- seq.Date(as.Date('2012-06-15'), as.Date('2012-09-15'), by='d')
+  cutoff <- as.Date('2012-07-25')
+  m1 <- prophet(DATA, changepoints = changepoints)
+  m2 <- prophet:::prophet_copy(m1, cutoff)
+  changepoints <- changepoints[changepoints <= cutoff]
+  expect_equal(prophet:::set_date(changepoints), m2$changepoints)
+})

+ 6 - 2
python/fbprophet/diagnostics.py

@@ -38,6 +38,8 @@ def _cutoffs(df, horizon, k, period):
     """
     """
     # Last cutoff is 'latest date in data - horizon' date
     # Last cutoff is 'latest date in data - horizon' date
     cutoff = df['ds'].max() - horizon
     cutoff = df['ds'].max() - horizon
+    if cutoff < df['ds'].min():
+        raise ValueError('Less data than horizon.')
     result = [cutoff]
     result = [cutoff]
 
 
     for i in range(1, k):
     for i in range(1, k):
@@ -48,7 +50,7 @@ def _cutoffs(df, horizon, k, period):
             closest_date = df[df['ds'] <= cutoff].max()['ds']
             closest_date = df[df['ds'] <= cutoff].max()['ds']
             cutoff = closest_date - horizon
             cutoff = closest_date - horizon
         if cutoff < df['ds'].min():
         if cutoff < df['ds'].min():
-            logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(k))
+            logger.warning('Not enough data for requested number of cutoffs! Using {}.'.format(i))
             break
             break
         result.append(cutoff)
         result.append(cutoff)
 
 
@@ -127,5 +129,7 @@ def cross_validation(model, horizon, period, initial=None):
     horizon = pd.Timedelta(horizon)
     horizon = pd.Timedelta(horizon)
     period = pd.Timedelta(period)
     period = pd.Timedelta(period)
     initial = 3 * horizon if initial is None else pd.Timedelta(initial)
     initial = 3 * horizon if initial is None else pd.Timedelta(initial)
-    k = int(np.floor(((te - horizon) - (ts + initial)) / period))
+    k = int(np.ceil(((te - horizon) - (ts + initial)) / period))
+    if k < 1:
+        raise ValueError('Not enough data for specified horizon and initial.')
     return simulated_historical_forecasts(model, horizon, k, period)
     return simulated_historical_forecasts(model, horizon, k, period)

+ 11 - 8
python/fbprophet/forecaster.py

@@ -100,8 +100,10 @@ class Prophet(object):
         self.changepoints = pd.to_datetime(changepoints)
         self.changepoints = pd.to_datetime(changepoints)
         if self.changepoints is not None:
         if self.changepoints is not None:
             self.n_changepoints = len(self.changepoints)
             self.n_changepoints = len(self.changepoints)
+            self.specified_changepoints = True
         else:
         else:
             self.n_changepoints = n_changepoints
             self.n_changepoints = n_changepoints
+            self.specified_changepoints = False
 
 
         self.yearly_seasonality = yearly_seasonality
         self.yearly_seasonality = yearly_seasonality
         self.weekly_seasonality = weekly_seasonality
         self.weekly_seasonality = weekly_seasonality
@@ -1420,21 +1422,24 @@ class Prophet(object):
         ----------
         ----------
         cutoff: pd.Timestamp or None, default None.
         cutoff: pd.Timestamp or None, default None.
             cuttoff Timestamp for changepoints member variable.
             cuttoff Timestamp for changepoints member variable.
-            changepoints are only remained if 'changepoints <= cutoff'
+            changepoints are only retained if 'changepoints <= cutoff'
 
 
         Returns
         Returns
         -------
         -------
         Prophet class object with the same parameter with model variable
         Prophet class object with the same parameter with model variable
         """
         """
-        if self.changepoints is not None and cutoff is not None:
-            # Filter change points '<= cutoff'
-            self.changepoints = self.changepoints[self.changepoints <= cutoff]
-            self.n_changepoints = len(self.changepoints)
+        if self.specified_changepoints:
+            changepoints = self.changepoints
+            if cutoff is not None:
+                # Filter change points '<= cutoff'
+                changepoints = changepoints[changepoints <= cutoff]
+        else:
+            changepoints = None
 
 
         return Prophet(
         return Prophet(
             growth=self.growth,
             growth=self.growth,
             n_changepoints=self.n_changepoints,
             n_changepoints=self.n_changepoints,
-            changepoints=self.changepoints,
+            changepoints=changepoints,
             yearly_seasonality=self.yearly_seasonality,
             yearly_seasonality=self.yearly_seasonality,
             weekly_seasonality=self.weekly_seasonality,
             weekly_seasonality=self.weekly_seasonality,
             daily_seasonality=self.daily_seasonality,
             daily_seasonality=self.daily_seasonality,
@@ -1446,5 +1451,3 @@ class Prophet(object):
             interval_width=self.interval_width,
             interval_width=self.interval_width,
             uncertainty_samples=self.uncertainty_samples
             uncertainty_samples=self.uncertainty_samples
         )
         )
-
-

+ 3 - 3
python/fbprophet/tests/test_diagnostics.py

@@ -77,9 +77,9 @@ class TestDiagnostics(TestCase):
         ts = self.__df['ds'].min()
         ts = self.__df['ds'].min()
         horizon = pd.Timedelta('4 days')
         horizon = pd.Timedelta('4 days')
         period = pd.Timedelta('10 days')
         period = pd.Timedelta('10 days')
-        initial = pd.Timedelta('90 days')
-        k = int(np.floor(((te - horizon) - (ts + initial)) / period))
-        df_cv = diagnostics.cross_validation(m, horizon=horizon, period=period, initial=initial)
+        k = 5
+        df_cv = diagnostics.cross_validation(
+            m, horizon='4 days', period='10 days', initial='90 days')
         # The unique size of output cutoff should be equal to 'k'
         # The unique size of output cutoff should be equal to 'k'
         self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
         self.assertEqual(len(np.unique(df_cv['cutoff'])), k)
         self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)
         self.assertEqual(max(df_cv['ds'] - df_cv['cutoff']), horizon)

+ 3 - 2
python/fbprophet/tests/test_prophet.py

@@ -490,9 +490,10 @@ class TestProphet(TestCase):
             self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
             self.assertEqual(m1.uncertainty_samples, m2.uncertainty_samples)
 
 
         # Check for cutoff
         # Check for cutoff
-        changepoints = pd.date_range('2016-12-15', '2017-01-15')
-        cutoff = pd.Timestamp('2016-12-25')
+        changepoints = pd.date_range('2012-06-15', '2012-09-15')
+        cutoff = pd.Timestamp('2012-07-25')
         m1 = Prophet(changepoints=changepoints)
         m1 = Prophet(changepoints=changepoints)
+        m1.fit(DATA)
         m2 = m1.copy(cutoff=cutoff)
         m2 = m1.copy(cutoff=cutoff)
         changepoints = changepoints[changepoints <= cutoff]
         changepoints = changepoints[changepoints <= cutoff]
         self.assertTrue((changepoints == m2.changepoints).all())
         self.assertTrue((changepoints == m2.changepoints).all())