Skip to content

Commit

Permalink
feat: fixes benchmark plot legend, adds logy option
Browse files Browse the repository at this point in the history
  • Loading branch information
bastidas committed Feb 8, 2025
1 parent af6a47e commit 60639ee
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions emukit/benchmarking/loop_benchmarking/benchmark_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
self.fig_handle = None
self.x_axis = x_axis_metric_name

def make_plot(self) -> None:
def make_plot(self, logy: bool = False) -> None:
"""
Make one plot for each metric measured, comparing the different loop results against each other
"""
Expand All @@ -92,6 +92,7 @@ def make_plot(self) -> None:
min_x = np.inf
max_x = -np.inf

legend_handles = []
for j, loop_name in enumerate(self.loop_names):
# Get all results for this metric
metric = self.benchmark_results.extract_metric_as_array(loop_name, metric_name)
Expand All @@ -113,15 +114,19 @@ def make_plot(self) -> None:
max_x = np.max([np.max(x), max_x])

# Plot
plt.plot(x, mean, color=colour, linestyle=line_style)
mean_plt = plt.plot(x, mean, color=colour, linestyle=line_style)
plt.xlabel(self.x_label)
plt.fill_between(x, mean - std, mean + std, alpha=0.2, color=colour)

fill_plt = plt.fill_between(x, mean - std, mean + std, alpha=0.2, color=colour)
legend_handles.append((fill_plt, mean_plt[0]))

# Make legend
plt.legend(self.loop_names)
plt.legend(legend_handles, self.loop_names)
plt.tight_layout()

plt.xlim(min_x, max_x)

if logy:
plt.yscale('log')

def save_plot(self, file_name: str) -> None:
"""
Expand Down

0 comments on commit 60639ee

Please sign in to comment.