Skip to content

Commit

Permalink
add function 'plot_pareto_front'
Browse files Browse the repository at this point in the history
  • Loading branch information
enarjord committed Nov 8, 2024
1 parent 40ec29f commit 8cc6fe1
Showing 1 changed file with 81 additions and 0 deletions.
81 changes: 81 additions & 0 deletions src/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,3 +483,84 @@ def plot_fills_forager(fdf: pd.DataFrame, hlcvs_df: pd.DataFrame, start_pct=0.0,
)
ax.legend(legend)
return plt


def plot_pareto_front(df, metrics, minimize=(True, True)):
"""
Plot optimization results with Pareto front highlighted.
Parameters:
df (pandas.DataFrame): DataFrame containing optimization results
metrics (tuple): Tuple of two column names to plot (metric1, metric2)
minimize (tuple): Tuple of booleans indicating whether each metric should be minimized (default: (True, True))
Returns:
matplotlib.figure.Figure: The generated plot
"""
if len(metrics) != 2:
raise ValueError("Exactly two metrics must be provided")

metric1, metric2 = metrics

# Extract the metrics data
x = df[metric1].values
y = df[metric2].values

# Function to identify Pareto optimal points
def is_pareto_efficient(costs):
is_efficient = np.ones(costs.shape[0], dtype=bool)
for i, c in enumerate(costs):
if is_efficient[i]:
# Keep any point with at least one better coordinate than this one
if minimize[0] and minimize[1]:
is_efficient[is_efficient] = np.any(costs[is_efficient] < c, axis=1)
elif not minimize[0] and minimize[1]:
costs_comp = costs.copy()
costs_comp[:, 0] = -costs_comp[:, 0]
is_efficient[is_efficient] = np.any(
costs_comp[is_efficient] < costs_comp[i], axis=1
)
elif minimize[0] and not minimize[1]:
costs_comp = costs.copy()
costs_comp[:, 1] = -costs_comp[:, 1]
is_efficient[is_efficient] = np.any(
costs_comp[is_efficient] < costs_comp[i], axis=1
)
else: # not minimize[0] and not minimize[1]
is_efficient[is_efficient] = np.any(-costs[is_efficient] < -c, axis=1)
is_efficient[i] = True
return is_efficient

# Find Pareto optimal points
costs = np.column_stack((x, y))
pareto_mask = is_pareto_efficient(costs)

# Create the plot
fig, ax = plt.subplots(figsize=(10, 6))

# Plot all points
ax.scatter(x[~pareto_mask], y[~pareto_mask], c="gray", alpha=0.5, label="Non-Pareto optimal")

# Plot Pareto optimal points
ax.scatter(x[pareto_mask], y[pareto_mask], c="red", label="Pareto optimal")

# Connect Pareto points with a line
pareto_points = costs[pareto_mask]
# Sort points for proper line connection
if minimize[0]:
pareto_points = pareto_points[pareto_points[:, 0].argsort()]
else:
pareto_points = pareto_points[(-pareto_points[:, 0]).argsort()]
ax.plot(pareto_points[:, 0], pareto_points[:, 1], "r--", alpha=0.5)

# Labels and title
ax.set_xlabel(metric1)
ax.set_ylabel(metric2)
ax.set_title("Optimization Results with Pareto Front")
ax.legend()

# Add grid
ax.grid(True, linestyle="--", alpha=0.6)

plt.tight_layout()
return fig

0 comments on commit 8cc6fe1

Please sign in to comment.