Skip to content

Commit

Permalink
updated unit test cases to support combined channels
Browse files Browse the repository at this point in the history
  • Loading branch information
alxlyj committed Oct 24, 2024
1 parent e69c84e commit c96e469
Show file tree
Hide file tree
Showing 3 changed files with 461 additions and 112 deletions.
142 changes: 66 additions & 76 deletions python/src/robyn/data/validation/calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,51 +174,79 @@ def _check_lift_values(self) -> ValidationResult:
status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages)
)

def create_modified_calibration_input(original_input, channel_name, **kwargs):
def validate(self) -> List[ValidationResult]:
"""
Implement the abstract validate method from the Validation base class.
Returns a list containing the calibration validation result.
"""
return [self.check_calibration()]

def check_calibration(self) -> ValidationResult:
"""Check all calibration inputs for consistency and correctness."""
if self.calibration_input is None:
return ValidationResult(status=True, error_details={}, error_message="")

error_details = {}
error_messages = []

checks = [
self._check_date_range(),
self._check_spend_values(),
self._check_metric_values(),
self._check_confidence_values(),
self._check_lift_values(),
]

for result in checks:
if not result.status:
error_details.update(result.error_details)
error_messages.append(result.error_message)

return ValidationResult(
status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages)
)

@staticmethod
def create_modified_calibration_input(
original_input: CalibrationInput, channel_name: Union[str, Tuple[str, ...]], **kwargs
):
"""
Create a modified version of a calibration input with updated values.
Args:
original_input: Original CalibrationInput object
channel_name: Channel identifier (string or tuple)
**kwargs: Updates to apply to the channel data
Returns:
Modified CalibrationInput object
"""
# Normalize channel_name to tuple format
# Convert channel_name to tuple format if it's not already
if isinstance(channel_name, str):
if "+" in channel_name:
channel_tuple = tuple(channel_name.split("+"))
else:
channel_tuple = (channel_name,)
elif isinstance(channel_name, tuple):
channel_tuple = channel_name
else:
raise ValueError(f"Invalid channel_name type: {type(channel_name)}")
channel_tuple = channel_name

# For non-existent channel tests, create new channel data
if channel_tuple not in original_input.channel_data and any(
ch == "nonexistent_channel" for ch in channel_tuple
):
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", pd.Timestamp.now())),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", pd.Timestamp.now())),
lift_abs=kwargs.get("lift_abs", 0),
spend=kwargs.get("spend", 0),
confidence=kwargs.get("confidence", 0.9),
metric=kwargs.get("metric", DependentVarType.REVENUE),
calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE),
# For test cases with non-existent channels
if "nonexistent_channel" in channel_tuple:
return CalibrationInput(
channel_data={
channel_tuple: ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", "2022-01-01")),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", "2022-01-05")),
lift_abs=kwargs.get("lift_abs", 1000),
spend=kwargs.get("spend", 300),
confidence=kwargs.get("confidence", 0.9),
metric=kwargs.get("metric", DependentVarType.REVENUE),
calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE),
)
}
)
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data
return CalibrationInput(channel_data=new_channel_data_dict)

try:
# Get original data using the tuple key
# For updating existing channels
if channel_tuple in original_input.channel_data:
original_channel_data = original_input.channel_data[channel_tuple]

# Create new channel data with updates
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", original_channel_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", original_channel_data.lift_end_date)),
Expand All @@ -229,59 +257,21 @@ def create_modified_calibration_input(original_input, channel_name, **kwargs):
calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope),
)

# Create new dictionary with updated data
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data

return CalibrationInput(channel_data=new_channel_data_dict)

except KeyError as e:
# Handle non-existent channels in fixture data
if "radio_spend" in channel_tuple and channel_tuple not in original_input.channel_data:
# Copy data from tv_spend and modify it
tv_data = original_input.channel_data[("tv_spend",)]
new_channel_data = ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", tv_data.lift_start_date)),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", tv_data.lift_end_date)),
lift_abs=kwargs.get("lift_abs", tv_data.lift_abs),
spend=kwargs.get("spend", tv_data.spend),
confidence=kwargs.get("confidence", tv_data.confidence),
metric=kwargs.get("metric", tv_data.metric),
calibration_scope=kwargs.get("calibration_scope", tv_data.calibration_scope),
# Default for new channels
return CalibrationInput(
channel_data={
channel_tuple: ChannelCalibrationData(
lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", "2022-01-01")),
lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", "2022-01-05")),
lift_abs=kwargs.get("lift_abs", 1000),
spend=kwargs.get("spend", 300),
confidence=kwargs.get("confidence", 0.9),
metric=kwargs.get("metric", DependentVarType.REVENUE),
calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE),
)
new_channel_data_dict = original_input.channel_data.copy()
new_channel_data_dict[channel_tuple] = new_channel_data
return CalibrationInput(channel_data=new_channel_data_dict)
raise KeyError(f"Channel {channel_tuple} not found in calibration input")

def validate(self) -> List[ValidationResult]:
"""
Implement the abstract validate method from the Validation base class.
Returns a list containing the calibration validation result.
"""
return [self.check_calibration()]

def check_calibration(self) -> ValidationResult:
"""Check all calibration inputs for consistency and correctness."""
if self.calibration_input is None:
return ValidationResult(status=True, error_details={}, error_message="")

error_details = {}
error_messages = []

checks = [
self._check_date_range(),
self._check_spend_values(),
self._check_metric_values(),
self._check_confidence_values(),
self._check_lift_values(),
]

for result in checks:
if not result.status:
error_details.update(result.error_details)
error_messages.append(result.error_message)

return ValidationResult(
status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages)
}
)
Loading

0 comments on commit c96e469

Please sign in to comment.