diff --git a/python/src/robyn/modeling/convergence/convergence.py b/python/src/robyn/modeling/convergence/convergence.py index b0d87a8a5..cdb4c8aef 100644 --- a/python/src/robyn/modeling/convergence/convergence.py +++ b/python/src/robyn/modeling/convergence/convergence.py @@ -68,19 +68,22 @@ def calculate_convergence(self, trials: List[Trial]) -> Dict[str, Any]: # Create visualization plots self.logger.info("Creating visualization plots") - moo_distrb_plot = self.visualizer.create_moo_distrb_plot( - dt_objfunc_cvg, conv_msg - ) - moo_cloud_plot = self.visualizer.create_moo_cloud_plot( - df, conv_msg, calibrated - ) - ts_validation_plot = None # self.visualizer.create_ts_validation_plot(trials) #Disabled for testing. #Sandeep + plots_dict = { + "moo_distrb_plot": self.visualizer.create_moo_distrb_plot( + dt_objfunc_cvg, conv_msg + ), + "moo_cloud_plot": self.visualizer.create_moo_cloud_plot( + df, conv_msg, calibrated + ), + "ts_validation_plot": None # Disabled for testing + } + + # Display the plots + self.visualizer.display_convergence_plots(plots_dict) self.logger.info("Convergence calculation completed successfully") return { - "moo_distrb_plot": moo_distrb_plot, - "moo_cloud_plot": moo_cloud_plot, - "ts_validation_plot": ts_validation_plot, + **plots_dict, "errors": errors, "conv_msg": conv_msg, } diff --git a/python/src/robyn/reporting/onepager_reporting.py b/python/src/robyn/reporting/onepager_reporting.py index e09ef27e0..000e74961 100644 --- a/python/src/robyn/reporting/onepager_reporting.py +++ b/python/src/robyn/reporting/onepager_reporting.py @@ -39,9 +39,9 @@ def __init__( # Default plots using PlotType enum directly self.default_plots = [ - PlotType.SPEND_EFFECT, PlotType.WATERFALL, PlotType.FITTED_VS_ACTUAL, + PlotType.SPEND_EFFECT, PlotType.BOOTSTRAP, PlotType.ADSTOCK, PlotType.IMMEDIATE_CARRYOVER, diff --git a/python/src/robyn/visualization/base_visualizer.py b/python/src/robyn/visualization/base_visualizer.py index 0971e0067..26ff878e7 100644 --- a/python/src/robyn/visualization/base_visualizer.py +++ b/python/src/robyn/visualization/base_visualizer.py @@ -1,16 +1,14 @@ # pyre-strict import logging - from abc import ABC, abstractmethod from typing import Dict, Optional, Tuple, Union, List from pathlib import Path -from IPython.display import Image, display - import matplotlib.pyplot as plt import numpy as np import base64 import io +from IPython.display import Image, display # Configure logger logger = logging.getLogger(__name__) @@ -18,76 +16,170 @@ class BaseVisualizer(ABC): """ - Base class for all Robyn visualization components. - Provides common plotting functionality and styling. + Enhanced base class for all Robyn visualization components. + Provides standardized plotting functionality and styling. """ def __init__(self, style: str = "bmh"): """ - Initialize BaseVisualizer with common plot settings. + Initialize BaseVisualizer with standardized plot settings. Args: style: matplotlib style to use (default: "bmh") """ logger.info("Initializing BaseVisualizer with style: %s", style) - # Store style settings self.style = style - self.default_figsize = (12, 8) + # Standard figure sizes + self.figure_sizes = { + "default": (12, 8), + "wide": (16, 8), + "square": (10, 10), + "tall": (8, 12), + "small": (8, 6), + "large": (15, 10), + "medium": (10, 6) + } - # Enhanced color schemes + # Standardized color schemes self.colors = { + # Primary colors for main data series "primary": "#4688C7", # Steel blue "secondary": "#FF9F1C", # Orange + "tertiary": "#37B067", # Green + # Status colors "positive": "#2ECC71", # Green "negative": "#E74C3C", # Red "neutral": "#95A5A6", # Gray - "current": "lightgray", # For current values - "optimal": "#4688C7", # For optimal values - "grid": "#E0E0E0", # For grid lines + # Chart elements + "grid": "#E0E0E0", # Light gray for grid lines + "baseline": "#CCCCCC", # Medium gray for baseline/reference lines + "annotation": "#666666", # Dark gray for annotations + # Channel-specific colors (for consistency across plots) + "channels": { + "facebook": "#3B5998", + "search": "#4285F4", + "display": "#34A853", + "youtube": "#FF0000", + "twitter": "#1DA1F2", + "email": "#DB4437", + "print": "#9C27B0", + "tv": "#E91E63", + "radio": "#795548", + "ooh": "#607D8B", + }, + } + + # Standard line styles + self.line_styles = { + "solid": "-", + "dashed": "--", + "dotted": ":", + "dashdot": "-.", } - logger.debug("Color scheme initialized: %s", self.colors) - - # Plot settings - self.font_sizes = { - "title": 14, - "subtitle": 12, - "label": 12, - "tick": 10, - "annotation": 9, - "legend": 10, + + # Standard markers + self.markers = { + "circle": "o", + "square": "s", + "triangle": "^", + "diamond": "D", + "plus": "+", + "cross": "x", + "star": "*", } - logger.debug("Font sizes configured: %s", self.font_sizes) - # Default alpha values - self.alpha = {"primary": 0.7, "secondary": 0.5, "grid": 0.3, "annotation": 0.7} + # Font configurations + self.fonts = { + "family": "sans-serif", + "sizes": { + "title": 14, + "subtitle": 12, + "label": 11, + "tick": 10, + "annotation": 9, + "legend": 10, + "small": 8, + }, + } - # Default spacing - self.spacing = {"tight_layout_pad": 1.05, "subplot_adjust_hspace": 0.4} + # Common alpha values + self.alpha = { + "primary": 0.8, + "secondary": 0.6, + "grid": 0.3, + "annotation": 0.7, + "highlight": 0.9, + "background": 0.2, + } + + # Standard spacing + self.spacing = { + "tight_layout_pad": 1.05, + "subplot_adjust_hspace": 0.4, + "label_pad": 10, + "title_pad": 20, + } # Initialize plot tracking self.current_figure: Optional[plt.Figure] = None self.current_axes: Optional[Union[plt.Axes, np.ndarray]] = None - # Apply default style + # Apply default style and settings self._setup_plot_style() logger.info("BaseVisualizer initialization completed") + def format_number(self, x: float, pos=None) -> str: + """Format large numbers with K/M/B abbreviations. + + Args: + x: Number to format + pos: Position parameter (required by matplotlib formatter but not used) + + Returns: + Formatted string representation of the number + """ + try: + if abs(x) >= 1e9: + return f"{x/1e9:.1f}B" + elif abs(x) >= 1e6: + return f"{x/1e6:.1f}M" + elif abs(x) >= 1e3: + return f"{x/1e3:.1f}K" + else: + return f"{x:.1f}" + except (TypeError, ValueError): + return str(x) + def _setup_plot_style(self) -> None: """Configure default plotting style.""" - logger.debug("Setting up plot style with style: %s", self.style) + logger.debug("Setting up plot style") try: plt.style.use(self.style) plt.rcParams.update( { - "figure.figsize": self.default_figsize, + # Figure settings + "figure.figsize": self.figure_sizes["default"], + "figure.facecolor": "white", + # Font settings + "font.family": self.fonts["family"], + "font.size": self.fonts["sizes"]["label"], + # Axes settings "axes.grid": True, "axes.spines.top": False, "axes.spines.right": False, - "font.size": self.font_sizes["label"], + "axes.labelsize": self.fonts["sizes"]["label"], + "axes.titlesize": self.fonts["sizes"]["title"], + # Grid settings "grid.alpha": self.alpha["grid"], "grid.color": self.colors["grid"], + # Legend settings + "legend.fontsize": self.fonts["sizes"]["legend"], + "legend.framealpha": self.alpha["annotation"], + # Tick settings + "xtick.labelsize": self.fonts["sizes"]["tick"], + "ytick.labelsize": self.fonts["sizes"]["tick"], } ) logger.debug("Plot style parameters updated successfully") @@ -191,6 +283,117 @@ def setup_axis( logger.error("Failed to setup axis: %s", str(e)) raise + def _add_standardized_grid( + self, + ax: plt.Axes, + axis: str = "both", + alpha: Optional[float] = None, + color: Optional[str] = None, + linestyle: Optional[str] = None + ) -> None: + """Add standardized grid to plot.""" + ax.grid( + True, + axis=axis, + alpha=alpha or self.alpha["grid"], + color=color or self.colors["grid"], + linestyle=linestyle or self.line_styles["solid"], + zorder=0 + ) + ax.set_axisbelow(True) + + def _add_standardized_legend( + self, + ax: plt.Axes, + title: Optional[str] = None, + loc: str = "lower right", + ncol: int = 1, + handles: Optional[List] = None, + labels: Optional[List[str]] = None, + ) -> None: + """Add standardized legend to plot. + + Args: + ax: Matplotlib axes to add legend to + title: Optional legend title + loc: Legend location + ncol: Number of columns in legend + handles: Optional list of legend handles + labels: Optional list of legend labels + """ + legend_handles = handles if handles is not None else ax.get_legend_handles_labels()[0] + legend_labels = labels if labels is not None else ax.get_legend_handles_labels()[1] + + legend = ax.legend( + handles=legend_handles, + labels=legend_labels, + title=title, + loc=loc, + ncol=ncol, + fontsize=self.fonts["sizes"]["legend"], + framealpha=self.alpha["annotation"], + title_fontsize=self.fonts["sizes"]["subtitle"] + ) + if legend: + legend.get_frame().set_linewidth(0.5) + legend.get_frame().set_edgecolor(self.colors["grid"]) + + def _set_standardized_labels( + self, + ax: plt.Axes, + xlabel: Optional[str] = None, + ylabel: Optional[str] = None, + title: Optional[str] = None + ) -> None: + """Set standardized labels for plot.""" + if xlabel: + ax.set_xlabel( + xlabel, + fontsize=self.fonts["sizes"]["label"], + labelpad=self.spacing["label_pad"] + ) + if ylabel: + ax.set_ylabel( + ylabel, + fontsize=self.fonts["sizes"]["label"], + labelpad=self.spacing["label_pad"] + ) + if title: + ax.set_title( + title, + fontsize=self.fonts["sizes"]["title"], + pad=self.spacing["title_pad"] + ) + + def _format_standardized_ticks( + self, + ax: plt.Axes, + x_rotation: int = 0, + y_rotation: int = 0 + ) -> None: + """Format tick labels with standardized styling.""" + ax.tick_params( + axis='both', + labelsize=self.fonts["sizes"]["tick"] + ) + plt.setp( + ax.get_xticklabels(), + rotation=x_rotation, + ha='right' if x_rotation > 0 else 'center' + ) + plt.setp( + ax.get_yticklabels(), + rotation=y_rotation, + va='center' + ) + + def _set_standardized_spines(self, ax: plt.Axes, spines: List[str] = None) -> None: + """Configure plot spines with standardized styling.""" + if spines is None: + spines = ['top', 'right'] + for spine in spines: + ax.spines[spine].set_visible(False) + def add_percentage_annotation( self, ax: plt.Axes, diff --git a/python/src/robyn/visualization/feature_visualization.py b/python/src/robyn/visualization/feature_visualization.py index 348deb3d9..acf9158aa 100644 --- a/python/src/robyn/visualization/feature_visualization.py +++ b/python/src/robyn/visualization/feature_visualization.py @@ -92,29 +92,6 @@ def plot_spend_exposure( ) -> Dict[str, plt.Figure]: """ Generates a spend-exposure plot for a specified channel. - - Parameters: - ----------- - channel : str - The name of the channel for which the spend-exposure plot is to be generated. - - Returns: - -------- - plt.Figure - The matplotlib Figure object containing the spend-exposure plot. - - Raises: - ------- - ValueError - If no spend-exposure data or plot data is available for the specified channel. - Exception - If any other error occurs during the plot generation process. - - Notes: - ------ - The function retrieves the model results and plot data for the specified channel from the featurized_mmmdata attribute. - It creates a scatter plot of the actual data and a fitted line plot. The plot includes model information such as - model type, R-squared value, and model-specific parameters (e.g., Vmax and Km for Michaelis-Menten model or coefficient for linear model). """ logger.info("Generating spend-exposure plot for channel: %s", channel) @@ -128,67 +105,83 @@ def plot_spend_exposure( ), None, ) - logger.info("Found result for channel %s", channel) if res is None: - logger.error("Channel %s not found in featurized data results", channel) - raise ValueError( - f"No spend-exposure data available for channel: {channel}" - ) + raise ValueError(f"No spend-exposure data available for channel: {channel}") + plot_data = self.featurized_mmmdata.modNLS["plots"].get(channel) if plot_data is None: - logger.error("Plot data for channel %s not found", channel) raise ValueError(f"No plot data available for channel: {channel}") - fig, ax = plt.subplots(figsize=(10, 6)) + + # Create figure using base visualizer methods + fig, ax = self.create_figure(figsize=self.figure_sizes["medium"]) + # Plot scatter of actual data sns.scatterplot( x="spend", y="exposure", data=plot_data, ax=ax, - alpha=0.6, + alpha=self.alpha["primary"], label="Actual", + color=self.colors["primary"] ) - logger.debug("Created scatter plot for actual data") + # Plot fitted line sns.lineplot( - x="spend", y="yhat", data=plot_data, ax=ax, color="red", label="Fitted" + x="spend", + y="yhat", + data=plot_data, + ax=ax, + color=self.colors["secondary"], + label="Fitted" ) - logger.debug("Added fitted line to plot") - ax.set_xlabel(f"Spend [{channel}]") - ax.set_ylabel(f"Exposure [{channel}]") - ax.set_title(f"Spend vs Exposure for {channel}") - # Add model information to the plot + + # Set labels and title using base visualizer methods + self._set_standardized_labels( + ax, + xlabel=f"Spend [{channel}]", + ylabel=f"Exposure [{channel}]", + title=f"Spend vs Exposure for {channel}" + ) + + # Add model information model_type = res["model_type"] rsq = res["rsq"] - logger.debug("Model type: %s, R-squared: %f", model_type, rsq) + if model_type == "nls": Vmax, Km = res["Vmax"], res["Km"] - ax.text( - 0.05, - 0.95, - f"Model: Michaelis-Menten\nR² = {rsq:.4f}\nVmax = {Vmax:.2f}\nKm = {Km:.2f}", - transform=ax.transAxes, - verticalalignment="top", - bbox=dict(boxstyle="round", facecolor="white", alpha=0.7), - ) - logger.debug("Added NLS model parameters: Vmax=%f, Km=%f", Vmax, Km) + text = f"Model: Michaelis-Menten\nR² = {rsq:.4f}\nVmax = {Vmax:.2f}\nKm = {Km:.2f}" else: coef = res["coef_lm"] - ax.text( - 0.05, - 0.95, - f"Model: Linear\nR² = {rsq:.4f}\nCoefficient = {coef:.4f}", - transform=ax.transAxes, - verticalalignment="top", - bbox=dict(boxstyle="round", facecolor="white", alpha=0.7), - ) - logger.debug("Added linear model parameters: coefficient=%f", coef) - plt.legend() - plt.tight_layout() - plt.close() - logger.info( - "Successfully generated spend-exposure plot for channel %s", channel + text = f"Model: Linear\nR² = {rsq:.4f}\nCoefficient = {coef:.4f}" + + # Add text box with model information + ax.text( + 0.05, + 0.95, + text, + transform=ax.transAxes, + verticalalignment="top", + bbox=dict( + boxstyle="round", + facecolor="white", + alpha=self.alpha["annotation"], + edgecolor=self.colors["grid"] + ), + fontsize=self.fonts["sizes"]["annotation"] ) + + # Add grid and style using base visualizer methods + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + self._add_standardized_legend(ax, loc='lower right') + + # Finalize the figure + self.finalize_figure(tight_layout=True) + + logger.info("Successfully generated spend-exposure plot for channel %s", channel) + + self.cleanup() return {"spend-exposure": fig} except Exception as e: logger.error( @@ -199,71 +192,34 @@ def plot_spend_exposure( ) raise - def plot_feature_importance( - self, feature_importance: Dict[str, float], display: bool = True - ) -> Dict[str, plt.Figure]: - """ - Plot the importance of different features in the model. - - Args: - feature_importance (Dict[str, float]): Dictionary of feature importances. - - Returns: - plt.Figure: A matplotlib Figure object containing the feature importance plot. - """ - logger.info("Generating feature importance plot") - logger.debug("Feature importance data: %s", feature_importance) - try: - # Implementation placeholder - logger.warning("plot_feature_importance method not implemented yet") - - except Exception as e: - logger.error("Failed to generate feature importance plot: %s", str(e)) - raise - - def plot_response_curves(self, display: bool = True) -> Dict[str, plt.Figure]: - """ - Plot response curves for different channels. - - Args: - self.featurized_mmmdata (FeaturizedMMMData): The featurized data after feature engineering. - - Returns: - Dict[str, plt.Figure]: Dictionary mapping channel names to their respective response curve plots. - """ - logger.info("Generating response curves") - logger.debug("Processing featurized data: %s", self.featurized_mmmdata) - try: - dt_mod = self.featurized_mmmdata.dt_mod - logger.debug("Modified data: %s", dt_mod) - # Rest of the method implementation - logger.warning("plot_response_curves method not fully implemented yet") - - except Exception as e: - logger.error("Failed to generate response curves: %s", str(e)) - raise - def plot_all( self, display_plots: bool = True, export_location: Union[str, Path] = None ) -> Dict[str, plt.Figure]: """ - Override the abstract method plot_all from BaseVisualizer. + Generate all plots available in the feature plotter. """ logger.info("Generating all plots") plot_collect: Dict[str, plt.Figure] = {} + try: - for item in self.featurized_mmmdata.modNLS["results"]: - channel = item["channel"] - self.plot_adstock(channel, display_plots) - # plot_collect.update(self.plot_adstock(channel, display)) - # plot_collect.update(self.plot_saturation(channel, display)) - plot_collect[channel] = self.plot_spend_exposure( - channel, display_plots - )["spend-exposure"] + # Create plots for each channel only once + channels = {item["channel"] for item in self.featurized_mmmdata.modNLS["results"]} + + for channel in channels: + spend_exposure_plot = self.plot_spend_exposure(channel, display=False) + plot_collect[f"{channel}_spend_exposure"] = spend_exposure_plot["spend-exposure"] - # plot_collect.update(self.plot_feature_importance({}, display)) + if display_plots: + self.display_plots(plot_collect) - super().display_plots(plot_collect) + if export_location: + self.export_plots_fig(export_location, plot_collect) + + return plot_collect except Exception as e: logger.error("Failed to generate all plots: %s", str(e)) raise + + def __del__(self): + """Cleanup when the plotter is destroyed.""" + self.cleanup() \ No newline at end of file diff --git a/python/src/robyn/visualization/model_convergence_visualizer.py b/python/src/robyn/visualization/model_convergence_visualizer.py index 8165a3105..934b1a4d3 100644 --- a/python/src/robyn/visualization/model_convergence_visualizer.py +++ b/python/src/robyn/visualization/model_convergence_visualizer.py @@ -2,22 +2,16 @@ import numpy as np import pandas as pd import matplotlib.pyplot as plt -import matplotlib -from IPython.display import Image, display - -matplotlib.use("Agg") import seaborn as sns -from typing import List, Optional, Union -import io -import base64 +from typing import Dict, List, Optional, Union, Any import logging from robyn.modeling.entities.modeloutputs import Trial +from robyn.visualization.base_visualizer import BaseVisualizer -# Initialize logger for this module logger = logging.getLogger(__name__) -class ModelConvergenceVisualizer: +class ModelConvergenceVisualizer(BaseVisualizer): def __init__( self, n_cuts: Optional[int] = None, @@ -26,6 +20,7 @@ def __init__( moo_cloud_plot: Optional[pd.DataFrame] = None, moo_distrb_plot: Optional[pd.DataFrame] = None, ): + super().__init__() # Initialize BaseVisualizer self.n_cuts = n_cuts self.nrmse_win = nrmse_win self.ts_validation_plot = ts_validation_plot @@ -35,7 +30,7 @@ def __init__( def create_moo_distrb_plot( self, dt_objfunc_cvg: pd.DataFrame, conv_msg: List[str] - ) -> str: + ) -> Dict[str, plt.Figure]: logger.debug( "Starting moo distribution plot creation with data shape: %s", dt_objfunc_cvg.shape, @@ -47,10 +42,8 @@ def create_moo_distrb_plot( dt_objfunc_cvg["cuts"], categories=sorted(dt_objfunc_cvg["cuts"].unique(), reverse=True), ) + # Clip values based on quantiles - logger.debug( - "Processing error types: %s", dt_objfunc_cvg["error_type"].unique() - ) for error_type in dt_objfunc_cvg["error_type"].unique(): mask = dt_objfunc_cvg["error_type"] == error_type original_values = dt_objfunc_cvg.loc[mask, "value"] @@ -62,11 +55,11 @@ def create_moo_distrb_plot( quantiles[0], quantiles[1], ) - # Set the style and color palette - sns.set_style("whitegrid") - sns.set_palette("Set2") - # Create the violin plot with a larger figure size - fig, ax = plt.subplots(figsize=(14, 10)) + + # Create figure using base visualizer methods + fig, ax = self.create_figure(figsize=self.figure_sizes["default"]) + + # Create the violin plot sns.violinplot( data=dt_objfunc_cvg, x="value", @@ -76,29 +69,33 @@ def create_moo_distrb_plot( inner="quartile", ax=ax, ) - ax.set_xlabel("Objective functions", fontsize=12, ha="left", x=0) - ax.set_ylabel("Iterations [#]", fontsize=12) - ax.set_title( - "Objective convergence by iterations quantiles", - fontsize=14, - fontweight="bold", + + # Set labels and styling using base visualizer methods + self._set_standardized_labels( + ax, + xlabel="Objective functions", + ylabel="Iterations [#]", + title="Objective convergence by iterations quantiles" ) - ax.grid(True, linestyle="--", linewidth=0.5) - # Adjust layout to make room for figtext on the bottom right - plt.subplots_adjust(right=0.75, bottom=0.15) - # Add text annotations on the bottom right + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + self._add_standardized_legend(ax, loc='lower right') + + # Add convergence messages plt.figtext( 0.98, - 0, + 0.02, "\n".join(conv_msg), ha="right", va="bottom", - fontsize=8, + fontsize=self.fonts["sizes"]["small"], wrap=True, ) - plt.tight_layout() + + self.finalize_figure(tight_layout=True) logger.info("Successfully created moo distribution plot") - return self._convert_plot_to_base64(fig) + return {"moo_distribution": fig} + except Exception as e: logger.error( "Failed to create moo distribution plot: %s", str(e), exc_info=True @@ -107,7 +104,7 @@ def create_moo_distrb_plot( def create_moo_cloud_plot( self, df: pd.DataFrame, conv_msg: List[str], calibrated: bool - ) -> str: + ) -> Dict[str, plt.Figure]: logger.debug( "Starting moo cloud plot creation with data shape: %s, calibrated=%s", df.shape, @@ -119,21 +116,19 @@ def create_moo_cloud_plot( original_nrmse = df["nrmse"] quantiles = np.quantile(original_nrmse, self.nrmse_win) df["nrmse"] = np.clip(original_nrmse, *quantiles) - logger.debug( - "Clipped NRMSE values: min=%f, max=%f", quantiles[0], quantiles[1] - ) - # Set the style and color palette - sns.set_style("whitegrid") - sns.set_palette("Set2") - # Create the scatter plot - fig, ax = plt.subplots(figsize=(12, 10)) + + # Create figure using base visualizer methods + fig, ax = self.create_figure(figsize=self.figure_sizes["default"]) + + # Create scatter plot scatter = ax.scatter( df["nrmse"], df["decomp.rssd"], c=df["ElapsedAccum"], cmap="viridis", - alpha=0.7, + alpha=self.alpha["primary"] ) + if calibrated and "mape" in df.columns: logger.debug("Adding calibrated MAPE visualization") sizes = (df["mape"] - df["mape"].min()) / ( @@ -144,43 +139,49 @@ def create_moo_cloud_plot( df["nrmse"], df["decomp.rssd"], s=sizes, - alpha=0.5, + alpha=self.alpha["secondary"], edgecolor="w", linewidth=0.5, ) + + # Add colorbar plt.colorbar(scatter, label="Time [s]") - ax.set_xlabel("NRMSE", fontsize=12, ha="left", x=0) - ax.set_ylabel("DECOMP.RSSD", fontsize=12) - ax.set_title( - "Multi-objective evolutionary performance", - fontsize=14, - fontweight="bold", + + # Set labels and styling using base visualizer methods + self._set_standardized_labels( + ax, + xlabel="NRMSE", + ylabel="DECOMP.RSSD", + title="Multi-objective evolutionary performance" ) - # Add text annotations on the bottom right + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + + # Add convergence messages plt.figtext( 0.98, - 0, + 0.02, "\n".join(conv_msg), ha="right", va="bottom", - fontsize=8, + fontsize=self.fonts["sizes"]["small"], wrap=True, ) - plt.tight_layout() + self.finalize_figure(tight_layout=True) logger.info("Successfully created moo cloud plot") - return self._convert_plot_to_base64(fig) + return {"moo_cloud": fig} except Exception as e: logger.error("Failed to create moo cloud plot: %s", str(e), exc_info=True) raise - def create_ts_validation_plot(self, trials: List[Trial]) -> str: + def create_ts_validation_plot(self, trials: List[Trial]) -> Dict[str, plt.Figure]: logger.debug( "Starting time-series validation plot creation with %d trials", len(trials) ) try: - # Concatenate trial data + # Prepare data result_hyp_param = pd.concat( [trial.result_hyp_param for trial in trials], ignore_index=True ) @@ -188,68 +189,58 @@ def create_ts_validation_plot(self, trials: List[Trial]) -> str: result_hyp_param.groupby("sol_id").cumcount() + 1 ) result_hyp_param["iteration"] = result_hyp_param.index + 1 - logger.debug("Processing metrics for validation plot") + + # Process metrics result_hyp_param_long = result_hyp_param.melt( id_vars=["sol_id", "trial", "train_size", "iteration"], value_vars=[ - "rsq_train", - "rsq_val", - "rsq_test", - "nrmse_train", - "nrmse_val", - "nrmse_test", + "rsq_train", "rsq_val", "rsq_test", + "nrmse_train", "nrmse_val", "nrmse_test" ], var_name="metric", - value_name="value", + value_name="value" ) + # Extract dataset and metric type - result_hyp_param_long["dataset"] = ( - result_hyp_param_long["metric"].str.split("_").str[-1] - ) - result_hyp_param_long["metric_type"] = ( - result_hyp_param_long["metric"].str.split("_").str[0] - ) + result_hyp_param_long["dataset"] = result_hyp_param_long["metric"].str.split("_").str[-1] + result_hyp_param_long["metric_type"] = result_hyp_param_long["metric"].str.split("_").str[0] + # Winsorize the data - logger.debug("Winsorizing metric values") - result_hyp_param_long["value"] = result_hyp_param_long.groupby( - "metric_type" - )["value"].transform( + result_hyp_param_long["value"] = result_hyp_param_long.groupby("metric_type")["value"].transform( lambda x: np.clip( x, np.percentile(x, self.nrmse_win[0] * 100), np.percentile(x, self.nrmse_win[1] * 100), ) ) - # Set the style and color palette - sns.set_style("whitegrid") - sns.set_palette("Set2") - # Determine the number of trials + + # Create figure using base visualizer methods num_trials = result_hyp_param["trial"].nunique() - # Create subplots - fig, axes = plt.subplots( - num_trials + 1, - 1, - figsize=(12, 5 * (num_trials + 1)), - gridspec_kw={"height_ratios": [3] * num_trials + [1]}, - ) + fig = plt.figure(figsize=self.figure_sizes["default"]) + + # Create grid for subplots + gs = fig.add_gridspec(num_trials + 1, 1, height_ratios=[3] * num_trials + [1]) + # NRMSE plots for each trial - for i, (trial, ax) in enumerate( - zip(result_hyp_param["trial"].unique(), axes[:-1]) - ): + for i, trial in enumerate(result_hyp_param["trial"].unique()): + ax = fig.add_subplot(gs[i]) nrmse_data = result_hyp_param_long[ (result_hyp_param_long["metric_type"] == "nrmse") & (result_hyp_param_long["trial"] == trial) ] + + # Create plots sns.scatterplot( data=nrmse_data, x="iteration", y="value", hue="dataset", style="dataset", - markers=["o", "s", "D"], # Different markers for train, val, test + markers=["o", "s", "D"], ax=ax, - alpha=0.6, + alpha=self.alpha["primary"] ) + sns.lineplot( data=nrmse_data, x="iteration", @@ -257,33 +248,60 @@ def create_ts_validation_plot(self, trials: List[Trial]) -> str: hue="dataset", ax=ax, legend=False, - linewidth=1, + linewidth=1 ) - ax.set_ylabel(f"NRMSE [Trial {trial}]", fontsize=12, fontweight="bold") - ax.set_xlabel("Iteration", fontsize=12, fontweight="bold") - ax.legend(title="Dataset", loc="upper right") + + # Style the subplot + self._set_standardized_labels( + ax, + ylabel=f"NRMSE [Trial {trial}]", + xlabel="Iteration" if i == num_trials - 1 else "" + ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + self._add_standardized_legend(ax, loc='lower right') + + # Only show x-label on bottom plot + if i < num_trials - 1: + ax.set_xlabel("") + # Train Size plot + ax = fig.add_subplot(gs[-1]) sns.scatterplot( data=result_hyp_param, x="iteration", y="train_size", hue="trial", - ax=axes[-1], + ax=ax, legend=False, ) - axes[-1].set_ylabel("Train Size", fontsize=12, fontweight="bold") - axes[-1].set_xlabel("Iteration", fontsize=12, fontweight="bold") - axes[-1].set_ylim(0, 1) - axes[-1].yaxis.set_major_formatter( + + # Style the train size plot + self._set_standardized_labels( + ax, + xlabel="Iteration", + ylabel="Train Size" + ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + + ax.set_ylim(0, 1) + ax.yaxis.set_major_formatter( plt.FuncFormatter(lambda y, _: "{:.0%}".format(y)) ) - # Set the overall title - plt.suptitle( - "Time-series validation & Convergence", fontsize=14, fontweight="bold" + + # Set overall title + fig.suptitle( + "Time-series validation & Convergence", + fontsize=self.fonts["sizes"]["title"], + fontweight="bold", + y=1.02 ) - plt.tight_layout() + + self.finalize_figure(tight_layout=True) logger.info("Successfully created time-series validation plot") - return self._convert_plot_to_base64(fig) + return {"ts_validation": fig} + except Exception as e: logger.error( "Failed to create time-series validation plot: %s", @@ -291,7 +309,7 @@ def create_ts_validation_plot(self, trials: List[Trial]) -> str: exc_info=True, ) raise - + def _convert_plot_to_base64(self, fig: plt.Figure) -> str: logger.debug("Converting plot to base64") try: @@ -307,24 +325,72 @@ def _convert_plot_to_base64(self, fig: plt.Figure) -> str: logger.error("Failed to convert plot to base64: %s", str(e), exc_info=True) raise - def display_moo_distrb_plot(self): - """Display the MOO Distribution Plot.""" - self._display_base64_image(self.moo_distrb_plot) - - def display_moo_cloud_plot(self): - """Display the MOO Cloud Plot.""" - self._display_base64_image(self.moo_cloud_plot) + def display_convergence_plots(self, plots_dict: Dict[str, Any]) -> None: + """ + Display all convergence plots from a dictionary. + """ + logger.info("Displaying convergence plots") + try: + if 'moo_distrb_plot' in plots_dict and plots_dict['moo_distrb_plot']: + logger.info("Displaying MOO distribution plot") + for name, fig in plots_dict['moo_distrb_plot'].items(): + plt.figure(fig.number) + plt.show() - def display_ts_validation_plot(self): - """Display the Time-Series Validation Plot.""" - self._display_base64_image(self.ts_validation_plot) + if 'moo_cloud_plot' in plots_dict and plots_dict['moo_cloud_plot']: + logger.info("Displaying MOO cloud plot") + for name, fig in plots_dict['moo_cloud_plot'].items(): + plt.figure(fig.number) + plt.show() - def _display_base64_image(self, base64_image: str): - """Helper method to display a base64-encoded image.""" - display(Image(data=base64.b64decode(base64_image))) + if 'ts_validation_plot' in plots_dict and plots_dict['ts_validation_plot']: + logger.info("Displaying time series validation plot") + for name, fig in plots_dict['ts_validation_plot'].items(): + plt.figure(fig.number) + plt.show() + except Exception as e: + logger.error(f"Error displaying plots: {str(e)}") + raise def plot_all( self, display_plots: bool = True, export_location: Union[str, Path] = None - ) -> None: + ) -> Dict[str, plt.Figure]: + """ + Generate all available plots. + """ + logger.info("Generating all plots") + plot_collect: Dict[str, plt.Figure] = {} + + try: + # Generate plots if data is available + if self.moo_distrb_plot is not None: + logger.info("Creating moo distribution plot") + plot_collect.update(self.create_moo_distrb_plot(self.moo_distrb_plot, [])) + + if self.moo_cloud_plot is not None: + logger.info("Creating moo cloud plot") + plot_collect.update(self.create_moo_cloud_plot(self.moo_cloud_plot, [], False)) + + if self.ts_validation_plot is not None: + logger.info("Creating time series validation plot") + plot_collect.update(self.create_ts_validation_plot(self.ts_validation_plot)) + + if display_plots: + logger.info(f"Displaying plots: {list(plot_collect.keys())}") + for plot_name, fig in plot_collect.items(): + plt.figure(fig.number) + plt.show() + + if export_location: + logger.info(f"Exporting plots to: {export_location}") + self.export_plots_fig(export_location, plot_collect) + + return plot_collect + + except Exception as e: + logger.error("Failed to generate all plots: %s", str(e)) + raise - logger.warning("this method is not yet implemented") + def __del__(self): + """Cleanup when the visualizer is destroyed.""" + self.cleanup() \ No newline at end of file diff --git a/python/src/robyn/visualization/pareto_visualizer.py b/python/src/robyn/visualization/pareto_visualizer.py index ad07e54d7..13156d7fa 100644 --- a/python/src/robyn/visualization/pareto_visualizer.py +++ b/python/src/robyn/visualization/pareto_visualizer.py @@ -1,28 +1,29 @@ from pathlib import Path import re -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Union, Any from matplotlib import ticker, transforms import matplotlib.pyplot as plt import numpy as np import pandas as pd -from robyn.modeling.entities.modeloutputs import ModelOutputs import seaborn as sns import logging -from robyn.data.entities.enums import ProphetVariableType +from robyn.data.entities.enums import ProphetVariableType, DependentVarType from robyn.data.entities.holidays_data import HolidaysData from robyn.modeling.entities.featurized_mmm_data import FeaturizedMMMData from robyn.modeling.entities.pareto_result import ParetoResult from robyn.data.entities.hyperparameters import AdstockType, Hyperparameters from robyn.data.entities.mmmdata import MMMData from robyn.visualization.base_visualizer import BaseVisualizer -from robyn.data.entities.enums import DependentVarType -import math +from robyn.modeling.entities.modeloutputs import ModelOutputs import matplotlib.dates as mdates +import math logger = logging.getLogger(__name__) class ParetoVisualizer(BaseVisualizer): + """Visualizer for Pareto optimization results.""" + def __init__( self, pareto_result: ParetoResult, @@ -33,6 +34,17 @@ def __init__( unfiltered_pareto_result: Optional[ParetoResult] = None, model_outputs: Optional[ModelOutputs] = None, ): + """Initialize ParetoVisualizer. + + Args: + pareto_result: Pareto optimization results + mmm_data: Marketing mix model data + holiday_data: Holiday data for prophet variables + hyperparameter: Model hyperparameters + featurized_mmm_data: Featurized marketing mix model data + unfiltered_pareto_result: Unfiltered Pareto results + model_outputs: Model outputs data + """ super().__init__() self.pareto_result = pareto_result self.mmm_data = mmm_data @@ -41,198 +53,237 @@ def __init__( self.featurized_mmm_data = featurized_mmm_data self.unfiltered_pareto_result = unfiltered_pareto_result self.model_outputs = model_outputs + logger.info("Initialized ParetoVisualizer") - def _baseline_vars( - self, baseline_level, prophet_vars: List[ProphetVariableType] = [] - ) -> list: - """ - Returns a list of baseline variables based on the provided level. + def _baseline_vars(self, baseline_level: int, prophet_vars: List[ProphetVariableType] = []) -> list: + """Returns a list of baseline variables based on the provided level. + Args: - InputCollect (dict): A dictionary containing various input data. - baseline_level (int): The level of baseline variables to include. + baseline_level: The level of baseline variables to include (0-5) + prophet_vars: List of prophet variables to include + Returns: - list: A list of baseline variable names. + List of baseline variable names """ - # Check if baseline_level is valid + logger.debug(f"Getting baseline variables for level {baseline_level}") + if baseline_level < 0 or baseline_level > 5: - raise ValueError("baseline_level must be an integer between 0 and 5") + raise ValueError("baseline_level must be between 0 and 5") + baseline_variables = [] + # Level 1: Include intercept variables if baseline_level >= 1: baseline_variables.extend(["(Intercept)", "intercept"]) + # Level 2: Include trend variables if baseline_level >= 2: baseline_variables.append("trend") + # Level 3: Include prophet variables if baseline_level >= 3: baseline_variables.extend(list(set(baseline_variables + prophet_vars))) + # Level 4: Include context variables if baseline_level >= 4: baseline_variables.extend(self.mmm_data.mmmdata_spec.context_vars) + # Level 5: Include organic variables if baseline_level >= 5: baseline_variables.extend(self.mmm_data.mmmdata_spec.organic_vars) - return list(set(baseline_variables)) - def format_number(self, x: float, pos=None) -> str: - """Format large numbers with K/M/B abbreviations. + return list(set(baseline_variables)) + def _validate_solution_id(self, solution_id: str) -> None: + """Validate that a solution ID exists in the Pareto results. + Args: - x: Number to format - pos: Position (required by matplotlib FuncFormatter but not used) + solution_id: Solution ID to validate + + Raises: + ValueError: If solution ID is invalid + """ + if solution_id not in self.pareto_result.plot_data_collect: + raise ValueError(f"Invalid solution ID: {solution_id}") + def _setup_plot_defaults(self) -> Dict[str, Any]: + """Set up default plotting parameters. + Returns: - Formatted string + Dictionary of default plotting parameters """ - if abs(x) >= 1e9: - return f"{x/1e9:.1f}B" - elif abs(x) >= 1e6: - return f"{x/1e6:.1f}M" - elif abs(x) >= 1e3: - return f"{x/1e3:.1f}K" - else: - return f"{x:.1f}" + return { + "colors": { + "actual": self.colors["secondary"], + "predicted": self.colors["primary"], + "positive": self.colors["positive"], + "negative": self.colors["negative"], + "neutral": self.colors["neutral"] + }, + "alphas": { + "main": self.alpha["primary"], + "secondary": self.alpha["secondary"], + "background": self.alpha["background"] + }, + "line_styles": self.line_styles, + "markers": self.markers, + "font_sizes": self.fonts["sizes"] + } + def _create_standard_figure(self, figsize: Optional[tuple] = None) -> tuple: + """Create a figure with standard styling. + + Args: + figsize (Optional[tuple]): Optional figure size tuple (width, height) + + Returns: + tuple: (figure, axes) + """ + fig, ax = self.create_figure(figsize=figsize if figsize else self.figure_sizes["default"]) + ax.set_facecolor("white") + return fig, ax + def generate_waterfall( - self, solution_id: str, ax: Optional[plt.Axes] = None, baseline_level: int = 0 + self, solution_id: str, ax: Optional[plt.Axes] = None ) -> Optional[plt.Figure]: """Generate waterfall chart for specific solution.""" - logger.debug("Starting generation of waterfall plot") - if solution_id not in self.pareto_result.plot_data_collect: - raise ValueError(f"Invalid solution ID: {solution_id}") + + try: + if solution_id not in self.pareto_result.plot_data_collect: + raise ValueError(f"Invalid solution ID: {solution_id}") + + # Get data for specific solution + plot_data = self.pareto_result.plot_data_collect[solution_id] + waterfall_data = plot_data["plot2data"]["plotWaterfallLoop"].copy() + + # Get baseline variables + prophet_vars = self.holiday_data.prophet_vars if self.holiday_data else [] + baseline_vars = self._baseline_vars(baseline_level=0, prophet_vars=prophet_vars) + + # Transform baseline variables + waterfall_data["rn"] = np.where( + waterfall_data["rn"].isin(baseline_vars), + f"Baseline_L0", + waterfall_data["rn"], + ) - # Get data for specific solution - plot_data = self.pareto_result.plot_data_collect[solution_id] - waterfall_data = plot_data["plot2data"]["plotWaterfallLoop"].copy() - - # Get baseline variables - prophet_vars = self.holiday_data.prophet_vars if self.holiday_data else [] - baseline_vars = self._baseline_vars(baseline_level, prophet_vars) - - # Transform baseline variables - waterfall_data["rn"] = np.where( - waterfall_data["rn"].isin(baseline_vars), - f"Baseline_L{baseline_level}", - waterfall_data["rn"], - ) - - # Group and summarize - waterfall_data = ( - waterfall_data.groupby("rn", as_index=False) - .agg({"xDecompAgg": "sum", "xDecompPerc": "sum"}) - .reset_index() - ) - - # Sort by percentage contribution - waterfall_data = waterfall_data.sort_values("xDecompPerc", ascending=True) - - # Calculate waterfall positions - waterfall_data["end"] = 1 - waterfall_data["xDecompPerc"].cumsum() - waterfall_data["start"] = waterfall_data["end"].shift(1) - waterfall_data["start"] = waterfall_data["start"].fillna(1) - waterfall_data["sign"] = np.where( - waterfall_data["xDecompPerc"] >= 0, "Positive", "Negative" - ) - - # Create figure if no axes provided - if ax is None: - fig, ax = plt.subplots(figsize=(12, 8)) - else: - fig = None - - # Define colors - colors = {"Positive": "#59B3D2", "Negative": "#E5586E"} - - # Create categorical y-axis positions - y_pos = np.arange(len(waterfall_data)) - - # Create horizontal bars - bars = ax.barh( - y=y_pos, - width=waterfall_data["start"] - waterfall_data["end"], - left=waterfall_data["end"], - color=[colors[sign] for sign in waterfall_data["sign"]], - height=0.6, - ) - - # Add text labels - for idx, row in enumerate(waterfall_data.itertuples()): - # Format label text - if abs(row.xDecompAgg) >= 1e9: - formatted_num = f"{row.xDecompAgg/1e9:.1f}B" - elif abs(row.xDecompAgg) >= 1e6: - formatted_num = f"{row.xDecompAgg/1e6:.1f}M" - elif abs(row.xDecompAgg) >= 1e3: - formatted_num = f"{row.xDecompAgg/1e3:.1f}K" + # Group and summarize data + waterfall_data = ( + waterfall_data.groupby("rn", as_index=False) + .agg({"xDecompAgg": "sum", "xDecompPerc": "sum"}) + .reset_index() + ) + + # Sort by percentage contribution + waterfall_data = waterfall_data.sort_values("xDecompPerc", ascending=True) + + # Calculate waterfall positions + waterfall_data["end"] = 1 - waterfall_data["xDecompPerc"].cumsum() + waterfall_data["start"] = waterfall_data["end"].shift(1) + waterfall_data["start"] = waterfall_data["start"].fillna(1) + waterfall_data["sign"] = np.where( + waterfall_data["xDecompPerc"] >= 0, "Positive", "Negative" + ) + + # Create figure using BaseVisualizer methods + if ax is None: + fig, ax = self.create_figure(figsize=self.figure_sizes["medium"]) else: - formatted_num = f"{row.xDecompAgg:.1f}" + fig = None - # Calculate x-position as the middle of the bar - x_pos = (row.start + row.end) / 2 + # Define colors using BaseVisualizer color scheme + colors = { + "Positive": self.colors["positive"], + "Negative": self.colors["negative"] + } - # Use y_pos[idx] to ensure alignment with bars - ax.text( - x_pos, - y_pos[idx], # Use the same y-position as the corresponding bar - f"{formatted_num}\n{row.xDecompPerc*100:.1f}%", - ha="center", # Center align horizontally - va="center", # Center align vertically - fontsize=9, - linespacing=0.9, - ) - - # Set y-ticks and labels - ax.set_yticks(y_pos) - ax.set_yticklabels(waterfall_data["rn"]) - - # Format x-axis as percentage - ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: "{:.0%}".format(x))) - ax.set_xticks(np.arange(0, 1.1, 0.2)) - - # Set plot limits - ax.set_xlim(0, 1) - ax.set_ylim(-0.5, len(waterfall_data) - 0.5) - - # Add legend at top - from matplotlib.patches import Patch - - legend_elements = [ - Patch(facecolor=colors["Positive"], label="Positive"), - Patch(facecolor=colors["Negative"], label="Negative"), - ] - - # Create legend with white background - legend = ax.legend( - handles=legend_elements, - title="Sign", - loc="upper left", - bbox_to_anchor=(0, 1.15), - ncol=2, - frameon=True, - framealpha=1.0, - ) - - # Set title - ax.set_title("Response Decomposition Waterfall", pad=30, x=0.5, y=1.05) - - # Label axes - ax.set_xlabel("Contribution") - ax.set_ylabel(None) - - # Customize grid - ax.grid(True, axis="x", alpha=0.2) - ax.set_axisbelow(True) - - logger.debug("Successfully generated waterfall plot") - # Adjust layout - if fig: - plt.subplots_adjust(right=0.85, top=0.85) - fig = plt.gcf() - plt.close(fig) - return fig + # Create categorical y-axis positions + y_pos = np.arange(len(waterfall_data)) + + # Create horizontal bars + bars = ax.barh( + y=y_pos, + width=waterfall_data["start"] - waterfall_data["end"], + left=waterfall_data["end"], + color=[colors[sign] for sign in waterfall_data["sign"]], + height=0.6, + alpha=self.alpha["primary"] + ) + + # Add text labels + for idx, row in enumerate(waterfall_data.itertuples()): + # Format value using BaseVisualizer formatter + formatted_num = self.format_number(row.xDecompAgg) + + # Calculate x-position as the middle of the bar + x_pos = (row.start + row.end) / 2 + + # Add label with standardized font size + ax.text( + x_pos, + y_pos[idx], + f"{formatted_num}\n{row.xDecompPerc*100:.1f}%", + ha="center", + va="center", + fontsize=self.fonts["sizes"]["annotation"], + alpha=self.alpha["annotation"] + ) - return None + # Set y-ticks and labels + ax.set_yticks(y_pos) + ax.set_yticklabels(waterfall_data["rn"]) + + # Format x-axis as percentage + ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: "{:.0%}".format(x))) + ax.set_xticks(np.arange(0, 1.1, 0.2)) + + # Set plot limits + ax.set_xlim(0, 1) + ax.set_ylim(-0.5, len(waterfall_data) - 0.5) + + # Add standardized styling using BaseVisualizer methods + self._set_standardized_labels( + ax, + xlabel="Contribution", + ylabel=None, + title=f"Response Decomposition Waterfall (Solution {solution_id})" + ) + self._add_standardized_grid(ax, axis='x') + self._set_standardized_spines(ax) + + # Create legend elements + legend_handles = [ + plt.Rectangle((0, 0), 1, 1, facecolor=colors["Positive"], label="Positive"), + plt.Rectangle((0, 0), 1, 1, facecolor=colors["Negative"], label="Negative") + ] + legend_labels = ["Positive", "Negative"] + + # Add legend using BaseVisualizer method + self._add_standardized_legend( + ax, + loc='lower right', + ncol=2, + handles=legend_handles, + labels=legend_labels + ) + + # Set white background + ax.set_facecolor("white") + + # Finalize the figure + self.finalize_figure(tight_layout=True) + + logger.debug("Successfully generated waterfall plot") + + if fig: + plt.close(fig) + return fig + return None + + except Exception as e: + logger.error(f"Failed to generate waterfall plot: {str(e)}") + raise def generate_fitted_vs_actual( self, solution_id: str, ax: Optional[plt.Axes] = None @@ -240,159 +291,165 @@ def generate_fitted_vs_actual( """Generate time series plot comparing fitted vs actual values. Args: - ax: Optional matplotlib axes to plot on. If None, creates new figure + solution_id (str): The solution ID to generate the plot for + ax (Optional[plt.Axes]): Matplotlib axes to plot on. If None, creates new figure Returns: - Optional[plt.Figure]: Generated matplotlib Figure object + Optional[plt.Figure]: Generated matplotlib Figure object if ax is None """ - logger.debug("Starting generation of fitted vs actual plot") - if solution_id not in self.pareto_result.plot_data_collect: - raise ValueError(f"Invalid solution ID: {solution_id}") + try: + if solution_id not in self.pareto_result.plot_data_collect: + raise ValueError(f"Invalid solution ID: {solution_id}") - # Get data for specific solution - plot_data = self.pareto_result.plot_data_collect[solution_id] - ts_data = plot_data["plot5data"]["xDecompVecPlotMelted"].copy() + # Get data for specific solution + plot_data = self.pareto_result.plot_data_collect[solution_id] + ts_data = plot_data["plot5data"]["xDecompVecPlotMelted"].copy() - # Ensure ds column is datetime and remove any NaT values - ts_data["ds"] = pd.to_datetime(ts_data["ds"]) - ts_data = ts_data.dropna(subset=["ds"]) # Remove rows with NaT dates + # Ensure ds column is datetime and remove any NaT values + ts_data["ds"] = pd.to_datetime(ts_data["ds"]) + ts_data = ts_data.dropna(subset=["ds"]) - if ts_data.empty: - logger.warning(f"No valid date data found for solution {solution_id}") - return None + if ts_data.empty: + logger.warning(f"No valid date data found for solution {solution_id}") + return None - ts_data["linetype"] = np.where( - ts_data["variable"] == "predicted", "solid", "dotted" - ) - ts_data["variable"] = ts_data["variable"].str.title() - - # Get train_size from x_decomp_agg - train_size_series = self.pareto_result.x_decomp_agg[ - self.pareto_result.x_decomp_agg["sol_id"] == solution_id - ]["train_size"] - - if not train_size_series.empty: - train_size = float(train_size_series.iloc[0]) - else: - train_size = 0 - - if ax is None: - fig, ax = plt.subplots(figsize=(20, 10)) - else: - fig = None - - colors = { - "Actual": "#FF6B00", # Darker orange - "Predicted": "#0066CC", # Darker blue - } + # Prepare line styles + ts_data["linetype"] = np.where(ts_data["variable"] == "predicted", "solid", "dotted") + ts_data["variable"] = ts_data["variable"].str.title() - # Plot lines with different styles for predicted vs actual - for var in ts_data["variable"].unique(): - var_data = ts_data[ts_data["variable"] == var] - linestyle = "solid" if var_data["linetype"].iloc[0] == "solid" else "dotted" - ax.plot( - var_data["ds"], - var_data["value"], - label=var, - linestyle=linestyle, - linewidth=1, - color=colors[var], - ) - - # Format y-axis with abbreviations - ax.yaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) - - # Set y-axis limits with some padding - y_min, y_max = ax.get_ylim() - ax.set_ylim(y_min, y_max * 1.2) # Add 20% padding at the top - - # Add training/validation/test splits if train_size exists and is valid - if train_size > 0: - try: - # Get unique sorted dates, excluding NaT - unique_dates = sorted(ts_data["ds"].dropna().unique()) - total_days = len(unique_dates) - - if total_days > 0: - # Calculate split points - train_cut = int(total_days * train_size) - val_cut = train_cut + int(total_days * (1 - train_size) / 2) - - # Get dates for splits - splits = [ - (train_cut, "Train", train_size), - (val_cut, "Validation", (1 - train_size) / 2), - (total_days - 1, "Test", (1 - train_size) / 2), - ] - - # Get y-axis limits for text placement - y_min, y_max = ax.get_ylim() - - # Add vertical lines and labels - for idx, label, size in splits: - if 0 <= idx < len(unique_dates): # Ensure index is valid - date = unique_dates[idx] - if pd.notna(date): # Check if date is valid - # Add vertical line - extend beyond the top of the plot - ax.axvline( - date, color="#39638b", alpha=0.8, ymin=0, ymax=1.1 - ) - - # Add rotated text label - ax.text( - date, - y_max, - f"{label}: {size*100:.1f}%", - rotation=270, - color="#39638b", - alpha=0.5, - size=9, - ha="left", - va="top", - ) - except Exception as e: - logger.warning(f"Error adding split lines: {str(e)}") - # Continue with the rest of the plot even if split lines fail - - # Set title and labels - ax.set_title("Actual vs. Predicted Response", pad=20) - ax.set_xlabel("Date") - ax.set_ylabel("Response") - - # Configure legend - ax.legend( - bbox_to_anchor=(0.01, 1.02), # Position at top-left - loc="lower left", - ncol=2, # Two columns side by side - borderaxespad=0, - frameon=False, - fontsize=7, - handlelength=2, # Length of the legend lines - handletextpad=0.5, # Space between line and text - columnspacing=1.0, # Space between columns - ) - - # Grid styling - ax.grid(True, alpha=0.2) - ax.set_axisbelow(True) - ax.set_facecolor("white") + # Create figure using BaseVisualizer methods + if ax is None: + fig, ax = self.create_figure(figsize=self.figure_sizes["medium"]) + else: + fig = None + + # Define colors using BaseVisualizer color scheme + colors = { + "Actual": self.colors["secondary"], # Orange from BaseVisualizer + "Predicted": self.colors["primary"], # Blue from BaseVisualizer + } + + # Plot lines with different styles for predicted vs actual + for var in ts_data["variable"].unique(): + var_data = ts_data[ts_data["variable"] == var] + linestyle = "solid" if var_data["linetype"].iloc[0] == "solid" else "dotted" + ax.plot( + var_data["ds"], + var_data["value"], + label=var, + linestyle=linestyle, + linewidth=1.5, + color=colors[var], + alpha=self.alpha["primary"] + ) - # Format dates on x-axis using datetime locator and formatter - years = mdates.YearLocator() - years_fmt = mdates.DateFormatter("%Y") - ax.xaxis.set_major_locator(years) - ax.xaxis.set_major_formatter(years_fmt) + # Format y-axis with abbreviations using BaseVisualizer formatter + ax.yaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) + + # Set y-axis limits with padding + y_min, y_max = ax.get_ylim() + ax.set_ylim(y_min, y_max * 1.1) + + # Add training/validation/test splits if train_size exists + train_size_series = self.pareto_result.x_decomp_agg[ + self.pareto_result.x_decomp_agg["sol_id"] == solution_id + ]["train_size"] + + if not train_size_series.empty: + train_size = float(train_size_series.iloc[0]) + + if train_size > 0: + try: + # Get unique sorted dates + unique_dates = sorted(ts_data["ds"].dropna().unique()) + total_days = len(unique_dates) + + if total_days > 0: + # Calculate split points + train_cut = int(total_days * train_size) + val_cut = train_cut + int(total_days * (1 - train_size) / 2) + + # Get dates for splits + splits = [ + (train_cut, "Train", train_size), + (val_cut, "Validation", (1 - train_size) / 2), + (total_days - 1, "Test", (1 - train_size) / 2), + ] + + y_min, y_max = ax.get_ylim() + + # Add vertical lines and labels + for idx, label, size in splits: + if 0 <= idx < len(unique_dates): + date = unique_dates[idx] + if pd.notna(date): + # Add vertical line + ax.axvline( + date, + color=self.colors["grid"], + alpha=self.alpha["grid"], + ymin=0, + ymax=1.1, + linestyle='--' + ) + + # Add rotated text label + ax.text( + date, + y_max, + f"{label}: {size*100:.1f}%", + rotation=270, + color=self.colors["annotation"], + alpha=self.alpha["annotation"], + fontsize=self.fonts["sizes"]["annotation"], + ha="left", + va="top" + ) + except Exception as e: + logger.warning(f"Error adding split lines: {str(e)}") + + # Add standardized styling using BaseVisualizer methods + self._set_standardized_labels( + ax, + xlabel="Date", + ylabel="Response", + title=f"Actual vs. Predicted Response (Solution {solution_id})" + ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + self._add_standardized_legend( + ax, + loc='lower right', + ncol=2, + ) - logger.debug("Successfully generated fitted vs actual plot") - if fig: - plt.tight_layout() - plt.subplots_adjust(top=0.85) - fig = plt.gcf() - plt.close(fig) - return fig - return None + # Format dates on x-axis + years = mdates.YearLocator() + years_fmt = mdates.DateFormatter("%Y") + ax.xaxis.set_major_locator(years) + ax.xaxis.set_major_formatter(years_fmt) + + # Rotate x-axis labels for better readability + plt.setp(ax.get_xticklabels(), rotation=45, ha='right') + + # Set white background + ax.set_facecolor("white") + + # Finalize the figure + self.finalize_figure(tight_layout=True) + + logger.debug("Successfully generated fitted vs actual plot") + + if fig: + plt.close(fig) + return fig + return None + + except Exception as e: + logger.error(f"Failed to generate fitted vs actual plot: {str(e)}") + raise def generate_diagnostic_plot( self, solution_id: str, ax: Optional[plt.Axes] = None @@ -400,200 +457,289 @@ def generate_diagnostic_plot( """Generate diagnostic scatter plot of fitted vs residual values. Args: - ax: Optional matplotlib axes to plot on. If None, creates new figure + solution_id (str): The solution ID to generate the plot for + ax (Optional[plt.Axes]): Matplotlib axes to plot on. If None, creates new figure Returns: - Optional[plt.Figure]: Generated matplotlib Figure object + Optional[plt.Figure]: Generated matplotlib Figure object if ax is None """ - logger.debug("Starting generation of diagnostic plot") - if solution_id not in self.pareto_result.plot_data_collect: - raise ValueError(f"Invalid solution ID: {solution_id}") - - # Get data for specific solution - plot_data = self.pareto_result.plot_data_collect[solution_id] - diag_data = plot_data["plot6data"]["xDecompVecPlot"].copy() + try: + if solution_id not in self.pareto_result.plot_data_collect: + raise ValueError(f"Invalid solution ID: {solution_id}") - # Calculate residuals - diag_data["residuals"] = diag_data["actual"] - diag_data["predicted"] + # Get data for specific solution + plot_data = self.pareto_result.plot_data_collect[solution_id] + diag_data = plot_data["plot6data"]["xDecompVecPlot"].copy() - # Create figure if no axes provided - if ax is None: - fig, ax = plt.subplots(figsize=(16, 10)) - else: - fig = None - - # Create scatter plot - ax.scatter( - diag_data["predicted"], diag_data["residuals"], alpha=0.5, color="steelblue" - ) - - # Add horizontal line at y=0 - ax.axhline(y=0, color="black", linestyle="-", linewidth=0.8) + # Calculate residuals + diag_data["residuals"] = diag_data["actual"] - diag_data["predicted"] + + # Create figure using BaseVisualizer methods + if ax is None: + fig, ax = self.create_figure(figsize=self.figure_sizes["medium"]) + else: + fig = None + + # Create scatter plot with BaseVisualizer colors + scatter = ax.scatter( + diag_data["predicted"], + diag_data["residuals"], + alpha=self.alpha["primary"], + color=self.colors["primary"], + label="Residuals" + ) - # Add smoothed line with confidence interval - from scipy.stats import gaussian_kde + # Add horizontal line at y=0 + ax.axhline( + y=0, + color=self.colors["baseline"], + linestyle=self.line_styles["solid"], + linewidth=0.8, + alpha=self.alpha["annotation"] + ) - x_smooth = np.linspace( - diag_data["predicted"].min(), diag_data["predicted"].max(), 100 - ) + # Fit LOWESS smoother + from statsmodels.nonparametric.smoothers_lowess import lowess + + # Calculate smooth line + smoothed = lowess( + diag_data["residuals"], + diag_data["predicted"], + frac=0.2, + return_sorted=True + ) - # Fit LOWESS - from statsmodels.nonparametric.smoothers_lowess import lowess + # Plot smoothed line + ax.plot( + smoothed[:, 0], + smoothed[:, 1], + color=self.colors["secondary"], + linewidth=2, + alpha=self.alpha["secondary"], + label="Smoothed trend" + ) - smoothed = lowess(diag_data["residuals"], diag_data["predicted"], frac=0.2) + # Calculate and plot confidence intervals + residual_std = np.std(diag_data["residuals"]) + ax.fill_between( + smoothed[:, 0], + smoothed[:, 1] - 2 * residual_std, + smoothed[:, 1] + 2 * residual_std, + color=self.colors["secondary"], + alpha=self.alpha["background"], + label="95% Confidence interval" + ) - # Plot smoothed line - ax.plot(smoothed[:, 0], smoothed[:, 1], color="red", linewidth=2, alpha=0.8) + # Add statistical annotations + stats_text = ( + f"Standard deviation: {residual_std:.2f}\n" + f"Mean residual: {np.mean(diag_data['residuals']):.2f}\n" + f"Median residual: {np.median(diag_data['residuals']):.2f}" + ) + + ax.text( + 0.02, 0.98, + stats_text, + transform=ax.transAxes, + verticalalignment='top', + bbox=dict( + boxstyle='round', + facecolor='white', + alpha=self.alpha["annotation"], + edgecolor=self.colors["grid"] + ), + fontsize=self.fonts["sizes"]["annotation"] + ) - # Calculate confidence intervals (using standard error bands) - residual_std = np.std(diag_data["residuals"]) - ax.fill_between( - smoothed[:, 0], - smoothed[:, 1] - 2 * residual_std, - smoothed[:, 1] + 2 * residual_std, - color="red", - alpha=0.1, - ) + # Format axes with abbreviations using BaseVisualizer formatter + ax.xaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) + ax.yaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) - # Format axes with abbreviations - ax.xaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) - ax.yaxis.set_major_formatter(ticker.FuncFormatter(self.format_number)) + # Add standardized styling using BaseVisualizer methods + self._set_standardized_labels( + ax, + xlabel="Fitted Values", + ylabel="Residuals", + title=f"Diagnostic Plot: Fitted vs Residuals (Solution {solution_id})" + ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + self._add_standardized_legend( + ax, + loc='lower right', + title="Components" + ) - # Set labels and title - ax.set_xlabel("Fitted") - ax.set_ylabel("Residual") - ax.set_title("Fitted vs. Residual") + # Set white background + ax.set_facecolor("white") - # Customize grid - ax.grid(True, alpha=0.2) - ax.set_axisbelow(True) + # Calculate and set reasonable axis limits with padding + x_range = diag_data["predicted"].max() - diag_data["predicted"].min() + y_range = diag_data["residuals"].max() - diag_data["residuals"].min() + + ax.set_xlim( + diag_data["predicted"].min() - 0.05 * x_range, + diag_data["predicted"].max() + 0.05 * x_range + ) + ax.set_ylim( + diag_data["residuals"].min() - 0.05 * y_range, + diag_data["residuals"].max() + 0.05 * y_range + ) - # Use white background - ax.set_facecolor("white") + # Add diagnostic lines + mean_residual = np.mean(diag_data["residuals"]) + if abs(mean_residual) > residual_std * 0.1: # Only show if mean is notably different from 0 + ax.axhline( + y=mean_residual, + color=self.colors["annotation"], + linestyle=self.line_styles["dotted"], + alpha=self.alpha["annotation"], + label="Mean residual" + ) - logger.debug("Successfully generated of diagnostic plot") + # Finalize the figure + self.finalize_figure(tight_layout=True) - if fig: - plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - return fig - return None + logger.debug("Successfully generated diagnostic plot") + + if fig: + plt.close(fig) + return fig + return None + except Exception as e: + logger.error(f"Failed to generate diagnostic plot: {str(e)}") + raise + def generate_immediate_vs_carryover( self, solution_id: str, ax: Optional[plt.Axes] = None ) -> Optional[plt.Figure]: """Generate stacked bar chart comparing immediate vs carryover effects. Args: - ax: Optional matplotlib axes to plot on. If None, creates new figure + solution_id: Solution ID to visualize + ax: Optional matplotlib axes to plot on Returns: - plt.Figure if ax is None, else None + Optional[plt.Figure]: Generated figure if ax is None """ - logger.debug("Starting generation of immediate vs carryover plot") - if solution_id not in self.pareto_result.plot_data_collect: - raise ValueError(f"Invalid solution ID: {solution_id}") + try: + self._validate_solution_id(solution_id) + + # Get and prepare data + plot_data = self.pareto_result.plot_data_collect[solution_id] + df_imme_caov = plot_data["plot7data"].copy() + + # Ensure percentage is numeric + df_imme_caov["percentage"] = pd.to_numeric(df_imme_caov["percentage"], errors="coerce") + + # Sort channels alphabetically + df_imme_caov = df_imme_caov.sort_values("rn", ascending=True) - plot_data = self.pareto_result.plot_data_collect[solution_id] - df_imme_caov = plot_data["plot7data"].copy() - - # Ensure percentage is numeric - df_imme_caov["percentage"] = pd.to_numeric( - df_imme_caov["percentage"], errors="coerce" - ) - - # Sort channels alphabetically - df_imme_caov = df_imme_caov.sort_values("rn", ascending=True) - - # Set up type factor levels matching R plot order - df_imme_caov["type"] = pd.Categorical( - df_imme_caov["type"], categories=["Immediate", "Carryover"], ordered=True - ) - - if ax is None: - fig, ax = plt.subplots(figsize=(16, 10)) - else: - fig = None - - colors = {"Immediate": "#59B3D2", "Carryover": "coral"} - - bottom = np.zeros(len(df_imme_caov["rn"].unique())) - y_pos = range(len(df_imme_caov["rn"].unique())) - channels = df_imme_caov["rn"].unique() - types = ["Immediate", "Carryover"] # Order changed to Immediate first - - # Normalize percentages to sum to 100% for each channel - for channel in channels: - mask = df_imme_caov["rn"] == channel - total = df_imme_caov.loc[mask, "percentage"].sum() - if total > 0: # Avoid division by zero - df_imme_caov.loc[mask, "percentage"] = ( - df_imme_caov.loc[mask, "percentage"] / total + # Set up type factor levels + df_imme_caov["type"] = pd.Categorical( + df_imme_caov["type"], + categories=["Immediate", "Carryover"], + ordered=True + ) + + # Create figure + if ax is None: + fig, ax = self._create_standard_figure(figsize=self.figure_sizes["medium"]) + else: + fig = None + + # Set up colors using BaseVisualizer scheme + colors = { + "Immediate": self.colors["primary"], + "Carryover": self.colors["secondary"] + } + + # Initialize variables for stacked bars + bottom = np.zeros(len(df_imme_caov["rn"].unique())) + y_pos = range(len(df_imme_caov["rn"].unique())) + channels = df_imme_caov["rn"].unique() + types = ["Immediate", "Carryover"] + + # Normalize percentages to sum to 100% for each channel + for channel in channels: + mask = df_imme_caov["rn"] == channel + total = df_imme_caov.loc[mask, "percentage"].sum() + if total > 0: + df_imme_caov.loc[mask, "percentage"] = ( + df_imme_caov.loc[mask, "percentage"] / total + ) + + # Create stacked bars + for type_name in types: + type_data = df_imme_caov[df_imme_caov["type"] == type_name] + percentages = type_data["percentage"].values + + bars = ax.barh( + y_pos, + percentages, + left=bottom, + height=0.5, + label=type_name, + color=colors[type_name], + alpha=self.alpha["primary"] ) - for type_name in types: - type_data = df_imme_caov[df_imme_caov["type"] == type_name] - percentages = type_data["percentage"].values + # Add percentage labels + for i, (rect, percentage) in enumerate(zip(bars, percentages)): + width = rect.get_width() + x_pos = bottom[i] + width / 2 + percentage_text = f"{percentage*100:.0f}%" + ax.text( + x_pos, + i, + percentage_text, + ha="center", + va="center", + fontsize=self.fonts["sizes"]["annotation"], + color=self.colors["annotation"] + ) + + bottom += percentages - bars = ax.barh( - y_pos, - percentages, - left=bottom, - height=0.5, - label=type_name, - color=colors[type_name], + # Set up axes + ax.set_yticks(y_pos) + ax.set_yticklabels(channels) + ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%")) + ax.set_xlim(0, 1) + + # Add styling using BaseVisualizer methods + self._set_standardized_labels( + ax, + xlabel="Response Percentage", + ylabel=None, + title=f"Immediate vs. Carryover Response (Solution {solution_id})" + ) + self._add_standardized_grid(ax, axis='x') + self._set_standardized_spines(ax) + self._add_standardized_legend( + ax, + loc='lower right', + ncol=2, ) - for i, (rect, percentage) in enumerate(zip(bars, percentages)): - width = rect.get_width() - x_pos = bottom[i] + width / 2 - try: - percentage_text = f"{round(float(percentage) * 100)}%" - except (ValueError, TypeError): - percentage_text = "0%" - ax.text(x_pos, i, percentage_text, ha="center", va="center") - - bottom += percentages - - ax.set_yticks(y_pos) - ax.set_yticklabels(channels) - - ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%")) - ax.set_xlim(0, 1) - - # Reduced legend size - ax.legend( - title=None, - bbox_to_anchor=(0, 1.02, 0.15, 0.1), # Reduced width from 0.3 to 0.2 - loc="lower left", - ncol=2, - mode="expand", - borderaxespad=0, - frameon=False, - fontsize=7, # Reduced from 8 to 7 - ) - - ax.set_xlabel("% Response") - ax.set_ylabel(None) - ax.set_title("Immediate vs. Carryover Response Percentage", pad=50, y=1.2) - - ax.grid(True, axis="x", alpha=0.2) - ax.grid(False, axis="y") - ax.set_axisbelow(True) - ax.set_facecolor("white") + # Finalize figure + self.finalize_figure(tight_layout=True) - if fig: - plt.tight_layout() - plt.subplots_adjust(top=0.85) - fig = plt.gcf() - plt.close(fig) - return fig - return None + logger.debug("Successfully generated immediate vs carryover plot") + + if fig: + plt.close(fig) + return fig + return None + + except Exception as e: + logger.error(f"Failed to generate immediate vs carryover plot: {str(e)}") + raise def generate_adstock_rate( self, solution_id: str, ax: Optional[plt.Axes] = None @@ -601,405 +747,685 @@ def generate_adstock_rate( """Generate adstock rate visualization based on adstock type. Args: - solution_id: ID of solution to visualize - ax: Optional matplotlib axes to plot on. If None, creates new figure + solution_id: Solution ID to visualize + ax: Optional matplotlib axes to plot on Returns: - Optional[plt.Figure]: Generated figure if ax is None, otherwise None + Optional[plt.Figure]: Generated figure if ax is None """ - logger.debug("Starting generation of adstock plot") - plot_data = self.pareto_result.plot_data_collect[solution_id] - adstock_data = plot_data["plot3data"] + try: + self._validate_solution_id(solution_id) + plot_data = self.pareto_result.plot_data_collect[solution_id] + adstock_data = plot_data["plot3data"] - if ax is None: - fig, ax = plt.subplots(figsize=(16, 10)) - else: - fig = None + if ax is None: + fig, ax = self._create_standard_figure(figsize=self.figure_sizes["medium"]) + else: + fig = None + + if self.hyperparameter.adstock == AdstockType.GEOMETRIC: + # Handle Geometric Adstock + dt_geometric = adstock_data["dt_geometric"].copy() + dt_geometric = dt_geometric.sort_values("channels", ascending=True) + + # Create horizontal bars + bars = ax.barh( + y=range(len(dt_geometric)), + width=dt_geometric["thetas"], + height=0.5, + color=self.colors["secondary"], + alpha=self.alpha["primary"] + ) - # Handle different adstock types - if self.hyperparameter.adstock == AdstockType.GEOMETRIC: - dt_geometric = adstock_data["dt_geometric"].copy() + # Add value labels + for i, theta in enumerate(dt_geometric["thetas"]): + ax.text( + theta + 0.01, + i, + f"{theta*100:.1f}%", + va="center", + fontsize=self.fonts["sizes"]["annotation"], + color=self.colors["annotation"] + ) - # Sort data alphabetically by channel - dt_geometric = dt_geometric.sort_values("channels", ascending=True) + # Set up axes + ax.set_yticks(range(len(dt_geometric))) + ax.set_yticklabels(dt_geometric["channels"]) + ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%")) + ax.set_xlim(0, 1) + ax.set_xticks(np.arange(0, 1.25, 0.25)) - bars = ax.barh( - y=range(len(dt_geometric)), - width=dt_geometric["thetas"], - height=0.5, - color="coral", - ) + # Add styling + interval_type = ( + self.mmm_data.mmmdata_spec.interval_type if self.mmm_data else "day" + ) + self._set_standardized_labels( + ax, + xlabel=f"Thetas [by {interval_type}]", + ylabel=None, + title=f"Geometric Adstock: Fixed Rate Over Time (Solution {solution_id})" + ) - for i, theta in enumerate(dt_geometric["thetas"]): - ax.text( - theta + 0.01, i, f"{theta*100:.1f}%", va="center", fontweight="bold" + elif self.hyperparameter.adstock in [AdstockType.WEIBULL_CDF, AdstockType.WEIBULL_PDF]: + # Handle Weibull Adstock + weibull_data = adstock_data["weibullCollect"] + channels = sorted(weibull_data["channel"].unique()) + rows = (len(channels) + 2) // 3 # Calculate rows needed for 3 columns + + if ax is None: + # Create new figure with subplots + fig = plt.figure(figsize=self.figure_sizes["wide"]) + gs = fig.add_gridspec(rows, 3, hspace=0.4, wspace=0.3) + axes = [] + for i in range(rows): + for j in range(3): + axes.append(fig.add_subplot(gs[i, j])) + else: + # Use existing axes layout + gs = ax.get_gridspec() + fig = ax.figure + axes = [ax] + + # Create plots for each channel + for idx, channel in enumerate(channels): + if idx < len(axes): + ax_sub = axes[idx] + channel_data = weibull_data[weibull_data["channel"] == channel] + + # Plot decay curve + ax_sub.plot( + channel_data["x"], + channel_data["decay_accumulated"], + color=self.colors["primary"], + alpha=self.alpha["primary"], + linewidth=2 + ) + + # Add halflife line + ax_sub.axhline( + y=0.5, + color=self.colors["grid"], + linestyle=self.line_styles["dashed"], + alpha=self.alpha["grid"] + ) + + # Add halflife label + halflife_x = channel_data[ + channel_data["decay_accumulated"].between(0.49, 0.51) + ]["x"].iloc[0] + + ax_sub.text( + halflife_x * 1.1, + 0.52, + f"Halflife: {halflife_x:.1f}", + color=self.colors["annotation"], + fontsize=self.fonts["sizes"]["annotation"], + va="bottom", + ha="left" + ) + + # Style subplot + self._set_standardized_labels( + ax_sub, + xlabel="Time", + ylabel="Decay Rate" if idx % 3 == 0 else None, + title=channel + ) + self._add_standardized_grid(ax_sub) + self._set_standardized_spines(ax_sub) + + # Set axis limits + ax_sub.set_ylim(0, 1.1) + ax_sub.set_xlim(0, max(channel_data["x"]) * 1.2) + + # Hide unused subplots + for idx in range(len(channels), len(axes)): + axes[idx].set_visible(False) + + # Add overall title + fig.suptitle( + f"Weibull {self.hyperparameter.adstock.value} Adstock Decay Curves (Solution {solution_id})", + fontsize=self.fonts["sizes"]["title"], + y=1.02 ) - ax.set_yticks(range(len(dt_geometric))) - ax.set_yticklabels(dt_geometric["channels"]) + else: + logger.warning(f"Unsupported adstock type: {self.hyperparameter.adstock}") + return None + + # Add common styling + self._add_standardized_grid(ax, axis='x') + self._set_standardized_spines(ax) + ax.set_facecolor("white") + + # Finalize figure + self.finalize_figure(tight_layout=True) + + logger.debug("Successfully generated adstock plot") + + if fig: + plt.close(fig) + return fig + return None + + except Exception as e: + logger.error(f"Failed to generate adstock plot: {str(e)}") + raise - # Format x-axis with 25% increments - ax.xaxis.set_major_formatter( - plt.FuncFormatter(lambda x, p: f"{x*100:.0f}%") - ) - ax.set_xlim(0, 1) - ax.set_xticks(np.arange(0, 1.25, 0.25)) # Changed to 0.25 increments + def create_prophet_decomposition_plot(self) -> Optional[plt.Figure]: + """Create Prophet Decomposition Plot showing model components.""" + logger.debug("Starting generation of prophet decomposition plot") - # Set title and labels - interval_type = ( - self.mmm_data.mmmdata_spec.interval_type if self.mmm_data else "day" + try: + # Get prophet variables + prophet_vars = ( + [ProphetVariableType(var) for var in self.holiday_data.prophet_vars] + if self.holiday_data and self.holiday_data.prophet_vars + else [] ) - ax.set_title( - f"Geometric Adstock: Fixed Rate Over Time (Solution {solution_id})" - ) - ax.set_xlabel(f"Thetas [by {interval_type}]") - ax.set_ylabel(None) + factor_vars = self.mmm_data.mmmdata_spec.factor_vars if self.mmm_data else [] - elif self.hyperparameter.adstock in [ - AdstockType.WEIBULL_CDF, - AdstockType.WEIBULL_PDF, - ]: - # [Weibull code remains the same] - weibull_data = adstock_data["weibullCollect"] - wb_type = adstock_data["wb_type"] + if not (prophet_vars or factor_vars): + logger.info("No prophet or factor variables found") + return None - channels = sorted( - weibull_data["channel"].unique() - ) # Sort channels alphabetically - rows = (len(channels) + 2) // 3 + # Prepare data + df = self.featurized_mmm_data.dt_mod.copy() + prophet_vars_str = [variable.value for variable in prophet_vars] + prophet_vars_str.sort(reverse=True) - if ax is None: - fig, axes = plt.subplots(rows, 3, figsize=(15, 4 * rows), squeeze=False) - axes = axes.flatten() - else: - gs = ax.get_gridspec() - subfigs = ax.figure.subfigures(rows, 3) - axes = [subfig.subplots() for subfig in subfigs] - axes = [ax for sublist in axes for ax in sublist] - - for idx, channel in enumerate(channels): - ax_sub = axes[idx] - channel_data = weibull_data[weibull_data["channel"] == channel] - - ax_sub.plot( - channel_data["x"], - channel_data["decay_accumulated"], - color="steelblue", + # Combine variables + value_variables = ( + [ + "dep_var" + if hasattr(df, "dep_var") + else self.mmm_data.mmmdata_spec.dep_var + ] + + factor_vars + + prophet_vars_str + ) + + # Prepare long format data + df_long = df.melt( + id_vars=["ds"], + value_vars=value_variables, + var_name="variable", + value_name="value", + ) + df_long["ds"] = pd.to_datetime(df_long["ds"]) + + # Create figure with subplots + n_vars = len(df_long["variable"].unique()) + fig = plt.figure(figsize=(12, 3 * n_vars)) + + # Create gridspec with tighter spacing + gs = fig.add_gridspec( + n_vars, + 1, + height_ratios=[1] * n_vars, + hspace=0.7 # Adjust vertical space between subplots + ) + + # Create subplot for each variable + for i, var in enumerate(df_long["variable"].unique()): + ax = fig.add_subplot(gs[i]) + var_data = df_long[df_long["variable"] == var] + + # Plot time series + ax.plot( + var_data["ds"], + var_data["value"], + color=self.colors["primary"], + alpha=self.alpha["primary"] ) - ax_sub.axhline(y=0.5, color="gray", linestyle="--", alpha=0.5) - ax_sub.text( - max(channel_data["x"]), - 0.5, - "Halflife", - color="gray", - va="bottom", - ha="right", + # Style subplot + self._set_standardized_labels( + ax, + xlabel=None, + ylabel=None, + title=var ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + ax.set_facecolor("white") + + # Format x-axis dates + ax.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m')) + plt.setp(ax.get_xticklabels(), rotation=45, ha='right') + + # Adjust subplot padding + ax.margins(x=0.02) + + # Add overall title with adjusted position + fig.suptitle( + "Prophet Decomposition", + fontsize=self.fonts["sizes"]["title"], + y=1 # Moved title closer to the first subplot + ) - ax_sub.set_title(channel) - ax_sub.grid(True, alpha=0.2) - ax_sub.set_ylim(0, 1) + # Fine-tune the layout + plt.tight_layout() + + # Adjust layout to accommodate the title + plt.subplots_adjust(top=0.95) # Adjust top margin to prevent title overlap + + logger.debug("Successfully generated prophet decomposition plot") + return fig - # Customize grid - if self.hyperparameter.adstock == AdstockType.GEOMETRIC: - ax.grid(True, axis="x", alpha=0.2) - ax.grid(False, axis="y") - ax.set_axisbelow(True) + except Exception as e: + logger.error(f"Failed to generate prophet decomposition plot: {str(e)}") + raise - ax.set_facecolor("white") + def create_hyperparameter_sampling_distribution(self) -> Optional[plt.Figure]: + """Create Hyperparameter Sampling Distribution Plot. + + Returns: + Optional[plt.Figure]: Generated figure showing hyperparameter distributions + """ + logger.debug("Starting generation of hyperparameter sampling distribution plot") + + try: + if self.unfiltered_pareto_result is None: + logger.info("No unfiltered Pareto results available") + return None + + # Get hyperparameter data + result_hyp_param = self.unfiltered_pareto_result.result_hyp_param + hp_names = list(self.hyperparameter.hyperparameters.keys()) + hp_names = [name.replace("lambda", "lambda_hp") for name in hp_names] + + # Find matching columns + matching_columns = [ + col + for col in result_hyp_param.columns + if any(re.search(pattern, col, re.IGNORECASE) for pattern in hp_names) + ] + matching_columns.sort() + + if not matching_columns: + logger.info("No matching hyperparameter columns found") + return None + + # Prepare data + hyp_df = result_hyp_param[matching_columns] + melted_df = hyp_df.melt(var_name="variable", value_name="value") + melted_df["variable"] = melted_df["variable"].replace("lambda_hp", "lambda") + + # Parse variable names + def parse_variable(variable): + parts = variable.split("_") + return {"type": parts[-1], "channel": "_".join(parts[:-1])} + + parsed_vars = melted_df["variable"].apply(parse_variable).apply(pd.Series) + melted_df[["type", "channel"]] = parsed_vars + + # Create categorical variables + melted_df["type"] = pd.Categorical( + melted_df["type"], + categories=melted_df["type"].unique() + ) + melted_df["channel"] = pd.Categorical( + melted_df["channel"], + categories=melted_df["channel"].unique()[::-1] + ) + + # Create figure using facet grid + g = sns.FacetGrid( + melted_df, + col="type", + sharex=False, + height=6, + aspect=1 + ) + + def violin_plot(x, y, **kwargs): + sns.violinplot( + x=x, + y=y, + **kwargs, + alpha=self.alpha["primary"], + linewidth=0, + palette=sns.color_palette("Set2") + ) + + g.map_dataframe( + violin_plot, + x="value", + y="channel", + hue="channel" + ) + + # Style facets + g.set_titles("{col_name}") + g.set_xlabels("Hyperparameter space") + g.set_ylabels("") - logger.debug("Successfully generated adstock plot") + # Add titles + g.figure.suptitle( + "Hyperparameters Optimization Distributions", + y=1.05, + fontsize=self.fonts["sizes"]["title"] + ) + subtitle_text = ( + f"Sample distribution, iterations = " + f"{self.model_outputs.iterations} x {len(self.model_outputs.trials)} trial" + ) + g.figure.text( + 0.5, + 0.98, + subtitle_text, + ha="center", + fontsize=self.fonts["sizes"]["subtitle"] + ) - if fig: + # Finalize figure + plt.subplots_adjust(top=0.9) plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - return fig - return None - - def create_prophet_decomposition_plot(self): - """Create Prophet Decomposition Plot.""" - prophet_vars = ( - [ProphetVariableType(var) for var in self.holiday_data.prophet_vars] - if self.holiday_data and self.holiday_data.prophet_vars - else [] - ) - factor_vars = self.mmm_data.mmmdata_spec.factor_vars if self.mmm_data else [] - if not (prophet_vars or factor_vars): - return None - df = self.featurized_mmm_data.dt_mod.copy() - prophet_vars_str = [variable.value for variable in prophet_vars] - prophet_vars_str.sort(reverse=True) - value_variables = ( - [ - ( - "dep_var" - if hasattr(df, "dep_var") - else self.mmm_data.mmmdata_spec.dep_var + + logger.debug("Successfully generated hyperparameter sampling distribution plot") + return g.figure + + except Exception as e: + logger.error(f"Failed to generate hyperparameter sampling distribution plot: {str(e)}") + raise + + def create_pareto_front_plot(self, is_calibrated: bool) -> Optional[plt.Figure]: + """Create Pareto Front Plot showing optimization performance. + + Args: + is_calibrated: Whether the model is calibrated + + Returns: + Optional[plt.Figure]: Generated figure showing Pareto fronts + """ + logger.debug(f"Starting generation of Pareto front plot (calibrated={is_calibrated})") + + try: + unfiltered_pareto_results = self.unfiltered_pareto_result + result_hyp_param = unfiltered_pareto_results.result_hyp_param + pareto_fronts = self.pareto_result.pareto_fronts + + # Create figure using BaseVisualizer + fig, ax = self._create_standard_figure(figsize=self.figure_sizes["medium"]) + + # Handle calibrated case + if is_calibrated: + result_hyp_param["iterations"] = np.where( + result_hyp_param["robynPareto"].isna(), + np.nan, + result_hyp_param["iterations"], ) - ] - + factor_vars - + prophet_vars_str - ) - df_long = df.melt( - id_vars=["ds"], - value_vars=value_variables, - var_name="variable", - value_name="value", - ) - df_long["ds"] = pd.to_datetime(df_long["ds"]) - plt.figure(figsize=(12, 3 * len(df_long["variable"].unique()))) - prophet_decomp_plot = plt.figure( - figsize=(12, 3 * len(df_long["variable"].unique())) - ) - gs = prophet_decomp_plot.add_gridspec(len(df_long["variable"].unique()), 1) - for i, var in enumerate(df_long["variable"].unique()): - ax = prophet_decomp_plot.add_subplot(gs[i, 0]) - var_data = df_long[df_long["variable"] == var] - ax.plot(var_data["ds"], var_data["value"], color="steelblue") - ax.set_title(var) - ax.set_xlabel(None) - ax.set_ylabel(None) - plt.suptitle("Prophet decomposition") - plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - return fig - - def create_hyperparameter_sampling_distribution(self): - """Create Hyperparameter Sampling Distribution Plot.""" - unfiltered_pareto_results = self.unfiltered_pareto_result - if unfiltered_pareto_results is None: - return None - result_hyp_param = unfiltered_pareto_results.result_hyp_param - hp_names = list(self.hyperparameter.hyperparameters.keys()) - hp_names = [name.replace("lambda", "lambda_hp") for name in hp_names] - matching_columns = [ - col - for col in result_hyp_param.columns - if any(re.search(pattern, col, re.IGNORECASE) for pattern in hp_names) - ] - matching_columns.sort() - hyp_df = result_hyp_param[matching_columns] - melted_df = hyp_df.melt(var_name="variable", value_name="value") - melted_df["variable"] = melted_df["variable"].replace("lambda_hp", "lambda") - - def parse_variable(variable): - parts = variable.split("_") - return {"type": parts[-1], "channel": "_".join(parts[:-1])} - - parsed_vars = melted_df["variable"].apply(parse_variable).apply(pd.Series) - melted_df[["type", "channel"]] = parsed_vars - melted_df["type"] = pd.Categorical( - melted_df["type"], categories=melted_df["type"].unique() - ) - melted_df["channel"] = pd.Categorical( - melted_df["channel"], categories=melted_df["channel"].unique()[::-1] - ) - plt.figure(figsize=(12, 7)) - g = sns.FacetGrid(melted_df, col="type", sharex=False, height=6, aspect=1) - - def violin_plot(x, y, **kwargs): - sns.violinplot(x=x, y=y, **kwargs, alpha=0.8, linewidth=0) - - g.map_dataframe( - violin_plot, x="value", y="channel", hue="channel", palette="Set2" - ) - g.set_titles("{col_name}") - g.set_xlabels("Hyperparameter space") - g.set_ylabels("") - g.figure.suptitle("Hyperparameters Optimization Distributions", y=1.05) - subtitle_text = ( - f"Sample distribution, iterations = " - f"{self.model_outputs.iterations} x {len(self.model_outputs.trials)} trial" - ) - g.figure.text(0.5, 0.98, subtitle_text, ha="center", fontsize=10) - plt.subplots_adjust(top=0.9) - plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - return fig - - def create_pareto_front_plot(self, is_calibrated): - """Create Pareto Front Plot.""" - unfiltered_pareto_results = self.unfiltered_pareto_result - result_hyp_param = unfiltered_pareto_results.result_hyp_param - pareto_fronts = self.pareto_result.pareto_fronts - if is_calibrated: - result_hyp_param["iterations"] = np.where( - result_hyp_param["robynPareto"].isna(), - np.nan, - result_hyp_param["iterations"], - ) - result_hyp_param = result_hyp_param.sort_values( - by="robynPareto", na_position="first" - ) - pareto_fronts_vec = list(range(1, pareto_fronts + 1)) - plt.figure(figsize=(12, 8)) - scatter = plt.scatter( - result_hyp_param["nrmse"], - result_hyp_param["decomp.rssd"], - c=result_hyp_param["iterations"], - cmap="Blues", - alpha=0.7, - ) - plt.colorbar(scatter, label="Iterations") - if is_calibrated: - scatter = plt.scatter( + result_hyp_param = result_hyp_param.sort_values( + by="robynPareto", na_position="first" + ) + + # Create main scatter plot + scatter = ax.scatter( result_hyp_param["nrmse"], result_hyp_param["decomp.rssd"], c=result_hyp_param["iterations"], cmap="Blues", - s=result_hyp_param["mape"] * 100, - alpha=1 - result_hyp_param["mape"], - ) - for pfs in range(1, max(pareto_fronts_vec) + 1): - temp = result_hyp_param[result_hyp_param["robynPareto"] == pfs] - if len(temp) > 1: - temp = temp.sort_values("nrmse") - plt.plot(temp["nrmse"], temp["decomp.rssd"], color="coral", linewidth=2) - plt.title( - "Multi-objective Evolutionary Performance" - + (" with Calibration" if is_calibrated else "") - ) - plt.xlabel("NRMSE") - plt.ylabel("DECOMP.RSSD") - plt.suptitle( - f"2D Pareto fronts with {self.model_outputs.nevergrad_algo or 'Unknown'}, " - f"for {len(self.model_outputs.trials)} trial{'' if pareto_fronts == 1 else 's'} " - f"with {self.model_outputs.iterations or 1} iterations each" - ) - plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - return fig - - def create_ridgeline_model_convergence(self): - """Create Ridgeline Model Convergence Plots.""" - all_plots = {} - x_decomp_agg = self.unfiltered_pareto_result.x_decomp_agg - paid_media_spends = self.mmm_data.mmmdata_spec.paid_media_spends - dt_ridges = x_decomp_agg[x_decomp_agg["rn"].isin(paid_media_spends)].copy() - dt_ridges["iteration"] = ( - dt_ridges["iterNG"] - 1 - ) * self.model_outputs.cores + dt_ridges["iterPar"] - dt_ridges = dt_ridges[["rn", "roi_total", "iteration", "trial"]] - dt_ridges = dt_ridges.sort_values(["iteration", "rn"]) - iterations = self.model_outputs.iterations or 100 - qt_len = ( - 1 - if iterations <= 100 - else (20 if iterations > 2000 else int(np.ceil(iterations / 100))) - ) - set_qt = np.floor(np.linspace(1, iterations, qt_len + 1)).astype(int) - set_bin = set_qt[1:] - dt_ridges["iter_bin"] = pd.cut( - dt_ridges["iteration"], bins=set_qt, labels=set_bin - ) - dt_ridges = dt_ridges.dropna(subset=["iter_bin"]) - dt_ridges["iter_bin"] = pd.Categorical( - dt_ridges["iter_bin"], - categories=sorted(set_bin, reverse=True), - ordered=True, - ) - dt_ridges["trial"] = dt_ridges["trial"].astype("category") - plot_vars = dt_ridges["rn"].unique() - plot_n = int(np.ceil(len(plot_vars) / 6)) - metric = ( - "ROAS" - if self.mmm_data.mmmdata_spec.dep_var_type == DependentVarType.REVENUE - else "CPA" - ) - for pl in range(1, plot_n + 1): - start_idx = (pl - 1) * 6 - loop_vars = plot_vars[start_idx : start_idx + 6] - dt_ridges_loop = dt_ridges[dt_ridges["rn"].isin(loop_vars)] - fig, axes = plt.subplots( - nrows=len(loop_vars), figsize=(12, 3 * len(loop_vars)), sharex=False - ) - if len(loop_vars) == 1: - axes = [axes] - for idx, var in enumerate(loop_vars): - var_data = dt_ridges_loop[dt_ridges_loop["rn"] == var] - offset = 0 - for iter_bin in sorted(var_data["iter_bin"].unique(), reverse=True): - bin_data = var_data[var_data["iter_bin"] == iter_bin]["roi_total"] - sns.kdeplot( - bin_data, - ax=axes[idx], - fill=True, - alpha=0.6, - color=plt.cm.GnBu(offset / len(var_data["iter_bin"].unique())), - label=f"Bin {iter_bin}", - warn_singular=False, + alpha=self.alpha["primary"] + ) + plt.colorbar(scatter, label="Iterations") + + # Add calibration-specific scatter if needed + if is_calibrated and "mape" in result_hyp_param.columns: + scatter = ax.scatter( + result_hyp_param["nrmse"], + result_hyp_param["decomp.rssd"], + c=result_hyp_param["iterations"], + cmap="Blues", + s=result_hyp_param["mape"] * 100, + alpha=1 - result_hyp_param["mape"] + ) + + # Plot Pareto fronts + pareto_fronts_vec = list(range(1, pareto_fronts + 1)) + for pfs in pareto_fronts_vec: + temp = result_hyp_param[result_hyp_param["robynPareto"] == pfs] + if len(temp) > 1: + temp = temp.sort_values("nrmse") + ax.plot( + temp["nrmse"], + temp["decomp.rssd"], + color=self.colors["secondary"], + linewidth=2, + alpha=self.alpha["secondary"] ) - offset += 1 - axes[idx].set_title(f"{var} {metric}") - axes[idx].set_ylabel("") - axes[idx].legend().remove() - axes[idx].spines["right"].set_visible(False) - axes[idx].spines["top"].set_visible(False) - plt.suptitle(f"{metric} Distribution over Iteration Buckets", fontsize=16) - plt.tight_layout() - fig = plt.gcf() - plt.close(fig) - all_plots[f"{metric}_convergence_{pl}"] = fig - return all_plots + + # Add styling using BaseVisualizer methods + self._set_standardized_labels( + ax, + xlabel="NRMSE", + ylabel="DECOMP.RSSD", + title="Multi-objective Evolutionary Performance" + + (" with Calibration" if is_calibrated else "") + ) + self._add_standardized_grid(ax) + self._set_standardized_spines(ax) + + # Add subtitle with algorithm details + subtitle = ( + f"2D Pareto fronts with {self.model_outputs.nevergrad_algo or 'Unknown'}, " + f"for {len(self.model_outputs.trials)} trial" + f"{'s' if pareto_fronts != 1 else ''} " + f"with {self.model_outputs.iterations or 1} iterations each" + ) + plt.suptitle(subtitle, y=1.05, fontsize=self.fonts["sizes"]["subtitle"]) + + self.finalize_figure(tight_layout=True) + + logger.debug("Successfully generated Pareto front plot") + return fig + + except Exception as e: + logger.error(f"Failed to generate Pareto front plot: {str(e)}") + raise + + def create_ridgeline_model_convergence(self) -> Dict[str, plt.Figure]: + """Create Ridgeline Model Convergence Plots. + + Returns: + Dict[str, plt.Figure]: Dictionary of generated ridgeline plots + """ + logger.debug("Starting generation of ridgeline model convergence plots") + + try: + all_plots = {} + x_decomp_agg = self.unfiltered_pareto_result.x_decomp_agg + paid_media_spends = self.mmm_data.mmmdata_spec.paid_media_spends + + # Prepare data + dt_ridges = x_decomp_agg[x_decomp_agg["rn"].isin(paid_media_spends)].copy() + dt_ridges["iteration"] = ( + dt_ridges["iterNG"] - 1 + ) * self.model_outputs.cores + dt_ridges["iterPar"] + dt_ridges = dt_ridges[["rn", "roi_total", "iteration", "trial"]] + dt_ridges = dt_ridges.sort_values(["iteration", "rn"]) + + # Calculate iteration bins + iterations = self.model_outputs.iterations or 100 + qt_len = ( + 1 if iterations <= 100 + else (20 if iterations > 2000 else int(np.ceil(iterations / 100))) + ) + set_qt = np.floor(np.linspace(1, iterations, qt_len + 1)).astype(int) + set_bin = set_qt[1:] + + # Create iteration bins + dt_ridges["iter_bin"] = pd.cut( + dt_ridges["iteration"], + bins=set_qt, + labels=set_bin + ) + dt_ridges = dt_ridges.dropna(subset=["iter_bin"]) + dt_ridges["iter_bin"] = pd.Categorical( + dt_ridges["iter_bin"], + categories=sorted(set_bin, reverse=True), + ordered=True + ) + dt_ridges["trial"] = dt_ridges["trial"].astype("category") + + # Determine metric type + metric = ( + "ROAS" + if self.mmm_data.mmmdata_spec.dep_var_type == DependentVarType.REVENUE + else "CPA" + ) + + # Create plots for each set of variables + plot_vars = dt_ridges["rn"].unique() + plot_n = int(np.ceil(len(plot_vars) / 6)) + + for pl in range(1, plot_n + 1): + start_idx = (pl - 1) * 6 + loop_vars = plot_vars[start_idx:start_idx + 6] + dt_ridges_loop = dt_ridges[dt_ridges["rn"].isin(loop_vars)] + + # Create figure for this set of variables + fig, axes = plt.subplots( + nrows=len(loop_vars), + figsize=(12, 3 * len(loop_vars)), + sharex=False + ) + + if len(loop_vars) == 1: + axes = [axes] + + # Create ridge plot for each variable + for idx, var in enumerate(loop_vars): + var_data = dt_ridges_loop[dt_ridges_loop["rn"] == var] + offset = 0 + + # Plot distributions for each iteration bin + for iter_bin in sorted(var_data["iter_bin"].unique(), reverse=True): + bin_data = var_data[var_data["iter_bin"] == iter_bin]["roi_total"] + + sns.kdeplot( + bin_data, + ax=axes[idx], + fill=True, + alpha=self.alpha["secondary"], + color=plt.cm.GnBu(offset / len(var_data["iter_bin"].unique())), + label=f"Bin {iter_bin}", + warn_singular=False + ) + offset += 1 + + # Style subplot + axes[idx].set_title( + f"{var} {metric}", + fontsize=self.fonts["sizes"]["subtitle"] + ) + axes[idx].set_ylabel("") + axes[idx].legend().remove() + self._set_standardized_spines(axes[idx]) + + # Add overall title + plt.suptitle( + f"{metric} Distribution over Iteration Buckets", + fontsize=self.fonts["sizes"]["title"] + ) + + # Finalize figure + self.finalize_figure(tight_layout=True) + all_plots[f"{metric}_convergence_{pl}"] = fig + + logger.debug("Successfully generated ridgeline model convergence plots") + return all_plots + + except Exception as e: + logger.error(f"Failed to generate ridgeline model convergence plots: {str(e)}") + raise def plot_all( self, display_plots: bool = True, export_location: Union[str, Path] = None - ) -> None: - # Generate all plots - solution_ids = self.pareto_result.pareto_solutions - # Clean up nan values - cleaned_solution_ids = [ - sid - for sid in solution_ids - if not (isinstance(sid, float) and math.isnan(sid)) - ] - # Assign the cleaned list back to self.pareto_result.pareto_solutions - self.pareto_result.pareto_solutions = cleaned_solution_ids + ) -> Dict[str, plt.Figure]: + """Generate and optionally display/export all available plots. + + Args: + display_plots: Whether to display the plots + export_location: Optional path to export plots + + Returns: + Dict[str, plt.Figure]: Dictionary of all generated plots + """ + logger.info("Generating all Pareto plots") figures: Dict[str, plt.Figure] = {} - for solution_id in cleaned_solution_ids: - fig1 = self.generate_waterfall(solution_id) - if fig1: - figures["waterfall_" + solution_id] = fig1 - - fig2 = self.generate_fitted_vs_actual(solution_id) - if fig2: - figures["fitted_vs_actual_" + solution_id] = fig2 - - fig3 = self.generate_diagnostic_plot(solution_id) - if fig3: - figures["diagnostic_plot_" + solution_id] = fig3 - - fig4 = self.generate_immediate_vs_carryover(solution_id) - if fig4: - figures["immediate_vs_carryover_" + solution_id] = fig4 - - fig5 = self.generate_adstock_rate(solution_id) - if fig5: - figures["adstock_rate_" + solution_id] = fig5 - - break # TODO: This will generate too many plots. Only generate plots for the first solution. we can export all plots to a folder if too many to display - - if not self.model_outputs.hyper_fixed: - prophet_decomp_plot = self.create_prophet_decomposition_plot() - if prophet_decomp_plot: - figures["prophet_decomp"] = prophet_decomp_plot - hyperparameters_plot = self.create_hyperparameter_sampling_distribution() - if hyperparameters_plot: - figures["hyperparameters_sampling"] = hyperparameters_plot - pareto_front_plot = self.create_pareto_front_plot(is_calibrated=False) - if pareto_front_plot: - figures["pareto_front"] = pareto_front_plot - ridgeline_plots = self.create_ridgeline_model_convergence() - figures.update(ridgeline_plots) - - # Display plots if required - if display_plots: - self.display_plots(figures) + try: + # Clean solution IDs + cleaned_solution_ids = [ + sid for sid in self.pareto_result.pareto_solutions + if not (isinstance(sid, float) and math.isnan(sid)) + ] + + if cleaned_solution_ids: + # Generate plots for first solution only + solution_id = cleaned_solution_ids[0] + logger.info(f"Generating plots for solution {solution_id}") + + # Core plots + plot_methods = { + "waterfall": self.generate_waterfall, + "fitted_vs_actual": self.generate_fitted_vs_actual, + "diagnostic": self.generate_diagnostic_plot, + "immediate_vs_carryover": self.generate_immediate_vs_carryover, + "adstock_rate": self.generate_adstock_rate + } + + for name, method in plot_methods.items(): + try: + fig = method(solution_id) + if fig: + figures[f"{name}_{solution_id}"] = fig + logger.debug(f"Generated {name} plot") + except Exception as e: + logger.error(f"Failed to generate {name} plot: {str(e)}") + + # Generate additional plots if not using fixed hyperparameters + if not self.model_outputs.hyper_fixed: + additional_plots = { + "prophet_decomp": lambda: self.create_prophet_decomposition_plot(), + "hyperparameter_sampling": lambda: self.create_hyperparameter_sampling_distribution(), + "pareto_front": lambda: self.create_pareto_front_plot(is_calibrated=False), + "pareto_front_calibrated": lambda: self.create_pareto_front_plot(is_calibrated=True) + } + + for name, method in additional_plots.items(): + try: + fig = method() + if fig: + figures[name] = fig + logger.debug(f"Generated {name} plot") + except Exception as e: + logger.error(f"Failed to generate {name} plot: {str(e)}") + + # Generate ridgeline plots + try: + ridgeline_plots = self.create_ridgeline_model_convergence() + figures.update(ridgeline_plots) + logger.debug("Generated ridgeline plots") + except Exception as e: + logger.error(f"Failed to generate ridgeline plots: {str(e)}") + + # Display plots if requested + if display_plots: + logger.info(f"Displaying {len(figures)} plots") + self.display_plots(figures) + + # Export plots if location provided + if export_location: + logger.info(f"Exporting plots to {export_location}") + self.export_plots_fig(export_location, figures) + + return figures + + except Exception as e: + logger.error(f"Failed to generate all plots: {str(e)}") + raise + \ No newline at end of file