diff --git a/pyproject.toml b/pyproject.toml index 2e88d2d..6eb5192 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ numpyro = "<0.16" optax = ">=0.2" graphviz = "^0.20.3" scikit-base = "^0.11.0" +skpro = "^2.8.0" ipykernel = { version = ">=6.26.0,<7.0.0", optional = true } pytest = { version = ">=8.0.0,<9.0.0", optional = true } @@ -42,6 +43,7 @@ seaborn = {version = "^0.13.2", optional = true} statsmodels = {version = "^0.14.4", optional = true} + [tool.poetry.extras] dev = [ "ipykernel", diff --git a/src/prophetverse/experimental/simulate.py b/src/prophetverse/experimental/simulate.py new file mode 100644 index 0000000..89cf974 --- /dev/null +++ b/src/prophetverse/experimental/simulate.py @@ -0,0 +1,69 @@ +"""Simulate data from a model, and intervene optionally.""" + +from typing import Dict, Optional, Union + +import jax.numpy as jnp +import numpy as np +import numpyro +import pandas as pd +from jax.random import PRNGKey + +from prophetverse.sktime.base import BaseProphetForecaster + + +def simulate( + model: BaseProphetForecaster, + fh: pd.Index, + X: Optional[pd.DataFrame] = None, + y: Optional[pd.DataFrame] = None, + do: Optional[Dict[str, Union[jnp.ndarray, float]]] = None, + num_samples: int = 10, +): + """ + Simulate data from a model. + + **EXPERIMENTAL FEATURE** + This feature allow to do prior predictive checks and to intervene to + obtain simulated data. + + Parameters + ---------- + model : BaseProphetForecaster + The probabilistic model to perform inference on. + fh : pd.Index + The forecasting horizon as a pandas Index. + X : pd.DataFrame, optional + The input DataFrame containing the exogenous variables. + y : pd.DataFrame, optional + The timeseries dataframe. This is used by effects that implement `_fit` and + use the target timeseries to initialize some parameters. If not provided, + a dummy y will be created. + do : Dict, optional + A dictionary with the variables to intervene and their values. + num_samples : int, optional + The number of samples to generate. Defaults to 10. + + Returns + ------- + Dict + A dictionary with the simulated data. + """ + # Fit model, creating a dummy y if it is not provided + if y is None: + y = pd.DataFrame(index=fh, data=np.random.rand(len(fh)) * 10, columns=["dummy"]) + + model.fit(X=X, y=y) + + # Get predict data to call predictive model + predict_data = model._get_predict_data(X=X, fh=fh) + predict_data["y"] = None + from numpyro.infer import Predictive + + predictive_model = model.model + if do is not None: + predictive_model = numpyro.handlers.do(predictive_model, data=do) + + # predictive_model = model.model + predictive_model = Predictive(model=predictive_model, num_samples=num_samples) + predictive_output = predictive_model(PRNGKey(0), **predict_data) + return predictive_output diff --git a/src/prophetverse/models.py b/src/prophetverse/models.py index 2f7a302..17ac179 100644 --- a/src/prophetverse/models.py +++ b/src/prophetverse/models.py @@ -248,7 +248,9 @@ def _compute_mean_univariate( predicted_effects: Dict[str, jnp.ndarray] = {} - trend = trend_model(data=trend_data, predicted_effects=predicted_effects) + with numpyro.handlers.scope(prefix="trend"): + trend = trend_model(data=trend_data, predicted_effects=predicted_effects) + predicted_effects["trend"] = trend numpyro.deterministic("trend", trend) diff --git a/src/prophetverse/sktime/base.py b/src/prophetverse/sktime/base.py index 4180452..68e6e20 100644 --- a/src/prophetverse/sktime/base.py +++ b/src/prophetverse/sktime/base.py @@ -104,6 +104,11 @@ def __init__( self.inference_engine = inference_engine super().__init__() + if self.scale: + self._scale = self.scale + else: + self._scale = None + if self.inference_engine is not None: self._inference_engine = self.inference_engine else: diff --git a/src/prophetverse/sktime/multivariate.py b/src/prophetverse/sktime/multivariate.py index ba8b6ee..6ea33be 100644 --- a/src/prophetverse/sktime/multivariate.py +++ b/src/prophetverse/sktime/multivariate.py @@ -314,9 +314,7 @@ def _get_predict_data(self, X: pd.DataFrame, fh: ForecastingHorizon) -> np.ndarr np.ndarray Predicted samples. """ - fh_dates = fh.to_absolute( - cutoff=self.internal_y_indexes_.get_level_values(-1).max() - ) + fh_dates = self.fh_to_index(fh) fh_as_index = pd.Index(list(fh_dates.to_numpy())) if not isinstance(fh, ForecastingHorizon): diff --git a/tests/experimental/test_simulate.py b/tests/experimental/test_simulate.py new file mode 100644 index 0000000..3ba351c --- /dev/null +++ b/tests/experimental/test_simulate.py @@ -0,0 +1,30 @@ +import jax.numpy as jnp +import pandas as pd +import pytest + +from prophetverse.experimental.simulate import simulate +from prophetverse.sktime import HierarchicalProphet, Prophetverse + + +@pytest.mark.parametrize( + "model,do", + [ + (HierarchicalProphet(), {"exogenous_variables_effect/coefs": jnp.array([1])}), + (Prophetverse(), {"exogenous_variables_effect/coefs": jnp.array([1])}), + ], +) +def test_simulate(model, do): + + fh = pd.period_range(start="2022-01-01", periods=50, freq="M") + X = pd.DataFrame(index=fh, data={"x1": list(range(len(fh)))}) + num_samples = 10 + samples = simulate(model=model, fh=fh, X=X, do=do, num_samples=num_samples) + assert isinstance(samples, dict) + assert samples["obs"].shape[0] == num_samples + assert samples["obs"].shape[1] == len(fh) + + expected_intervention = jnp.arange(len(fh)).reshape((-1, 1)) + assert jnp.all( + samples["exogenous_variables_effect"] + == jnp.tile(expected_intervention, (num_samples, 1, 1)) + )