Skip to content

Draft (new feature) : Model to estimate when a intervention had effect #480

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
10a017e
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
69d79b3
New feature : Model to estimate when a intervention had effect
JeanVanDyk May 28, 2025
bf4eaaa
Minor fix in docstring
JeanVanDyk May 29, 2025
3420c9a
Minor fix in docstring
JeanVanDyk May 29, 2025
3dc23b3
Minor fix in docstring
JeanVanDyk May 29, 2025
d739b4a
Minor fix in docstring
JeanVanDyk May 29, 2025
d48f0c3
Minor fix in docstring
JeanVanDyk May 29, 2025
14afe09
Minor fix in docstring
JeanVanDyk May 29, 2025
60357a5
Minor fix in docstring
JeanVanDyk May 29, 2025
7f57b13
Minor fix in docstring
JeanVanDyk May 29, 2025
2cb92fc
Minor fix in docstring
JeanVanDyk May 29, 2025
d9c06ac
Minor fix in docstring
JeanVanDyk May 29, 2025
52cc0fa
Minor fix in docstring
JeanVanDyk May 29, 2025
faf085b
Minor fix in docstring
JeanVanDyk May 29, 2025
cc9a1f4
Minor fix in docstring
JeanVanDyk May 29, 2025
dea9d6e
Minor fix in docstring
JeanVanDyk May 29, 2025
5e9cde6
fix : hiding progressbar
JeanVanDyk May 30, 2025
ee701f2
Enhancement : Adding the possibility for the user to indicate priors …
JeanVanDyk May 30, 2025
5ee3cb4
Minor fix in docstring
JeanVanDyk Jun 4, 2025
08c520c
updating example notebook
JeanVanDyk Jun 4, 2025
b1681da
updating example notebook
JeanVanDyk Jun 4, 2025
fcfd059
Supporting Date format and adding exceptions for model related issues
JeanVanDyk Jun 4, 2025
64c97b7
changing column index restriction to label restriction
JeanVanDyk Jun 5, 2025
2996331
codespell
JeanVanDyk Jun 17, 2025
1da80fd
resolved merge
JeanVanDyk Jun 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions causalpy/custom_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,10 @@ class DataException(Exception):

def __init__(self, message: str):
self.message = message


class ModelException(Exception):
"""Exception raised given when there is some error in user-provided model"""

def __init__(self, message: str):
self.message = message
258 changes: 223 additions & 35 deletions causalpy/experiments/interrupted_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,163 @@
from patsy import build_design_matrices, dmatrices
from sklearn.base import RegressorMixin

from causalpy.custom_exceptions import BadIndexException
from causalpy.custom_exceptions import BadIndexException, ModelException
from causalpy.experiments.base import BaseExperiment
from causalpy.plot_utils import get_hdi_to_df, plot_xY
from causalpy.pymc_models import PyMCModel
from causalpy.utils import round_num

from .base import BaseExperiment

LEGEND_FONT_SIZE = 12


class HandlerUTT:
"""
Handle data preprocessing, postprocessing, and plotting steps for models
with unknown treatment intervention times.
"""

def data_preprocessing(self, data, treatment_time, formula, model):
"""
Preprocess the data using patsy for fittng into the model and update the model with required infos
"""
y, X = dmatrices(formula, data)
# Restrict model's treatment time inference to given range
model.set_time_range(treatment_time, data)
# Needed to track time evolution across model predictions
model.set_timeline(X.design_info.column_names.index("t"))
return y, X

def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
"""
Postprocess the data accordingly to the inferred treatment time for calculation and plot purpose
"""
# Retrieve posterior mean of inferred treatment time
treatment_time_mean = idata.posterior["treatment_time"].mean().item()
inferred_time = int(treatment_time_mean)

# Safety check: ensure the inferred time is present in the dataset
if inferred_time not in data["t"].values:
raise ValueError(
f"Inferred treatment time {inferred_time} not found in data['t']."
)

# Convert the inferred time to its corresponding DataFrame index
inferred_index = data[data["t"] == inferred_time].index[0]

# Retrieve HDI bounds of treatment time (uncertainty interval)
hdi_bounds = az.hdi(idata, var_names=["treatment_time"])[
"treatment_time"
].values
hdi_start_time = int(hdi_bounds[0])

# Convert HDI lower bound to DataFrame index for slicing
if hdi_start_time not in data["t"].values:
raise ValueError(f"HDI start time {hdi_start_time} not found in data['t'].")

hdi_start_idx_df = data[data["t"] == hdi_start_time].index[0]
hdi_start_idx_np = data.index.get_loc(hdi_start_idx_df)

# Slice both pandas and numpy objects accordingly
df_pre = data[data.index < hdi_start_idx_df]
df_post = data[data.index >= hdi_start_idx_df]
truncated_y = pre_y[:hdi_start_idx_np]
truncated_X = pre_X[:hdi_start_idx_np]

return df_pre, df_post, truncated_y, truncated_X, inferred_index

def plot_intervention_line(self, ax, idata, datapost, treatment_time):
"""
Plot a vertical line at the inferred treatment time, along with a shaded area
representing the Highest Density Interval (HDI) of the inferred time.
"""
# Extract the HDI (uncertainty interval) of the treatment time
hdi = az.hdi(idata, var_names=["treatment_time"])["treatment_time"].values
x1 = datapost.index[datapost["t"] == int(hdi[0])][0]
x2 = datapost.index[datapost["t"] == int(hdi[1])][0]

for i in [0, 1, 2]:
ymin, ymax = ax[i].get_ylim()

# Vertical line for inferred treatment time
ax[i].plot(
[treatment_time, treatment_time],
[ymin, ymax],
ls="-",
lw=3,
color="r",
solid_capstyle="butt",
)

# Shaded region for HDI of treatment time
ax[i].fill_betweenx(
y=[ymin, ymax],
x1=x1,
x2=x2,
alpha=0.1,
color="r",
)

def plot_treated_counterfactual(
self, ax, handles, labels, datapost, post_pred, post_y
):
"""
Plot the inferred post-intervention trajectory (with treatment effect).
"""
# --- Plot predicted trajectory under treatment (with HDI)
h_line, h_patch = plot_xY(
datapost.index,
post_pred["posterior_predictive"].mu_ts,
ax=ax[0],
plot_hdi_kwargs={"color": "yellowgreen"},
)
handles.append((h_line, h_patch))
labels.append("treated counterfactual")


class HandlerKTT:
"""
Handles data preprocessing, postprocessing, and plotting logic for models
where the treatment time is known in advance.
"""

def data_preprocessing(self, data, treatment_time, formula, model):
"""
Preprocess the data using patsy for fitting into the model
"""
# Use only data before treatment for training the model
return dmatrices(formula, data[data.index < treatment_time])

def data_postprocessing(self, data, idata, treatment_time, pre_y, pre_X):
"""
Postprocess data by splitting it into pre- and post-intervention periods, using the known treatment time.
"""
return (
data[data.index < treatment_time],
data[data.index >= treatment_time],
pre_y,
pre_X,
treatment_time,
)

def plot_intervention_line(self, ax, idata, datapost, treatment_time):
"""
Plot a vertical line at the known treatment time.
"""
# --- Plot a vertical line at the known treatment time
for i in [0, 1, 2]:
ax[i].axvline(
x=treatment_time, ls="-", lw=3, color="r", solid_capstyle="butt"
)

def plot_treated_counterfactual(
self, sax, handles, labels, datapost, post_pred, post_y
):
"""
Placeholder method to maintain interface compatibility.
"""
pass


class InterruptedTimeSeries(BaseExperiment):
"""
The class for interrupted time series analysis.
Expand Down Expand Up @@ -79,31 +226,66 @@ class InterruptedTimeSeries(BaseExperiment):
def __init__(
self,
data: pd.DataFrame,
treatment_time: Union[int, float, pd.Timestamp],
treatment_time: Union[int, float, pd.Timestamp, tuple, None],
formula: str,
model=None,
**kwargs,
) -> None:
super().__init__(model=model)

# rename the index to "obs_ind"
data.index.name = "obs_ind"
self.input_validation(data, treatment_time)
self.input_validation(data, treatment_time, model)
self.treatment_time = treatment_time
# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"
# split data in to pre and post intervention
self.datapre = data[data.index < self.treatment_time]
self.datapost = data[data.index >= self.treatment_time]

self.treatment_time = treatment_time
self.formula = formula

# Getting the right handler
if treatment_time is None or isinstance(treatment_time, tuple):
self.handler = HandlerUTT()
else:
self.handler = HandlerKTT()

# set experiment type - usually done in subclasses
self.expt_type = "Pre-Post Fit"

# Preprocessing based on handler type
y, X = self.handler.data_preprocessing(
data, self.treatment_time, formula, self.model
)

# set things up with pre-intervention data
y, X = dmatrices(formula, self.datapre)
self.outcome_variable_name = y.design_info.column_names[0]
self._y_design_info = y.design_info
self._x_design_info = X.design_info
self.labels = X.design_info.column_names
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)

# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.pre_X, y=self.pre_y)
else:
raise ValueError("Model type not recognized")

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)

# Postprocessing with handler
self.datapre, self.datapost, self.pre_y, self.pre_X, self.treatment_time = (
self.handler.data_postprocessing(
data, self.idata, treatment_time, self.pre_y, self.pre_X
)
)

# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)

# process post-intervention data
(new_y, new_x) = build_design_matrices(
[self._y_design_info, self._x_design_info], self.datapost
Expand Down Expand Up @@ -138,21 +320,6 @@ def __init__(
coords={"obs_ind": self.datapost.index},
)

# fit the model to the observed (pre-intervention) data
if isinstance(self.model, PyMCModel):
COORDS = {"coeffs": self.labels, "obs_ind": np.arange(self.pre_X.shape[0])}
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
elif isinstance(self.model, RegressorMixin):
self.model.fit(X=self.pre_X, y=self.pre_y)
else:
raise ValueError("Model type not recognized")

# score the goodness of fit to the pre-intervention data
self.score = self.model.score(X=self.pre_X, y=self.pre_y)

# get the model predictions of the observed (pre-intervention) data
self.pre_pred = self.model.predict(X=self.pre_X)

# calculate the counterfactual
self.post_pred = self.model.predict(X=self.post_X)
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
Expand All @@ -161,16 +328,24 @@ def __init__(
self.post_impact
)

def input_validation(self, data, treatment_time):
def input_validation(self, data, treatment_time, model):
"""Validate the input data and model formula for correctness"""
if treatment_time is None and not hasattr(model, "set_time_range"):
raise ModelException(
"If treatment_time is None, provided model must have a 'set_time_range' method"
)
if isinstance(treatment_time, tuple) and not hasattr(model, "set_time_range"):
raise ModelException(
"If treatment_time is a tuple, provided model must have a 'set_time_range' method"
)
if isinstance(data.index, pd.DatetimeIndex) and not isinstance(
treatment_time, pd.Timestamp
treatment_time, (pd.Timestamp, tuple, type(None))
):
raise BadIndexException(
"If data.index is DatetimeIndex, treatment_time must be pd.Timestamp."
)
if not isinstance(data.index, pd.DatetimeIndex) and isinstance(
treatment_time, pd.Timestamp
treatment_time, (pd.Timestamp)
):
raise BadIndexException(
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
Expand Down Expand Up @@ -199,6 +374,7 @@ def _bayesian_plot(

fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
# TOP PLOT --------------------------------------------------

# pre-intervention period
h_line, h_patch = plot_xY(
self.datapre.index,
Expand All @@ -213,6 +389,11 @@ def _bayesian_plot(
handles.append(h)
labels.append("Observations")

# Green line for treated counterfactual (if unknown treatment time)
self.handler.plot_treated_counterfactual(
ax, handles, labels, self.datapost, self.post_pred, self.post_y
)

# post intervention period
h_line, h_patch = plot_xY(
self.datapost.index,
Expand Down Expand Up @@ -277,14 +458,10 @@ def _bayesian_plot(
)
ax[2].axhline(y=0, c="k")

# Intervention line
for i in [0, 1, 2]:
ax[i].axvline(
x=self.treatment_time,
ls="-",
lw=3,
color="r",
)
# Plot vertical line marking treatment time (with HDI if it's inferred)
self.handler.plot_intervention_line(
ax, self.idata, self.datapost, self.treatment_time
)

ax[0].legend(
handles=(h_tuple for h_tuple in handles),
Expand Down Expand Up @@ -429,3 +606,14 @@ def get_plot_data_ols(self) -> pd.DataFrame:
self.plot_data = pd.concat([pre_data, post_data])

return self.plot_data

def plot_treatment_time(self):
"""
display the posterior estimates of the treatment time
"""
if "treatment_time" not in self.idata.posterior.data_vars:
raise ValueError(
"Variable 'treatment_time' not found in inference data (idata)."
)

az.plot_trace(self.idata, var_names="treatment_time")
Loading
Loading