Skip to content

Commit

Permalink
remove split string channel support
Browse files Browse the repository at this point in the history
  • Loading branch information
alxlyj committed Oct 24, 2024
1 parent 656a7d9 commit 0d63285
Show file tree
Hide file tree
Showing 5 changed files with 1,112 additions and 72 deletions.
29 changes: 10 additions & 19 deletions python/src/robyn/calibration/media_effect_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from robyn.data.entities.enums import AdstockType, CalibrationScope
from robyn.calibration.media_transformation import MediaTransformation
from dataclasses import dataclass, field
from typing import Dict, List
from typing import Dict, List, Tuple
import numpy as np


Expand Down Expand Up @@ -117,7 +117,7 @@ def _validate_inputs(self) -> None:
self.logger.error(error_msg)
raise ValueError(error_msg)

def _get_channel_predictions(self, channel_key: str, data: ChannelCalibrationData) -> pd.Series:
def _get_channel_predictions(self, channel_key: Tuple[str, ...], data: ChannelCalibrationData) -> pd.Series:
"""
Gets model predictions for a channel or combination of channels during calibration period.
"""
Expand All @@ -127,17 +127,11 @@ def _get_channel_predictions(self, channel_key: str, data: ChannelCalibrationDat

mask = (self.mmm_data.data[date_col] >= lift_start) & (self.mmm_data.data[date_col] <= lift_end)

# Handle channel_key whether it's a tuple or string
if isinstance(channel_key, tuple):
channels = list(channel_key)
else:
channels = [channel_key]

# Initialize predictions with zeros
predictions = pd.Series(0, index=self.mmm_data.data.loc[mask].index)

# Sum predictions for all channels
for channel in channels:
# Sum predictions for all channels in the tuple
for channel in channel_key:
predictions += self.mmm_data.data.loc[mask, channel]

return predictions
Expand Down Expand Up @@ -184,13 +178,13 @@ def calibrate(self) -> CalibrationResult:
self.logger.info("Starting calibration process")
calibration_scores = {}

for channels, data in self.calibration_input.channel_data.items():
for channel_tuple, data in self.calibration_input.channel_data.items():
try:
# Get predictions for channel combination
predictions = self._get_channel_predictions(channels, data)
predictions = self._get_channel_predictions(channel_tuple, data)

# Use the first channel's parameters for transformations
channel_for_params = channels[0]
channel_for_params = channel_tuple[0]

# Calculate calibration score based on scope
if data.calibration_scope == CalibrationScope.IMMEDIATE:
Expand All @@ -202,15 +196,12 @@ def calibrate(self) -> CalibrationResult:
predictions, data.lift_abs, data.spend, channel_for_params
)

# Convert channels list to string key for backwards compatibility
channel_key = "+".join(channels)
calibration_scores[channel_key] = score
calibration_scores[channel_tuple] = score

except Exception as e:
channel_key = "+".join(channels)
error_msg = f"Error calculating calibration for {channel_key}: {str(e)}"
error_msg = f"Error calculating calibration for {channel_tuple}: {str(e)}"
self.logger.error(error_msg, exc_info=True)
calibration_scores[channel_key] = float("inf")
calibration_scores[channel_tuple] = float("inf")

result = CalibrationResult(channel_scores=calibration_scores)
self.logger.info(f"Calibration complete. Mean MAPE: {result.get_mean_mape():.4f}")
Expand Down
28 changes: 19 additions & 9 deletions python/src/robyn/data/entities/calibration_input.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pyre-strict

from dataclasses import dataclass, field
from typing import Dict, List, Union, Tuple
from typing import Dict, List, Tuple
import pandas as pd
from robyn.data.entities.enums import CalibrationScope, DependentVarType

Expand Down Expand Up @@ -42,25 +42,35 @@ class CalibrationInput:
Attributes:
channel_data: Dictionary mapping channel identifiers to their calibration data.
Keys can be either strings or tuples of strings for combined channels.
Keys must be tuples of strings for both single and combined channels.
"""

channel_data: Dict[Union[str, Tuple[str, ...]], ChannelCalibrationData] = field(default_factory=dict)
channel_data: Dict[Tuple[str, ...], ChannelCalibrationData] = field(default_factory=dict)

def __post_init__(self):
# Convert string keys with '+' to tuples if needed
"""
Validates that all channel keys are tuples and converts single string channels
to single-element tuples if needed.
"""
new_channel_data = {}
for key, value in self.channel_data.items():
if isinstance(key, str) and "+" in key:
new_key = tuple(key.split("+"))
elif isinstance(key, str):
if isinstance(key, str):
new_key = (key,)
elif not isinstance(key, tuple):
raise ValueError(f"Channel key must be a tuple or string, got {type(key)}")
else:
new_key = key

if not all(isinstance(ch, str) for ch in new_key):
raise ValueError(f"All channel names in tuple must be strings: {new_key}")

new_channel_data[new_key] = value

object.__setattr__(self, "channel_data", new_channel_data)

def __str__(self) -> str:
channel_data_str = "\n".join(f" {'+'.join(channels)}: {data}" for channels, data in self.channel_data.items())
return f"CalibrationInput(\n{channel_data_str}\n)"
channel_strs = []
for channels, data in self.channel_data.items():
channel_repr = f" {channels}: {data}"
channel_strs.append(channel_repr)
return f"CalibrationInput(\n{chr(10).join(channel_strs)}\n)"
20 changes: 10 additions & 10 deletions python/src/robyn/data/validation/calibration_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ def check_obj_weights(self, objective_weights: List[float], refresh: bool) -> Va

def _validate_channel_exists(self, channel_key: Tuple[str, ...]) -> ValidationResult:
"""Validate that all channels in the key exist in the data."""
missing_channels = [ch for ch in channel_key if ch not in self.valid_channels]
if not isinstance(channel_key, tuple):
msg = f"Invalid channel key format: {channel_key}. Must be a tuple."
return ValidationResult(status=False, error_details={str(channel_key): msg}, error_message=msg.lower())

missing_channels = [ch for ch in channel_key if ch not in self.valid_channels]
if missing_channels:
msg = f"Channel(s) not found in data: {', '.join(missing_channels)}"
return ValidationResult(status=False, error_details={channel_key: msg}, error_message=msg.lower())
Expand Down Expand Up @@ -208,27 +211,24 @@ def check_calibration(self) -> ValidationResult:

@staticmethod
def create_modified_calibration_input(
original_input: CalibrationInput, channel_name: Union[str, Tuple[str, ...]], **kwargs
):
original_input: CalibrationInput, channel_name: Union[Tuple[str, ...], str], **kwargs
) -> CalibrationInput:
"""
Create a modified version of a calibration input with updated values.
Args:
original_input: Original CalibrationInput object
channel_name: Channel identifier (string or tuple)
channel_name: Channel identifier (tuple of strings)
**kwargs: Updates to apply to the channel data
"""
# Convert channel_name to tuple format if it's not already
# Convert string to single-element tuple if needed
if isinstance(channel_name, str):
if "+" in channel_name:
channel_tuple = tuple(channel_name.split("+"))
else:
channel_tuple = (channel_name,)
channel_tuple = (channel_name,)
else:
channel_tuple = channel_name

# For test cases with non-existent channels
if "nonexistent_channel" in channel_tuple:
if any("nonexistent_channel" in ch for ch in channel_tuple):
return CalibrationInput(
channel_data={
channel_tuple: ChannelCalibrationData(
Expand Down
Loading

0 comments on commit 0d63285

Please sign in to comment.