Skip to content

Commit

Permalink
Add geometric adstock
Browse files Browse the repository at this point in the history
  • Loading branch information
felipeangelimvieira committed Dec 15, 2024
1 parent a4aa653 commit b08152f
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
132 changes: 132 additions & 0 deletions src/prophetverse/effects/adstock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
"""Definition of Geometric Adstock Effect class."""

from typing import Dict

import jax
import jax.numpy as jnp
import numpyro
from numpyro import distributions as dist

from prophetverse.effects.base import BaseEffect

__all__ = ["GeometricAdstockEffect"]


class GeometricAdstockEffect(BaseEffect):
"""Represents a Geometric Adstock effect in a time series model.
Parameters
----------
decay_prior : Distribution, optional
Prior distribution for the decay parameter (controls the rate of decay).
rase_error_if_fh_changes : bool, optional
Whether to raise an error if the forecasting horizon changes during predict
"""

_tags = {
"supports_multivariate": False,
"skip_predict_if_no_match": True,
"filter_indexes_with_forecating_horizon_at_transform": True,
}

def __init__(
self,
decay_prior: dist.Distribution = None,
raise_error_if_fh_changes: bool = True,
):
self.decay_prior = decay_prior or dist.Beta(
2, 2
) # Default Beta distribution for decay rate.
self.raise_errror_if_fh_changes = raise_error_if_fh_changes
super().__init__()

self._min_date = None

def _transform(self, X, fh):
"""Transform the dataframe and horizon to array.
Parameters
----------
X : pd.DataFrame
dataframe with exogenous variables
fh : pd.Index
Forecast horizon
Returns
-------
jnp.ndarray
the array with data for _predict
Raises
------
ValueError
If the forecasting horizon is different during predict and fit.
"""
if self._min_date is None:
self._min_date = X.index.min()
else:
if self._min_date != X.index.min() and self.raise_errror_if_fh_changes:
raise ValueError(
"The X dataframe and forecat horizon"
"must be start at the same"
"date as the previous one"
)
return super()._transform(X, fh)

def _sample_params(
self, data: jnp.ndarray, predicted_effects: Dict[str, jnp.ndarray]
) -> Dict[str, jnp.ndarray]:
"""
Sample the parameters of the effect.
Parameters
----------
data : jnp.ndarray
Data obtained from the transformed method.
predicted_effects : Dict[str, jnp.ndarray]
A dictionary containing the predicted effects.
Returns
-------
Dict[str, jnp.ndarray]
A dictionary containing the sampled parameters of the effect.
"""
return {
"decay": numpyro.sample("decay", self.decay_prior),
}

def _predict(
self,
data: jnp.ndarray,
predicted_effects: Dict[str, jnp.ndarray],
params: Dict[str, jnp.ndarray],
) -> jnp.ndarray:
"""
Apply and return the geometric adstock effect values.
Parameters
----------
data : jnp.ndarray
Data obtained from the transformed method (shape: T, 1).
predicted_effects : Dict[str, jnp.ndarray]
A dictionary containing the predicted effects.
params : Dict[str, jnp.ndarray]
A dictionary containing the sampled parameters of the effect.
Returns
-------
jnp.ndarray
An array with shape (T, 1) for univariate timeseries.
"""
decay = params["decay"]

# Apply geometric adstock using jax.lax.scan for efficiency
def adstock_step(carry, current):
prev_adstock = carry
new_adstock = current + decay * prev_adstock
return new_adstock, new_adstock

_, adstock = jax.lax.scan(
adstock_step, init=jnp.array([0], dtype=data.dtype), xs=data
)
return adstock.reshape(-1, 1)
70 changes: 70 additions & 0 deletions tests/effects/test_adstock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Pytest for Geometric Adstock Effect class."""

import jax.numpy as jnp
import pandas as pd
import pytest
from numpyro import handlers
from numpyro.distributions import Beta

from prophetverse.effects.adstock import GeometricAdstockEffect


def test_geometric_adstock_sampling():
"""Test parameter sampling using numpyro.handlers.trace."""
effect = GeometricAdstockEffect(decay_prior=Beta(2, 2))
data = jnp.ones((10, 1)) # Dummy data
predicted_effects = {}

with handlers.trace() as trace, handlers.seed(rng_seed=0):
effect._sample_params(data, predicted_effects)

# Verify trace contains decay site
assert "decay" in trace, "Decay parameter not found in trace."

# Verify decay is sampled from the correct prior
assert trace["decay"]["type"] == "sample", "Decay parameter not sampled."
assert isinstance(
trace["decay"]["fn"], Beta
), "Decay parameter not sampled from Beta distribution."


def test_geometric_adstock_predict():
"""Test the predict method for correctness with predefined parameters."""
effect = GeometricAdstockEffect()

# Define mock data and parameters
data = jnp.array([[10.0], [20.0], [30.0]]) # Example input data (T, 1)
params = {"decay": jnp.array(0.5)}
predicted_effects = {}

# Call _predict
result = effect._predict(data, predicted_effects, params)

# Expected adstock output
expected = jnp.array(
[
[10.0],
[20.0 + 0.5 * 10.0],
[30.0 + 0.5 * (20.0 + 0.5 * 10.0)],
]
)

# Verify output shape
assert result.shape == data.shape, "Output shape mismatch."

# Verify output values
assert jnp.allclose(result, expected), "Adstock computation incorrect."


def test_error_when_different_fh():
effect = GeometricAdstockEffect()
X = pd.DataFrame(
data={"exog": [10.0, 20.0, 30.0, 30.0, 40.0, 50.0]},
index=pd.date_range("2021-01-01", periods=6),
)
fh = X.index
effect.transform(X=X, fh=fh)

effect.transform(X=X.iloc[:1], fh=fh[:1])
with pytest.raises(ValueError):
effect.transform(X=X, fh=fh[1:])

0 comments on commit b08152f

Please sign in to comment.