Prechádzať zdrojové kódy

Move copy from method to function in diagnostics file

Ben Letham 7 rokov pred
rodič
commit
3afdaaf4e1

+ 47 - 0
R/R/diagnostics.R

@@ -144,3 +144,50 @@ cross_validation <- function(
   }
   return(simulated_historical_forecasts(model, horizon, units, k, period))
 }
+
+#' Copy Prophet object.
+#'
+#' @param m Prophet model object.
+#' @param cutoff Date, possibly as string. Changepoints are only retained if
+#'  changepoints <= cutoff.
+#'
+#' @return An unfitted Prophet model object with the same parameters as the
+#'  input model.
+#'
+#' @keywords internal
+prophet_copy <- function(m, cutoff = NULL) {
+  if (is.null(m$history)) {
+    stop("This is for copying a fitted Prophet object.")
+  }
+
+  if (m$specified.changepoints) {
+    changepoints <- m$changepoints
+    if (!is.null(cutoff)) {
+      cutoff <- set_date(cutoff)
+      changepoints <- changepoints[changepoints <= cutoff]
+    }
+  } else {
+    changepoints <- NULL
+  }
+  # Auto seasonalities are set to FALSE because they are already set in
+  # m$seasonalities.
+  m2 <- prophet(
+    growth = m$growth,
+    changepoints = changepoints,
+    n.changepoints = m$n.changepoints,
+    yearly.seasonality = FALSE,
+    weekly.seasonality = FALSE,
+    daily.seasonality = FALSE,
+    holidays = m$holidays,
+    seasonality.prior.scale = m$seasonality.prior.scale,
+    changepoint.prior.scale = m$changepoint.prior.scale,
+    holidays.prior.scale = m$holidays.prior.scale,
+    mcmc.samples = m$mcmc.samples,
+    interval.width = m$interval.width,
+    uncertainty.samples = m$uncertainty.samples,
+    fit = FALSE
+  )
+  m2$extra_regressors <- m$extra_regressors
+  m2$seasonalities <- m$seasonalities
+  return(m2)
+}

+ 0 - 47
R/R/prophet.R

@@ -1374,51 +1374,4 @@ make_future_dataframe <- function(m, periods, freq = 'day',
   return(data.frame(ds = dates))
 }
 
-#' Copy Prophet object.
-#'
-#' @param m Prophet model object.
-#' @param cutoff Date, possibly as string. Changepoints are only retained if
-#'  changepoints <= cutoff.
-#'
-#' @return An unfitted Prophet model object with the same parameters as the
-#'  input model.
-#'
-#' @keywords internal
-prophet_copy <- function(m, cutoff = NULL) {
-  if (is.null(m$history)) {
-    stop("This is for copying a fitted Prophet object.")
-  }
-
-  if (m$specified.changepoints) {
-    changepoints <- m$changepoints
-    if (!is.null(cutoff)) {
-      cutoff <- set_date(cutoff)
-      changepoints <- changepoints[changepoints <= cutoff]
-    }
-  } else {
-    changepoints <- NULL
-  }
-  # Auto seasonalities are set to FALSE because they are already set in
-  # m$seasonalities.
-  m2 <- prophet(
-    growth = m$growth,
-    changepoints = changepoints,
-    n.changepoints = m$n.changepoints,
-    yearly.seasonality = FALSE,
-    weekly.seasonality = FALSE,
-    daily.seasonality = FALSE,
-    holidays = m$holidays,
-    seasonality.prior.scale = m$seasonality.prior.scale,
-    changepoint.prior.scale = m$changepoint.prior.scale,
-    holidays.prior.scale = m$holidays.prior.scale,
-    mcmc.samples = m$mcmc.samples,
-    interval.width = m$interval.width,
-    uncertainty.samples = m$uncertainty.samples,
-    fit = FALSE
-  )
-  m2$extra_regressors <- m$extra_regressors
-  m2$seasonalities <- m$seasonalities
-  return(m2)
-}
-
 # fb-block 3

+ 0 - 0
R/R/utils.R


+ 1 - 1
R/man/prophet_copy.Rd

@@ -1,5 +1,5 @@
 % Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/prophet.R
+% Please edit documentation in R/diagnostics.R
 \name{prophet_copy}
 \alias{prophet_copy}
 \title{Copy Prophet object.}

+ 55 - 5
python/fbprophet/diagnostics.py

@@ -10,13 +10,15 @@ from __future__ import division
 from __future__ import print_function
 from __future__ import unicode_literals
 
+from copy import deepcopy
+from functools import reduce
 import logging
 
-logger = logging.getLogger(__name__)
-
 import numpy as np
 import pandas as pd
-from functools import reduce
+
+
+logger = logging.getLogger(__name__)
 
 
 def _cutoffs(df, horizon, k, period):
@@ -88,7 +90,7 @@ def simulated_historical_forecasts(model, horizon, k, period=None):
     predicts = []
     for cutoff in cutoffs:
         # Generate new object with copying fitting options
-        m = model.copy(cutoff)
+        m = prophet_copy(model, cutoff)
         # Train model
         m.fit(df[df['ds'] <= cutoff])
         # Calculate yhat
@@ -146,6 +148,54 @@ def cross_validation(model, horizon, period=None, initial=None):
             'Not enough data for specified horizon, period, and initial.')
     return simulated_historical_forecasts(model, horizon, k, period)
 
+
+def prophet_copy(m, cutoff=None):
+    """Copy Prophet object
+
+    Parameters
+    ----------
+    m: Prophet model.
+    cutoff: pd.Timestamp or None, default None.
+        cuttoff Timestamp for changepoints member variable.
+        changepoints are only retained if 'changepoints <= cutoff'
+
+    Returns
+    -------
+    Prophet class object with the same parameter with model variable
+    """
+    if m.history is None:
+        raise Exception('This is for copying a fitted Prophet object.')
+
+    if m.specified_changepoints:
+        changepoints = m.changepoints
+        if cutoff is not None:
+            # Filter change points '<= cutoff'
+            changepoints = changepoints[changepoints <= cutoff]
+    else:
+        changepoints = None
+
+    # Auto seasonalities are set to False because they are already set in
+    # m.seasonalities.
+    m2 = m.__class__(
+        growth=m.growth,
+        n_changepoints=m.n_changepoints,
+        changepoints=changepoints,
+        yearly_seasonality=False,
+        weekly_seasonality=False,
+        daily_seasonality=False,
+        holidays=m.holidays,
+        seasonality_prior_scale=m.seasonality_prior_scale,
+        changepoint_prior_scale=m.changepoint_prior_scale,
+        holidays_prior_scale=m.holidays_prior_scale,
+        mcmc_samples=m.mcmc_samples,
+        interval_width=m.interval_width,
+        uncertainty_samples=m.uncertainty_samples,
+    )
+    m2.extra_regressors = deepcopy(m.extra_regressors)
+    m2.seasonalities = deepcopy(m.seasonalities)
+    return m2
+
+
 def me(df):
     return((df['yhat'] - df['y']).sum()/len(df['yhat']))
 def mse(df):
@@ -209,4 +259,4 @@ def all_metrics(model, df_cv = None):
             'MAE': mae(df), 
             'MPE': mpe(df), 
             'MAPE': mape(df)
-            }
+            }

+ 6 - 43
python/fbprophet/forecaster.py

@@ -11,7 +11,6 @@ from __future__ import print_function
 from __future__ import unicode_literals
 
 from collections import defaultdict
-from copy import deepcopy
 from datetime import timedelta
 import logging
 import warnings
@@ -30,6 +29,7 @@ from fbprophet.plot import (
     plot_yearly,
     plot_seasonality,
 )
+from fbprophet.diagnostics import prophet_copy
 # fb-block 1 end
 
 logging.basicConfig()
@@ -1340,46 +1340,9 @@ class Prophet(object):
         )
 
     def copy(self, cutoff=None):
-        """Copy Prophet object
-
-        Parameters
-        ----------
-        cutoff: pd.Timestamp or None, default None.
-            cuttoff Timestamp for changepoints member variable.
-            changepoints are only retained if 'changepoints <= cutoff'
-
-        Returns
-        -------
-        Prophet class object with the same parameter with model variable
-        """
-        if self.history is None:
-            raise Exception('This is for copying a fitted Prophet object.')
-
-        if self.specified_changepoints:
-            changepoints = self.changepoints
-            if cutoff is not None:
-                # Filter change points '<= cutoff'
-                changepoints = changepoints[changepoints <= cutoff]
-        else:
-            changepoints = None
-
-        # Auto seasonalities are set to False because they are already set in
-        # self.seasonalities.
-        m = Prophet(
-            growth=self.growth,
-            n_changepoints=self.n_changepoints,
-            changepoints=changepoints,
-            yearly_seasonality=False,
-            weekly_seasonality=False,
-            daily_seasonality=False,
-            holidays=self.holidays,
-            seasonality_prior_scale=self.seasonality_prior_scale,
-            changepoint_prior_scale=self.changepoint_prior_scale,
-            holidays_prior_scale=self.holidays_prior_scale,
-            mcmc_samples=self.mcmc_samples,
-            interval_width=self.interval_width,
-            uncertainty_samples=self.uncertainty_samples,
+        warnings.warn(
+            'This method will be removed in the next version. '
+            'Please use fbprophet.diagnostics.prophet_copy. ',
+            DeprecationWarning,
         )
-        m.extra_regressors = deepcopy(self.extra_regressors)
-        m.seasonalities = deepcopy(self.seasonalities)
-        return m
+        return prophet_copy(m=self, cutoff=cutoff)