|
@@ -11,33 +11,54 @@ from pkg_resources import (
|
|
)
|
|
)
|
|
from setuptools import setup
|
|
from setuptools import setup
|
|
from setuptools.command.build_py import build_py
|
|
from setuptools.command.build_py import build_py
|
|
|
|
+from setuptools.command.develop import develop
|
|
from setuptools.command.test import test as test_command
|
|
from setuptools.command.test import test as test_command
|
|
|
|
|
|
|
|
+
|
|
|
|
+PLATFORM = 'unix'
|
|
|
|
+if platform.platform().startswith('Win'):
|
|
|
|
+ PLATFORM = 'win'
|
|
|
|
+
|
|
|
|
+SETUP_DIR = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
+MODELS_DIR = os.path.join(SETUP_DIR, 'stan', PLATFORM)
|
|
|
|
+MODELS_TARGET_DIR = os.path.join('fbprophet', 'stan_models')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def build_stan_models(target_dir, models_dir=MODELS_DIR):
|
|
|
|
+ from pystan import StanModel
|
|
|
|
+ for model_type in ['linear', 'logistic']:
|
|
|
|
+ model_name = 'prophet_{}_growth.stan'.format(model_type)
|
|
|
|
+ target_name = '{}_growth.pkl'.format(model_type)
|
|
|
|
+ with open(os.path.join(models_dir, model_name)) as f:
|
|
|
|
+ model_code = f.read()
|
|
|
|
+ sm = StanModel(model_code=model_code)
|
|
|
|
+ with open(os.path.join(target_dir, target_name), 'wb') as f:
|
|
|
|
+ pickle.dump(sm, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
+
|
|
|
|
+
|
|
class BuildPyCommand(build_py):
|
|
class BuildPyCommand(build_py):
|
|
"""Custom build command to pre-compile Stan models."""
|
|
"""Custom build command to pre-compile Stan models."""
|
|
|
|
|
|
def run(self):
|
|
def run(self):
|
|
if not self.dry_run:
|
|
if not self.dry_run:
|
|
- self.build_stan_models()
|
|
|
|
|
|
+ target_dir = os.path.join(self.build_lib, MODELS_TARGET_DIR)
|
|
|
|
+ self.mkpath(target_dir)
|
|
|
|
+ build_stan_models(target_dir)
|
|
|
|
|
|
build_py.run(self)
|
|
build_py.run(self)
|
|
|
|
|
|
- def build_stan_models(self):
|
|
|
|
- from pystan import StanModel
|
|
|
|
- target_dir = os.path.join(self.build_lib, 'fbprophet/stan_models')
|
|
|
|
- self.mkpath(target_dir)
|
|
|
|
|
|
|
|
- if platform.platform().startswith('Win'):
|
|
|
|
- plat = 'win'
|
|
|
|
- else:
|
|
|
|
- plat = 'unix'
|
|
|
|
|
|
+class DevelopCommand(develop):
|
|
|
|
+ """Custom develop command to pre-compile Stan models in-place."""
|
|
|
|
+
|
|
|
|
+ def run(self):
|
|
|
|
+ if not self.dry_run:
|
|
|
|
+ target_dir = os.path.join(self.setup_path, MODELS_TARGET_DIR)
|
|
|
|
+ self.mkpath(target_dir)
|
|
|
|
+ build_stan_models(target_dir)
|
|
|
|
+
|
|
|
|
+ develop.run(self)
|
|
|
|
|
|
- for model_type in ['linear', 'logistic']:
|
|
|
|
- with open('stan/{}/prophet_{}_growth.stan'.format(plat, model_type)) as f:
|
|
|
|
- model_code = f.read()
|
|
|
|
- sm = StanModel(model_code=model_code)
|
|
|
|
- with open(os.path.join(target_dir, '{}_growth.pkl'.format(model_type)), 'wb') as f:
|
|
|
|
- pickle.dump(sm, f, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
|
|
|
|
class TestCommand(test_command):
|
|
class TestCommand(test_command):
|
|
"""We must run tests on the build directory, not source."""
|
|
"""We must run tests on the build directory, not source."""
|
|
@@ -96,6 +117,7 @@ setup(
|
|
include_package_data=True,
|
|
include_package_data=True,
|
|
cmdclass={
|
|
cmdclass={
|
|
'build_py': BuildPyCommand,
|
|
'build_py': BuildPyCommand,
|
|
|
|
+ 'develop': DevelopCommand,
|
|
'test': TestCommand,
|
|
'test': TestCommand,
|
|
},
|
|
},
|
|
test_suite='fbprophet.tests.test_prophet',
|
|
test_suite='fbprophet.tests.test_prophet',
|