From c5752c32320568409eca3351f281b8707bf3e8cb Mon Sep 17 00:00:00 2001 From: Felipe Angelim Date: Sun, 15 Dec 2024 12:42:40 -0300 Subject: [PATCH] Add sample_params to existing effects --- extension_templates/effect.py | 81 ++++++++++++++++++-- src/prophetverse/effects/base.py | 49 +++++++++--- src/prophetverse/effects/exact_likelihood.py | 2 +- src/prophetverse/effects/hill.py | 19 ++++- src/prophetverse/effects/lift_likelihood.py | 51 ++++++++---- src/prophetverse/effects/linear.py | 8 +- src/prophetverse/effects/log.py | 16 +++- src/prophetverse/effects/trend/flat.py | 4 +- src/prophetverse/effects/trend/piecewise.py | 29 ++++++- 9 files changed, 214 insertions(+), 45 deletions(-) diff --git a/extension_templates/effect.py b/extension_templates/effect.py index f08f8e9..02319fe 100644 --- a/extension_templates/effect.py +++ b/extension_templates/effect.py @@ -9,6 +9,66 @@ from prophetverse.utils.frame_to_array import series_to_tensor_or_array +class MySimpleEffectName(BaseEffect): + """Base class for effects.""" + + _tags = { + # Supports multivariate data? Can this + # Effect be used with Multiariate prophet? + "supports_multivariate": False, + # If no columns are found, should + # _predict be skipped? + "skip_predict_if_no_match": True, + # Should only the indexes related to the forecasting horizon be passed to + # _transform? + "filter_indexes_with_forecating_horizon_at_transform": True, + } + + def __init__(self, param1: Any, param2: Any): + self.param1 = param1 + self.param2 = param2 + + def _sample_params(self, data, predicted_effects): + # call numpyro.sample to sample the parameters of the effect + # return a dictionary with the sampled parameters, where + # key is the name of the parameter and value is the sampled value + return {} + + def _predict( + self, + data: Any, + predicted_effects: Dict[str, jnp.ndarray], + params: Dict[str, jnp.ndarray], + ) -> jnp.ndarray: + """Apply and return the effect values. + + Parameters + ---------- + data : Any + Data obtained from the transformed method. + + predicted_effects : Dict[str, jnp.ndarray], optional + A dictionary containing the predicted effects, by default None. + + params : Dict[str, jnp.ndarray] + A dictionary containing the sampled parameters of the effect. + + Returns + ------- + jnp.ndarray + An array with shape (T,1) for univariate timeseries, or (N, T, 1) for + multivariate timeseries, where T is the number of timepoints and N is the + number of series. + """ + # predicted effects come with the following shapes: + # (T, 1) shaped array for univariate timeseries + # (N, T, 1) shaped array for multivariate timeseries, where N is the number of + # series + + # Here you use the params to compute the effect. + raise NotImplementedError("Subclasses must implement _predict()") + + class MyEffectName(BaseEffect): """Base class for effects.""" @@ -76,10 +136,17 @@ def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Any: array = series_to_tensor_or_array(X) return array - def predict( + def _sample_params(self, data, predicted_effects): + # call numpyro.sample to sample the parameters of the effect + # return a dictionary with the sampled parameters, where + # key is the name of the parameter and value is the sampled value + return {} + + def _predict( self, - data: Dict, + data: Any, predicted_effects: Dict[str, jnp.ndarray], + params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Apply and return the effect values. @@ -91,6 +158,9 @@ def predict( predicted_effects : Dict[str, jnp.ndarray], optional A dictionary containing the predicted effects, by default None. + params : Dict[str, jnp.ndarray] + A dictionary containing the sampled parameters of the effect. + Returns ------- jnp.ndarray @@ -98,11 +168,10 @@ def predict( multivariate timeseries, where T is the number of timepoints and N is the number of series. """ - # Get the trend + # predicted effects come with the following shapes: # (T, 1) shaped array for univariate timeseries # (N, T, 1) shaped array for multivariate timeseries, where N is the number of # series - # trend: jnp.ndarray = predicted_effects["trend"] - # Or user predicted_effects.get("trend") to return None if the trend is - # not found + + # Here you use the params to compute the effect. raise NotImplementedError("Subclasses must implement _predict()") diff --git a/src/prophetverse/effects/base.py b/src/prophetverse/effects/base.py index 14cb9fd..bfcc0c6 100644 --- a/src/prophetverse/effects/base.py +++ b/src/prophetverse/effects/base.py @@ -273,13 +273,33 @@ def sample_params( Dict A dictionary containing the sampled parameters. """ + if predicted_effects is None: + predicted_effects = {} + return self._sample_params(data, predicted_effects) def _sample_params( self, - data: Dict, - predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, + data: Any, + predicted_effects: Dict[str, jnp.ndarray], ): + """Sample parameters from the prior distribution. + + Should be implemented by subclasses to provide the actual sampling logic. + + Parameters + ---------- + data : Any + The data to be used for sampling the parameters, obtained from + `transform` method. + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects, by default None. + + Returns + ------- + Dict + A dictionary containing the sampled parameters. + """ return {} def _predict( @@ -336,9 +356,10 @@ class BaseAdditiveOrMultiplicativeEffect(BaseEffect): or "multiplicative". """ - def __init__(self, effect_mode="additive"): + def __init__(self, effect_mode="additive", base_effect_name: str = "trend"): self.effect_mode = effect_mode + self.base_effect_name = base_effect_name if effect_mode not in ["additive", "multiplicative"]: raise ValueError( @@ -372,23 +393,29 @@ def predict( number of series. """ if predicted_effects is None: + predicted_effects = {} + + if params is None: + params = self.sample_params(data, predicted_effects) + + if ( + self.base_effect_name not in predicted_effects + and self.effect_mode == "multiplicative" + ): raise ValueError( "BaseAdditiveOrMultiplicativeEffect requires trend in" + " predicted_effects" ) - trend = predicted_effects["trend"] - if trend.ndim == 1: - trend = trend.reshape((-1, 1)) - - if params is None: - params = self.sample_params(data, predicted_effects) x = super().predict( data=data, predicted_effects=predicted_effects, params=params ) - x = x.reshape(trend.shape) if self.effect_mode == "additive": return x - return trend * x + base_effect = predicted_effects[self.base_effect_name] + if base_effect.ndim == 1: + base_effect = base_effect.reshape((-1, 1)) + x = x.reshape(base_effect.shape) + return base_effect * x diff --git a/src/prophetverse/effects/exact_likelihood.py b/src/prophetverse/effects/exact_likelihood.py index 2d44651..ce2e3f5 100644 --- a/src/prophetverse/effects/exact_likelihood.py +++ b/src/prophetverse/effects/exact_likelihood.py @@ -133,7 +133,7 @@ def _predict( with numpyro.handlers.mask(mask=obs_mask): numpyro.sample( - "lift_experiment", + "exact_likelihood:ignore", dist.Normal(x, self.prior_scale), obs=observed_reference_value, ) diff --git a/src/prophetverse/effects/hill.py b/src/prophetverse/effects/hill.py index e12f9f9..278df88 100644 --- a/src/prophetverse/effects/hill.py +++ b/src/prophetverse/effects/hill.py @@ -44,7 +44,24 @@ def __init__( super().__init__(effect_mode=effect_mode) - def _sample_params(self, data, predicted_effects=None): + def _sample_params( + self, data, predicted_effects: Dict[str, jnp.ndarray] + ) -> Dict[str, jnp.ndarray]: + """ + Sample the parameters of the effect. + + Parameters + ---------- + data : Any + Data obtained from the transformed method. + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects + + Returns + ------- + Dict[str, jnp.ndarray] + A dictionary containing the sampled parameters of the effect. + """ return { "half_max": numpyro.sample("half_max", self.half_max_prior), "slope": numpyro.sample("slope", self.slope_prior), diff --git a/src/prophetverse/effects/lift_likelihood.py b/src/prophetverse/effects/lift_likelihood.py index 6031402..d86f5dd 100644 --- a/src/prophetverse/effects/lift_likelihood.py +++ b/src/prophetverse/effects/lift_likelihood.py @@ -81,7 +81,8 @@ def fit(self, y: pd.DataFrame, X: pd.DataFrame, scale: float = 1): ------- None """ - self.effect.fit(X=X, y=y, scale=scale) + self.effect_ = self.effect.clone() + self.effect_.fit(X=X, y=y, scale=scale) self.timeseries_scale = scale super().fit(X=X, y=y, scale=scale) @@ -106,8 +107,13 @@ def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Dict[str, Any]: Dictionary with data for the lift and for the inner effect """ data_dict = {} - data_dict["inner_effect_data"] = self.effect._transform(X, fh=fh) + data_dict["inner_effect_data"] = self.effect_._transform(X, fh=fh) + # Check if fh and self.lift_test_results have same index type + if not isinstance(fh, self.lift_test_results.index.__class__): + raise TypeError( + "fh and self.lift_test_results must have the same index type" + ) X_lift = self.lift_test_results.reindex(fh, fill_value=jnp.nan) data_dict["observed_lift"] = ( @@ -119,8 +125,25 @@ def _transform(self, X: pd.DataFrame, fh: pd.Index) -> Dict[str, Any]: return data_dict - def _sample_params(self, data, predicted_effects=None): - return self.effect.sample_params( + def _sample_params(self, data, predicted_effects): + """ + Sample the parameters of the effect. + + Calls the sample_params method of the inner effect. + + Parameters + ---------- + data : Any + Data obtained from the transformed method. + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects + + Returns + ------- + Dict[str, jnp.ndarray] + A dictionary containing the sampled parameters of the effect. + """ + return self.effect_.sample_params( data=data["inner_effect_data"], predicted_effects=predicted_effects ) @@ -150,35 +173,35 @@ def _predict( x_end = data["x_end"].reshape((-1, 1)) obs_mask = data["obs_mask"] - effect_params = self.effect.sample_params( - data=data["inner_effect_data"], - predicted_effects=predicted_effects, - ) - predicted_effects_masked = { k: v[obs_mask] for k, v in predicted_effects.items() } - x = self.effect.predict( + # Call the effect a first time + x = self.effect_.predict( data=data["inner_effect_data"], predicted_effects=predicted_effects, params=params, ) - y_start = self.effect.predict( + # Get the start and end values + y_start = self.effect_.predict( data=x_start, predicted_effects=predicted_effects_masked, - params=effect_params, + params=params, ) - y_end = self.effect.predict( - data=x_end, predicted_effects=predicted_effects_masked, params=effect_params + y_end = self.effect_.predict( + data=x_end, predicted_effects=predicted_effects_masked, params=params ) + # Calculate the delta_y delta_y = jnp.abs(y_end - y_start) with numpyro.handlers.scale(scale=self.likelihood_scale): distribution = GammaReparametrized(delta_y, self.prior_scale) + # Add :ignore so that the model removes this + # sample when organizing the output dataframe numpyro.sample( "lift_experiment:ignore", distribution, diff --git a/src/prophetverse/effects/linear.py b/src/prophetverse/effects/linear.py index 6b2f1ea..f6c26fb 100644 --- a/src/prophetverse/effects/linear.py +++ b/src/prophetverse/effects/linear.py @@ -40,7 +40,7 @@ def __init__( super().__init__(effect_mode=effect_mode) - def _sample_params(self, data, predicted_effects=None): + def _sample_params(self, data, predicted_effects): n_features = data.shape[-1] @@ -54,7 +54,7 @@ def _sample_params(self, data, predicted_effects=None): def _predict( self, data: Any, - predicted_effects: Optional[Dict[str, jnp.ndarray]], + predicted_effects: Dict[str, jnp.ndarray], params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Apply and return the effect values. @@ -64,8 +64,8 @@ def _predict( data : Any Data obtained from the transformed method. - predicted_effects : Dict[str, jnp.ndarray], optional - A dictionary containing the predicted effects, by default None. + predicted_effects : Dict[str, jnp.ndarray] + A dictionary containing the predicted effects Returns ------- diff --git a/src/prophetverse/effects/log.py b/src/prophetverse/effects/log.py index dee6fba..c6235e7 100644 --- a/src/prophetverse/effects/log.py +++ b/src/prophetverse/effects/log.py @@ -38,10 +38,19 @@ def __init__( self.rate_prior = rate_prior or dist.Gamma(1, 1) super().__init__(effect_mode=effect_mode) + def _sample_params(self, data, predicted_effects): + scale = numpyro.sample("log_scale", self.scale_prior) + rate = numpyro.sample("log_rate", self.rate_prior) + return { + "scale": scale, + "rate": rate, + } + def _predict( # type: ignore[override] self, data: jnp.ndarray, - predicted_effects: Optional[Dict[str, jnp.ndarray]] = None, + predicted_effects: Dict[str, jnp.ndarray], + params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """Apply and return the effect values. @@ -60,8 +69,9 @@ def _predict( # type: ignore[override] multivariate timeseries, where T is the number of timepoints and N is the number of series. """ - scale = numpyro.sample("log_scale", self.scale_prior) - rate = numpyro.sample("log_rate", self.rate_prior) + scale = params["scale"] + rate = params["rate"] + effect = scale * jnp.log(jnp.clip(rate * data + 1, 1e-8, None)) return effect diff --git a/src/prophetverse/effects/trend/flat.py b/src/prophetverse/effects/trend/flat.py index 7cad722..d44fec1 100644 --- a/src/prophetverse/effects/trend/flat.py +++ b/src/prophetverse/effects/trend/flat.py @@ -1,5 +1,7 @@ """Flat trend model.""" +from typing import Any, Dict + import jax.numpy as jnp import numpyro import numpyro.distributions as dist @@ -56,7 +58,7 @@ def _transform(self, X: pd.DataFrame, fh: pd.Index) -> dict: idx = X.index return jnp.ones((len(idx), 1)) - def _sample_params(self, data, predicted_effects=None): + def _sample_params(self, data: Any, predicted_effects: Dict[str, jnp.ndarray]): """Sample parameters from the prior distribution. Parameters diff --git a/src/prophetverse/effects/trend/piecewise.py b/src/prophetverse/effects/trend/piecewise.py index 164b055..42b9010 100644 --- a/src/prophetverse/effects/trend/piecewise.py +++ b/src/prophetverse/effects/trend/piecewise.py @@ -5,7 +5,7 @@ """ import itertools -from typing import Dict, Tuple, Union +from typing import Any, Dict, Tuple, Union import jax.numpy as jnp import numpy as np @@ -157,7 +157,7 @@ def _transform(self, X: pd.DataFrame, fh: pd.Index) -> dict: idx = self._fh_to_index(fh) return self.get_changepoint_matrix(idx) - def _sample_params(self, data, predicted_effects=None): + def _sample_params(self, data: Any, predicted_effects: Dict[str, jnp.ndarray]): changepoint_matrix = data @@ -546,7 +546,25 @@ def _suggest_global_trend_and_offset( return global_rates, offset - def _sample_params(self, data, predicted_effects=None): + def _sample_params(self, data, predicted_effects): + """ + Sample params for the effect. + + Use super to sample the changepoint coefficients and offset, and then sample + the capacity using the capacity prior. + + Parameters + ---------- + data : Any + The input data. + predicted_effects : Dict[str, jnp.ndarray] + The predicted effects + + Returns + ------- + dict + The sampled parameters. + """ with numpyro.plate("series", self.n_series, dim=-3): capacity = numpyro.sample("capacity", self.capacity_prior) @@ -556,7 +574,10 @@ def _sample_params(self, data, predicted_effects=None): } def _predict( # type: ignore[override] - self, data: jnp.ndarray, predicted_effects, params + self, + data: Any, + predicted_effects: Dict[str, jnp.ndarray], + params: Dict[str, jnp.ndarray], ) -> jnp.ndarray: """ Compute the trend for the given changepoint matrix.