Parcourir la source

Loop debug (#168)

* Load Stan models on package load and keep in environment

* Load models on package import
Ben Letham il y a 8 ans
Parent
commit
5971a2369b

+ 6 - 3
R/R/prophet.R

@@ -539,11 +539,15 @@ fit.prophet <- function(m, df, ...) {
   # Run stan
   if (m$growth == 'linear') {
     kinit <- linear_growth_init(history)
-    model <- get_prophet_stan_model('linear')
   } else {
     dat$cap <- history$cap_scaled  # Add capacities to the Stan data
     kinit <- logistic_growth_init(history)
-    model <- get_prophet_stan_model('logistic')
+  }
+
+  if (exists(".prophet.stan.models")) {
+    model <- .prophet.stan.models[[m$growth]]
+  } else {
+    model <- get_prophet_stan_model(m$growth)
   }
 
   stan_init <- function() {
@@ -592,7 +596,6 @@ fit.prophet <- function(m, df, ...) {
     m$params$k <- m$params$k + m$params$delta[, 1]
     m$params$delta <- matrix(rep(0, length(m$params$delta)), nrow = n.iteration)
   }
-  gc() ## This is hack.  We don't know why it works but it solves issue #93.
   return(m)
 }
 

+ 14 - 0
R/R/zzz.R

@@ -0,0 +1,14 @@
+## Copyright (c) 2017-present, Facebook, Inc.
+## All rights reserved.
+
+## This source code is licensed under the BSD-style license found in the
+## LICENSE file in the root directory of this source tree. An additional grant
+## of patent rights can be found in the PATENTS file in the same directory.
+
+.onLoad <- function(libname, pkgname) {
+  .prophet.stan.models <- list(
+    "linear"=get_prophet_stan_model("linear"),
+    "logistic"=get_prophet_stan_model("logistic"))
+  assign(".prophet.stan.models", .prophet.stan.models,
+         envir=parent.env(environment()))
+}

+ 0 - 5
R/tests/testthat/test_prophet.R

@@ -7,11 +7,6 @@ N <- nrow(DATA)
 train <- DATA[1:floor(N / 2), ]
 future <- DATA[(ceiling(N/2) + 1):N, ]
 
-test_that("load_models", {
-  expect_error(prophet:::get_prophet_stan_model('linear'), NA)
-  expect_error(prophet:::get_prophet_stan_model('logistic'), NA)
-})
-
 test_that("fit_predict", {
   skip_if_not(Sys.getenv('R_ARCH') != '/i386')
   m <- prophet(train)

+ 3 - 32
python/fbprophet/forecaster.py

@@ -12,7 +12,6 @@ from __future__ import unicode_literals
 
 from collections import defaultdict
 from datetime import timedelta
-import pickle
 
 from matplotlib import pyplot as plt
 from matplotlib.dates import MonthLocator, num2date
@@ -21,9 +20,7 @@ from matplotlib.ticker import FuncFormatter
 import numpy as np
 import pandas as pd
 
-# fb-block 1 start
-import pkg_resources
-# fb-block 1 end
+from fbprophet.models import prophet_stan_models
 
 try:
     import pystan
@@ -151,32 +148,6 @@ class Prophet(object):
                          'trend']:
                     raise ValueError('Holiday name {} reserved.'.format(h))
 
-    @classmethod
-    def get_linear_model(cls):
-        """Load compiled linear trend Stan model"""
-        # fb-block 3
-        # fb-block 4 start
-        model_file = pkg_resources.resource_filename(
-            'fbprophet',
-            'stan_models/linear_growth.pkl'
-        )
-        # fb-block 4 end
-        with open(model_file, 'rb') as f:
-            return pickle.load(f)
-
-    @classmethod
-    def get_logistic_model(cls):
-        """Load compiled logistic trend Stan model"""
-        # fb-block 5
-        # fb-block 6 start
-        model_file = pkg_resources.resource_filename(
-            'fbprophet',
-            'stan_models/logistic_growth.pkl'
-        )
-        # fb-block 6 end
-        with open(model_file, 'rb') as f:
-            return pickle.load(f)
-
     def setup_dataframe(self, df, initialize_scales=False):
         """Prepare dataframe for fitting or predicting.
 
@@ -531,11 +502,11 @@ class Prophet(object):
 
         if self.growth == 'linear':
             kinit = self.linear_growth_init(history)
-            model = self.get_linear_model()
         else:
             dat['cap'] = history['cap_scaled']
             kinit = self.logistic_growth_init(history)
-            model = self.get_logistic_model()
+
+        model = prophet_stan_models[self.growth]
 
         def stan_init():
             return {

+ 34 - 0
python/fbprophet/models.py

@@ -0,0 +1,34 @@
+# Copyright (c) 2017-present, Facebook, Inc.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree. An additional grant 
+# of patent rights can be found in the PATENTS file in the same directory.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from __future__ import unicode_literals
+
+import pickle
+
+# fb-block 1 start
+import pkg_resources
+# fb-block 1 end
+
+def get_prophet_stan_model(model):
+    """Load compiled Stan model"""
+    # fb-block 3
+    # fb-block 4 start
+    model_file = pkg_resources.resource_filename(
+        'fbprophet',
+        'stan_models/{}_growth.pkl'.format(model),
+    )
+    # fb-block 4 end
+    with open(model_file, 'rb') as f:
+        return pickle.load(f)
+
+prophet_stan_models = {
+    'linear': get_prophet_stan_model('linear'),
+    'logistic': get_prophet_stan_model('logistic'),
+}

+ 0 - 4
python/fbprophet/tests/test_prophet.py

@@ -25,10 +25,6 @@ from fbprophet import Prophet
 # fb-block 2
 
 class TestProphet(TestCase):
-    def test_load_models(self):
-        forecaster = Prophet()
-        forecaster.get_linear_model()
-        forecaster.get_logistic_model()
 
     def test_fit_predict(self):
         N = DATA.shape[0]