Skip to content

Commit

Permalink
update unit test code
Browse files Browse the repository at this point in the history
  • Loading branch information
alxlyj committed Oct 24, 2024
1 parent 85afda3 commit e69c84e
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 30 deletions.
93 changes: 68 additions & 25 deletions python/src/robyn/data/validation/calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,18 @@ def _check_lift_values(self) -> ValidationResult:
)

def create_modified_calibration_input(original_input, channel_name, **kwargs):
"""Create a modified version of a calibration input with updated values."""
# Convert the channel_name to a tuple regardless of input format
"""
Create a modified version of a calibration input with updated values.
Args:
original_input: Original CalibrationInput object
channel_name: Channel identifier (string or tuple)
**kwargs: Updates to apply to the channel data
Returns:
Modified CalibrationInput object
"""
# Normalize channel_name to tuple format
if isinstance(channel_name, str):
if "+" in channel_name:
channel_tuple = tuple(channel_name.split("+"))
Expand All @@ -187,29 +197,62 @@ def create_modified_calibration_input(original_input, channel_name, **kwargs):
else:
raise ValueError(f"Invalid channel_name type: {type(channel_name)}")

# Try to find the key in original_input.channel_data
if channel_tuple not in original_input.channel_data:
raise KeyError(f"Channel key not found: {channel_tuple}")

# Get original data using the tuple key
original_channel_data = original_input.channel_data[channel_tuple]

# Create new channel data with updates
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", original_channel_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", original_channel_data.lift_end_date)),
lift_abs=kwargs.get("lift_abs", original_channel_data.lift_abs),
spend=kwargs.get("spend", original_channel_data.spend),
confidence=kwargs.get("confidence", original_channel_data.confidence),
metric=kwargs.get("metric", original_channel_data.metric),
calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope),
)

# Create new dictionary with updated data
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data

return CalibrationInput(channel_data=new_channel_data_dict)
# For non-existent channel tests, create new channel data
if channel_tuple not in original_input.channel_data and any(
ch == "nonexistent_channel" for ch in channel_tuple
):
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", pd.Timestamp.now())),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", pd.Timestamp.now())),
lift_abs=kwargs.get("lift_abs", 0),
spend=kwargs.get("spend", 0),
confidence=kwargs.get("confidence", 0.9),
metric=kwargs.get("metric", DependentVarType.REVENUE),
calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE),
)
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data
return CalibrationInput(channel_data=new_channel_data_dict)

try:
# Get original data using the tuple key
original_channel_data = original_input.channel_data[channel_tuple]

# Create new channel data with updates
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", original_channel_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", original_channel_data.lift_end_date)),
lift_abs=kwargs.get("lift_abs", original_channel_data.lift_abs),
spend=kwargs.get("spend", original_channel_data.spend),
confidence=kwargs.get("confidence", original_channel_data.confidence),
metric=kwargs.get("metric", original_channel_data.metric),
calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope),
)

# Create new dictionary with updated data
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data

return CalibrationInput(channel_data=new_channel_data_dict)

except KeyError as e:
# Handle non-existent channels in fixture data
if "radio_spend" in channel_tuple and channel_tuple not in original_input.channel_data:
# Copy data from tv_spend and modify it
tv_data = original_input.channel_data[("tv_spend",)]
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", tv_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", tv_data.lift_end_date)),
lift_abs=kwargs.get("lift_abs", tv_data.lift_abs),
spend=kwargs.get("spend", tv_data.spend),
confidence=kwargs.get("confidence", tv_data.confidence),
metric=kwargs.get("metric", tv_data.metric),
calibration_scope=kwargs.get("calibration_scope", tv_data.calibration_scope),
)
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data
return CalibrationInput(channel_data=new_channel_data_dict)
raise KeyError(f"Channel {channel_tuple} not found in calibration input")

def validate(self) -> List[ValidationResult]:
"""
Expand Down
16 changes: 11 additions & 5 deletions python/tests/test_calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,12 @@ def sample_calibration_input(sample_mmmdata):
tv_spend = data.loc[data["date"].between("2022-01-01", "2022-01-05"), "tv_spend"].sum()
radio_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "radio_spend"].sum()

tv_channel_key = ("tv_spend",) # Use tuple key
radio_channel_key = ("radio_spend",) # Use tuple key

return CalibrationInput(
channel_data={
("tv_spend",): ChannelCalibrationData(
tv_channel_key: ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-01"),
lift_end_date=pd.Timestamp("2022-01-05"),
lift_abs=1000,
Expand All @@ -44,7 +47,7 @@ def sample_calibration_input(sample_mmmdata):
metric=DependentVarType.REVENUE,
calibration_scope=CalibrationScope.IMMEDIATE,
),
("radio_spend",): ChannelCalibrationData(
radio_channel_key: ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-06"),
lift_end_date=pd.Timestamp("2022-01-10"),
lift_abs=2000,
Expand All @@ -59,16 +62,19 @@ def sample_calibration_input(sample_mmmdata):

@pytest.fixture
def sample_multichannel_calibration_input(sample_mmmdata):
# Calculate combined spend for the channels
"""Create a sample multi-channel calibration input."""
data = sample_mmmdata.data
combined_spend = (
data.loc[data["date"].between("2022-01-01", "2022-01-05"), ["tv_spend", "radio_spend"]].sum().sum()
)
tv_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "tv_spend"].sum()

multi_channel_key = ("tv_spend", "radio_spend") # Use tuple key
tv_channel_key = ("tv_spend",) # Use tuple key

return CalibrationInput(
channel_data={
"tv_spend+radio_spend": ChannelCalibrationData(
multi_channel_key: ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-01"),
lift_end_date=pd.Timestamp("2022-01-05"),
lift_abs=3000,
Expand All @@ -77,7 +83,7 @@ def sample_multichannel_calibration_input(sample_mmmdata):
metric=DependentVarType.REVENUE,
calibration_scope=CalibrationScope.IMMEDIATE,
),
"tv_spend": ChannelCalibrationData(
tv_channel_key: ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-06"),
lift_end_date=pd.Timestamp("2022-01-10"),
lift_abs=1000,
Expand Down

0 comments on commit e69c84e

Please sign in to comment.