Skip to content

Commit d92b135

Browse files
roesta07Rojan Shresthadrbenvincent
authored
Refactored scikit-learn flavour of DifferenceInDifferences and allowed custom column names for post_treatment variable. (#515)
* Added post_treatment_variable_name parameter and sklearn model summary for did * Refactor DiD validation: segregate FormulaException and DataException * added validations for interactions, test coverage expanded to test interaction terms,more generic messages * get pre-commit checks to pass * Refactor interaction term extraction in DiD and utils Moved the interaction term extraction logic from DifferenceInDifferences to a new get_interaction_terms utility function. Updated relevant imports and tests to use the new function, improving code reuse and maintainability. * update exception message when we detect more than one interaction term * updates to FormulaException wording --------- Co-authored-by: Rojan Shrestha <[email protected]> Co-authored-by: Benjamin T. Vincent <[email protected]>
1 parent 1d69dee commit d92b135

File tree

5 files changed

+286
-18
lines changed

5 files changed

+286
-18
lines changed

causalpy/experiments/diff_in_diff.py

Lines changed: 73 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,12 @@
3030
)
3131
from causalpy.plot_utils import plot_xY
3232
from causalpy.pymc_models import PyMCModel
33-
from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num
33+
from causalpy.utils import (
34+
_is_variable_dummy_coded,
35+
convert_to_string,
36+
get_interaction_terms,
37+
round_num,
38+
)
3439

3540
from .base import BaseExperiment
3641

@@ -52,6 +57,8 @@ class DifferenceInDifferences(BaseExperiment):
5257
Name of the data column for the time variable
5358
:param group_variable_name:
5459
Name of the data column for the group variable
60+
:param post_treatment_variable_name:
61+
Name of the data column indicating post-treatment period (default: "post_treatment")
5562
:param model:
5663
A PyMC model for difference in differences
5764
@@ -84,6 +91,7 @@ def __init__(
8491
formula: str,
8592
time_variable_name: str,
8693
group_variable_name: str,
94+
post_treatment_variable_name: str = "post_treatment",
8795
model=None,
8896
**kwargs,
8997
) -> None:
@@ -95,6 +103,7 @@ def __init__(
95103
self.formula = formula
96104
self.time_variable_name = time_variable_name
97105
self.group_variable_name = group_variable_name
106+
self.post_treatment_variable_name = post_treatment_variable_name
98107
self.input_validation()
99108

100109
y, X = dmatrices(formula, self.data)
@@ -128,6 +137,12 @@ def __init__(
128137
}
129138
self.model.fit(X=self.X, y=self.y, coords=COORDS)
130139
elif isinstance(self.model, RegressorMixin):
140+
# For scikit-learn models, automatically set fit_intercept=False
141+
# This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute
142+
# without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary
143+
# TODO: later, this should be handled in ScikitLearnAdaptor itself
144+
if hasattr(self.model, "fit_intercept"):
145+
self.model.fit_intercept = False
131146
self.model.fit(X=self.X, y=self.y)
132147
else:
133148
raise ValueError("Model type not recognized")
@@ -173,7 +188,7 @@ def __init__(
173188
# just the treated group
174189
.query(f"{self.group_variable_name} == 1")
175190
# just the treatment period(s)
176-
.query("post_treatment == True")
191+
.query(f"{self.post_treatment_variable_name} == True")
177192
# drop the outcome variable
178193
.drop(self.outcome_variable_name, axis=1)
179194
# We may have multiple units per time point, we only want one time point
@@ -189,7 +204,10 @@ def __init__(
189204
# INTERVENTION: set the interaction term between the group and the
190205
# post_treatment variable to zero. This is the counterfactual.
191206
for i, label in enumerate(self.labels):
192-
if "post_treatment" in label and self.group_variable_name in label:
207+
if (
208+
self.post_treatment_variable_name in label
209+
and self.group_variable_name in label
210+
):
193211
new_x.iloc[:, i] = 0
194212
self.y_pred_counterfactual = self.model.predict(np.asarray(new_x))
195213

@@ -198,31 +216,44 @@ def __init__(
198216
# This is the coefficient on the interaction term
199217
coeff_names = self.model.idata.posterior.coords["coeffs"].data
200218
for i, label in enumerate(coeff_names):
201-
if "post_treatment" in label and self.group_variable_name in label:
219+
if (
220+
self.post_treatment_variable_name in label
221+
and self.group_variable_name in label
222+
):
202223
self.causal_impact = self.model.idata.posterior["beta"].isel(
203224
{"coeffs": i}
204225
)
205226
elif isinstance(self.model, RegressorMixin):
206227
# This is the coefficient on the interaction term
207-
# TODO: CHECK FOR CORRECTNESS
208-
self.causal_impact = (
209-
self.y_pred_treatment[1] - self.y_pred_counterfactual[0]
210-
).item()
228+
# Store the coefficient into dictionary {intercept:value}
229+
coef_map = dict(zip(self.labels, self.model.get_coeffs()))
230+
# Create and find the interaction term based on the values user provided
231+
interaction_term = (
232+
f"{self.group_variable_name}:{self.post_treatment_variable_name}"
233+
)
234+
matched_key = next((k for k in coef_map if interaction_term in k), None)
235+
att = coef_map.get(matched_key)
236+
self.causal_impact = att
211237
else:
212238
raise ValueError("Model type not recognized")
213239

214240
return
215241

216242
def input_validation(self):
243+
# Validate formula structure and interaction interaction terms
244+
self._validate_formula_interaction_terms()
245+
217246
"""Validate the input data and model formula for correctness"""
218-
if "post_treatment" not in self.formula:
247+
# Check if post_treatment_variable_name is in formula
248+
if self.post_treatment_variable_name not in self.formula:
219249
raise FormulaException(
220-
"A predictor called `post_treatment` should be in the formula"
250+
f"Missing required variable '{self.post_treatment_variable_name}' in formula"
221251
)
222252

223-
if "post_treatment" not in self.data.columns:
253+
# Check if post_treatment_variable_name is in data columns
254+
if self.post_treatment_variable_name not in self.data.columns:
224255
raise DataException(
225-
"Require a boolean column labelling observations which are `treated`"
256+
f"Missing required column '{self.post_treatment_variable_name}' in dataset"
226257
)
227258

228259
if "unit" not in self.data.columns:
@@ -236,6 +267,36 @@ def input_validation(self):
236267
coded. Consisting of 0's and 1's only."""
237268
)
238269

270+
def _validate_formula_interaction_terms(self):
271+
"""
272+
Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
273+
Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
274+
"""
275+
# Define interaction indicators
276+
INTERACTION_INDICATORS = ["*", ":"]
277+
278+
# Get interaction terms
279+
interaction_terms = get_interaction_terms(self.formula)
280+
281+
# Check for interaction terms with more than 2 variables (more than one '*' or ':')
282+
for term in interaction_terms:
283+
total_indicators = sum(
284+
term.count(indicator) for indicator in INTERACTION_INDICATORS
285+
)
286+
if (
287+
total_indicators >= 2
288+
): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols)
289+
raise FormulaException(
290+
f"Formula contains interaction term with more than 2 variables: {term}. "
291+
"Three-way or higher-order interactions are not supported as they complicate interpretation of the causal effect."
292+
)
293+
294+
if len(interaction_terms) > 1:
295+
raise FormulaException(
296+
f"Formula contains {len(interaction_terms)} interaction terms: {interaction_terms}. "
297+
"Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
298+
)
299+
239300
def summary(self, round_to=None) -> None:
240301
"""Print summary of main results and model coefficients.
241302

causalpy/tests/test_input_validation.py

Lines changed: 117 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,18 +30,29 @@
3030

3131

3232
def test_did_validation_post_treatment_formula():
33-
"""Test that we get a FormulaException if do not include post_treatment in the
34-
formula"""
33+
"""Test that we get a FormulaException for invalid formulas and missing post_treatment variables"""
3534
df = pd.DataFrame(
3635
{
3736
"group": [0, 0, 1, 1],
3837
"t": [0, 1, 0, 1],
3938
"unit": [0, 0, 1, 1],
4039
"post_treatment": [0, 1, 0, 1],
40+
"male": [0, 1, 0, 1], # Additional variable for testing
4141
"y": [1, 2, 3, 4],
4242
}
4343
)
4444

45+
df_with_custom = pd.DataFrame(
46+
{
47+
"group": [0, 0, 1, 1],
48+
"t": [0, 1, 0, 1],
49+
"unit": [0, 0, 1, 1],
50+
"custom_post": [0, 1, 0, 1], # Custom column name
51+
"y": [1, 2, 3, 4],
52+
}
53+
)
54+
55+
# Test 1: Missing post_treatment variable in formula
4556
with pytest.raises(FormulaException):
4657
_ = cp.DifferenceInDifferences(
4758
df,
@@ -51,6 +62,7 @@ def test_did_validation_post_treatment_formula():
5162
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
5263
)
5364

65+
# Test 2: Missing post_treatment variable in formula (duplicate test)
5466
with pytest.raises(FormulaException):
5567
_ = cp.DifferenceInDifferences(
5668
df,
@@ -60,6 +72,88 @@ def test_did_validation_post_treatment_formula():
6072
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
6173
)
6274

75+
# Test 3: Custom post_treatment_variable_name but formula uses different name
76+
with pytest.raises(FormulaException):
77+
_ = cp.DifferenceInDifferences(
78+
df_with_custom,
79+
formula="y ~ 1 + group*post_treatment", # Formula uses 'post_treatment'
80+
time_variable_name="t",
81+
group_variable_name="group",
82+
post_treatment_variable_name="custom_post", # But user specifies 'custom_post'
83+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
84+
)
85+
86+
# Test 4: Default post_treatment_variable_name but formula uses different name
87+
with pytest.raises(FormulaException):
88+
_ = cp.DifferenceInDifferences(
89+
df,
90+
formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post'
91+
time_variable_name="t",
92+
group_variable_name="group",
93+
# post_treatment_variable_name defaults to "post_treatment"
94+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
95+
)
96+
97+
# Test 5: Repeated interaction terms (should be invalid)
98+
with pytest.raises(FormulaException):
99+
_ = cp.DifferenceInDifferences(
100+
df,
101+
formula="y ~ 1 + group + group*post_treatment + group*post_treatment",
102+
time_variable_name="t",
103+
group_variable_name="group",
104+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
105+
)
106+
107+
# Test 6: Three-way interactions using * (should be invalid)
108+
with pytest.raises(FormulaException):
109+
_ = cp.DifferenceInDifferences(
110+
df,
111+
formula="y ~ 1 + group + group*post_treatment*male",
112+
time_variable_name="t",
113+
group_variable_name="group",
114+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
115+
)
116+
117+
# Test 7: Three-way interactions using : (should be invalid)
118+
with pytest.raises(FormulaException):
119+
_ = cp.DifferenceInDifferences(
120+
df,
121+
formula="y ~ 1 + group + group:post_treatment:male",
122+
time_variable_name="t",
123+
group_variable_name="group",
124+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
125+
)
126+
127+
# Test 8: Multiple different interaction terms using * (should be invalid)
128+
with pytest.raises(FormulaException):
129+
_ = cp.DifferenceInDifferences(
130+
df,
131+
formula="y ~ 1 + group + group*post_treatment + group*male",
132+
time_variable_name="t",
133+
group_variable_name="group",
134+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
135+
)
136+
137+
# Test 9: Multiple different interaction terms using : (should be invalid)
138+
with pytest.raises(FormulaException):
139+
_ = cp.DifferenceInDifferences(
140+
df,
141+
formula="y ~ 1 + group + group:post_treatment + group:male",
142+
time_variable_name="t",
143+
group_variable_name="group",
144+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
145+
)
146+
147+
# Test 10: Mixed issues - multiple terms + three-way interaction (should be invalid)
148+
with pytest.raises(FormulaException):
149+
_ = cp.DifferenceInDifferences(
150+
df,
151+
formula="y ~ 1 + group + group*post_treatment + group:post_treatment:male",
152+
time_variable_name="t",
153+
group_variable_name="group",
154+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
155+
)
156+
63157

64158
def test_did_validation_post_treatment_data():
65159
"""Test that we get a DataException if do not include post_treatment in the data"""
@@ -91,6 +185,27 @@ def test_did_validation_post_treatment_data():
91185
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
92186
)
93187

188+
# Test 2: Custom post_treatment_variable_name but column doesn't exist in data
189+
df_with_post = pd.DataFrame(
190+
{
191+
"group": [0, 0, 1, 1],
192+
"t": [0, 1, 0, 1],
193+
"unit": [0, 0, 1, 1],
194+
"post_treatment": [0, 1, 0, 1], # Data has 'post_treatment'
195+
"y": [1, 2, 3, 4],
196+
}
197+
)
198+
199+
with pytest.raises(DataException):
200+
_ = cp.DifferenceInDifferences(
201+
df_with_post,
202+
formula="y ~ 1 + group*custom_post", # Formula uses 'custom_post'
203+
time_variable_name="t",
204+
group_variable_name="group",
205+
post_treatment_variable_name="custom_post", # User specifies 'custom_post'
206+
model=cp.pymc_models.LinearRegression(sample_kwargs=sample_kwargs),
207+
)
208+
94209

95210
def test_did_validation_unit_data():
96211
"""Test that we get a DataException if do not include unit in the data"""

causalpy/tests/test_utils.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@
1717

1818
import pandas as pd
1919

20-
from causalpy.utils import _is_variable_dummy_coded, _series_has_2_levels, round_num
20+
from causalpy.utils import (
21+
_is_variable_dummy_coded,
22+
_series_has_2_levels,
23+
get_interaction_terms,
24+
round_num,
25+
)
2126

2227

2328
def test_dummy_coding():
@@ -57,3 +62,43 @@ def test_round_num():
5762
assert round_num(123.456, 5) == "123.46"
5863
assert round_num(123.456, 6) == "123.456"
5964
assert round_num(123.456, 7) == "123.456"
65+
66+
67+
def test_get_interaction_terms():
68+
"""Test if the function to extract interaction terms from formulas works correctly"""
69+
# No interaction terms
70+
assert get_interaction_terms("y ~ x1 + x2 + x3") == []
71+
assert get_interaction_terms("y ~ 1 + x1 + x2") == []
72+
73+
# Single interaction term with '*'
74+
assert get_interaction_terms("y ~ x1 + x2*x3") == ["x2*x3"]
75+
assert get_interaction_terms("y ~ 1 + group*post_treatment") == [
76+
"group*post_treatment"
77+
]
78+
79+
# Single interaction term with ':'
80+
assert get_interaction_terms("y ~ x1 + x2:x3") == ["x2:x3"]
81+
assert get_interaction_terms("y ~ 1 + group:post_treatment") == [
82+
"group:post_treatment"
83+
]
84+
85+
# Multiple interaction terms
86+
assert get_interaction_terms("y ~ x1*x2 + x3*x4") == ["x1*x2", "x3*x4"]
87+
assert get_interaction_terms("y ~ a:b + c*d") == ["a:b", "c*d"]
88+
89+
# Three-way interaction
90+
assert get_interaction_terms("y ~ x1*x2*x3") == ["x1*x2*x3"]
91+
assert get_interaction_terms("y ~ a:b:c") == ["a:b:c"]
92+
93+
# Formula with spaces (should be handled correctly)
94+
assert get_interaction_terms("y ~ x1 + x2 * x3") == ["x2*x3"]
95+
assert get_interaction_terms("y ~ 1 + group * post_treatment") == [
96+
"group*post_treatment"
97+
]
98+
99+
# Mixed main effects and interactions
100+
assert get_interaction_terms("y ~ 1 + x1 + x2 + x1*x2") == ["x1*x2"]
101+
assert get_interaction_terms("y ~ x1 + x2*x3 + x4") == ["x2*x3"]
102+
103+
# Formula with subtraction (edge case)
104+
assert get_interaction_terms("y ~ x1*x2 - x3") == ["x1*x2"]

0 commit comments

Comments
 (0)