Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added functionality that allows user to save figure. #315

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion lightweight_mmm/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -890,7 +890,8 @@ def plot_pre_post_budget_allocation_comparison(
optimal_buget_allocation: jnp.ndarray,
previous_budget_allocation: jnp.ndarray,
channel_names: Optional[Sequence[Any]] = None,
figure_size: Tuple[int, int] = (20, 10)
figure_size: Tuple[int, int] = (20, 10),
save_path: Optional[str] = None
) -> matplotlib.figure.Figure:
"""Plots a barcharts to compare pre & post budget allocation.

Expand All @@ -905,6 +906,7 @@ def plot_pre_post_budget_allocation_comparison(
budget allocation proportion.
channel_names: Names of media channels to be added to plot.
figure_size: size of the plot.
save_path: Path to save the plotted figure.

Returns:
Barplots of budget allocation across media channels pre & post optimization.
Expand Down Expand Up @@ -1004,6 +1006,11 @@ def plot_pre_post_budget_allocation_comparison(
textcoords="offset points")

plt.tight_layout()

# Save the plot if save_path is provided
if save_path:
fig.savefig(save_path, bbox_inches="tight")

plt.close()
return fig

Expand All @@ -1014,6 +1021,7 @@ def plot_media_baseline_contribution_area_plot(
channel_names: Optional[Sequence[Any]] = None,
fig_size: Optional[Tuple[int, int]] = (20, 7),
legend_outside: Optional[bool] = False,
save_path: Optional[str] = None
) -> matplotlib.figure.Figure:
"""Plots an area chart to visualize weekly media & baseline contribution.

Expand All @@ -1023,6 +1031,7 @@ def plot_media_baseline_contribution_area_plot(
channel_names: Names of media channels.
fig_size: Size of the figure to plot as used by matplotlib.
legend_outside: Put the legend outside of the chart, center-right.
save_path: Path to save the plotted figure.

Returns:
Stacked area chart of weekly baseline & media contribution.
Expand Down Expand Up @@ -1072,6 +1081,11 @@ def plot_media_baseline_contribution_area_plot(

for tick in ax.get_xticklabels():
tick.set_rotation(45)

# Save the plot if save_path is provided
if save_path:
fig.savefig(save_path, bbox_inches="tight")

plt.close()
return fig

Expand Down