Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMM] Model events as gaussian bumps #1465

Merged
merged 8 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
231 changes: 231 additions & 0 deletions pymc_marketing/mmm/events.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
# Copyright 2022 - 2025 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Event transformations.

This module provides event transformations for use in Marketing Mix Models.

.. plot::
:context: close-figs

import numpy as np
import pandas as pd
import pymc as pm

import matplotlib.pyplot as plt

from pymc_marketing.mmm.events import EventEffect, GaussianBasis
from pymc_marketing.plot import plot_curve
from pymc_marketing.prior import Prior

seed = sum(map(ord, "Events"))
rng = np.random.default_rng(seed)

df_events = pd.DataFrame(
{
"event": ["single day", "multi day"],
"start_date": pd.to_datetime(["2025-01-01", "2025-01-20"]),
"end_date": pd.to_datetime(["2025-01-02", "2025-01-25"]),
}
)

def difference_in_days(model_dates, event_dates):
if hasattr(model_dates, "to_numpy"):
model_dates = model_dates.to_numpy()
if hasattr(event_dates, "to_numpy"):
event_dates = event_dates.to_numpy()

one_day = np.timedelta64(1, "D")
return (model_dates[:, None] - event_dates) / one_day


def create_basis_matrix(df_events: pd.DataFrame, model_dates: np.ndarray):
start_dates = df_events["start_date"]
end_dates = df_events["end_date"]

s_ref = difference_in_days(model_dates, start_dates)
e_ref = difference_in_days(model_dates, end_dates)

return np.where(
(s_ref >= 0) & (e_ref <= 0),
0,
np.where(np.abs(s_ref) < np.abs(e_ref), s_ref, e_ref),
)


gaussian = GaussianBasis(
priors={
"sigma": Prior("Gamma", mu=7, sigma=1, dims="event"),
}
)
effect_size = Prior("Normal", mu=1, sigma=1, dims="event")
effect = EventEffect(basis=gaussian, effect_size=effect_size, dims=("event",))

dates = pd.date_range("2024-12-01", periods=3 * 31, freq="D")

X = create_basis_matrix(df_events, model_dates=dates)

coords = {"date": dates, "event": df_events["event"].to_numpy()}
with pm.Model(coords=coords) as model:
pm.Deterministic("effect", effect.apply(X), dims=("date", "event"))

idata = pm.sample_prior_predictive(random_seed=rng)

fig, axes = idata.prior.effect.pipe(
plot_curve,
{"date"},
subplot_kwargs={"ncols": 1},
sample_kwargs={"rng": rng},
)
fig.suptitle("Gaussian Event Effect")
plt.show()

"""

import numpy as np
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import BaseModel, Field, InstanceOf, validate_call
from pytensor.tensor.variable import TensorVariable

from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.mmm.components.base import Transformation, create_registration_meta
from pymc_marketing.prior import Prior, create_dim_handler

BASIS_TRANSFORMATIONS: dict = {}
BasisMeta = create_registration_meta(BASIS_TRANSFORMATIONS)


class Basis(Transformation, metaclass=BasisMeta): # type: ignore[misc]
"""Basis transformation associated with an event model."""

prefix: str = "basis"
lookup_name: str

@validate_call
def sample_curve(
self,
parameters: InstanceOf[xr.Dataset] = Field(
..., description="Parameters of the saturation transformation."
),
days: int = Field(0, ge=0, description="Number of days around basis."),
) -> xr.DataArray:
"""Sample the curve of the saturation transformation given parameters.

Parameters
----------
parameters : xr.Dataset
Dataset with the parameters of the saturation transformation.
days : int
Number of days around basis.

Returns
-------
xr.DataArray
Curve of the saturation transformation.

"""
x = np.linspace(-days, days, 100)

Check warning on line 139 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L139

Added line #L139 was not covered by tests

coords = {"x": x}

Check warning on line 141 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L141

Added line #L141 was not covered by tests

return self._sample_curve(

Check warning on line 143 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L143

Added line #L143 was not covered by tests
var_name="saturation",
parameters=parameters,
x=x,
coords=coords,
)


def basis_from_dict(data: dict) -> Basis:
"""Create a basis transformation from a dictionary."""
data = data.copy()
lookup_name = data.pop("lookup_name")
cls = BASIS_TRANSFORMATIONS[lookup_name]

Check warning on line 155 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L153-L155

Added lines #L153 - L155 were not covered by tests

if "priors" in data:
data["priors"] = {k: deserialize(v) for k, v in data["priors"].items()}

Check warning on line 158 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L157-L158

Added lines #L157 - L158 were not covered by tests

return cls(**data)

Check warning on line 160 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L160

Added line #L160 was not covered by tests


def _is_basis(data):
return "lookup_name" in data and data["lookup_name"] in BASIS_TRANSFORMATIONS

Check warning on line 164 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L164

Added line #L164 was not covered by tests


register_deserialization(
is_type=_is_basis,
deserialize=basis_from_dict,
)


class EventEffect(BaseModel):
"""Event effect associated with an event model."""

basis: InstanceOf[Basis]
effect_size: InstanceOf[Prior]
dims: tuple[str, ...]

def apply(self, X: pt.TensorLike, name: str = "event") -> TensorVariable:
"""Apply the event effect to the data."""
dim_handler = create_dim_handler(("x", *self.dims))
return self.basis.apply(X, dims=self.dims) * dim_handler(

Check warning on line 183 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L182-L183

Added lines #L182 - L183 were not covered by tests
self.effect_size.create_variable(f"{name}_effect_size"),
self.effect_size.dims,
)

def to_dict(self) -> dict:
"""Convert the event effect to a dictionary."""
return {

Check warning on line 190 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L190

Added line #L190 was not covered by tests
"class": "EventEffect",
"data": {
"basis": self.basis.to_dict(),
"effect_size": self.effect_size.to_dict(),
"dims": self.dims,
},
}

@classmethod
def from_dict(cls, data: dict) -> "EventEffect":
"""Create an event effect from a dictionary."""
return cls(

Check warning on line 202 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L202

Added line #L202 was not covered by tests
basis=deserialize(data["basis"]),
effect_size=deserialize(data["effect_size"]),
dims=data["dims"],
)


def _is_event_effect(data: dict) -> bool:
"""Check if the data is an event effect."""
return data["class"] == "EventEffect"

Check warning on line 211 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L211

Added line #L211 was not covered by tests


register_deserialization(
is_type=_is_event_effect,
deserialize=lambda data: EventEffect.from_dict(data["data"]),
)


class GaussianBasis(Basis):
"""Gaussian basis transformation."""

lookup_name = "gaussian"

def function(self, x: pt.TensorLike, sigma: pt.TensorLike) -> TensorVariable:
"""Gaussian bump function."""
return pm.math.exp(-0.5 * (x / sigma) ** 2)

Check warning on line 227 in pymc_marketing/mmm/events.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/events.py#L227

Added line #L227 was not covered by tests

default_priors = {
"sigma": Prior("Gamma", mu=7, sigma=1),
}
Loading