Ver código fonte

Additional kwargs to Stan in Python

Ben Letham 8 anos atrás
pai
commit
e08cfd2176
1 arquivos alterados com 16 adições e 3 exclusões
  1. 16 3
      python/fbprophet/forecaster.py

+ 16 - 3
python/fbprophet/forecaster.py

@@ -330,7 +330,20 @@ class Prophet(object):
         return (k, m)
 
     # fb-block 7
-    def fit(self, df):
+    def fit(self, df, **kwargs):
+        """Fit the Prophet model to data.
+
+        Parameters
+        ----------
+        df: pd.DataFrame containing history. Must have columns 'ds', 'y', and
+            if logistic growth, 'cap'.
+        kwargs: Additional arguments passed to Stan's sampling or optimizing
+            function, as appropriate.
+
+        Returns
+        -------
+        The fitted Prophet object.
+        """
         history = df[df['y'].notnull()].copy()
         history.reset_index(inplace=True, drop=True)
 
@@ -377,14 +390,14 @@ class Prophet(object):
             stan_fit = model.sampling(
                 dat,
                 init=stan_init,
-                chains=1,
                 iter=self.mcmc_samples,
+                **kwargs
             )
             for par in stan_fit.model_pars:
                 self.params[par] = stan_fit[par]
 
         else:
-            params = model.optimizing(dat, init=stan_init, iter=1e4)
+            params = model.optimizing(dat, init=stan_init, iter=1e4, **kwargs)
             for par in params:
                 self.params[par] = params[par].reshape((1, -1))