3030)
3131from causalpy .plot_utils import plot_xY
3232from causalpy .pymc_models import PyMCModel
33- from causalpy .utils import _is_variable_dummy_coded , convert_to_string , round_num
33+ from causalpy .utils import (
34+ _is_variable_dummy_coded ,
35+ convert_to_string ,
36+ get_interaction_terms ,
37+ round_num ,
38+ )
3439
3540from .base import BaseExperiment
3641
@@ -52,6 +57,8 @@ class DifferenceInDifferences(BaseExperiment):
5257 Name of the data column for the time variable
5358 :param group_variable_name:
5459 Name of the data column for the group variable
60+ :param post_treatment_variable_name:
61+ Name of the data column indicating post-treatment period (default: "post_treatment")
5562 :param model:
5663 A PyMC model for difference in differences
5764
@@ -84,6 +91,7 @@ def __init__(
8491 formula : str ,
8592 time_variable_name : str ,
8693 group_variable_name : str ,
94+ post_treatment_variable_name : str = "post_treatment" ,
8795 model = None ,
8896 ** kwargs ,
8997 ) -> None :
@@ -95,6 +103,7 @@ def __init__(
95103 self .formula = formula
96104 self .time_variable_name = time_variable_name
97105 self .group_variable_name = group_variable_name
106+ self .post_treatment_variable_name = post_treatment_variable_name
98107 self .input_validation ()
99108
100109 y , X = dmatrices (formula , self .data )
@@ -128,6 +137,12 @@ def __init__(
128137 }
129138 self .model .fit (X = self .X , y = self .y , coords = COORDS )
130139 elif isinstance (self .model , RegressorMixin ):
140+ # For scikit-learn models, automatically set fit_intercept=False
141+ # This ensures the intercept is included in the coefficients array rather than being a separate intercept_ attribute
142+ # without this, the intercept is not included in the coefficients array hence would be displayed as 0 in the model summary
143+ # TODO: later, this should be handled in ScikitLearnAdaptor itself
144+ if hasattr (self .model , "fit_intercept" ):
145+ self .model .fit_intercept = False
131146 self .model .fit (X = self .X , y = self .y )
132147 else :
133148 raise ValueError ("Model type not recognized" )
@@ -173,7 +188,7 @@ def __init__(
173188 # just the treated group
174189 .query (f"{ self .group_variable_name } == 1" )
175190 # just the treatment period(s)
176- .query ("post_treatment == True" )
191+ .query (f" { self . post_treatment_variable_name } == True" )
177192 # drop the outcome variable
178193 .drop (self .outcome_variable_name , axis = 1 )
179194 # We may have multiple units per time point, we only want one time point
@@ -189,7 +204,10 @@ def __init__(
189204 # INTERVENTION: set the interaction term between the group and the
190205 # post_treatment variable to zero. This is the counterfactual.
191206 for i , label in enumerate (self .labels ):
192- if "post_treatment" in label and self .group_variable_name in label :
207+ if (
208+ self .post_treatment_variable_name in label
209+ and self .group_variable_name in label
210+ ):
193211 new_x .iloc [:, i ] = 0
194212 self .y_pred_counterfactual = self .model .predict (np .asarray (new_x ))
195213
@@ -198,31 +216,44 @@ def __init__(
198216 # This is the coefficient on the interaction term
199217 coeff_names = self .model .idata .posterior .coords ["coeffs" ].data
200218 for i , label in enumerate (coeff_names ):
201- if "post_treatment" in label and self .group_variable_name in label :
219+ if (
220+ self .post_treatment_variable_name in label
221+ and self .group_variable_name in label
222+ ):
202223 self .causal_impact = self .model .idata .posterior ["beta" ].isel (
203224 {"coeffs" : i }
204225 )
205226 elif isinstance (self .model , RegressorMixin ):
206227 # This is the coefficient on the interaction term
207- # TODO: CHECK FOR CORRECTNESS
208- self .causal_impact = (
209- self .y_pred_treatment [1 ] - self .y_pred_counterfactual [0 ]
210- ).item ()
228+ # Store the coefficient into dictionary {intercept:value}
229+ coef_map = dict (zip (self .labels , self .model .get_coeffs ()))
230+ # Create and find the interaction term based on the values user provided
231+ interaction_term = (
232+ f"{ self .group_variable_name } :{ self .post_treatment_variable_name } "
233+ )
234+ matched_key = next ((k for k in coef_map if interaction_term in k ), None )
235+ att = coef_map .get (matched_key )
236+ self .causal_impact = att
211237 else :
212238 raise ValueError ("Model type not recognized" )
213239
214240 return
215241
216242 def input_validation (self ):
243+ # Validate formula structure and interaction interaction terms
244+ self ._validate_formula_interaction_terms ()
245+
217246 """Validate the input data and model formula for correctness"""
218- if "post_treatment" not in self .formula :
247+ # Check if post_treatment_variable_name is in formula
248+ if self .post_treatment_variable_name not in self .formula :
219249 raise FormulaException (
220- "A predictor called `post_treatment` should be in the formula"
250+ f"Missing required variable ' { self . post_treatment_variable_name } ' in formula"
221251 )
222252
223- if "post_treatment" not in self .data .columns :
253+ # Check if post_treatment_variable_name is in data columns
254+ if self .post_treatment_variable_name not in self .data .columns :
224255 raise DataException (
225- "Require a boolean column labelling observations which are `treated` "
256+ f"Missing required column ' { self . post_treatment_variable_name } ' in dataset "
226257 )
227258
228259 if "unit" not in self .data .columns :
@@ -236,6 +267,36 @@ def input_validation(self):
236267 coded. Consisting of 0's and 1's only."""
237268 )
238269
270+ def _validate_formula_interaction_terms (self ):
271+ """
272+ Validate that the formula contains at most one interaction term and no three-way or higher-order interactions.
273+ Raises FormulaException if more than one interaction term is found or if any interaction term has more than 2 variables.
274+ """
275+ # Define interaction indicators
276+ INTERACTION_INDICATORS = ["*" , ":" ]
277+
278+ # Get interaction terms
279+ interaction_terms = get_interaction_terms (self .formula )
280+
281+ # Check for interaction terms with more than 2 variables (more than one '*' or ':')
282+ for term in interaction_terms :
283+ total_indicators = sum (
284+ term .count (indicator ) for indicator in INTERACTION_INDICATORS
285+ )
286+ if (
287+ total_indicators >= 2
288+ ): # 3 or more variables (e.g., a*b*c or a:b:c has 2 symbols)
289+ raise FormulaException (
290+ f"Formula contains interaction term with more than 2 variables: { term } . "
291+ "Three-way or higher-order interactions are not supported as they complicate interpretation of the causal effect."
292+ )
293+
294+ if len (interaction_terms ) > 1 :
295+ raise FormulaException (
296+ f"Formula contains { len (interaction_terms )} interaction terms: { interaction_terms } . "
297+ "Multiple interaction terms are not currently supported as they complicate interpretation of the causal effect."
298+ )
299+
239300 def summary (self , round_to = None ) -> None :
240301 """Print summary of main results and model coefficients.
241302
0 commit comments