From 72d834e0bc4f2cedf5d4c5b7f9f4d4679edcc3b7 Mon Sep 17 00:00:00 2001 From: alxlyj Date: Thu, 24 Oct 2024 02:18:40 -0400 Subject: [PATCH] update validation methods --- .../calibration_input_validation.py | 190 ++++++++---------- 1 file changed, 79 insertions(+), 111 deletions(-) diff --git a/python/src/robyn/data/validation/calibration_input_validation.py b/python/src/robyn/data/validation/calibration_input_validation.py index 5f2031690..f9c63035c 100644 --- a/python/src/robyn/data/validation/calibration_input_validation.py +++ b/python/src/robyn/data/validation/calibration_input_validation.py @@ -3,7 +3,7 @@ from robyn.data.entities.mmmdata import MMMData from robyn.data.validation.validation import Validation, ValidationResult import pandas as pd -from typing import List, Tuple +from typing import List, Tuple, Set from robyn.data.entities.enums import CalibrationScope, DependentVarType @@ -19,45 +19,36 @@ def __init__( self.calibration_input = calibration_input self.window_start = window_start self.window_end = window_end + # Initialize valid_channels from paid_media_spends + self.valid_channels: Set[str] = set(self.mmmdata.mmmdata_spec.paid_media_spends) def check_calibration(self) -> ValidationResult: - """ - This function checks the calibration input data for consistency and correctness. - It verifies that the input data contains the required columns, that the date range - is within the modeling window, and that the spend values match the input data. - """ - if self.calibration_input is None: return ValidationResult(status=True, error_details={}, error_message="") error_details = {} error_messages = [] - # check data range date_range_result = self._check_date_range() if not date_range_result.status: error_details.update(date_range_result.error_details) error_messages.append(date_range_result.error_message) - # Check lift values - lift_result = self._check_lift_values() - if not lift_result.status: - error_details.update(lift_result.error_details) - error_messages.append(lift_result.error_message) - - # Check spend values spend_result = self._check_spend_values() if not spend_result.status: error_details.update(spend_result.error_details) error_messages.append(spend_result.error_message) - # Check confidence values + lift_result = self._check_lift_values() + if not lift_result.status: + error_details.update(lift_result.error_details) + error_messages.append(lift_result.error_message) + confidence_result = self._check_confidence_values() if not confidence_result.status: error_details.update(confidence_result.error_details) error_messages.append(confidence_result.error_message) - # Check metric values metric_result = self._check_metric_values() if not metric_result.status: error_details.update(metric_result.error_details) @@ -67,120 +58,70 @@ def check_calibration(self) -> ValidationResult: status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) ) - def check_obj_weights(self, objective_weights: List[float], refresh: bool) -> ValidationResult: - """ - Check the objective weights for validity. - - :param objective_weights: List of objective weights - :param refresh: Boolean indicating if this is a refresh run - :return: ValidationResult with the status and any error messages - """ - error_details = {} - error_messages = [] - - if objective_weights is None: - if refresh: - obj_len = 3 # Assuming 3 objectives for refresh runs - objective_weights = [0, 1, 1] - else: - return ValidationResult(status=True, error_details={}, error_message="") - - if len(objective_weights) not in [2, 3]: - error_details["length"] = f"Expected 2 or 3 objective weights, got {len(objective_weights)}" - error_messages.append("Invalid number of objective weights.") - - if any(weight < 0 or weight > 10 for weight in objective_weights): - error_details["range"] = "Objective weights must be >= 0 & <= 10" - error_messages.append("Objective weights out of valid range.") - - return ValidationResult( - status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) - ) - def _validate_channel_exists(self, channel_key: Tuple[str, ...]) -> ValidationResult: - """ - Validate that all channels in the key exist in the data. - - Args: - channel_key: Tuple of channel names - - Returns: - ValidationResult indicating if all channels exist - """ + """Validate that all channels in the key exist in the data.""" missing_channels = [ch for ch in channel_key if ch not in self.valid_channels] if missing_channels: + error_msg = f"Channels not found in data: {', '.join(missing_channels)}" return ValidationResult( status=False, - error_details={channel_key: f"Channels not found in data: {', '.join(missing_channels)}"}, - error_message=f"The following channels for {channel_key} are not in the input data: {', '.join(missing_channels)}", + error_details={channel_key: error_msg}, + error_message=f"The following channels are not in the input data: {', '.join(missing_channels)}", ) return ValidationResult(status=True) + def _get_channel_spend( + self, channel_key: Tuple[str, ...], start_date: pd.Timestamp, end_date: pd.Timestamp + ) -> float: + """Calculate total spend for single or combined channels in the given date range.""" + date_var = self.mmmdata.mmmdata_spec.date_var + data = self.mmmdata.data + + date_mask = (data[date_var] >= start_date) & (data[date_var] <= end_date) + return sum(data.loc[date_mask, ch].sum() for ch in channel_key) + def _check_date_range(self) -> ValidationResult: error_details = {} error_messages = [] - for channel, data in self.calibration_input.channel_data.items(): + for channel_key, data in self.calibration_input.channel_data.items(): + # Validate channel exists before checking dates + channel_validation = self._validate_channel_exists(channel_key) + if not channel_validation.status: + return channel_validation + if data.lift_start_date < self.window_start or data.lift_end_date > self.window_end: - error_details[channel] = ( - f"Date range {data.lift_start_date} to {data.lift_end_date} is outside modeling window {self.window_start} to {self.window_end}" + error_details[channel_key] = ( + f"Date range {data.lift_start_date} to {data.lift_end_date} " + f"is outside modeling window {self.window_start} to {self.window_end}" + ) + error_messages.append( + f"Calibration date range for {'+'.join(channel_key)} is outside the modeling window." ) - error_messages.append(f"Calibration date range for {channel} is outside the modeling window.") if data.lift_start_date > data.lift_end_date: - error_details[channel] = f"Start date {data.lift_start_date} is after end date {data.lift_end_date}" - error_messages.append(f"Invalid date range for {channel}: start date is after end date.") - - return ValidationResult( - status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) - ) - - def _check_lift_values(self) -> ValidationResult: - error_details = {} - error_messages = [] - - for channel, data in self.calibration_input.channel_data.items(): - if not isinstance(data.lift_abs, (int, float)) or pd.isna(data.lift_abs): - error_details[channel] = f"Invalid lift value: {data.lift_abs}" - error_messages.append(f"Lift value for {channel} must be a valid number.") + error_details[channel_key] = ( + f"Start date {data.lift_start_date} is after end date {data.lift_end_date}" + ) + error_messages.append(f"Invalid date range for {'+'.join(channel_key)}: start date is after end date.") return ValidationResult( status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) ) - def _get_channel_spend( - self, channel_key: Tuple[str, ...], start_date: pd.Timestamp, end_date: pd.Timestamp - ) -> float: - """ - Calculate total spend for single or combined channels in the given date range. - """ - date_var = self.mmmdata.mmmdata_spec.date_var - data = self.mmmdata.data - - # Get data for the date range - date_mask = (data[date_var] >= start_date) & (data[date_var] <= end_date) - - # Sum the spend for all channels in the key - total_spend = sum(data.loc[date_mask, channel].sum() for channel in channel_key) - return total_spend - def _check_spend_values(self) -> ValidationResult: error_details = {} error_messages = [] for channel_key, cal_data in self.calibration_input.channel_data.items(): - # Validate that all channels exist + # Validate channel exists before checking spend channel_validation = self._validate_channel_exists(channel_key) if not channel_validation.status: - error_details.update(channel_validation.error_details) - error_messages.append(channel_validation.error_message) - continue + return channel_validation - # Calculate actual spend from data actual_spend = self._get_channel_spend(channel_key, cal_data.lift_start_date, cal_data.lift_end_date) - # Check if spend matches within tolerance if abs(actual_spend - cal_data.spend) > 0.1 * cal_data.spend: error_details[channel_key] = ( f"Spend mismatch: calibration input ({cal_data.spend}) " f"vs. data ({actual_spend})" @@ -193,15 +134,39 @@ def _check_spend_values(self) -> ValidationResult: status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) ) + def _check_lift_values(self) -> ValidationResult: + error_details = {} + error_messages = [] + + for channel_key, data in self.calibration_input.channel_data.items(): + # Validate channel exists before checking lift + channel_validation = self._validate_channel_exists(channel_key) + if not channel_validation.status: + return channel_validation + + if not isinstance(data.lift_abs, (int, float)) or pd.isna(data.lift_abs): + error_details[channel_key] = f"Invalid lift value: {data.lift_abs}" + error_messages.append(f"Lift value for {'+'.join(channel_key)} must be a valid number.") + + return ValidationResult( + status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) + ) + def _check_confidence_values(self) -> ValidationResult: error_details = {} error_messages = [] - for channel, data in self.calibration_input.channel_data.items(): + for channel_key, data in self.calibration_input.channel_data.items(): + # Validate channel exists before checking confidence + channel_validation = self._validate_channel_exists(channel_key) + if not channel_validation.status: + return channel_validation + if data.confidence < 0.8: - error_details[channel] = f"Low confidence: {data.confidence}" + error_details[channel_key] = f"Low confidence: {data.confidence}" error_messages.append( - f"Confidence for {channel} is lower than 80%, which is considered low confidence." + f"Confidence for {'+'.join(channel_key)} is lower than 80%, " + f"which is considered low confidence." ) return ValidationResult( @@ -214,19 +179,22 @@ def _check_metric_values(self) -> ValidationResult: dep_var = self.mmmdata.mmmdata_spec.dep_var - for channel, data in self.calibration_input.channel_data.items(): + for channel_key, data in self.calibration_input.channel_data.items(): + # Validate channel exists before checking metric + channel_validation = self._validate_channel_exists(channel_key) + if not channel_validation.status: + return channel_validation + if data.metric != DependentVarType(dep_var): - error_details[channel] = f"Metric mismatch: {data.metric} vs. {dep_var}" - error_messages.append(f"Metric for {channel} does not match the dependent variable ({dep_var}).") + error_details[channel_key] = f"Metric mismatch: {data.metric} vs. {dep_var}" + error_messages.append( + f"Metric for {'+'.join(channel_key)} does not match the dependent variable ({dep_var})." + ) return ValidationResult( status=len(error_details) == 0, error_details=error_details, error_message="\n".join(error_messages) ) - def validate(self) -> ValidationResult: - """ - Perform all validations and return the results. - - :return: A dictionary containing the results of all validations. - """ + def validate(self) -> List[ValidationResult]: + """Perform all validations and return the results.""" return [self.check_calibration()]