Skip to content

Commit

Permalink
update unit test code
Browse files Browse the repository at this point in the history
  • Loading branch information
alxlyj committed Oct 24, 2024
1 parent d5b5c1a commit 85afda3
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 48 deletions.
106 changes: 63 additions & 43 deletions python/src/robyn/data/validation/calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,51 +174,71 @@ def _check_lift_values(self) -> ValidationResult:
status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages)
)

def create_modified_calibration_input(original_input, channel_name, **kwargs):
"""Create a modified version of a calibration input with updated values."""
# Convert the channel_name to a tuple regardless of input format
if isinstance(channel_name, str):
if "+" in channel_name:
channel_tuple = tuple(channel_name.split("+"))
else:
channel_tuple = (channel_name,)
elif isinstance(channel_name, tuple):
channel_tuple = channel_name
else:
raise ValueError(f"Invalid channel_name type: {type(channel_name)}")

def create_modified_calibration_input(original_input, channel_name, **kwargs):
"""Create a modified version of a calibration input with updated values."""
# Convert the channel_name to the format stored in channel_data
if isinstance(channel_name, str):
channel_tuple = (channel_name,)
elif isinstance(channel_name, tuple):
channel_tuple = channel_name
else:
raise ValueError(f"Invalid channel_name type: {type(channel_name)}")

# Handle multi-channel cases
if len(channel_tuple) > 1:
# For multi-channel cases, check if it exists in original_input
# Try to find the key in original_input.channel_data
if channel_tuple not in original_input.channel_data:
# Try the '+' joined version
channel_key = "+".join(channel_tuple)
if channel_key not in original_input.channel_data:
raise KeyError(f"Channel combination not found: {channel_tuple}")
else:
channel_key = channel_tuple
else:
# For single channel cases
channel_key = channel_tuple

# Get original data
original_channel_data = original_input.channel_data[channel_key]

# Create new channel data with updates
new_channel_data = ChannelCalibrationData(
lift_start_date=kwargs.get("lift_start_date", original_channel_data.lift_start_date),
lift_end_date=kwargs.get("lift_end_date", original_channel_data.lift_end_date),
lift_abs=kwargs.get("lift_abs", original_channel_data.lift_abs),
spend=kwargs.get("spend", original_channel_data.spend),
confidence=kwargs.get("confidence", original_channel_data.confidence),
metric=kwargs.get("metric", original_channel_data.metric),
calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope),
)

# Create new dictionary with updated data
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_key] = new_channel_data

return CalibrationInput(channel_data=new_channel_data_dict)
raise KeyError(f"Channel key not found: {channel_tuple}")

# Get original data using the tuple key
original_channel_data = original_input.channel_data[channel_tuple]

# Create new channel data with updates
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", original_channel_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", original_channel_data.lift_end_date)),
lift_abs=kwargs.get("lift_abs", original_channel_data.lift_abs),
spend=kwargs.get("spend", original_channel_data.spend),
confidence=kwargs.get("confidence", original_channel_data.confidence),
metric=kwargs.get("metric", original_channel_data.metric),
calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope),
)

# Create new dictionary with updated data
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data

return CalibrationInput(channel_data=new_channel_data_dict)

def validate(self) -> List[ValidationResult]:
"""Run all validations."""
"""
Implement the abstract validate method from the Validation base class.
Returns a list containing the calibration validation result.
"""
return [self.check_calibration()]

def check_calibration(self) -> ValidationResult:
"""Check all calibration inputs for consistency and correctness."""
if self.calibration_input is None:
return ValidationResult(status=True, error_details={}, error_message="")

error_details = {}
error_messages = []

checks = [
self._check_date_range(),
self._check_spend_values(),
self._check_metric_values(),
self._check_confidence_values(),
self._check_lift_values(),
]

for result in checks:
if not result.status:
error_details.update(result.error_details)
error_messages.append(result.error_message)

return ValidationResult(
status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages)
)
10 changes: 5 additions & 5 deletions python/tests/test_calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,27 @@ def sample_mmmdata():

@pytest.fixture
def sample_calibration_input(sample_mmmdata):
# Calculate actual spends from the data
"""Create a sample calibration input with actual spend values."""
data = sample_mmmdata.data
tv_spend = data.loc[data["date"].between("2022-01-01", "2022-01-05"), "tv_spend"].sum()
radio_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "radio_spend"].sum()

return CalibrationInput(
channel_data={
"tv_spend": ChannelCalibrationData(
("tv_spend",): ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-01"),
lift_end_date=pd.Timestamp("2022-01-05"),
lift_abs=1000,
spend=tv_spend, # Use actual spend
spend=tv_spend,
confidence=0.9,
metric=DependentVarType.REVENUE,
calibration_scope=CalibrationScope.IMMEDIATE,
),
"radio_spend": ChannelCalibrationData(
("radio_spend",): ChannelCalibrationData(
lift_start_date=pd.Timestamp("2022-01-06"),
lift_end_date=pd.Timestamp("2022-01-10"),
lift_abs=2000,
spend=radio_spend, # Use actual spend
spend=radio_spend,
confidence=0.85,
metric=DependentVarType.REVENUE,
calibration_scope=CalibrationScope.IMMEDIATE,
Expand Down

0 comments on commit 85afda3

Please sign in to comment.