|
@@ -12,7 +12,6 @@ from __future__ import unicode_literals
|
|
|
|
|
|
from collections import defaultdict
|
|
|
from datetime import timedelta
|
|
|
-import pickle
|
|
|
|
|
|
from matplotlib import pyplot as plt
|
|
|
from matplotlib.dates import MonthLocator, num2date
|
|
@@ -21,9 +20,7 @@ from matplotlib.ticker import FuncFormatter
|
|
|
import numpy as np
|
|
|
import pandas as pd
|
|
|
|
|
|
-# fb-block 1 start
|
|
|
-import pkg_resources
|
|
|
-# fb-block 1 end
|
|
|
+from fbprophet.models import prophet_stan_models
|
|
|
|
|
|
try:
|
|
|
import pystan
|
|
@@ -151,32 +148,6 @@ class Prophet(object):
|
|
|
'trend']:
|
|
|
raise ValueError('Holiday name {} reserved.'.format(h))
|
|
|
|
|
|
- @classmethod
|
|
|
- def get_linear_model(cls):
|
|
|
- """Load compiled linear trend Stan model"""
|
|
|
- # fb-block 3
|
|
|
- # fb-block 4 start
|
|
|
- model_file = pkg_resources.resource_filename(
|
|
|
- 'fbprophet',
|
|
|
- 'stan_models/linear_growth.pkl'
|
|
|
- )
|
|
|
- # fb-block 4 end
|
|
|
- with open(model_file, 'rb') as f:
|
|
|
- return pickle.load(f)
|
|
|
-
|
|
|
- @classmethod
|
|
|
- def get_logistic_model(cls):
|
|
|
- """Load compiled logistic trend Stan model"""
|
|
|
- # fb-block 5
|
|
|
- # fb-block 6 start
|
|
|
- model_file = pkg_resources.resource_filename(
|
|
|
- 'fbprophet',
|
|
|
- 'stan_models/logistic_growth.pkl'
|
|
|
- )
|
|
|
- # fb-block 6 end
|
|
|
- with open(model_file, 'rb') as f:
|
|
|
- return pickle.load(f)
|
|
|
-
|
|
|
def setup_dataframe(self, df, initialize_scales=False):
|
|
|
"""Prepare dataframe for fitting or predicting.
|
|
|
|
|
@@ -531,11 +502,11 @@ class Prophet(object):
|
|
|
|
|
|
if self.growth == 'linear':
|
|
|
kinit = self.linear_growth_init(history)
|
|
|
- model = self.get_linear_model()
|
|
|
else:
|
|
|
dat['cap'] = history['cap_scaled']
|
|
|
kinit = self.logistic_growth_init(history)
|
|
|
- model = self.get_logistic_model()
|
|
|
+
|
|
|
+ model = prophet_stan_models[self.growth]
|
|
|
|
|
|
def stan_init():
|
|
|
return {
|