Skip to content

Commit

Permalink
Add basis bumps effect to MMM model (#1475)
Browse files Browse the repository at this point in the history
* add a helper function to create days from reference

* add ability to add events to MMM class

* check with multiple add_events

* add some validation and guardrails

* add to the docstrings

* change order of check

* support of out of sample

* test reference_date agnostic

* test for out of sample

---------

Co-authored-by: Juan Orduz <[email protected]>
  • Loading branch information
wd60622 and juanitorduz authored Feb 9, 2025
1 parent eec024d commit 23ec02e
Show file tree
Hide file tree
Showing 5 changed files with 494 additions and 10 deletions.
5 changes: 5 additions & 0 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,11 @@ def variable_mapping(self) -> dict[str, str]:
for parameter in self.default_priors.keys()
}

@property
def combined_dims(self) -> tuple[str, ...]:
"""Get the combined dims for all the parameters."""
return tuple(self._infer_output_core_dims())

def _infer_output_core_dims(self) -> tuple[str, ...]:
parameter_dims = sorted(
[
Expand Down
59 changes: 57 additions & 2 deletions pymc_marketing/mmm/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,15 @@ def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
"""

from typing import cast

import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, validate_call
from pydantic import BaseModel, Field, InstanceOf, model_validator, validate_call
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.deserialize import deserialize, register_deserialization
Expand Down Expand Up @@ -175,7 +179,28 @@ class EventEffect(BaseModel):

basis: InstanceOf[Basis]
effect_size: InstanceOf[Prior]
dims: tuple[str, ...]
dims: str | tuple[str, ...]

@model_validator(mode="before")
def _dims_to_tuple(self):
if isinstance(self["dims"], str):
self["dims"] = (self["dims"],)

return self

@model_validator(mode="after")
def _validate_dims(self):
print(self)
if not self.dims:
raise ValueError("The dims must not be empty.")

if not set(self.basis.combined_dims).issubset(set(self.dims)):
raise ValueError("The dims must contain all dimensions of the basis.")

if not set(self.effect_size.dims).issubset(set(self.dims)):
raise ValueError("The dims must contain all dimensions of the effect size.")

return self

def apply(self, X: pt.TensorLike, name: str = "event") -> TensorVariable:
"""Apply the event effect to the data."""
Expand Down Expand Up @@ -229,3 +254,33 @@ def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
default_priors = {
"sigma": Prior("Gamma", mu=7, sigma=1),
}


def days_from_reference(
dates: pd.Series | pd.DatetimeIndex,
reference_date: str | pd.Timestamp,
) -> npt.NDArray[np.int64]:
"""Calculate the difference in days between dates and a reference date.
Parameters
----------
dates : pd.Series | pd.DatetimeIndex
Dates to calculate the difference from the reference date.
reference_date : str | pd.Timestamp
Reference date.
Returns
-------
np.ndarray
Difference in days between dates and the reference date.
"""
reference_date = cast(pd.Timestamp, pd.to_datetime(reference_date))
dates = pd.to_datetime(dates)

diff = dates - reference_date

if isinstance(diff, pd.Series):
diff = diff.dt # type: ignore

return diff.days.to_numpy() # type: ignore
200 changes: 194 additions & 6 deletions pymc_marketing/mmm/multidimensional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
"""Multidimensional Marketing Mix Model class."""

from __future__ import annotations

import json
import warnings
from typing import Any, Literal
from typing import Any, Literal, Protocol

import arviz as az
import numpy as np
Expand All @@ -36,12 +38,13 @@
SaturationTransformation,
saturation_from_dict,
)
from pymc_marketing.mmm.events import EventEffect, days_from_reference
from pymc_marketing.mmm.fourier import YearlyFourier
from pymc_marketing.mmm.plot import MMMPlotSuite
from pymc_marketing.mmm.tvp import infer_time_index
from pymc_marketing.model_builder import ModelBuilder, _handle_deprecate_pred_argument
from pymc_marketing.model_config import parse_model_config
from pymc_marketing.prior import Prior
from pymc_marketing.prior import Prior, create_dim_handler

PYMC_MARKETING_ISSUE = "https://github.com/pymc-labs/pymc-marketing/issues/new"
warning_msg = (
Expand All @@ -52,6 +55,143 @@
warnings.warn(warning_msg, FutureWarning, stacklevel=1)


class MuEffect(Protocol):
"""Protocol for arbitrary additive mu effect."""

def create_data(self, mmm: MMM) -> None:
"""Create the required data in the model."""

def create_effect(self, mmm: MMM) -> pt.TensorVariable:
"""Create the additive effect in the model."""

def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
"""Set the data for new predictions."""


def create_event_mu_effect(
df_events: pd.DataFrame,
prefix: str,
effect: EventEffect,
) -> MuEffect:
"""Create an event effect for the MMM.
This class has the ability to create data and mean effects for the MMM model.
Parameters
----------
df_events : pd.DataFrame
The DataFrame containing the event data.
* `name`: name of the event. Used as the model coordinates.
* `start_date`: start date of the event
* `end_date`: end date of the event
prefix : str
The prefix to use for the event effect and associated variables.
effect : EventEffect
The event effect to apply.
Returns
-------
MuEffect
The event effect which is used in the MMM.
"""
if missing_columns := df_events.columns.difference(
["start_date", "end_date", "name"]
).tolist():
raise ValueError(f"Columns {missing_columns} are missing in df_events.")

effect.basis.prefix = prefix

reference_date = "2025-01-01"
start_dates = pd.to_datetime(df_events["start_date"])
end_dates = pd.to_datetime(df_events["end_date"])

class Effect:
"""Event effect class for the MMM."""

def create_data(self, mmm: MMM) -> None:
"""Create the required data in the model.
Parameters
----------
mmm : MMM
The MMM model instance.
"""
model: pm.Model = mmm.model

model_dates = pd.to_datetime(model.coords["date"])

model.add_coord(prefix, df_events["name"].to_numpy())

if "days" not in model:
pm.Data(
"days",
days_from_reference(model_dates, reference_date),
dims="date",
)

pm.Data(
f"{prefix}_start_diff",
days_from_reference(start_dates, reference_date),
dims=prefix,
)
pm.Data(
f"{prefix}_end_diff",
days_from_reference(end_dates, reference_date),
dims=prefix,
)

def create_effect(self, mmm: MMM) -> pt.TensorVariable:
"""Create the event effect in the model.
Parameters
----------
mmm : MMM
The MMM model instance.
Returns
-------
pt.TensorVariable
The average event effect in the model.
"""
model: pm.Model = mmm.model

s_ref = model["days"][:, None] - model[f"{prefix}_start_diff"]
e_ref = model["days"][:, None] - model[f"{prefix}_end_diff"]

def create_basis_matrix(s_ref, e_ref):
return pt.where(
(s_ref >= 0) & (e_ref <= 0),
0,
pt.where(pt.abs(s_ref) < pt.abs(e_ref), s_ref, e_ref),
)

X = create_basis_matrix(s_ref, e_ref)
event_effect = effect.apply(X, name=prefix)

total_effect = pm.Deterministic(
f"{prefix}_total_effect",
event_effect.sum(axis=1),
dims="date",
)

dim_handler = create_dim_handler(("date", *mmm.dims))
return dim_handler(total_effect, "date")

def set_data(self, mmm: MMM, model: pm.Model, X: xr.Dataset) -> None:
"""Set the data for new predictions."""
new_dates = pd.to_datetime(model.coords["date"])

new_data = {
"days": days_from_reference(new_dates, reference_date),
}
pm.set_data(new_data=new_data, model=model)

return Effect()


class MMM(ModelBuilder):
"""Marketing Mix Model class for estimating the impact of marketing channels on a target variable.
Expand Down Expand Up @@ -149,13 +289,51 @@ def __init__(
variable_name="gamma_fourier",
)

self.mu_effects: list[MuEffect] = []

@property
def default_sampler_config(self) -> dict:
"""Default sampler configuration."""
return {}

def _data_setter(self, X, y=None): ...

def add_events(
self,
df_events: pd.DataFrame,
prefix: str,
effect: EventEffect,
) -> None:
"""Add event effects to the model.
This must be called before building the model.
Parameters
----------
df_events : pd.DataFrame
The DataFrame containing the event data.
* `name`: name of the event. Used as the model coordinates.
* `start_date`: start date of the event
* `end_date`: end date of the event
prefix : str
The prefix to use for the event effect and associated variables.
effect : EventEffect
The event effect to apply.
Raises
------
ValueError
If the event effect dimensions do not contain the prefix and model dimensions.
"""
if not set(effect.dims).issubset((prefix, self.dims)):
raise ValueError(
f"Event effect dims {effect.dims} must contain {prefix} and {self.dims}"
)

event_effect = create_event_mu_effect(df_events, prefix, effect)
self.mu_effects.append(event_effect)

@property
def _serializable_model_config(self) -> dict[str, Any]:
def ndarray_to_list(d: dict) -> dict:
Expand Down Expand Up @@ -822,6 +1000,9 @@ def build_model(
dims=("date", *self.dims),
)

for mu_effect in self.mu_effects:
mu_effect.create_data(self)

if self.time_varying_intercept | self.time_varying_media:
time_index = pm.Data(
name="time_index",
Expand Down Expand Up @@ -936,6 +1117,9 @@ def create_deterministic(x: pt.TensorVariable) -> None:
)
mu_var += yearly_seasonality_contribution

for mu_effect in self.mu_effects:
mu_var += mu_effect.create_effect(self)

mu_var.name = "mu"
mu_var.dims = ("date", *self.dims)

Expand Down Expand Up @@ -1200,14 +1384,18 @@ def sample_posterior_predictive(
# Update model data with xarray
if X is None:
raise ValueError("X values must be provided")
dataset_xarray = self._posterior_predictive_data_transformation(
X=X,
include_last_observations=include_last_observations,
)
model = self._set_xarray_data(
self._posterior_predictive_data_transformation(
X=X,
include_last_observations=include_last_observations,
),
dataset_xarray=dataset_xarray,
clone_model=clone_model,
)

for mu_effect in self.mu_effects:
mu_effect.set_data(self, model, dataset_xarray)

with model:
# Sample from posterior predictive
post_pred = pm.sample_posterior_predictive(
Expand Down
Loading

0 comments on commit 23ec02e

Please sign in to comment.