From c96e4697b398a326c07f0dff9c3f23a02b02c380 Mon Sep 17 00:00:00 2001 From: alxlyj Date: Thu, 24 Oct 2024 02:57:59 -0400 Subject: [PATCH] updated unit test cases to support combined channels --- .../calibration_input_validation.py | 142 +++---- .../test_calibration_input_validation.py | 392 ++++++++++++++++++ .../test_calibration_input_validation.py | 39 +- 3 files changed, 461 insertions(+), 112 deletions(-) create mode 100644 python/src/robyn/data/validation/test_calibration_input_validation.py diff --git a/python/src/robyn/data/validation/calibration_input_validation.py b/python/src/robyn/data/validation/calibration_input_validation.py index dc78fc780..f6de2ca54 100644 --- a/python/src/robyn/data/validation/calibration_input_validation.py +++ b/python/src/robyn/data/validation/calibration_input_validation.py @@ -174,7 +174,42 @@ def _check_lift_values(self) -> ValidationResult: status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) ) - def create_modified_calibration_input(original_input, channel_name, **kwargs): + def validate(self) -> List[ValidationResult]: + """ + Implement the abstract validate method from the Validation base class. + Returns a list containing the calibration validation result. + """ + return [self.check_calibration()] + + def check_calibration(self) -> ValidationResult: + """Check all calibration inputs for consistency and correctness.""" + if self.calibration_input is None: + return ValidationResult(status=True, error_details={}, error_message="") + + error_details = {} + error_messages = [] + + checks = [ + self._check_date_range(), + self._check_spend_values(), + self._check_metric_values(), + self._check_confidence_values(), + self._check_lift_values(), + ] + + for result in checks: + if not result.status: + error_details.update(result.error_details) + error_messages.append(result.error_message) + + return ValidationResult( + status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) + ) + + @staticmethod + def create_modified_calibration_input( + original_input: CalibrationInput, channel_name: Union[str, Tuple[str, ...]], **kwargs + ): """ Create a modified version of a calibration input with updated values. @@ -182,43 +217,36 @@ def create_modified_calibration_input(original_input, channel_name, **kwargs): original_input: Original CalibrationInput object channel_name: Channel identifier (string or tuple) **kwargs: Updates to apply to the channel data - - Returns: - Modified CalibrationInput object """ - # Normalize channel_name to tuple format + # Convert channel_name to tuple format if it's not already if isinstance(channel_name, str): if "+" in channel_name: channel_tuple = tuple(channel_name.split("+")) else: channel_tuple = (channel_name,) - elif isinstance(channel_name, tuple): - channel_tuple = channel_name else: - raise ValueError(f"Invalid channel_name type: {type(channel_name)}") + channel_tuple = channel_name - # For non-existent channel tests, create new channel data - if channel_tuple not in original_input.channel_data and any( - ch == "nonexistent_channel" for ch in channel_tuple - ): - new_channel_data = ChannelCalibrationData( - lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", pd.Timestamp.now())), - lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", pd.Timestamp.now())), - lift_abs=kwargs.get("lift_abs", 0), - spend=kwargs.get("spend", 0), - confidence=kwargs.get("confidence", 0.9), - metric=kwargs.get("metric", DependentVarType.REVENUE), - calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE), + # For test cases with non-existent channels + if "nonexistent_channel" in channel_tuple: + return CalibrationInput( + channel_data={ + channel_tuple: ChannelCalibrationData( + lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", "2022-01-01")), + lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", "2022-01-05")), + lift_abs=kwargs.get("lift_abs", 1000), + spend=kwargs.get("spend", 300), + confidence=kwargs.get("confidence", 0.9), + metric=kwargs.get("metric", DependentVarType.REVENUE), + calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE), + ) + } ) - new_channel_data_dict = original_input.channel_data.copy() - new_channel_data_dict[channel_tuple] = new_channel_data - return CalibrationInput(channel_data=new_channel_data_dict) - try: - # Get original data using the tuple key + # For updating existing channels + if channel_tuple in original_input.channel_data: original_channel_data = original_input.channel_data[channel_tuple] - # Create new channel data with updates new_channel_data = ChannelCalibrationData( lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", original_channel_data.lift_start_date)), lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", original_channel_data.lift_end_date)), @@ -229,59 +257,21 @@ def create_modified_calibration_input(original_input, channel_name, **kwargs): calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope), ) - # Create new dictionary with updated data new_channel_data_dict = original_input.channel_data.copy() new_channel_data_dict[channel_tuple] = new_channel_data - return CalibrationInput(channel_data=new_channel_data_dict) - except KeyError as e: - # Handle non-existent channels in fixture data - if "radio_spend" in channel_tuple and channel_tuple not in original_input.channel_data: - # Copy data from tv_spend and modify it - tv_data = original_input.channel_data[("tv_spend",)] - new_channel_data = ChannelCalibrationData( - lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", tv_data.lift_start_date)), - lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", tv_data.lift_end_date)), - lift_abs=kwargs.get("lift_abs", tv_data.lift_abs), - spend=kwargs.get("spend", tv_data.spend), - confidence=kwargs.get("confidence", tv_data.confidence), - metric=kwargs.get("metric", tv_data.metric), - calibration_scope=kwargs.get("calibration_scope", tv_data.calibration_scope), + # Default for new channels + return CalibrationInput( + channel_data={ + channel_tuple: ChannelCalibrationData( + lift_start_date=pd.Timestamp(kwargs.get("lift_start_date", "2022-01-01")), + lift_end_date=pd.Timestamp(kwargs.get("lift_end_date", "2022-01-05")), + lift_abs=kwargs.get("lift_abs", 1000), + spend=kwargs.get("spend", 300), + confidence=kwargs.get("confidence", 0.9), + metric=kwargs.get("metric", DependentVarType.REVENUE), + calibration_scope=kwargs.get("calibration_scope", CalibrationScope.IMMEDIATE), ) - new_channel_data_dict = original_input.channel_data.copy() - new_channel_data_dict[channel_tuple] = new_channel_data - return CalibrationInput(channel_data=new_channel_data_dict) - raise KeyError(f"Channel {channel_tuple} not found in calibration input") - - def validate(self) -> List[ValidationResult]: - """ - Implement the abstract validate method from the Validation base class. - Returns a list containing the calibration validation result. - """ - return [self.check_calibration()] - - def check_calibration(self) -> ValidationResult: - """Check all calibration inputs for consistency and correctness.""" - if self.calibration_input is None: - return ValidationResult(status=True, error_details={}, error_message="") - - error_details = {} - error_messages = [] - - checks = [ - self._check_date_range(), - self._check_spend_values(), - self._check_metric_values(), - self._check_confidence_values(), - self._check_lift_values(), - ] - - for result in checks: - if not result.status: - error_details.update(result.error_details) - error_messages.append(result.error_message) - - return ValidationResult( - status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) + } ) diff --git a/python/src/robyn/data/validation/test_calibration_input_validation.py b/python/src/robyn/data/validation/test_calibration_input_validation.py new file mode 100644 index 000000000..fc0d8bc2a --- /dev/null +++ b/python/src/robyn/data/validation/test_calibration_input_validation.py @@ -0,0 +1,392 @@ +import pytest +import pandas as pd +from datetime import datetime, timedelta +from robyn.data.entities.calibration_input import CalibrationInput, ChannelCalibrationData +from robyn.data.entities.mmmdata import MMMData +from robyn.data.validation.calibration_input_validation import CalibrationInputValidation +from robyn.data.entities.enums import DependentVarType, CalibrationScope + + +@pytest.fixture +def sample_mmmdata(): + data = pd.DataFrame( + { + "date": pd.date_range(start="2022-01-01", periods=10), + "revenue": [100, 120, 110, 130, 140, 150, 160, 170, 180, 190], + "tv_spend": [50, 60, 55, 65, 70, 75, 80, 85, 90, 95], + "radio_spend": [30, 35, 32, 38, 40, 42, 45, 48, 50, 52], + "temperature": [20, 22, 21, 23, 24, 25, 26, 27, 28, 29], + } + ) + + mmm_data_spec = MMMData.MMMDataSpec( + dep_var="revenue", date_var="date", paid_media_spends=["tv_spend", "radio_spend"], context_vars=["temperature"] + ) + + return MMMData(data, mmm_data_spec) + + +@pytest.fixture +def sample_calibration_input(sample_mmmdata): + """Create a sample calibration input with actual spend values.""" + data = sample_mmmdata.data + tv_spend = data.loc[data["date"].between("2022-01-01", "2022-01-05"), "tv_spend"].sum() + radio_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "radio_spend"].sum() + + tv_channel_key = ("tv_spend",) + radio_channel_key = ("radio_spend",) + + return CalibrationInput( + channel_data={ + tv_channel_key: ChannelCalibrationData( + lift_start_date=pd.Timestamp("2022-01-01"), + lift_end_date=pd.Timestamp("2022-01-05"), + lift_abs=1000, + spend=tv_spend, + confidence=0.9, + metric=DependentVarType.REVENUE, + calibration_scope=CalibrationScope.IMMEDIATE, + ), + radio_channel_key: ChannelCalibrationData( + lift_start_date=pd.Timestamp("2022-01-06"), + lift_end_date=pd.Timestamp("2022-01-10"), + lift_abs=2000, + spend=radio_spend, + confidence=0.85, + metric=DependentVarType.REVENUE, + calibration_scope=CalibrationScope.IMMEDIATE, + ), + } + ) + + +@pytest.fixture +def sample_multichannel_calibration_input(sample_mmmdata): + data = sample_mmmdata.data + combined_spend = ( + data.loc[data["date"].between("2022-01-01", "2022-01-05"), ["tv_spend", "radio_spend"]].sum().sum() + ) + tv_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "tv_spend"].sum() + + return CalibrationInput( + channel_data={ + "tv_spend+radio_spend": ChannelCalibrationData( + lift_start_date=pd.Timestamp("2022-01-01"), + lift_end_date=pd.Timestamp("2022-01-05"), + lift_abs=3000, + spend=combined_spend, + confidence=0.9, + metric=DependentVarType.REVENUE, + calibration_scope=CalibrationScope.IMMEDIATE, + ), + "tv_spend": ChannelCalibrationData( + lift_start_date=pd.Timestamp("2022-01-06"), + lift_end_date=pd.Timestamp("2022-01-10"), + lift_abs=1000, + spend=tv_spend, + confidence=0.85, + metric=DependentVarType.REVENUE, + calibration_scope=CalibrationScope.IMMEDIATE, + ), + } + ) + + +def test_check_calibration_valid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=pd.Timestamp("2022-01-01"), + window_end=pd.Timestamp("2022-01-10"), + ) + result = validator.check_calibration() + assert result.status == True + assert not result.error_details + assert not result.error_message + + +def test_check_date_range_invalid(sample_mmmdata, sample_calibration_input): + # First create the validator + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + # Then use the static method to create modified input + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), lift_start_date=datetime(2021, 12, 31) + ) + + validator = CalibrationInputValidation( + sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_date_range() + assert result.status == False + assert ("tv_spend",) in result.error_details + assert "outside the modeling window" in result.error_message + + +def test_check_lift_values_invalid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("radio_spend",), lift_abs="invalid" + ) + + validator = CalibrationInputValidation( + sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_lift_values() + assert result.status == False + assert ("radio_spend",) in result.error_details + assert "must be a valid number" in result.error_message + + +def test_check_spend_values_invalid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), spend=1000 + ) + + validator = CalibrationInputValidation( + sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_spend_values() + assert result.status == False + assert ("tv_spend",) in result.error_details + assert "does not match the input data" in result.error_message + + +def test_check_confidence_values_invalid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("radio_spend",), confidence=0.7 + ) + + validator = CalibrationInputValidation( + sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_confidence_values() + assert result.status == False + assert ("radio_spend",) in result.error_details + assert "lower than 80%" in result.error_message + + +def test_check_metric_values_invalid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), metric=DependentVarType.CONVERSION + ) + + validator = CalibrationInputValidation( + sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_metric_values() + assert result.status == False + assert ("tv_spend",) in result.error_details + assert "does not match the dependent variable" in result.error_message + + +def test_check_obj_weights_valid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=pd.Timestamp("2022-01-01"), + window_end=pd.Timestamp("2022-01-10"), + ) + result = validator.check_obj_weights([0, 1, 1], True) + assert result.status is True + assert not result.error_details + assert result.error_message == "" + + +def test_check_obj_weights_invalid(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + result = validator.check_obj_weights([0, 1, 1, 1], False) + assert result.status == False + assert "length" in result.error_details + assert "Invalid number of objective weights" in result.error_message + + result = validator.check_obj_weights([-1, 1, 11], False) + assert result.status == False + assert "range" in result.error_details + assert "Objective weights out of valid range" in result.error_message + + +def test_validate(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, sample_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + + results = validator.validate() + assert len(results) == 1 + assert all(result.status for result in results) + assert all(not result.error_details for result in results) + assert all(not result.error_message for result in results) + + # Test with invalid input + invalid_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, + ("tv_spend",), + lift_start_date=datetime(2021, 12, 31), + lift_abs="invalid", + spend=1000000, + confidence=0.5, + metric=DependentVarType.CONVERSION, + ) + + invalid_validator = CalibrationInputValidation( + sample_mmmdata, invalid_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + invalid_results = invalid_validator.validate() + assert len(invalid_results) == 1 + assert any(not result.status for result in invalid_results) + assert any(result.error_details for result in invalid_results) + assert any(result.error_message for result in invalid_results) + + +def test_multichannel_validation(sample_mmmdata, sample_multichannel_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_multichannel_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + result = validator.check_calibration() + assert result.status == True + assert not result.error_details + assert not result.error_message + + +def test_invalid_channel(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + invalid_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("nonexistent_channel",), lift_abs=1000 + ) + + validator = CalibrationInputValidation( + sample_mmmdata, invalid_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_spend_values() + assert result.status == False + assert ("nonexistent_channel",) in result.error_details + assert "not found in data" in result.error_message.lower() + + +def test_invalid_multichannel_combination(sample_mmmdata): + invalid_combination = CalibrationInput( + channel_data={ + "tv_spend+nonexistent_channel": ChannelCalibrationData( + lift_start_date=pd.Timestamp("2022-01-01"), + lift_end_date=pd.Timestamp("2022-01-05"), + lift_abs=1000, + spend=300, + confidence=0.9, + metric=DependentVarType.REVENUE, + calibration_scope=CalibrationScope.IMMEDIATE, + ) + } + ) + + validator = CalibrationInputValidation( + sample_mmmdata, + invalid_combination, + window_start=pd.Timestamp("2022-01-01"), + window_end=pd.Timestamp("2022-01-10"), + ) + result = validator._check_spend_values() + assert result.status is False + assert "not found in data" in result.error_message.lower() + + +def test_edge_cases(sample_mmmdata): + # Test with empty calibration input + empty_input = CalibrationInput(channel_data={}) + validator = CalibrationInputValidation( + sample_mmmdata, empty_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator.check_calibration() + assert result.status == True # Empty input should be valid + + # Test with None calibration input + validator_none = CalibrationInputValidation( + sample_mmmdata, None, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result_none = validator_none.check_calibration() + assert result_none.status == True # None input should be valid + + +def test_date_boundary_cases(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + # Test exact boundary dates + boundary_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, + ("tv_spend",), + lift_start_date=datetime(2022, 1, 1), # Exact start + lift_end_date=datetime(2022, 1, 10), # Exact end + ) + + validator = CalibrationInputValidation( + sample_mmmdata, boundary_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + ) + result = validator._check_date_range() + assert result.status == True + assert not result.error_details + + +def test_validate_with_multichannel(sample_mmmdata, sample_multichannel_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_multichannel_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + results = validator.validate() + assert len(results) == 1 + assert all(result.status for result in results) + assert all(not result.error_details for result in results) + assert all(not result.error_message for result in results) diff --git a/python/tests/test_calibration_input_validation.py b/python/tests/test_calibration_input_validation.py index 570490b9b..9c385c05d 100644 --- a/python/tests/test_calibration_input_validation.py +++ b/python/tests/test_calibration_input_validation.py @@ -62,19 +62,16 @@ def sample_calibration_input(sample_mmmdata): @pytest.fixture def sample_multichannel_calibration_input(sample_mmmdata): - """Create a sample multi-channel calibration input.""" + # Calculate combined spend for the channels data = sample_mmmdata.data combined_spend = ( data.loc[data["date"].between("2022-01-01", "2022-01-05"), ["tv_spend", "radio_spend"]].sum().sum() ) tv_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "tv_spend"].sum() - multi_channel_key = ("tv_spend", "radio_spend") # Use tuple key - tv_channel_key = ("tv_spend",) # Use tuple key - return CalibrationInput( channel_data={ - multi_channel_key: ChannelCalibrationData( + "tv_spend+radio_spend": ChannelCalibrationData( lift_start_date=pd.Timestamp("2022-01-01"), lift_end_date=pd.Timestamp("2022-01-05"), lift_abs=3000, @@ -83,7 +80,7 @@ def sample_multichannel_calibration_input(sample_mmmdata): metric=DependentVarType.REVENUE, calibration_scope=CalibrationScope.IMMEDIATE, ), - tv_channel_key: ChannelCalibrationData( + "tv_spend": ChannelCalibrationData( lift_start_date=pd.Timestamp("2022-01-06"), lift_end_date=pd.Timestamp("2022-01-10"), lift_abs=1000, @@ -96,36 +93,6 @@ def sample_multichannel_calibration_input(sample_mmmdata): ) -def create_modified_calibration_input(original_input, channel_name, **kwargs): - # Handle string channel names with '+' by converting to proper format - if isinstance(channel_name, str) and "+" in channel_name: - channel_key = channel_name # Keep as string with '+' for CalibrationInput - elif isinstance(channel_name, str): - channel_key = channel_name - else: - channel_key = "+".join(channel_name) # Convert tuple to string with '+' - - # Get original data using the key - original_channel_data = original_input.channel_data[channel_key] - - # Create new channel data with updates - new_channel_data = ChannelCalibrationData( - lift_start_date=kwargs.get("lift_start_date", original_channel_data.lift_start_date), - lift_end_date=kwargs.get("lift_end_date", original_channel_data.lift_end_date), - lift_abs=kwargs.get("lift_abs", original_channel_data.lift_abs), - spend=kwargs.get("spend", original_channel_data.spend), - confidence=kwargs.get("confidence", original_channel_data.confidence), - metric=kwargs.get("metric", original_channel_data.metric), - calibration_scope=kwargs.get("calibration_scope", original_channel_data.calibration_scope), - ) - - # Create new dictionary with updated data - new_channel_data_dict = original_input.channel_data.copy() - new_channel_data_dict[channel_key] = new_channel_data - - return CalibrationInput(channel_data=new_channel_data_dict) - - def test_check_calibration_valid(sample_mmmdata, sample_calibration_input): validator = CalibrationInputValidation( sample_mmmdata,