|
@@ -330,7 +330,20 @@ class Prophet(object):
|
|
|
return (k, m)
|
|
|
|
|
|
# fb-block 7
|
|
|
- def fit(self, df):
|
|
|
+ def fit(self, df, **kwargs):
|
|
|
+ """Fit the Prophet model to data.
|
|
|
+
|
|
|
+ Parameters
|
|
|
+ ----------
|
|
|
+ df: pd.DataFrame containing history. Must have columns 'ds', 'y', and
|
|
|
+ if logistic growth, 'cap'.
|
|
|
+ kwargs: Additional arguments passed to Stan's sampling or optimizing
|
|
|
+ function, as appropriate.
|
|
|
+
|
|
|
+ Returns
|
|
|
+ -------
|
|
|
+ The fitted Prophet object.
|
|
|
+ """
|
|
|
history = df[df['y'].notnull()].copy()
|
|
|
history.reset_index(inplace=True, drop=True)
|
|
|
|
|
@@ -377,14 +390,14 @@ class Prophet(object):
|
|
|
stan_fit = model.sampling(
|
|
|
dat,
|
|
|
init=stan_init,
|
|
|
- chains=1,
|
|
|
iter=self.mcmc_samples,
|
|
|
+ **kwargs
|
|
|
)
|
|
|
for par in stan_fit.model_pars:
|
|
|
self.params[par] = stan_fit[par]
|
|
|
|
|
|
else:
|
|
|
- params = model.optimizing(dat, init=stan_init, iter=1e4)
|
|
|
+ params = model.optimizing(dat, init=stan_init, iter=1e4, **kwargs)
|
|
|
for par in params:
|
|
|
self.params[par] = params[par].reshape((1, -1))
|
|
|
|