diff --git a/python/src/robyn/data/validation/test_calibration_input_validation.py b/python/src/robyn/data/validation/test_calibration_input_validation.py deleted file mode 100644 index fc0d8bc2a..000000000 --- a/python/src/robyn/data/validation/test_calibration_input_validation.py +++ /dev/null @@ -1,392 +0,0 @@ -import pytest -import pandas as pd -from datetime import datetime, timedelta -from robyn.data.entities.calibration_input import CalibrationInput, ChannelCalibrationData -from robyn.data.entities.mmmdata import MMMData -from robyn.data.validation.calibration_input_validation import CalibrationInputValidation -from robyn.data.entities.enums import DependentVarType, CalibrationScope - - -@pytest.fixture -def sample_mmmdata(): - data = pd.DataFrame( - { - "date": pd.date_range(start="2022-01-01", periods=10), - "revenue": [100, 120, 110, 130, 140, 150, 160, 170, 180, 190], - "tv_spend": [50, 60, 55, 65, 70, 75, 80, 85, 90, 95], - "radio_spend": [30, 35, 32, 38, 40, 42, 45, 48, 50, 52], - "temperature": [20, 22, 21, 23, 24, 25, 26, 27, 28, 29], - } - ) - - mmm_data_spec = MMMData.MMMDataSpec( - dep_var="revenue", date_var="date", paid_media_spends=["tv_spend", "radio_spend"], context_vars=["temperature"] - ) - - return MMMData(data, mmm_data_spec) - - -@pytest.fixture -def sample_calibration_input(sample_mmmdata): - """Create a sample calibration input with actual spend values.""" - data = sample_mmmdata.data - tv_spend = data.loc[data["date"].between("2022-01-01", "2022-01-05"), "tv_spend"].sum() - radio_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "radio_spend"].sum() - - tv_channel_key = ("tv_spend",) - radio_channel_key = ("radio_spend",) - - return CalibrationInput( - channel_data={ - tv_channel_key: ChannelCalibrationData( - lift_start_date=pd.Timestamp("2022-01-01"), - lift_end_date=pd.Timestamp("2022-01-05"), - lift_abs=1000, - spend=tv_spend, - confidence=0.9, - metric=DependentVarType.REVENUE, - calibration_scope=CalibrationScope.IMMEDIATE, - ), - radio_channel_key: ChannelCalibrationData( - lift_start_date=pd.Timestamp("2022-01-06"), - lift_end_date=pd.Timestamp("2022-01-10"), - lift_abs=2000, - spend=radio_spend, - confidence=0.85, - metric=DependentVarType.REVENUE, - calibration_scope=CalibrationScope.IMMEDIATE, - ), - } - ) - - -@pytest.fixture -def sample_multichannel_calibration_input(sample_mmmdata): - data = sample_mmmdata.data - combined_spend = ( - data.loc[data["date"].between("2022-01-01", "2022-01-05"), ["tv_spend", "radio_spend"]].sum().sum() - ) - tv_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "tv_spend"].sum() - - return CalibrationInput( - channel_data={ - "tv_spend+radio_spend": ChannelCalibrationData( - lift_start_date=pd.Timestamp("2022-01-01"), - lift_end_date=pd.Timestamp("2022-01-05"), - lift_abs=3000, - spend=combined_spend, - confidence=0.9, - metric=DependentVarType.REVENUE, - calibration_scope=CalibrationScope.IMMEDIATE, - ), - "tv_spend": ChannelCalibrationData( - lift_start_date=pd.Timestamp("2022-01-06"), - lift_end_date=pd.Timestamp("2022-01-10"), - lift_abs=1000, - spend=tv_spend, - confidence=0.85, - metric=DependentVarType.REVENUE, - calibration_scope=CalibrationScope.IMMEDIATE, - ), - } - ) - - -def test_check_calibration_valid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=pd.Timestamp("2022-01-01"), - window_end=pd.Timestamp("2022-01-10"), - ) - result = validator.check_calibration() - assert result.status == True - assert not result.error_details - assert not result.error_message - - -def test_check_date_range_invalid(sample_mmmdata, sample_calibration_input): - # First create the validator - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - # Then use the static method to create modified input - new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), lift_start_date=datetime(2021, 12, 31) - ) - - validator = CalibrationInputValidation( - sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_date_range() - assert result.status == False - assert ("tv_spend",) in result.error_details - assert "outside the modeling window" in result.error_message - - -def test_check_lift_values_invalid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("radio_spend",), lift_abs="invalid" - ) - - validator = CalibrationInputValidation( - sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_lift_values() - assert result.status == False - assert ("radio_spend",) in result.error_details - assert "must be a valid number" in result.error_message - - -def test_check_spend_values_invalid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), spend=1000 - ) - - validator = CalibrationInputValidation( - sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_spend_values() - assert result.status == False - assert ("tv_spend",) in result.error_details - assert "does not match the input data" in result.error_message - - -def test_check_confidence_values_invalid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("radio_spend",), confidence=0.7 - ) - - validator = CalibrationInputValidation( - sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_confidence_values() - assert result.status == False - assert ("radio_spend",) in result.error_details - assert "lower than 80%" in result.error_message - - -def test_check_metric_values_invalid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), metric=DependentVarType.CONVERSION - ) - - validator = CalibrationInputValidation( - sample_mmmdata, new_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_metric_values() - assert result.status == False - assert ("tv_spend",) in result.error_details - assert "does not match the dependent variable" in result.error_message - - -def test_check_obj_weights_valid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=pd.Timestamp("2022-01-01"), - window_end=pd.Timestamp("2022-01-10"), - ) - result = validator.check_obj_weights([0, 1, 1], True) - assert result.status is True - assert not result.error_details - assert result.error_message == "" - - -def test_check_obj_weights_invalid(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - result = validator.check_obj_weights([0, 1, 1, 1], False) - assert result.status == False - assert "length" in result.error_details - assert "Invalid number of objective weights" in result.error_message - - result = validator.check_obj_weights([-1, 1, 11], False) - assert result.status == False - assert "range" in result.error_details - assert "Objective weights out of valid range" in result.error_message - - -def test_validate(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, sample_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - - results = validator.validate() - assert len(results) == 1 - assert all(result.status for result in results) - assert all(not result.error_details for result in results) - assert all(not result.error_message for result in results) - - # Test with invalid input - invalid_calibration_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, - ("tv_spend",), - lift_start_date=datetime(2021, 12, 31), - lift_abs="invalid", - spend=1000000, - confidence=0.5, - metric=DependentVarType.CONVERSION, - ) - - invalid_validator = CalibrationInputValidation( - sample_mmmdata, invalid_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - invalid_results = invalid_validator.validate() - assert len(invalid_results) == 1 - assert any(not result.status for result in invalid_results) - assert any(result.error_details for result in invalid_results) - assert any(result.error_message for result in invalid_results) - - -def test_multichannel_validation(sample_mmmdata, sample_multichannel_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_multichannel_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - result = validator.check_calibration() - assert result.status == True - assert not result.error_details - assert not result.error_message - - -def test_invalid_channel(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - invalid_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, ("nonexistent_channel",), lift_abs=1000 - ) - - validator = CalibrationInputValidation( - sample_mmmdata, invalid_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_spend_values() - assert result.status == False - assert ("nonexistent_channel",) in result.error_details - assert "not found in data" in result.error_message.lower() - - -def test_invalid_multichannel_combination(sample_mmmdata): - invalid_combination = CalibrationInput( - channel_data={ - "tv_spend+nonexistent_channel": ChannelCalibrationData( - lift_start_date=pd.Timestamp("2022-01-01"), - lift_end_date=pd.Timestamp("2022-01-05"), - lift_abs=1000, - spend=300, - confidence=0.9, - metric=DependentVarType.REVENUE, - calibration_scope=CalibrationScope.IMMEDIATE, - ) - } - ) - - validator = CalibrationInputValidation( - sample_mmmdata, - invalid_combination, - window_start=pd.Timestamp("2022-01-01"), - window_end=pd.Timestamp("2022-01-10"), - ) - result = validator._check_spend_values() - assert result.status is False - assert "not found in data" in result.error_message.lower() - - -def test_edge_cases(sample_mmmdata): - # Test with empty calibration input - empty_input = CalibrationInput(channel_data={}) - validator = CalibrationInputValidation( - sample_mmmdata, empty_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator.check_calibration() - assert result.status == True # Empty input should be valid - - # Test with None calibration input - validator_none = CalibrationInputValidation( - sample_mmmdata, None, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result_none = validator_none.check_calibration() - assert result_none.status == True # None input should be valid - - -def test_date_boundary_cases(sample_mmmdata, sample_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - # Test exact boundary dates - boundary_input = CalibrationInputValidation.create_modified_calibration_input( - sample_calibration_input, - ("tv_spend",), - lift_start_date=datetime(2022, 1, 1), # Exact start - lift_end_date=datetime(2022, 1, 10), # Exact end - ) - - validator = CalibrationInputValidation( - sample_mmmdata, boundary_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) - ) - result = validator._check_date_range() - assert result.status == True - assert not result.error_details - - -def test_validate_with_multichannel(sample_mmmdata, sample_multichannel_calibration_input): - validator = CalibrationInputValidation( - sample_mmmdata, - sample_multichannel_calibration_input, - window_start=datetime(2022, 1, 1), - window_end=datetime(2022, 1, 10), - ) - - results = validator.validate() - assert len(results) == 1 - assert all(result.status for result in results) - assert all(not result.error_details for result in results) - assert all(not result.error_message for result in results) diff --git a/python/tests/test_calibration_input_validation.py b/python/tests/test_calibration_input_validation.py index 9c385c05d..11b0f5f51 100644 --- a/python/tests/test_calibration_input_validation.py +++ b/python/tests/test_calibration_input_validation.py @@ -33,8 +33,8 @@ def sample_calibration_input(sample_mmmdata): tv_spend = data.loc[data["date"].between("2022-01-01", "2022-01-05"), "tv_spend"].sum() radio_spend = data.loc[data["date"].between("2022-01-06", "2022-01-10"), "radio_spend"].sum() - tv_channel_key = ("tv_spend",) # Use tuple key - radio_channel_key = ("radio_spend",) # Use tuple key + tv_channel_key = ("tv_spend",) + radio_channel_key = ("radio_spend",) return CalibrationInput( channel_data={ @@ -62,7 +62,6 @@ def sample_calibration_input(sample_mmmdata): @pytest.fixture def sample_multichannel_calibration_input(sample_mmmdata): - # Calculate combined spend for the channels data = sample_mmmdata.data combined_spend = ( data.loc[data["date"].between("2022-01-01", "2022-01-05"), ["tv_spend", "radio_spend"]].sum().sum() @@ -107,9 +106,17 @@ def test_check_calibration_valid(sample_mmmdata, sample_calibration_input): def test_check_date_range_invalid(sample_mmmdata, sample_calibration_input): - # Use tuple for channel key - new_calibration_input = create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), lift_start_date=datetime(2021, 12, 31) # Changed to tuple + # First create the validator + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + # Then use the static method to create modified input + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), lift_start_date=datetime(2021, 12, 31) ) validator = CalibrationInputValidation( @@ -117,14 +124,20 @@ def test_check_date_range_invalid(sample_mmmdata, sample_calibration_input): ) result = validator._check_date_range() assert result.status == False - # Check using tuple key in error_details assert ("tv_spend",) in result.error_details assert "outside the modeling window" in result.error_message def test_check_lift_values_invalid(sample_mmmdata, sample_calibration_input): - new_calibration_input = create_modified_calibration_input( - sample_calibration_input, ("radio_spend",), lift_abs="invalid" # Changed to tuple + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("radio_spend",), lift_abs="invalid" ) validator = CalibrationInputValidation( @@ -137,8 +150,15 @@ def test_check_lift_values_invalid(sample_mmmdata, sample_calibration_input): def test_check_spend_values_invalid(sample_mmmdata, sample_calibration_input): - new_calibration_input = create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), spend=1000 # Changed to tuple + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), spend=1000 ) validator = CalibrationInputValidation( @@ -151,8 +171,15 @@ def test_check_spend_values_invalid(sample_mmmdata, sample_calibration_input): def test_check_confidence_values_invalid(sample_mmmdata, sample_calibration_input): - new_calibration_input = create_modified_calibration_input( - sample_calibration_input, ("radio_spend",), confidence=0.7 # Changed to tuple + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("radio_spend",), confidence=0.7 ) validator = CalibrationInputValidation( @@ -165,8 +192,15 @@ def test_check_confidence_values_invalid(sample_mmmdata, sample_calibration_inpu def test_check_metric_values_invalid(sample_mmmdata, sample_calibration_input): - new_calibration_input = create_modified_calibration_input( - sample_calibration_input, ("tv_spend",), metric=DependentVarType.CONVERSION # Changed to tuple + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + new_calibration_input = CalibrationInputValidation.create_modified_calibration_input( + sample_calibration_input, ("tv_spend",), metric=DependentVarType.CONVERSION ) validator = CalibrationInputValidation( @@ -193,16 +227,17 @@ def test_check_obj_weights_valid(sample_mmmdata, sample_calibration_input): def test_check_obj_weights_invalid(sample_mmmdata, sample_calibration_input): validator = CalibrationInputValidation( - sample_mmmdata, sample_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), ) - # Test case 1: Incorrect number of weights result = validator.check_obj_weights([0, 1, 1, 1], False) assert result.status == False assert "length" in result.error_details assert "Invalid number of objective weights" in result.error_message - # Test case 2: Weights out of valid range result = validator.check_obj_weights([-1, 1, 11], False) assert result.status == False assert "range" in result.error_details @@ -210,7 +245,6 @@ def test_check_obj_weights_invalid(sample_mmmdata, sample_calibration_input): def test_validate(sample_mmmdata, sample_calibration_input): - # Test with valid input validator = CalibrationInputValidation( sample_mmmdata, sample_calibration_input, window_start=datetime(2022, 1, 1), window_end=datetime(2022, 1, 10) ) @@ -222,9 +256,9 @@ def test_validate(sample_mmmdata, sample_calibration_input): assert all(not result.error_message for result in results) # Test with invalid input - invalid_calibration_input = create_modified_calibration_input( + invalid_calibration_input = CalibrationInputValidation.create_modified_calibration_input( sample_calibration_input, - ("tv_spend",), # Changed to tuple + ("tv_spend",), lift_start_date=datetime(2021, 12, 31), lift_abs="invalid", spend=1000000, @@ -256,8 +290,14 @@ def test_multichannel_validation(sample_mmmdata, sample_multichannel_calibration def test_invalid_channel(sample_mmmdata, sample_calibration_input): - # Test with non-existent channel - invalid_input = create_modified_calibration_input( + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + + invalid_input = CalibrationInputValidation.create_modified_calibration_input( sample_calibration_input, ("nonexistent_channel",), lift_abs=1000 ) @@ -314,8 +354,15 @@ def test_edge_cases(sample_mmmdata): def test_date_boundary_cases(sample_mmmdata, sample_calibration_input): + validator = CalibrationInputValidation( + sample_mmmdata, + sample_calibration_input, + window_start=datetime(2022, 1, 1), + window_end=datetime(2022, 1, 10), + ) + # Test exact boundary dates - boundary_input = create_modified_calibration_input( + boundary_input = CalibrationInputValidation.create_modified_calibration_input( sample_calibration_input, ("tv_spend",), lift_start_date=datetime(2022, 1, 1), # Exact start @@ -330,7 +377,6 @@ def test_date_boundary_cases(sample_mmmdata, sample_calibration_input): assert not result.error_details -# Modify your existing test_validate to include multi-channel cases def test_validate_with_multichannel(sample_mmmdata, sample_multichannel_calibration_input): validator = CalibrationInputValidation( sample_mmmdata,