Ben Letham 7 лет назад
Родитель
Сommit
8066634cb4

+ 51 - 34
R/R/prophet.R

@@ -333,40 +333,7 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
   df <- df %>%
     dplyr::arrange(ds)
 
-  if (initialize_scales) {
-    if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
-      m$logistic.floor <- TRUE
-      floor <- df$floor
-    } else {
-      floor <- 0
-    }
-    m$y.scale <- max(abs(df$y - floor))
-    if (m$y.scale == 0) {
-      m$y.scale <- 1
-    }
-    m$start <- min(df$ds)
-    m$t.scale <- time_diff(max(df$ds), m$start, "secs")
-    for (name in names(m$extra_regressors)) {
-      standardize <- m$extra_regressors[[name]]$standardize
-      if (standardize == 'auto') {
-        if (all(sort(unique(df[[name]])) == c(0, 1))) {
-          # Don't standardize binary variables
-          standardize <- FALSE
-        } else {
-          standardize <- TRUE
-        }
-      }
-      if (standardize) {
-        mu <- mean(df[[name]])
-        std <- stats::sd(df[[name]])
-        if (std == 0) {
-          std <- mu
-        }
-        m$extra_regressors[[name]]$mu <- mu
-        m$extra_regressors[[name]]$std <- std
-      }
-    }
-  }
+  m <- initialize_scales_fn(m, initialize_scales, df)
 
   if (m$logistic.floor) {
     if (!('floor' %in% colnames(df))) {
@@ -400,6 +367,56 @@ setup_dataframe <- function(m, df, initialize_scales = FALSE) {
   return(list("m" = m, "df" = df))
 }
 
+#' Initialize model scales.
+#'
+#' Sets model scaling factors using df.
+#'
+#' @param m Prophet object.
+#' @param initialize_scales Boolean set the scales or not.
+#' @param df Dataframe for setting scales.
+#'
+#' @return Prophet object with scales set.
+#'
+#' @keywords internal
+initialize_scales_fn <- function(m, initialize_scales, df) {
+  if (!initialize_scales) {
+    return(m)
+  }
+  if ((m$growth == 'logistic') && ('floor' %in% colnames(df))) {
+    m$logistic.floor <- TRUE
+    floor <- df$floor
+  } else {
+    floor <- 0
+  }
+  m$y.scale <- max(abs(df$y - floor))
+  if (m$y.scale == 0) {
+    m$y.scale <- 1
+  }
+  m$start <- min(df$ds)
+  m$t.scale <- time_diff(max(df$ds), m$start, "secs")
+  for (name in names(m$extra_regressors)) {
+    standardize <- m$extra_regressors[[name]]$standardize
+    if (standardize == 'auto') {
+      if (all(sort(unique(df[[name]])) == c(0, 1))) {
+        # Don't standardize binary variables
+        standardize <- FALSE
+      } else {
+        standardize <- TRUE
+      }
+    }
+    if (standardize) {
+      mu <- mean(df[[name]])
+      std <- stats::sd(df[[name]])
+      if (std == 0) {
+        std <- mu
+      }
+      m$extra_regressors[[name]]$mu <- mu
+      m$extra_regressors[[name]]$std <- std
+    }
+  }
+  return(m)
+}
+
 #' Set changepoints
 #'
 #' Sets m$changepoints to the dates of changepoints. Either:

+ 22 - 0
R/man/initialize_scales_fn.Rd

@@ -0,0 +1,22 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/prophet.R
+\name{initialize_scales_fn}
+\alias{initialize_scales_fn}
+\title{Initialize model scales.}
+\usage{
+initialize_scales_fn(m, initialize_scales, df)
+}
+\arguments{
+\item{m}{Prophet object.}
+
+\item{initialize_scales}{Boolean set the scales or not.}
+
+\item{df}{Dataframe for setting scales.}
+}
+\value{
+Prophet object with scales set.
+}
+\description{
+Sets model scaling factors using df.
+}
+\keyword{internal}

+ 9 - 0
python/fbprophet/forecaster.py

@@ -255,6 +255,15 @@ class Prophet(object):
         return df
 
     def initialize_scales(self, initialize_scales, df):
+        """Initialize model scales.
+
+        Sets model scaling factors using df.
+
+        Parameters
+        ----------
+        initialize_scales: Boolean set the scales or not.
+        df: pd.DataFrame for setting scales.
+        """
         if not initialize_scales:
             return
         if self.growth == 'logistic' and 'floor' in df:

+ 5 - 1
python/fbprophet/tests/test_prophet.py

@@ -472,7 +472,11 @@ class TestProphet(TestCase):
         m.fit(DATA.copy())
         seasonal_features, prior_scales = m.make_all_seasonality_features(
             m.history)
-        self.assertEqual(prior_scales, [2.] * 10 + [10.] * 6 + [4.])
+        if seasonal_features.columns[0] == 'monthly_delim_1':
+            true = [2.] * 10 + [10.] * 6 + [4.]
+        else:
+            true = [10.] * 6 + [2.] * 10 + [4.]
+        self.assertEqual(prior_scales, true)
 
     def test_added_regressors(self):
         m = Prophet()