|
@@ -10,13 +10,15 @@ from __future__ import division
|
|
|
from __future__ import print_function
|
|
|
from __future__ import unicode_literals
|
|
|
|
|
|
+from copy import deepcopy
|
|
|
+from functools import reduce
|
|
|
import logging
|
|
|
|
|
|
-logger = logging.getLogger(__name__)
|
|
|
-
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
-from functools import reduce
|
|
|
+
|
|
|
+
|
|
|
+logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
def _cutoffs(df, horizon, k, period):
|
|
@@ -88,7 +90,7 @@ def simulated_historical_forecasts(model, horizon, k, period=None):
|
|
|
predicts = []
|
|
|
for cutoff in cutoffs:
|
|
|
# Generate new object with copying fitting options
|
|
|
- m = model.copy(cutoff)
|
|
|
+ m = prophet_copy(model, cutoff)
|
|
|
# Train model
|
|
|
m.fit(df[df['ds'] <= cutoff])
|
|
|
# Calculate yhat
|
|
@@ -146,6 +148,54 @@ def cross_validation(model, horizon, period=None, initial=None):
|
|
|
'Not enough data for specified horizon, period, and initial.')
|
|
|
return simulated_historical_forecasts(model, horizon, k, period)
|
|
|
|
|
|
+
|
|
|
+def prophet_copy(m, cutoff=None):
|
|
|
+ """Copy Prophet object
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ m: Prophet model.
|
|
|
+ cutoff: pd.Timestamp or None, default None.
|
|
|
+ cuttoff Timestamp for changepoints member variable.
|
|
|
+ changepoints are only retained if 'changepoints <= cutoff'
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ Prophet class object with the same parameter with model variable
|
|
|
+ """
|
|
|
+ if m.history is None:
|
|
|
+ raise Exception('This is for copying a fitted Prophet object.')
|
|
|
+
|
|
|
+ if m.specified_changepoints:
|
|
|
+ changepoints = m.changepoints
|
|
|
+ if cutoff is not None:
|
|
|
+ # Filter change points '<= cutoff'
|
|
|
+ changepoints = changepoints[changepoints <= cutoff]
|
|
|
+ else:
|
|
|
+ changepoints = None
|
|
|
+
|
|
|
+ # Auto seasonalities are set to False because they are already set in
|
|
|
+ # m.seasonalities.
|
|
|
+ m2 = m.__class__(
|
|
|
+ growth=m.growth,
|
|
|
+ n_changepoints=m.n_changepoints,
|
|
|
+ changepoints=changepoints,
|
|
|
+ yearly_seasonality=False,
|
|
|
+ weekly_seasonality=False,
|
|
|
+ daily_seasonality=False,
|
|
|
+ holidays=m.holidays,
|
|
|
+ seasonality_prior_scale=m.seasonality_prior_scale,
|
|
|
+ changepoint_prior_scale=m.changepoint_prior_scale,
|
|
|
+ holidays_prior_scale=m.holidays_prior_scale,
|
|
|
+ mcmc_samples=m.mcmc_samples,
|
|
|
+ interval_width=m.interval_width,
|
|
|
+ uncertainty_samples=m.uncertainty_samples,
|
|
|
+ )
|
|
|
+ m2.extra_regressors = deepcopy(m.extra_regressors)
|
|
|
+ m2.seasonalities = deepcopy(m.seasonalities)
|
|
|
+ return m2
|
|
|
+
|
|
|
+
|
|
|
def me(df):
|
|
|
return((df['yhat'] - df['y']).sum()/len(df['yhat']))
|
|
|
def mse(df):
|
|
@@ -209,4 +259,4 @@ def all_metrics(model, df_cv = None):
|
|
|
'MAE': mae(df),
|
|
|
'MPE': mpe(df),
|
|
|
'MAPE': mape(df)
|
|
|
- }
|
|
|
+ }
|