Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor viz #1174

Draft
wants to merge 3 commits into
base: robynpy_release
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions python/src/robyn/modeling/convergence/convergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,22 @@ def calculate_convergence(self, trials: List[Trial]) -> Dict[str, Any]:

# Create visualization plots
self.logger.info("Creating visualization plots")
moo_distrb_plot = self.visualizer.create_moo_distrb_plot(
dt_objfunc_cvg, conv_msg
)
moo_cloud_plot = self.visualizer.create_moo_cloud_plot(
df, conv_msg, calibrated
)
ts_validation_plot = None # self.visualizer.create_ts_validation_plot(trials) #Disabled for testing. #Sandeep
plots_dict = {
"moo_distrb_plot": self.visualizer.create_moo_distrb_plot(
dt_objfunc_cvg, conv_msg
),
"moo_cloud_plot": self.visualizer.create_moo_cloud_plot(
df, conv_msg, calibrated
),
"ts_validation_plot": None # Disabled for testing
}

# Display the plots
self.visualizer.display_convergence_plots(plots_dict)

self.logger.info("Convergence calculation completed successfully")
return {
"moo_distrb_plot": moo_distrb_plot,
"moo_cloud_plot": moo_cloud_plot,
"ts_validation_plot": ts_validation_plot,
**plots_dict,
"errors": errors,
"conv_msg": conv_msg,
}
Expand Down
2 changes: 1 addition & 1 deletion python/src/robyn/reporting/onepager_reporting.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__(

# Default plots using PlotType enum directly
self.default_plots = [
PlotType.SPEND_EFFECT,
PlotType.WATERFALL,
PlotType.FITTED_VS_ACTUAL,
PlotType.SPEND_EFFECT,
PlotType.BOOTSTRAP,
PlotType.ADSTOCK,
PlotType.IMMEDIATE_CARRYOVER,
Expand Down
265 changes: 234 additions & 31 deletions python/src/robyn/visualization/base_visualizer.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,185 @@
# pyre-strict

import logging

from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple, Union, List
from pathlib import Path
from IPython.display import Image, display

import matplotlib.pyplot as plt
import numpy as np
import base64
import io
from IPython.display import Image, display

# Configure logger
logger = logging.getLogger(__name__)


class BaseVisualizer(ABC):
"""
Base class for all Robyn visualization components.
Provides common plotting functionality and styling.
Enhanced base class for all Robyn visualization components.
Provides standardized plotting functionality and styling.
"""

def __init__(self, style: str = "bmh"):
"""
Initialize BaseVisualizer with common plot settings.
Initialize BaseVisualizer with standardized plot settings.

Args:
style: matplotlib style to use (default: "bmh")
"""
logger.info("Initializing BaseVisualizer with style: %s", style)

# Store style settings
self.style = style
self.default_figsize = (12, 8)
# Standard figure sizes
self.figure_sizes = {
"default": (12, 8),
"wide": (16, 8),
"square": (10, 10),
"tall": (8, 12),
"small": (8, 6),
"large": (15, 10),
"medium": (10, 6)
}

# Enhanced color schemes
# Standardized color schemes
self.colors = {
# Primary colors for main data series
"primary": "#4688C7", # Steel blue
"secondary": "#FF9F1C", # Orange
"tertiary": "#37B067", # Green
# Status colors
"positive": "#2ECC71", # Green
"negative": "#E74C3C", # Red
"neutral": "#95A5A6", # Gray
"current": "lightgray", # For current values
"optimal": "#4688C7", # For optimal values
"grid": "#E0E0E0", # For grid lines
# Chart elements
"grid": "#E0E0E0", # Light gray for grid lines
"baseline": "#CCCCCC", # Medium gray for baseline/reference lines
"annotation": "#666666", # Dark gray for annotations
# Channel-specific colors (for consistency across plots)
"channels": {
"facebook": "#3B5998",
"search": "#4285F4",
"display": "#34A853",
"youtube": "#FF0000",
"twitter": "#1DA1F2",
"email": "#DB4437",
"print": "#9C27B0",
"tv": "#E91E63",
"radio": "#795548",
"ooh": "#607D8B",
},
}

# Standard line styles
self.line_styles = {
"solid": "-",
"dashed": "--",
"dotted": ":",
"dashdot": "-.",
}
logger.debug("Color scheme initialized: %s", self.colors)

# Plot settings
self.font_sizes = {
"title": 14,
"subtitle": 12,
"label": 12,
"tick": 10,
"annotation": 9,
"legend": 10,

# Standard markers
self.markers = {
"circle": "o",
"square": "s",
"triangle": "^",
"diamond": "D",
"plus": "+",
"cross": "x",
"star": "*",
}
logger.debug("Font sizes configured: %s", self.font_sizes)

# Default alpha values
self.alpha = {"primary": 0.7, "secondary": 0.5, "grid": 0.3, "annotation": 0.7}
# Font configurations
self.fonts = {
"family": "sans-serif",
"sizes": {
"title": 14,
"subtitle": 12,
"label": 11,
"tick": 10,
"annotation": 9,
"legend": 10,
"small": 8,
},
}

# Default spacing
self.spacing = {"tight_layout_pad": 1.05, "subplot_adjust_hspace": 0.4}
# Common alpha values
self.alpha = {
"primary": 0.8,
"secondary": 0.6,
"grid": 0.3,
"annotation": 0.7,
"highlight": 0.9,
"background": 0.2,
}

# Standard spacing
self.spacing = {
"tight_layout_pad": 1.05,
"subplot_adjust_hspace": 0.4,
"label_pad": 10,
"title_pad": 20,
}

# Initialize plot tracking
self.current_figure: Optional[plt.Figure] = None
self.current_axes: Optional[Union[plt.Axes, np.ndarray]] = None

# Apply default style
# Apply default style and settings
self._setup_plot_style()
logger.info("BaseVisualizer initialization completed")

def format_number(self, x: float, pos=None) -> str:
"""Format large numbers with K/M/B abbreviations.

Args:
x: Number to format
pos: Position parameter (required by matplotlib formatter but not used)

Returns:
Formatted string representation of the number
"""
try:
if abs(x) >= 1e9:
return f"{x/1e9:.1f}B"
elif abs(x) >= 1e6:
return f"{x/1e6:.1f}M"
elif abs(x) >= 1e3:
return f"{x/1e3:.1f}K"
else:
return f"{x:.1f}"
except (TypeError, ValueError):
return str(x)

def _setup_plot_style(self) -> None:
"""Configure default plotting style."""
logger.debug("Setting up plot style with style: %s", self.style)
logger.debug("Setting up plot style")
try:
plt.style.use(self.style)

plt.rcParams.update(
{
"figure.figsize": self.default_figsize,
# Figure settings
"figure.figsize": self.figure_sizes["default"],
"figure.facecolor": "white",
# Font settings
"font.family": self.fonts["family"],
"font.size": self.fonts["sizes"]["label"],
# Axes settings
"axes.grid": True,
"axes.spines.top": False,
"axes.spines.right": False,
"font.size": self.font_sizes["label"],
"axes.labelsize": self.fonts["sizes"]["label"],
"axes.titlesize": self.fonts["sizes"]["title"],
# Grid settings
"grid.alpha": self.alpha["grid"],
"grid.color": self.colors["grid"],
# Legend settings
"legend.fontsize": self.fonts["sizes"]["legend"],
"legend.framealpha": self.alpha["annotation"],
# Tick settings
"xtick.labelsize": self.fonts["sizes"]["tick"],
"ytick.labelsize": self.fonts["sizes"]["tick"],
}
)
logger.debug("Plot style parameters updated successfully")
Expand Down Expand Up @@ -191,6 +283,117 @@ def setup_axis(
logger.error("Failed to setup axis: %s", str(e))
raise

def _add_standardized_grid(
self,
ax: plt.Axes,
axis: str = "both",
alpha: Optional[float] = None,
color: Optional[str] = None,
linestyle: Optional[str] = None
) -> None:
"""Add standardized grid to plot."""
ax.grid(
True,
axis=axis,
alpha=alpha or self.alpha["grid"],
color=color or self.colors["grid"],
linestyle=linestyle or self.line_styles["solid"],
zorder=0
)
ax.set_axisbelow(True)

def _add_standardized_legend(
self,
ax: plt.Axes,
title: Optional[str] = None,
loc: str = "lower right",
ncol: int = 1,
handles: Optional[List] = None,
labels: Optional[List[str]] = None,
) -> None:
"""Add standardized legend to plot.

Args:
ax: Matplotlib axes to add legend to
title: Optional legend title
loc: Legend location
ncol: Number of columns in legend
handles: Optional list of legend handles
labels: Optional list of legend labels
"""
legend_handles = handles if handles is not None else ax.get_legend_handles_labels()[0]
legend_labels = labels if labels is not None else ax.get_legend_handles_labels()[1]

legend = ax.legend(
handles=legend_handles,
labels=legend_labels,
title=title,
loc=loc,
ncol=ncol,
fontsize=self.fonts["sizes"]["legend"],
framealpha=self.alpha["annotation"],
title_fontsize=self.fonts["sizes"]["subtitle"]
)
if legend:
legend.get_frame().set_linewidth(0.5)
legend.get_frame().set_edgecolor(self.colors["grid"])

def _set_standardized_labels(
self,
ax: plt.Axes,
xlabel: Optional[str] = None,
ylabel: Optional[str] = None,
title: Optional[str] = None
) -> None:
"""Set standardized labels for plot."""
if xlabel:
ax.set_xlabel(
xlabel,
fontsize=self.fonts["sizes"]["label"],
labelpad=self.spacing["label_pad"]
)
if ylabel:
ax.set_ylabel(
ylabel,
fontsize=self.fonts["sizes"]["label"],
labelpad=self.spacing["label_pad"]
)
if title:
ax.set_title(
title,
fontsize=self.fonts["sizes"]["title"],
pad=self.spacing["title_pad"]
)

def _format_standardized_ticks(
self,
ax: plt.Axes,
x_rotation: int = 0,
y_rotation: int = 0
) -> None:
"""Format tick labels with standardized styling."""
ax.tick_params(
axis='both',
labelsize=self.fonts["sizes"]["tick"]
)
plt.setp(
ax.get_xticklabels(),
rotation=x_rotation,
ha='right' if x_rotation > 0 else 'center'
)
plt.setp(
ax.get_yticklabels(),
rotation=y_rotation,
va='center'
)

def _set_standardized_spines(self, ax: plt.Axes, spines: List[str] = None) -> None:
"""Configure plot spines with standardized styling."""
if spines is None:
spines = ['top', 'right']
for spine in spines:
ax.spines[spine].set_visible(False)

def add_percentage_annotation(
self,
ax: plt.Axes,
Expand Down
Loading
Loading