From bfda9fc2f98967bb857dd2b3268481183bc89d0c Mon Sep 17 00:00:00 2001 From: Felipe Angelim Date: Sun, 15 Dec 2024 20:45:51 -0300 Subject: [PATCH] Add chain effect --- src/prophetverse/effects/__init__.py | 4 + src/prophetverse/effects/chain.py | 158 +++++++++++++++++++++++++++ tests/effects/test_chain.py | 109 ++++++++++++++++++ 3 files changed, 271 insertions(+) create mode 100644 src/prophetverse/effects/chain.py create mode 100644 tests/effects/test_chain.py diff --git a/src/prophetverse/effects/__init__.py b/src/prophetverse/effects/__init__.py index 901192a..b6b3075 100644 --- a/src/prophetverse/effects/__init__.py +++ b/src/prophetverse/effects/__init__.py @@ -1,6 +1,8 @@ """Effects that define relationships between variables and the target.""" +from .adstock import GeometricAdstockEffect from .base import BaseEffect +from .chain import ChainedEffects from .exact_likelihood import ExactLikelihood from .fourier import LinearFourierSeasonality from .hill import HillEffect @@ -16,4 +18,6 @@ "ExactLikelihood", "LiftExperimentLikelihood", "LinearFourierSeasonality", + "GeometricAdstockEffect", + "ChainedEffects", ] diff --git a/src/prophetverse/effects/chain.py b/src/prophetverse/effects/chain.py new file mode 100644 index 0000000..6740db0 --- /dev/null +++ b/src/prophetverse/effects/chain.py @@ -0,0 +1,158 @@ +"""Definition of Chained Effects class.""" + +from typing import Any, Dict, List + +import jax.numpy as jnp +from numpyro import handlers +from skbase.base import BaseMetaEstimatorMixin + +from prophetverse.effects.base import BaseEffect + +__all__ = ["ChainedEffects"] + + +class ChainedEffects(BaseMetaEstimatorMixin, BaseEffect): + """ + Chains multiple effects sequentially, applying them one after the other. + + Parameters + ---------- + steps : List[BaseEffect] + A list of effects to be applied sequentially. + """ + + _tags = { + "supports_multivariate": True, + "skip_predict_if_no_match": True, + "filter_indexes_with_forecating_horizon_at_transform": True, + } + + def __init__(self, steps: List[BaseEffect]): + self.steps = steps + super().__init__() + + def _fit(self, y: Any, X: Any, scale: float = 1.0): + """ + Fit all chained effects sequentially. + + Parameters + ---------- + y : Any + Target data (e.g., time series values). + X : Any + Exogenous variables. + scale : float, optional + Scale of the timeseries. + """ + for effect in self.steps: + effect.fit(y, X, scale) + + def _transform(self, X: Any, fh: Any) -> Any: + """ + Transform input data sequentially through all chained effects. + + Parameters + ---------- + X : Any + Input data (e.g., exogenous variables). + fh : Any + Forecasting horizon. + + Returns + ------- + Any + Transformed data after applying all effects. + """ + output = X + output = self.steps[0].transform(output, fh) + return output + + def _sample_params( + self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray] + ) -> Dict[str, jnp.ndarray]: + """ + Sample parameters for all chained effects. + + Parameters + ---------- + data : jnp.ndarray + Data obtained from the transformed method. + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects. + + Returns + ------- + Dict[str, jnp.ndarray] + A dictionary containing the sampled parameters for all effects. + """ + params = {} + for idx, effect in enumerate(self.steps): + with handlers.scope(prefix=f"{idx}"): + effect_params = effect.sample_params(data, predicted_effects) + params[f"effect_{idx}"] = effect_params + return params + + def _predict( + self, + data: jnp.ndarray, + predicted_effects: Dict[str, jnp.ndarray], + params: Dict[str, Dict[str, jnp.ndarray]], + ) -> jnp.ndarray: + """ + Apply all chained effects sequentially. + + Parameters + ---------- + data : jnp.ndarray + Data obtained from the transformed method (shape: T, 1). + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects. + params : Dict[str, Dict[str, jnp.ndarray]] + A dictionary containing the sampled parameters for each effect. + + Returns + ------- + jnp.ndarray + The transformed data after applying all effects. + """ + output = data + for idx, effect in enumerate(self.steps): + effect_params = params[f"effect_{idx}"] + output = effect._predict(output, predicted_effects, effect_params) + return output + + def _coerce_to_named_object_tuples(self, objs, clone=False, make_unique=True): + """Coerce sequence of objects or named objects to list of (str, obj) tuples. + + Input that is sequence of objects, list of (str, obj) tuples or + dict[str, object] will be coerced to list of (str, obj) tuples on return. + + Parameters + ---------- + objs : list of objects, list of (str, object tuples) or dict[str, object] + The input should be coerced to list of (str, object) tuples. Should + be a sequence of objects, or follow named object API. + clone : bool, default=False. + Whether objects in the returned list of (str, object) tuples are + cloned (True) or references (False). + make_unique : bool, default=True + Whether the str names in the returned list of (str, object) tuples + should be coerced to unique str values (if str names in input + are already unique they will not be changed). + + Returns + ------- + list[tuple[str, BaseObject]] + List of tuples following named object API. + + - If `objs` was already a list of (str, object) tuples then either the + same named objects (as with other cases cloned versions are + returned if ``clone=True``). + - If `objs` was a dict[str, object] then the named objects are unpacked + into a list of (str, object) tuples. + - If `objs` was a list of objects then string names were generated based + on the object's class names (with coercion to unique strings if + necessary). + """ + objs = [(f"effect_{idx}", obj) for idx, obj in enumerate(objs)] + return super()._coerce_to_named_object_tuples(objs, clone, make_unique) diff --git a/tests/effects/test_chain.py b/tests/effects/test_chain.py new file mode 100644 index 0000000..16f194e --- /dev/null +++ b/tests/effects/test_chain.py @@ -0,0 +1,109 @@ +"""Pytest for Chained Effects class.""" + +import jax.numpy as jnp +import numpyro +import pandas as pd +import pytest +from numpyro import handlers + +from prophetverse.effects.base import BaseEffect +from prophetverse.effects.chain import ChainedEffects + + +class MockEffect(BaseEffect): + def __init__(self, value): + self.value = value + super().__init__() + + self._transform_called = False + + def _transform(self, X, fh): + self._transform_called = True + return super()._transform(X, fh) + + def _sample_params(self, data, predicted_effects): + return { + "param": numpyro.sample("param", numpyro.distributions.Delta(self.value)) + } + + def _predict(self, data, predicted_effects, params): + return data * params["param"] + + +@pytest.fixture +def index(): + return pd.date_range("2021-01-01", periods=6) + + +@pytest.fixture +def y(index): + return pd.DataFrame(index=index, data=[1] * len(index)) + + +@pytest.fixture +def X(index): + return pd.DataFrame( + data={"exog": [10, 20, 30, 40, 50, 60]}, + index=index, + ) + + +def test_chained_effects_fit(X, y): + """Test the fit method of ChainedEffects.""" + effects = [MockEffect(2), MockEffect(3)] + chained = ChainedEffects(steps=effects) + + scale = 1 + chained.fit(y=y, X=X, scale=scale) + # Ensure no exceptions occur in fit + + +def test_chained_effects_transform(X, y): + """Test the transform method of ChainedEffects.""" + effects = [MockEffect(2), MockEffect(3)] + chained = ChainedEffects(steps=effects) + transformed = chained.transform(X, fh=X.index) + expected = MockEffect(2).transform(X, fh=X.index) + assert jnp.allclose(transformed, expected), "Chained transform output mismatch." + + +def test_chained_effects_sample_params(X, y): + """Test the sample_params method of ChainedEffects.""" + effects = [MockEffect(2), MockEffect(3)] + chained = ChainedEffects(steps=effects) + chained.fit(y=y, X=X, scale=1) + data = chained.transform(X, fh=X.index) + + with handlers.trace() as trace: + params = chained.sample_params(data, {}) + + assert "effect_0" in params, "Missing effect_0 params." + assert "effect_1" in params, "Missing effect_1 params." + assert params["effect_0"]["param"] == 2, "Incorrect effect_0 param." + assert params["effect_1"]["param"] == 3, "Incorrect effect_1 param." + + assert "0/param" in trace, "Missing effect_0 trace." + assert "1/param" in trace, "Missing effect_1 trace." + + +def test_chained_effects_predict(X, y): + """Test the predict method of ChainedEffects.""" + effects = [MockEffect(2), MockEffect(3)] + chained = ChainedEffects(steps=effects) + chained.fit(y=y, X=X, scale=1) + data = chained.transform(X, fh=X.index) + predicted_effects = {} + + predicted = chained.predict(data, predicted_effects) + expected = data * 2 * 3 + assert jnp.allclose(predicted, expected), "Chained predict output mismatch." + + +def test_get_params(): + effects = [MockEffect(2), MockEffect(3)] + chained = ChainedEffects(steps=effects) + + params = chained.get_params() + + assert params["effect_0__value"] == 2, "Incorrect effect_0 param." + assert params["effect_1__value"] == 3, "Incorrect effect_1 param."