Skip to content

Commit

Permalink
Add chain effect
Browse files Browse the repository at this point in the history
  • Loading branch information
felipeangelimvieira committed Dec 15, 2024
1 parent b08152f commit bfda9fc
Show file tree
Hide file tree
Showing 3 changed files with 271 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/prophetverse/effects/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
"""Effects that define relationships between variables and the target."""

from .adstock import GeometricAdstockEffect
from .base import BaseEffect
from .chain import ChainedEffects
from .exact_likelihood import ExactLikelihood
from .fourier import LinearFourierSeasonality
from .hill import HillEffect
Expand All @@ -16,4 +18,6 @@
"ExactLikelihood",
"LiftExperimentLikelihood",
"LinearFourierSeasonality",
"GeometricAdstockEffect",
"ChainedEffects",
]
158 changes: 158 additions & 0 deletions src/prophetverse/effects/chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Definition of Chained Effects class."""

from typing import Any, Dict, List

import jax.numpy as jnp
from numpyro import handlers
from skbase.base import BaseMetaEstimatorMixin

from prophetverse.effects.base import BaseEffect

__all__ = ["ChainedEffects"]


class ChainedEffects(BaseMetaEstimatorMixin, BaseEffect):
"""
Chains multiple effects sequentially, applying them one after the other.
Parameters
----------
steps : List[BaseEffect]
A list of effects to be applied sequentially.
"""

_tags = {
"supports_multivariate": True,
"skip_predict_if_no_match": True,
"filter_indexes_with_forecating_horizon_at_transform": True,
}

def __init__(self, steps: List[BaseEffect]):
self.steps = steps
super().__init__()

def _fit(self, y: Any, X: Any, scale: float = 1.0):
"""
Fit all chained effects sequentially.
Parameters
----------
y : Any
Target data (e.g., time series values).
X : Any
Exogenous variables.
scale : float, optional
Scale of the timeseries.
"""
for effect in self.steps:
effect.fit(y, X, scale)

def _transform(self, X: Any, fh: Any) -> Any:
"""
Transform input data sequentially through all chained effects.
Parameters
----------
X : Any
Input data (e.g., exogenous variables).
fh : Any
Forecasting horizon.
Returns
-------
Any
Transformed data after applying all effects.
"""
output = X
output = self.steps[0].transform(output, fh)
return output

def _sample_params(
self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
"""
Sample parameters for all chained effects.
Parameters
----------
data : jnp.ndarray
Data obtained from the transformed method.
predicted_effects : Dict[str, jnp.ndarray]
A dictionary containing the predicted effects.
Returns
-------
Dict[str, jnp.ndarray]
A dictionary containing the sampled parameters for all effects.
"""
params = {}
for idx, effect in enumerate(self.steps):
with handlers.scope(prefix=f"{idx}"):
effect_params = effect.sample_params(data, predicted_effects)
params[f"effect_{idx}"] = effect_params
return params

def _predict(
self,
data: jnp.ndarray,
predicted_effects: Dict[str, jnp.ndarray],
params: Dict[str, Dict[str, jnp.ndarray]],
) -> jnp.ndarray:
"""
Apply all chained effects sequentially.
Parameters
----------
data : jnp.ndarray
Data obtained from the transformed method (shape: T, 1).
predicted_effects : Dict[str, jnp.ndarray]
A dictionary containing the predicted effects.
params : Dict[str, Dict[str, jnp.ndarray]]
A dictionary containing the sampled parameters for each effect.
Returns
-------
jnp.ndarray
The transformed data after applying all effects.
"""
output = data
for idx, effect in enumerate(self.steps):
effect_params = params[f"effect_{idx}"]
output = effect._predict(output, predicted_effects, effect_params)
return output

def _coerce_to_named_object_tuples(self, objs, clone=False, make_unique=True):
"""Coerce sequence of objects or named objects to list of (str, obj) tuples.
Input that is sequence of objects, list of (str, obj) tuples or
dict[str, object] will be coerced to list of (str, obj) tuples on return.
Parameters
----------
objs : list of objects, list of (str, object tuples) or dict[str, object]
The input should be coerced to list of (str, object) tuples. Should
be a sequence of objects, or follow named object API.
clone : bool, default=False.
Whether objects in the returned list of (str, object) tuples are
cloned (True) or references (False).
make_unique : bool, default=True
Whether the str names in the returned list of (str, object) tuples
should be coerced to unique str values (if str names in input
are already unique they will not be changed).
Returns
-------
list[tuple[str, BaseObject]]
List of tuples following named object API.
- If `objs` was already a list of (str, object) tuples then either the
same named objects (as with other cases cloned versions are
returned if ``clone=True``).
- If `objs` was a dict[str, object] then the named objects are unpacked
into a list of (str, object) tuples.
- If `objs` was a list of objects then string names were generated based
on the object's class names (with coercion to unique strings if
necessary).
"""
objs = [(f"effect_{idx}", obj) for idx, obj in enumerate(objs)]
return super()._coerce_to_named_object_tuples(objs, clone, make_unique)
109 changes: 109 additions & 0 deletions tests/effects/test_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""Pytest for Chained Effects class."""

import jax.numpy as jnp
import numpyro
import pandas as pd
import pytest
from numpyro import handlers

from prophetverse.effects.base import BaseEffect
from prophetverse.effects.chain import ChainedEffects


class MockEffect(BaseEffect):
def __init__(self, value):
self.value = value
super().__init__()

self._transform_called = False

def _transform(self, X, fh):
self._transform_called = True
return super()._transform(X, fh)

def _sample_params(self, data, predicted_effects):
return {
"param": numpyro.sample("param", numpyro.distributions.Delta(self.value))
}

def _predict(self, data, predicted_effects, params):
return data * params["param"]


@pytest.fixture
def index():
return pd.date_range("2021-01-01", periods=6)


@pytest.fixture
def y(index):
return pd.DataFrame(index=index, data=[1] * len(index))


@pytest.fixture
def X(index):
return pd.DataFrame(
data={"exog": [10, 20, 30, 40, 50, 60]},
index=index,
)


def test_chained_effects_fit(X, y):
"""Test the fit method of ChainedEffects."""
effects = [MockEffect(2), MockEffect(3)]
chained = ChainedEffects(steps=effects)

scale = 1
chained.fit(y=y, X=X, scale=scale)
# Ensure no exceptions occur in fit


def test_chained_effects_transform(X, y):
"""Test the transform method of ChainedEffects."""
effects = [MockEffect(2), MockEffect(3)]
chained = ChainedEffects(steps=effects)
transformed = chained.transform(X, fh=X.index)
expected = MockEffect(2).transform(X, fh=X.index)
assert jnp.allclose(transformed, expected), "Chained transform output mismatch."


def test_chained_effects_sample_params(X, y):
"""Test the sample_params method of ChainedEffects."""
effects = [MockEffect(2), MockEffect(3)]
chained = ChainedEffects(steps=effects)
chained.fit(y=y, X=X, scale=1)
data = chained.transform(X, fh=X.index)

with handlers.trace() as trace:
params = chained.sample_params(data, {})

assert "effect_0" in params, "Missing effect_0 params."
assert "effect_1" in params, "Missing effect_1 params."
assert params["effect_0"]["param"] == 2, "Incorrect effect_0 param."
assert params["effect_1"]["param"] == 3, "Incorrect effect_1 param."

assert "0/param" in trace, "Missing effect_0 trace."
assert "1/param" in trace, "Missing effect_1 trace."


def test_chained_effects_predict(X, y):
"""Test the predict method of ChainedEffects."""
effects = [MockEffect(2), MockEffect(3)]
chained = ChainedEffects(steps=effects)
chained.fit(y=y, X=X, scale=1)
data = chained.transform(X, fh=X.index)
predicted_effects = {}

predicted = chained.predict(data, predicted_effects)
expected = data * 2 * 3
assert jnp.allclose(predicted, expected), "Chained predict output mismatch."


def test_get_params():
effects = [MockEffect(2), MockEffect(3)]
chained = ChainedEffects(steps=effects)

params = chained.get_params()

assert params["effect_0__value"] == 2, "Incorrect effect_0 param."
assert params["effect_1__value"] == 3, "Incorrect effect_1 param."

0 comments on commit bfda9fc

Please sign in to comment.